Skip to content

Commit 8f9934f

Browse files
authored
Merge pull request #44 from datachainlab/initializable-erc20transfermodule
fix: Initializable ERC20TransferModule
2 parents e578b4a + 6d9a51e commit 8f9934f

File tree

2 files changed

+111
-14
lines changed

2 files changed

+111
-14
lines changed

src/example/ERC20TransferModule.sol

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,21 @@ pragma solidity ^0.8.20;
33

44
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
55
import {SafeERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol";
6+
import {Initializable} from "@openzeppelin/contracts/proxy/utils/Initializable.sol";
7+
import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol";
68
import {ContractModuleBase} from "../core/ContractModuleBase.sol";
79
import {CrossContext} from "../core/IContractModule.sol";
810

9-
abstract contract ERC20TransferModule is ContractModuleBase {
11+
abstract contract ERC20TransferModule is Initializable, ContractModuleBase, Ownable {
1012
using SafeERC20 for IERC20;
1113

1214
error ERC20TransferModuleInvalidCallInfo();
1315
error ERC20TransferModuleTxAlreadyPending();
1416
error ERC20TransferModuleUnauthorized();
17+
error ERC20TransferModuleNotInitialized();
18+
error ERC20TransferModuleInvalidAddress();
19+
20+
event ERC20TransferModuleInitialized(address indexed crossModule, address indexed token);
1521

1622
struct PendingTx {
1723
address from;
@@ -22,17 +28,25 @@ abstract contract ERC20TransferModule is ContractModuleBase {
2228
// txID => PendingTx
2329
mapping(bytes32 => PendingTx) public pendingTxs;
2430

25-
address public immutable CROSS_MODULE;
26-
IERC20 public immutable TOKEN;
31+
address public crossModule;
32+
IERC20 public token;
33+
34+
constructor() Ownable(msg.sender) {}
2735

2836
modifier onlyCrossModule() {
29-
if (msg.sender != CROSS_MODULE) revert ERC20TransferModuleUnauthorized();
37+
if (crossModule == address(0)) revert ERC20TransferModuleNotInitialized();
38+
if (msg.sender != crossModule) revert ERC20TransferModuleUnauthorized();
3039
_;
3140
}
3241

33-
constructor(address _crossModule, address _token) {
34-
CROSS_MODULE = _crossModule;
35-
TOKEN = IERC20(_token);
42+
function initialize(address _crossModule, address _token) external initializer onlyOwner {
43+
if (_crossModule == address(0) || _token == address(0)) {
44+
revert ERC20TransferModuleInvalidAddress();
45+
}
46+
crossModule = _crossModule;
47+
token = IERC20(_token);
48+
49+
emit ERC20TransferModuleInitialized(_crossModule, _token);
3650
}
3751

3852
function decodeCallInfo(bytes calldata callInfo)
@@ -60,7 +74,7 @@ abstract contract ERC20TransferModule is ContractModuleBase {
6074

6175
// IMPORTANT: The implementing contract MUST ensure in `_authorize` that the `from` address corresponds to the authenticated signer.
6276
// slither-disable-next-line arbitrary-send-erc20
63-
TOKEN.safeTransferFrom(from, to, amount);
77+
token.safeTransferFrom(from, to, amount);
6478

6579
return "";
6680
}
@@ -82,7 +96,7 @@ abstract contract ERC20TransferModule is ContractModuleBase {
8296

8397
// IMPORTANT: The implementing contract MUST ensure in `_authorize` that the `from` address corresponds to the authenticated signer.
8498
// slither-disable-next-line arbitrary-send-erc20
85-
TOKEN.safeTransferFrom(from, address(this), amount);
99+
token.safeTransferFrom(from, address(this), amount);
86100

87101
return "";
88102
}
@@ -94,7 +108,7 @@ abstract contract ERC20TransferModule is ContractModuleBase {
94108
if (pending.from != address(0)) {
95109
delete pendingTxs[txID];
96110
// Transfer locked tokens to the destination
97-
TOKEN.safeTransfer(pending.to, pending.amount);
111+
token.safeTransfer(pending.to, pending.amount);
98112
}
99113
}
100114

@@ -105,7 +119,7 @@ abstract contract ERC20TransferModule is ContractModuleBase {
105119
if (pending.from != address(0)) {
106120
delete pendingTxs[txID];
107121
// Refund tokens to the sender
108-
TOKEN.safeTransfer(pending.from, pending.amount);
122+
token.safeTransfer(pending.from, pending.amount);
109123
}
110124
}
111125
}

test/ERC20TransferModule.t.sol

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import {Account as AuthAccount, AuthType} from "../src/proto/cross/core/auth/Aut
99
import {GoogleProtobufAny} from "@hyperledger-labs/yui-ibc-solidity/contracts/proto/GoogleProtobufAny.sol";
1010
import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol";
1111
import {IERC20Errors} from "@openzeppelin/contracts/interfaces/draft-IERC6093.sol";
12+
import {Initializable} from "@openzeppelin/contracts/proxy/utils/Initializable.sol";
13+
import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol";
1214

1315
contract MockERC20 is ERC20 {
1416
constructor() ERC20("Mock Token", "MCK") {}
@@ -19,8 +21,6 @@ contract MockERC20 is ERC20 {
1921
}
2022

2123
contract ERC20TransferModuleHarness is ERC20TransferModule {
22-
constructor(address _crossModule, address _token) ERC20TransferModule(_crossModule, _token) {}
23-
2424
function _authorize(CrossContext calldata, bytes calldata) internal pure override {
2525
// allow all for testing
2626
}
@@ -45,7 +45,8 @@ contract ERC20ModuleTest is Test {
4545

4646
token = new MockERC20();
4747

48-
harness = new ERC20TransferModuleHarness(address(this), address(token));
48+
harness = new ERC20TransferModuleHarness();
49+
harness.initialize(address(this), address(token));
4950

5051
token.mint(sender, INITIAL_BALANCE);
5152
}
@@ -74,6 +75,54 @@ contract ERC20ModuleTest is Test {
7475
return abi.encode(_from, _to, _amount);
7576
}
7677

78+
// --- Initialization Tests ---
79+
80+
function test_initialize_Success() public {
81+
ERC20TransferModuleHarness newHarness = new ERC20TransferModuleHarness();
82+
address crossModule = makeAddr("newCrossModule");
83+
address newToken = makeAddr("newToken");
84+
85+
vm.expectEmit(address(newHarness));
86+
emit ERC20TransferModule.ERC20TransferModuleInitialized(crossModule, newToken);
87+
88+
newHarness.initialize(crossModule, newToken);
89+
90+
assertEq(newHarness.crossModule(), crossModule);
91+
assertEq(address(newHarness.token()), newToken);
92+
}
93+
94+
function test_initialize_RevertWhen_InvalidAddress() public {
95+
ERC20TransferModuleHarness newHarness = new ERC20TransferModuleHarness();
96+
address validAddress = makeAddr("valid");
97+
98+
// 1. CrossModule is zero
99+
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleInvalidAddress.selector);
100+
newHarness.initialize(address(0), validAddress);
101+
102+
// 2. Token is zero
103+
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleInvalidAddress.selector);
104+
newHarness.initialize(validAddress, address(0));
105+
}
106+
107+
function test_initialize_RevertWhen_AlreadyInitialized() public {
108+
vm.expectRevert(Initializable.InvalidInitialization.selector);
109+
harness.initialize(address(this), address(token));
110+
}
111+
112+
function test_initialize_RevertWhen_CallerIsNotOwner() public {
113+
ERC20TransferModuleHarness newHarness = new ERC20TransferModuleHarness();
114+
115+
address attacker = makeAddr("attacker");
116+
address crossModule = makeAddr("crossModule");
117+
address newToken = makeAddr("newToken");
118+
119+
vm.prank(attacker);
120+
121+
vm.expectRevert(abi.encodeWithSelector(Ownable.OwnableUnauthorizedAccount.selector, attacker));
122+
123+
newHarness.initialize(crossModule, newToken);
124+
}
125+
77126
// --- Decode Tests ---
78127

79128
function test_decodeCallInfo_Success() public view {
@@ -136,6 +185,15 @@ contract ERC20ModuleTest is Test {
136185
harness.onContractCommitImmediately(context, callInfo);
137186
}
138187

188+
function test_onContractCommitImmediately_RevertWhen_NotInitialized() public {
189+
ERC20TransferModuleHarness uninitHarness = new ERC20TransferModuleHarness();
190+
CrossContext memory context = _createContext(sender);
191+
bytes memory callInfo = _createCallInfo(sender, receiver, AMOUNT);
192+
193+
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleNotInitialized.selector);
194+
uninitHarness.onContractCommitImmediately(context, callInfo);
195+
}
196+
139197
// --- Prepare Tests ---
140198

141199
function test_onContractPrepare_Success() public {
@@ -216,6 +274,15 @@ contract ERC20ModuleTest is Test {
216274
harness.onContractPrepare(context, callInfo);
217275
}
218276

277+
function test_onContractPrepare_RevertWhen_NotInitialized() public {
278+
ERC20TransferModuleHarness uninitHarness = new ERC20TransferModuleHarness();
279+
CrossContext memory context = _createContext(sender);
280+
bytes memory callInfo = _createCallInfo(sender, receiver, AMOUNT);
281+
282+
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleNotInitialized.selector);
283+
uninitHarness.onContractPrepare(context, callInfo);
284+
}
285+
219286
// --- Commit Tests ---
220287

221288
function test_onCommit_Success() public {
@@ -260,6 +327,14 @@ contract ERC20ModuleTest is Test {
260327
harness.onCommit(context);
261328
}
262329

330+
function test_onCommit_RevertWhen_NotInitialized() public {
331+
ERC20TransferModuleHarness uninitHarness = new ERC20TransferModuleHarness();
332+
CrossContext memory context = _createContext(sender);
333+
334+
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleNotInitialized.selector);
335+
uninitHarness.onCommit(context);
336+
}
337+
263338
// --- Abort Tests ---
264339

265340
function test_onAbort_Success() public {
@@ -305,4 +380,12 @@ contract ERC20ModuleTest is Test {
305380
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleUnauthorized.selector);
306381
harness.onAbort(context);
307382
}
383+
384+
function test_onAbort_RevertWhen_NotInitialized() public {
385+
ERC20TransferModuleHarness uninitHarness = new ERC20TransferModuleHarness();
386+
CrossContext memory context = _createContext(sender);
387+
388+
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleNotInitialized.selector);
389+
uninitHarness.onAbort(context);
390+
}
308391
}

0 commit comments

Comments
 (0)