implement muldiv without overflow and underflow

Signed-off-by: Uncle Stretch <uncle.stretch@ghostchain.io>
This commit is contained in:
Uncle Stretch 2025-08-10 16:41:14 +03:00
parent 9240f424e1
commit e2c75ca558
Signed by: str3tch
GPG Key ID: 84F3190747EE79AA
6 changed files with 323 additions and 26 deletions

3
Cargo.lock generated
View File

@ -3649,13 +3649,14 @@ dependencies = [
[[package]] [[package]]
name = "ghost-networks" name = "ghost-networks"
version = "0.1.13" version = "0.1.14"
dependencies = [ dependencies = [
"frame-benchmarking", "frame-benchmarking",
"frame-support", "frame-support",
"frame-system", "frame-system",
"ghost-core-primitives", "ghost-core-primitives",
"ghost-traits", "ghost-traits",
"num-traits",
"pallet-balances", "pallet-balances",
"pallet-staking", "pallet-staking",
"pallet-staking-reward-curve", "pallet-staking-reward-curve",

View File

@ -69,6 +69,7 @@ bs58 = { version = "0.5.0" }
prometheus-parse = { version = "0.2.2" } prometheus-parse = { version = "0.2.2" }
rustc-hex = { version = "2.1.0", default-features = false } rustc-hex = { version = "2.1.0", default-features = false }
log = { version = "0.4", default-features = false } log = { version = "0.4", default-features = false }
num-traits = { version = "0.2.17", default-features = false }
libsecp256k1 = { version = "0.7", default-features = false } libsecp256k1 = { version = "0.7", default-features = false }
bip39 = { package = "parity-bip39", version = "2.0.1" } bip39 = { package = "parity-bip39", version = "2.0.1" }
sha3 = { version = "0.10", default-features = false } sha3 = { version = "0.10", default-features = false }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "ghost-networks" name = "ghost-networks"
version = "0.1.13" version = "0.1.14"
license.workspace = true license.workspace = true
authors.workspace = true authors.workspace = true
edition.workspace = true edition.workspace = true
@ -10,6 +10,7 @@ repository.workspace = true
[dependencies] [dependencies]
scale-info = { workspace = true, features = ["derive"] } scale-info = { workspace = true, features = ["derive"] }
codec = { workspace = true, features = ["max-encoded-len"] } codec = { workspace = true, features = ["max-encoded-len"] }
num-traits = { workspace = true }
frame-benchmarking = { workspace = true, optional = true } frame-benchmarking = { workspace = true, optional = true }
frame-support = { workspace = true } frame-support = { workspace = true }
@ -30,6 +31,7 @@ default = ["std"]
std = [ std = [
"scale-info/std", "scale-info/std",
"codec/std", "codec/std",
"num-traits/std",
"frame-support/std", "frame-support/std",
"frame-system/std", "frame-system/std",
"frame-benchmarking?/std", "frame-benchmarking?/std",

View File

@ -21,9 +21,11 @@ pub use ghost_traits::networks::{
NetworkDataBasicHandler, NetworkDataInspectHandler, NetworkDataMutateHandler, NetworkDataBasicHandler, NetworkDataInspectHandler, NetworkDataMutateHandler,
}; };
mod math;
mod weights; mod weights;
pub use crate::weights::WeightInfo; pub use crate::weights::WeightInfo;
use math::MulDiv;
pub use module::*; pub use module::*;
#[cfg(any(feature = "runtime-benchmarks", test))] #[cfg(any(feature = "runtime-benchmarks", test))]
@ -73,7 +75,15 @@ pub struct BridgedInflationCurve<RewardCurve, T>(core::marker::PhantomData<(Rewa
impl<Balance, RewardCurve, T> pallet_staking::EraPayout<Balance> impl<Balance, RewardCurve, T> pallet_staking::EraPayout<Balance>
for BridgedInflationCurve<RewardCurve, T> for BridgedInflationCurve<RewardCurve, T>
where where
Balance: Default + AtLeast32BitUnsigned + Clone + Copy + From<u128>, Balance: Default
+ Copy
+ From<BalanceOf<T>>
+ AtLeast32BitUnsigned
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::overflowing::OverflowingAdd
+ sp_std::ops::AddAssign
+ sp_std::ops::Not<Output = Balance>
+ sp_std::ops::BitAnd<u128, Output = Balance>,
RewardCurve: Get<&'static PiecewiseLinear<'static>>, RewardCurve: Get<&'static PiecewiseLinear<'static>>,
T: Config, T: Config,
{ {
@ -82,30 +92,27 @@ where
total_issuance: Balance, total_issuance: Balance,
_era_duration_in_millis: u64, _era_duration_in_millis: u64,
) -> (Balance, Balance) { ) -> (Balance, Balance) {
let piecewise_linear = RewardCurve::get(); let reward_curve = RewardCurve::get();
let bridge_adjustment = BridgedImbalance::<T>::get(); let bridged_imbalance = BridgedImbalance::<T>::get();
let accumulated_commission = AccumulatedCommission::<T>::get(); let accumulated_commission = AccumulatedCommission::<T>::get();
let bridged_out: u128 = bridge_adjustment.bridged_out.try_into().unwrap_or_default(); let accumulated_commission: Balance = accumulated_commission.into();
let bridged_in: u128 = bridge_adjustment.bridged_in.try_into().unwrap_or_default(); let adjusted_issuance: Balance = total_issuance
let accumulated_commission: u128 = accumulated_commission.try_into().unwrap_or_default(); .saturating_add(bridged_imbalance.bridged_out.into())
.saturating_sub(bridged_imbalance.bridged_in.into());
let accumulated_balance = Balance::from(accumulated_commission);
let adjusted_issuance = match bridged_out > bridged_in {
true => total_issuance.saturating_add(Balance::from(bridged_out - bridged_in)),
false => total_issuance.saturating_sub(Balance::from(bridged_in - bridged_out)),
};
NullifyNeeded::<T>::set(true); NullifyNeeded::<T>::set(true);
match piecewise_linear let estimated_reward =
.calculate_for_fraction_times_denominator(total_staked, adjusted_issuance) reward_curve.calculate_for_fraction_times_denominator(total_staked, adjusted_issuance);
.checked_mul(&accumulated_balance) let payout = MulDiv::<Balance>::calculate(
.and_then(|product| product.checked_div(&adjusted_issuance)) estimated_reward,
{ accumulated_commission,
Some(payout) => (payout, accumulated_balance.saturating_sub(payout)), adjusted_issuance,
None => (Balance::default(), Balance::default()), );
} let rest_payout = accumulated_commission.saturating_sub(payout);
(payout, rest_payout)
} }
} }

View File

@ -0,0 +1,120 @@
use crate::AtLeast32BitUnsigned;
pub struct MulDiv<Balance>(core::marker::PhantomData<Balance>);
impl<Balance> MulDiv<Balance>
where
Balance: Copy
+ AtLeast32BitUnsigned
+ num_traits::ops::wrapping::WrappingAdd
+ num_traits::ops::overflowing::OverflowingAdd
+ sp_std::ops::AddAssign
+ sp_std::ops::Not<Output = Balance>
+ sp_std::ops::BitAnd<u128, Output = Balance>,
{
fn least_significant_bits(&self, a: Balance) -> Balance {
a & ((1 << 64) - 1)
}
fn most_significant_bits(&self, a: Balance) -> Balance {
a >> 64
}
fn two_complement(&self, a: Balance) -> Balance {
let one: Balance = 1u32.into();
(!a).wrapping_add(&one)
}
fn adjusted_ratio(&self, a: Balance) -> Balance {
let one: Balance = 1u32.into();
(self.two_complement(a) / a).wrapping_add(&one)
}
fn modulo(&self, a: Balance) -> Balance {
self.two_complement(a) % a
}
fn overflow_resistant_addition(
&self,
a0: Balance,
a1: Balance,
b0: Balance,
b1: Balance,
) -> (Balance, Balance) {
let (r0, overflow) = a0.overflowing_add(&b0);
let overflow: Balance = overflow.then(|| 1u32).unwrap_or_default().into();
let r1 = a1.wrapping_add(&b1).wrapping_add(&overflow);
(r0, r1)
}
fn overflow_resistant_multiplication(&self, a: Balance, b: Balance) -> (Balance, Balance) {
let (a0, a1) = (
self.least_significant_bits(a),
self.most_significant_bits(a),
);
let (b0, b1) = (
self.least_significant_bits(b),
self.most_significant_bits(b),
);
let (x, y) = (a1 * b0, b1 * a0);
let (r0, r1) = (a0 * b0, a1 * b1);
let (r0, r1) = self.overflow_resistant_addition(
r0,
r1,
self.least_significant_bits(x) << 64,
self.most_significant_bits(x),
);
let (r0, r1) = self.overflow_resistant_addition(
r0,
r1,
self.least_significant_bits(y) << 64,
self.most_significant_bits(y),
);
(r0, r1)
}
fn overflow_resistant_division(
&self,
mut a0: Balance,
mut a1: Balance,
b: Balance,
) -> (Balance, Balance) {
if b == 1u32.into() {
return (a0, a1);
}
let zero: Balance = 0u32.into();
let (q, r) = (self.adjusted_ratio(b), self.modulo(b));
let (mut x0, mut x1) = (zero, zero);
while a1 != zero {
let (t0, t1) = self.overflow_resistant_multiplication(a1, q);
let (new_x0, new_x1) = self.overflow_resistant_addition(x0, x1, t0, t1);
x0 = new_x0;
x1 = new_x1;
let (t0, t1) = self.overflow_resistant_multiplication(a1, r);
let (new_a0, new_a1) = self.overflow_resistant_addition(t0, t1, a0, zero);
a0 = new_a0;
a1 = new_a1;
}
self.overflow_resistant_addition(x0, x1, a0 / b, zero)
}
fn mul_div(&self, a: Balance, b: Balance, c: Balance) -> Balance {
let (t0, t1) = self.overflow_resistant_multiplication(a, b);
self.overflow_resistant_division(t0, t1, c).0
}
pub fn calculate(a: Balance, b: Balance, c: Balance) -> Balance {
let zero: Balance = 0u32.into();
if a == zero || b == zero || c == zero {
return zero;
}
let inner = MulDiv(core::marker::PhantomData);
inner.mul_div(a, b, c)
}
}

View File

@ -1295,8 +1295,10 @@ fn accumulated_commission_could_be_nullified() {
#[test] #[test]
fn bridged_inlation_reward_works() { fn bridged_inlation_reward_works() {
ExtBuilder::build().execute_with(|| { ExtBuilder::build().execute_with(|| {
let amount: u128 = 1337 * 1_000_000_000; let amount_full: u128 = 1337 * 1_000_000_000;
let commission: u128 = amount / 100; // 1% commission let commission: u128 = amount_full / 100; // 1% commission
let amount: u128 = amount_full - commission;
let total_staked_ideal: u128 = 69; let total_staked_ideal: u128 = 69;
let total_staked_not_ideal: u128 = 68; let total_staked_not_ideal: u128 = 68;
let total_issuance: u128 = 100; let total_issuance: u128 = 100;
@ -1538,8 +1540,10 @@ fn bridged_inlation_reward_works() {
#[test] #[test]
fn bridged_inflation_era_payout_triggers_need_of_nullification() { fn bridged_inflation_era_payout_triggers_need_of_nullification() {
ExtBuilder::build().execute_with(|| { ExtBuilder::build().execute_with(|| {
let amount: u128 = 1337 * 1_000_000_000; let amount_full: u128 = 1337 * 1_000_000_000;
let commission: u128 = amount / 100; // 1% commission let commission: u128 = amount_full / 100; // 1% commission
let amount: u128 = amount_full - commission;
let total_staked_ideal: u128 = 69; let total_staked_ideal: u128 = 69;
let total_issuance: u128 = 100; let total_issuance: u128 = 100;
@ -1570,3 +1574,165 @@ fn trigger_nullification_works_as_expected() {
assert_eq!(NullifyNeeded::<Test>::get(), false); assert_eq!(NullifyNeeded::<Test>::get(), false);
}); });
} }
#[test]
fn check_substrate_guarantees_not_to_overflow() {
ExtBuilder::build().execute_with(|| {
let reward_curve = RewardCurve::get();
let mut n: u128 = 69;
let mut d: u128 = 100;
loop {
n = match n.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
d = match d.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
assert_eq!(
reward_curve.calculate_for_fraction_times_denominator(n, d),
d
);
}
});
}
#[test]
fn check_muldiv_guarantees_not_to_overflow() {
ExtBuilder::build().execute_with(|| {
let mut a: u128 = 2;
let mut b: u128 = 3;
let mut c: u128 = 6;
let mut result: u128 = 1;
loop {
a = match a.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
b = match b.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
c = match c.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
result = match result.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
assert_eq!(MulDiv::<u128>::calculate(a, b, c), result);
}
assert_eq!(
MulDiv::<u128>::calculate(u128::MAX, u128::MAX, u128::MAX),
u128::MAX
);
assert_eq!(MulDiv::<u128>::calculate(u128::MAX, 0, 0), 0);
assert_eq!(MulDiv::<u128>::calculate(0, u128::MAX, 0), 0);
assert_eq!(MulDiv::<u128>::calculate(0, 0, u128::MAX), 0);
});
}
#[test]
fn check_bridged_inflation_curve_for_overflow() {
ExtBuilder::build().execute_with(|| {
let amount_full: u128 = 1337 * 1_000_000_000;
let commission: u128 = amount_full / 100; // 1% commission
let amount: u128 = amount_full - commission;
let tollerance: u128 = commission / 100; // 1% tollerance
let precomputed_payout: u128 = 13177568884;
let precomputed_rest: u128 = 192431116;
assert_eq!(precomputed_payout + precomputed_rest, commission);
let mut total_staked_ideal: u128 = 69_000;
let mut total_staked_not_ideal: u128 = 68_000;
let mut total_issuance: u128 = 100_000;
assert_ok!(GhostNetworks::accumulate_commission(&commission));
assert_ok!(GhostNetworks::accumulate_incoming_imbalance(&amount));
loop {
total_staked_ideal = match total_staked_ideal.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
total_staked_not_ideal = match total_staked_not_ideal.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
total_issuance = match total_issuance.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
assert_eq!(
BridgedInflationCurve::<RewardCurve, Test>::era_payout(
total_staked_ideal,
total_issuance + amount,
0
),
(commission, 0)
);
let (payout, rest) = BridgedInflationCurve::<RewardCurve, Test>::era_payout(
total_staked_not_ideal,
total_issuance + amount,
0,
);
let payout_deviation = if precomputed_payout > payout {
precomputed_payout - payout
} else {
payout - precomputed_payout
};
let rest_deviation = if precomputed_rest > rest {
precomputed_rest - rest
} else {
rest - precomputed_rest
};
assert!(payout_deviation < tollerance);
assert!(rest_deviation < tollerance);
}
});
}
#[test]
fn check_bridged_inflation_curve_for_big_commissions() {
ExtBuilder::build().execute_with(|| {
let mut amount_full: u128 = 1337;
let total_staked_ideal: u128 = 69_000_000;
let total_issuance: u128 = 100_000_000;
loop {
amount_full = match amount_full.checked_mul(1_000_000) {
Some(value) => value,
None => break,
};
let commission: u128 = amount_full / 100; // 1% commission
let amount: u128 = amount_full - commission;
AccumulatedCommission::<Test>::set(commission);
BridgedImbalance::<Test>::set(BridgeAdjustment {
bridged_in: amount,
bridged_out: 0,
});
assert_eq!(
BridgedInflationCurve::<RewardCurve, Test>::era_payout(
total_staked_ideal,
total_issuance + amount,
0
),
(commission, 0)
);
}
});
}