136 lines
3.9 KiB
Rust
136 lines
3.9 KiB
Rust
use crate::AtLeast32BitUnsigned;
|
|
|
|
pub struct MulDiv<Balance>(core::marker::PhantomData<Balance>);
|
|
impl<Balance> MulDiv<Balance>
|
|
where
|
|
Balance: Copy
|
|
+ AtLeast32BitUnsigned
|
|
+ num_traits::ops::wrapping::WrappingAdd
|
|
+ num_traits::ops::overflowing::OverflowingAdd
|
|
+ sp_std::ops::AddAssign
|
|
+ sp_std::ops::Not<Output = Balance>
|
|
+ sp_std::ops::Shl<Output = Balance>
|
|
+ sp_std::ops::Shr<Output = Balance>
|
|
+ sp_std::ops::BitAnd<Balance, Output = Balance>,
|
|
{
|
|
|
|
fn zero(&self) -> Balance {
|
|
0u32.into()
|
|
}
|
|
|
|
fn one(&self) -> Balance {
|
|
1u32.into()
|
|
}
|
|
|
|
fn bit_shift(&self) -> Balance {
|
|
let u32_shift: u32 = core::mem::size_of::<Balance>()
|
|
.saturating_mul(4)
|
|
.try_into()
|
|
.unwrap_or_default();
|
|
u32_shift.into()
|
|
}
|
|
|
|
fn least_significant_bits(&self, a: Balance) -> Balance {
|
|
a & ((self.one() << self.bit_shift()) - self.one())
|
|
}
|
|
|
|
fn most_significant_bits(&self, a: Balance) -> Balance {
|
|
a >> self.bit_shift()
|
|
}
|
|
|
|
fn two_complement(&self, a: Balance) -> Balance {
|
|
(!a).wrapping_add(&self.one())
|
|
}
|
|
|
|
fn adjusted_ratio(&self, a: Balance) -> Balance {
|
|
(self.two_complement(a) / a).wrapping_add(&self.one())
|
|
}
|
|
|
|
fn modulo(&self, a: Balance) -> Balance {
|
|
self.two_complement(a) % a
|
|
}
|
|
|
|
fn overflow_resistant_addition(
|
|
&self,
|
|
a0: Balance,
|
|
a1: Balance,
|
|
b0: Balance,
|
|
b1: Balance,
|
|
) -> (Balance, Balance) {
|
|
let (r0, overflow) = a0.overflowing_add(&b0);
|
|
let overflow: Balance = overflow.then(|| 1u32).unwrap_or_default().into();
|
|
let r1 = a1.wrapping_add(&b1).wrapping_add(&overflow);
|
|
(r0, r1)
|
|
}
|
|
|
|
fn overflow_resistant_multiplication(&self, a: Balance, b: Balance) -> (Balance, Balance) {
|
|
let (a0, a1) = (
|
|
self.least_significant_bits(a),
|
|
self.most_significant_bits(a),
|
|
);
|
|
let (b0, b1) = (
|
|
self.least_significant_bits(b),
|
|
self.most_significant_bits(b),
|
|
);
|
|
let (x, y) = (a1 * b0, b1 * a0);
|
|
|
|
let (r0, r1) = (a0 * b0, a1 * b1);
|
|
let (r0, r1) = self.overflow_resistant_addition(
|
|
r0,
|
|
r1,
|
|
self.least_significant_bits(x) << self.bit_shift(),
|
|
self.most_significant_bits(x),
|
|
);
|
|
let (r0, r1) = self.overflow_resistant_addition(
|
|
r0,
|
|
r1,
|
|
self.least_significant_bits(y) << self.bit_shift(),
|
|
self.most_significant_bits(y),
|
|
);
|
|
|
|
(r0, r1)
|
|
}
|
|
|
|
fn overflow_resistant_division(
|
|
&self,
|
|
mut a0: Balance,
|
|
mut a1: Balance,
|
|
b: Balance,
|
|
) -> (Balance, Balance) {
|
|
if b == self.one() {
|
|
return (a0, a1);
|
|
}
|
|
|
|
let zero: Balance = 0u32.into();
|
|
let (q, r) = (self.adjusted_ratio(b), self.modulo(b));
|
|
let (mut x0, mut x1) = (zero, zero);
|
|
|
|
while a1 != zero {
|
|
let (t0, t1) = self.overflow_resistant_multiplication(a1, q);
|
|
let (new_x0, new_x1) = self.overflow_resistant_addition(x0, x1, t0, t1);
|
|
x0 = new_x0;
|
|
x1 = new_x1;
|
|
|
|
let (t0, t1) = self.overflow_resistant_multiplication(a1, r);
|
|
let (new_a0, new_a1) = self.overflow_resistant_addition(t0, t1, a0, zero);
|
|
a0 = new_a0;
|
|
a1 = new_a1;
|
|
}
|
|
|
|
self.overflow_resistant_addition(x0, x1, a0 / b, zero)
|
|
}
|
|
|
|
fn mul_div(&self, a: Balance, b: Balance, c: Balance) -> Balance {
|
|
let (t0, t1) = self.overflow_resistant_multiplication(a, b);
|
|
self.overflow_resistant_division(t0, t1, c).0
|
|
}
|
|
|
|
pub fn calculate(a: Balance, b: Balance, c: Balance) -> Balance {
|
|
let inner = MulDiv(core::marker::PhantomData);
|
|
if c == inner.zero() {
|
|
return c;
|
|
}
|
|
inner.mul_div(a, b, c)
|
|
}
|
|
}
|