Skip to content

Commit f28d5c0

Browse files
committed
fix: unnaccounted for gas costs in withdraw tests established and set
1 parent 2e526f0 commit f28d5c0

File tree

1 file changed

+45
-27
lines changed

1 file changed

+45
-27
lines changed

test/TokenVault.t.sol

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ import {TokenVault} from "../src/TokenVault.sol";
55
contract TokenVaultTest is Test {
66
TokenVault public vault;
77
address public user1 = address(0xBEEF);
8+
address public user2 = address(0xBEEF11);
89
uint256 public constant INITIAL_BALANCE = 10 ether;
9-
10+
uint256 public constant WITHDRAW_AMOUNT = 0.1 ether;
1011
function setUp() public {
1112
// 1. Deploy the contract
1213
vault = new TokenVault();
1314

1415
// 2. Fund the test account (user1) for transactions
1516
vm.deal(user1, INITIAL_BALANCE);
17+
vm.deal(user2, INITIAL_BALANCE);
1618
}
1719

1820
function testDeposit() public {
@@ -72,32 +74,48 @@ contract TokenVaultTest is Test {
7274
}
7375

7476
// Invariant: The totalValue must always equal the sum of all user balances
75-
function invariant_TotalValueEqualsSumOfBalances() public view {
76-
// This invariant function is run after every random call in the sequence.
77-
uint256 sumOfBalances =
78-
vault.balances(user1) +
79-
vault.balances(address(0xDEAD)); // Imagine tracking a few key addresses
77+
function invariant_TotalValueEqualsSumOfBalances() public view {
78+
// This invariant function is run after every random call in the sequence.
79+
uint256 sumOfBalances = vault.balances(user1) +
80+
vault.balances(address(0xDEAD)); // Imagine tracking a few key addresses
8081

81-
assertEq(vault.totalValue(), sumOfBalances, "Invariant violated: totalValue != sum of balances");
82-
}
82+
assertEq(
83+
vault.totalValue(),
84+
sumOfBalances,
85+
"Invariant violated: totalValue != sum of balances"
86+
);
87+
}
8388

84-
function testCompareWithdrawGas() public {
85-
uint256 depositAmount = 1 ether;
86-
uint256 withdrawAmount = 0.1 ether;
87-
88-
vm.deal(user1, depositAmount);
89-
vm.prank(user1);
90-
vault.deposit{value: depositAmount}(depositAmount);
91-
92-
// 1. Call the optimized version
93-
vm.prank(user1);
94-
vault.withdrawOptimized(withdrawAmount);
95-
96-
// 2. Call the unoptimized version (requires resetting the state, usually in a separate test)
97-
// To properly compare, you should call them in separate test functions:
98-
// function testWithdrawOptimized() public { /* ... call vault.withdrawOptimized ... */ }
99-
// function testWithdrawUnoptimized() public { /* ... call vault.withdrawUnoptimized ... */ }
100-
101-
// In this example, we'll just run the report on all functions.
102-
}
89+
function testCompareWithdrawGas_Optimized() public {
90+
// This test runs completely independently of the unoptimized one,
91+
// ensuring a fresh state (cold storage) for the withdrawal.
92+
93+
// We're withdrawing only part of the deposit, so the vault's state
94+
// (e.g., balance, shares) will be written to for the first time
95+
// in this specific test function's execution.
96+
uint256 preUserBalance = vault.balances(user1);
97+
uint256 amount = 1 ether;
98+
// 2. Execute with random 'amount'
99+
vm.startPrank(user1);
100+
vault.deposit{value: amount}(amount);
101+
vault.withdrawOptimized(WITHDRAW_AMOUNT);
102+
vm.stopPrank();
103+
104+
// Note: No need for assertions here unless you are specifically testing
105+
// correctness. The primary goal is the gas report.
106+
}
107+
108+
// 2. Test for the Unoptimized version
109+
function testCompareWithdrawGas_Unoptimized() public {
110+
// This test also runs independently, ensuring a fresh state (cold storage)
111+
// for the unoptimized withdrawal.
112+
uint256 preUserBalance = vault.balances(user2);
113+
uint256 amount = 1 ether;
114+
// 2. Execute with random 'amount'
115+
vm.startPrank(user2);
116+
vault.deposit{value: amount}(amount);
117+
// If your Vault contract has an unoptimized function named withdrawUnoptimized:
118+
vault.withdrawUnoptimized(WITHDRAW_AMOUNT);
119+
vm.stopPrank();
120+
}
103121
}

0 commit comments

Comments
 (0)