diff --git a/src/Staking.sol b/src/Staking.sol index 99ab6b5..cb31205 100644 --- a/src/Staking.sol +++ b/src/Staking.sol @@ -31,6 +31,8 @@ contract GhostStaking is IStaking, GhostAccessControlled { address public gatekeeper; address public warmup; + uint256 private _lastRebaseBlock; + mapping(address => bool) public locks; constructor( @@ -55,6 +57,7 @@ contract GhostStaking is IStaking, GhostAccessControlled { GhostWarmup newWarmup = new GhostWarmup(_ghst); warmup = address(newWarmup); + _lastRebaseBlock = 1; } function stake( @@ -165,7 +168,7 @@ contract GhostStaking is IStaking, GhostAccessControlled { } function rebase() public override returns (uint256 bounty) { - if (epoch.end <= block.timestamp) { + if (epoch.end <= block.timestamp && block.number > _lastRebaseBlock) { ISTNK(stnk).rebase(epoch.distribute, epoch.number); unchecked { @@ -181,6 +184,7 @@ contract GhostStaking is IStaking, GhostAccessControlled { uint256 balance = IERC20(ftso).balanceOf(address(this)); uint256 extra = ISTNK(stnk).circulatingSupply() + bounty; + _lastRebaseBlock = block.number; epoch.distribute = balance > extra ? balance - extra : 0; } } diff --git a/test/bonding/BondDepositorty.t.sol b/test/bonding/BondDepositorty.t.sol index 8f0e6aa..1ea0702 100644 --- a/test/bonding/BondDepositorty.t.sol +++ b/test/bonding/BondDepositorty.t.sol @@ -335,6 +335,7 @@ contract GhostBondDepositoryTest is Test { depository.deposit(0, amount, type(uint256).max, BOB, BOB); skip(DEPOSIT_INTERVAL); + vm.roll(block.number + 1); staking.rebase(); vm.startPrank(ALICE); @@ -410,6 +411,7 @@ contract GhostBondDepositoryTest is Test { assertEq(ghst.balanceOf(address(depository)), 0); skip(DEPOSIT_INTERVAL); + vm.roll(block.number + 1); staking.rebase(); depository.redeemAll(ALICE, true); diff --git a/test/staking/Staking.t.sol b/test/staking/Staking.t.sol index a4aec64..1d7ebbd 100644 --- a/test/staking/Staking.t.sol +++ b/test/staking/Staking.t.sol @@ -16,6 +16,27 @@ import {GhostBondingCalculator} from "../../src/StandardBondingCalculator.sol"; import {ITreasury} from "../../src/interfaces/ITreasury.sol"; import {SafeERC20} from "@openzeppelin-contracts/token/ERC20/utils/SafeERC20.sol"; +contract RebaseBatcher { + address public immutable STAKING; + + constructor(address _staking) { + STAKING = _staking; + } + + function multipleRebases(uint256 times) external returns (uint256) { + return _rebase(times); + } + + function _rebase(uint256 times) internal returns (uint256) { + if (times == 0) return 0; + + (bool success, ) = STAKING.call(abi.encodeWithSignature("rebase()")); + require(success, "Batch rebase failed"); + + return 1 + _rebase(times - 1); + } +} + contract StakingTest is Test { using SafeERC20 for Stinky; @@ -92,6 +113,7 @@ contract StakingTest is Test { gatekeeper = new Gatekeeper(address(staking), 0, 0, 0, 0, 0); calculator = new GhostBondingCalculator(address(ftso), 1, 1); vm.stopPrank(); + vm.roll(block.number + 1); } function test_correctAfterConstruction() public view { @@ -537,11 +559,13 @@ contract StakingTest is Test { (,, uint48 end,) = staking.epoch(); skip(end); + vm.roll(block.number + 1); vm.startPrank(ALICE); ftso.approve(address(staking), type(uint256).max); staking.stake(AMOUNT, ALICE, true, true); vm.stopPrank(); + vm.roll(block.number + 1); uint256 postBounty = AMOUNT + bounty; @@ -666,7 +690,6 @@ contract StakingTest is Test { vm.prank(BOB); staking.breakout(receiver, payout); - uint256 requestedPayout = 1; uint256 expectedPayout = 0; while (expectedPayout < payout / 2) { @@ -703,6 +726,30 @@ contract StakingTest is Test { assertEq(newDeposit, ftso.balanceOf(ALICE)); } + function test_revertDuringReentrancyRebase() public { + RebaseBatcher rebaseBatcher = new RebaseBatcher(address(staking)); + uint256 startBlock = block.number; + uint256 passedRebases = 69; + skip((1 + passedRebases) * EPOCH_LENGTH); + + (, uint48 prevNumber,,) = staking.epoch(); + rebaseBatcher.multipleRebases(passedRebases); + (, uint48 currNumber,,) = staking.epoch(); + assertEq(currNumber, prevNumber + 1); + + uint256 i; + for (; i < passedRebases;) { + vm.roll(block.number + 1); + (,uint48 number,,) = staking.epoch(); + rebaseBatcher.multipleRebases(1); + (,currNumber,,) = staking.epoch(); + assertEq(currNumber, number + 1); + unchecked { ++i; } + } + + assertEq(currNumber, startBlock + passedRebases); + } + function _mintAndApprove(address who, uint256 value) internal { vm.prank(VAULT); ftso.mint(who, value); diff --git a/test/staking/StakingDistributor.t.sol b/test/staking/StakingDistributor.t.sol index f9906a3..bdcc617 100644 --- a/test/staking/StakingDistributor.t.sol +++ b/test/staking/StakingDistributor.t.sol @@ -210,6 +210,7 @@ contract StakingDistributorTest is Test { assertEq(distributor.rewardRate(), TEN_PERCENT); (,, uint256 end,) = staking.epoch(); skip(end); + vm.roll(block.number + 1); staking.rebase(); assertEq(distributor.rewardRate(), TEN_PERCENT + 69); } @@ -234,6 +235,7 @@ contract StakingDistributorTest is Test { assertEq(distributor.rewardRate(), TEN_PERCENT); (,, uint256 end,) = staking.epoch(); skip(end); + vm.roll(block.number + 1); staking.rebase(); assertEq(distributor.rewardRate(), 1337); } @@ -246,6 +248,7 @@ contract StakingDistributorTest is Test { assertEq(distributor.rewardRate(), TEN_PERCENT); (,, uint256 end,) = staking.epoch(); skip(end); + vm.roll(block.number + 1); staking.rebase(); assertEq(distributor.rewardRate(), 1337); }