use crate::AtLeast32BitUnsigned; pub struct MulDiv(core::marker::PhantomData); impl MulDiv where Balance: Copy + AtLeast32BitUnsigned + num_traits::ops::wrapping::WrappingAdd + num_traits::ops::overflowing::OverflowingAdd + sp_std::ops::AddAssign + sp_std::ops::Not + sp_std::ops::Shl + sp_std::ops::Shr + sp_std::ops::BitAnd, { fn zero(&self) -> Balance { 0u32.into() } fn one(&self) -> Balance { 1u32.into() } fn bit_shift(&self) -> Balance { let u32_shift: u32 = core::mem::size_of::() .saturating_mul(4) .try_into() .unwrap_or_default(); u32_shift.into() } fn least_significant_bits(&self, a: Balance) -> Balance { a & ((self.one() << self.bit_shift()) - self.one()) } fn most_significant_bits(&self, a: Balance) -> Balance { a >> self.bit_shift() } fn two_complement(&self, a: Balance) -> Balance { (!a).wrapping_add(&self.one()) } fn adjusted_ratio(&self, a: Balance) -> Balance { (self.two_complement(a) / a).wrapping_add(&self.one()) } fn modulo(&self, a: Balance) -> Balance { self.two_complement(a) % a } fn overflow_resistant_addition( &self, a0: Balance, a1: Balance, b0: Balance, b1: Balance, ) -> (Balance, Balance) { let (r0, overflow) = a0.overflowing_add(&b0); let overflow: Balance = overflow.then(|| 1u32).unwrap_or_default().into(); let r1 = a1.wrapping_add(&b1).wrapping_add(&overflow); (r0, r1) } fn overflow_resistant_multiplication(&self, a: Balance, b: Balance) -> (Balance, Balance) { let (a0, a1) = ( self.least_significant_bits(a), self.most_significant_bits(a), ); let (b0, b1) = ( self.least_significant_bits(b), self.most_significant_bits(b), ); let (x, y) = (a1 * b0, b1 * a0); let (r0, r1) = (a0 * b0, a1 * b1); let (r0, r1) = self.overflow_resistant_addition( r0, r1, self.least_significant_bits(x) << self.bit_shift(), self.most_significant_bits(x), ); let (r0, r1) = self.overflow_resistant_addition( r0, r1, self.least_significant_bits(y) << self.bit_shift(), self.most_significant_bits(y), ); (r0, r1) } fn overflow_resistant_division( &self, mut a0: Balance, mut a1: Balance, b: Balance, ) -> (Balance, Balance) { if b == self.one() { return (a0, a1); } let zero: Balance = 0u32.into(); let (q, r) = (self.adjusted_ratio(b), self.modulo(b)); let (mut x0, mut x1) = (zero, zero); while a1 != zero { let (t0, t1) = self.overflow_resistant_multiplication(a1, q); let (new_x0, new_x1) = self.overflow_resistant_addition(x0, x1, t0, t1); x0 = new_x0; x1 = new_x1; let (t0, t1) = self.overflow_resistant_multiplication(a1, r); let (new_a0, new_a1) = self.overflow_resistant_addition(t0, t1, a0, zero); a0 = new_a0; a1 = new_a1; } self.overflow_resistant_addition(x0, x1, a0 / b, zero) } fn mul_div(&self, a: Balance, b: Balance, c: Balance) -> Balance { let (t0, t1) = self.overflow_resistant_multiplication(a, b); self.overflow_resistant_division(t0, t1, c).0 } pub fn calculate(a: Balance, b: Balance, c: Balance) -> Balance { let inner = MulDiv(core::marker::PhantomData); if c == inner.zero() { return c; } inner.mul_div(a, b, c) } }