Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/example/ERC20TransferModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ pragma solidity ^0.8.20;

import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import {SafeERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol";
import {Initializable} from "@openzeppelin/contracts/proxy/utils/Initializable.sol";
import {ContractModuleBase} from "../core/ContractModuleBase.sol";
import {CrossContext} from "../core/IContractModule.sol";

abstract contract ERC20TransferModule is ContractModuleBase {
abstract contract ERC20TransferModule is Initializable, ContractModuleBase {
using SafeERC20 for IERC20;

error ERC20TransferModuleInvalidCallInfo();
error ERC20TransferModuleTxAlreadyPending();
error ERC20TransferModuleUnauthorized();
error ERC20TransferModuleNotInitialized();
error ERC20TransferModuleInvalidAddress();

struct PendingTx {
address from;
Expand All @@ -22,17 +25,21 @@ abstract contract ERC20TransferModule is ContractModuleBase {
// txID => PendingTx
mapping(bytes32 => PendingTx) public pendingTxs;

address public immutable CROSS_MODULE;
IERC20 public immutable TOKEN;
address public crossModule;
IERC20 public token;

modifier onlyCrossModule() {
if (msg.sender != CROSS_MODULE) revert ERC20TransferModuleUnauthorized();
if (crossModule == address(0)) revert ERC20TransferModuleNotInitialized();
if (msg.sender != crossModule) revert ERC20TransferModuleUnauthorized();
_;
}

constructor(address _crossModule, address _token) {
CROSS_MODULE = _crossModule;
TOKEN = IERC20(_token);
function initialize(address _crossModule, address _token) external initializer {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize can be called by anyone. Isn’t there a risk that it could be initialized maliciously? It might be better to restrict it using Ownable or AccessControl.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix: cd171a6

if (_crossModule == address(0) || _token == address(0)) {
revert ERC20TransferModuleInvalidAddress();
}
crossModule = _crossModule;
token = IERC20(_token);
}

function decodeCallInfo(bytes calldata callInfo)
Expand Down Expand Up @@ -60,7 +67,7 @@ abstract contract ERC20TransferModule is ContractModuleBase {

// IMPORTANT: The implementing contract MUST ensure in `_authorize` that the `from` address corresponds to the authenticated signer.
// slither-disable-next-line arbitrary-send-erc20
TOKEN.safeTransferFrom(from, to, amount);
token.safeTransferFrom(from, to, amount);

return "";
}
Expand All @@ -82,7 +89,7 @@ abstract contract ERC20TransferModule is ContractModuleBase {

// IMPORTANT: The implementing contract MUST ensure in `_authorize` that the `from` address corresponds to the authenticated signer.
// slither-disable-next-line arbitrary-send-erc20
TOKEN.safeTransferFrom(from, address(this), amount);
token.safeTransferFrom(from, address(this), amount);

return "";
}
Expand All @@ -94,7 +101,7 @@ abstract contract ERC20TransferModule is ContractModuleBase {
if (pending.from != address(0)) {
delete pendingTxs[txID];
// Transfer locked tokens to the destination
TOKEN.safeTransfer(pending.to, pending.amount);
token.safeTransfer(pending.to, pending.amount);
}
}

Expand All @@ -105,7 +112,7 @@ abstract contract ERC20TransferModule is ContractModuleBase {
if (pending.from != address(0)) {
delete pendingTxs[txID];
// Refund tokens to the sender
TOKEN.safeTransfer(pending.from, pending.amount);
token.safeTransfer(pending.from, pending.amount);
}
}
}
74 changes: 71 additions & 3 deletions test/ERC20TransferModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {Account as AuthAccount, AuthType} from "../src/proto/cross/core/auth/Aut
import {GoogleProtobufAny} from "@hyperledger-labs/yui-ibc-solidity/contracts/proto/GoogleProtobufAny.sol";
import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol";
import {IERC20Errors} from "@openzeppelin/contracts/interfaces/draft-IERC6093.sol";
import {Initializable} from "@openzeppelin/contracts/proxy/utils/Initializable.sol";

contract MockERC20 is ERC20 {
constructor() ERC20("Mock Token", "MCK") {}
Expand All @@ -19,8 +20,6 @@ contract MockERC20 is ERC20 {
}

contract ERC20TransferModuleHarness is ERC20TransferModule {
constructor(address _crossModule, address _token) ERC20TransferModule(_crossModule, _token) {}

function _authorize(CrossContext calldata, bytes calldata) internal pure override {
// allow all for testing
}
Expand All @@ -45,7 +44,8 @@ contract ERC20ModuleTest is Test {

token = new MockERC20();

harness = new ERC20TransferModuleHarness(address(this), address(token));
harness = new ERC20TransferModuleHarness();
harness.initialize(address(this), address(token));

token.mint(sender, INITIAL_BALANCE);
}
Expand Down Expand Up @@ -74,6 +74,40 @@ contract ERC20ModuleTest is Test {
return abi.encode(_from, _to, _amount);
}

// --- Initialization Tests ---

function test_initialize_Success() public {
ERC20TransferModuleHarness newHarness = new ERC20TransferModuleHarness();
address crossModule = makeAddr("newCrossModule");
address newToken = makeAddr("newToken");

assertEq(newHarness.crossModule(), address(0));
assertEq(address(newHarness.token()), address(0));

newHarness.initialize(crossModule, newToken);

assertEq(newHarness.crossModule(), crossModule);
assertEq(address(newHarness.token()), newToken);
}

function test_initialize_RevertWhen_InvalidAddress() public {
ERC20TransferModuleHarness newHarness = new ERC20TransferModuleHarness();
address validAddress = makeAddr("valid");

// 1. CrossModule is zero
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleInvalidAddress.selector);
newHarness.initialize(address(0), validAddress);

// 2. Token is zero
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleInvalidAddress.selector);
newHarness.initialize(validAddress, address(0));
}

function test_initialize_RevertWhen_AlreadyInitialized() public {
vm.expectRevert(Initializable.InvalidInitialization.selector);
harness.initialize(address(this), address(token));
}

// --- Decode Tests ---

function test_decodeCallInfo_Success() public view {
Expand Down Expand Up @@ -136,6 +170,15 @@ contract ERC20ModuleTest is Test {
harness.onContractCommitImmediately(context, callInfo);
}

function test_onContractCommitImmediately_RevertWhen_NotInitialized() public {
ERC20TransferModuleHarness uninitHarness = new ERC20TransferModuleHarness();
CrossContext memory context = _createContext(sender);
bytes memory callInfo = _createCallInfo(sender, receiver, AMOUNT);

vm.expectRevert(ERC20TransferModule.ERC20TransferModuleNotInitialized.selector);
uninitHarness.onContractCommitImmediately(context, callInfo);
}

// --- Prepare Tests ---

function test_onContractPrepare_Success() public {
Expand Down Expand Up @@ -216,6 +259,15 @@ contract ERC20ModuleTest is Test {
harness.onContractPrepare(context, callInfo);
}

function test_onContractPrepare_RevertWhen_NotInitialized() public {
ERC20TransferModuleHarness uninitHarness = new ERC20TransferModuleHarness();
CrossContext memory context = _createContext(sender);
bytes memory callInfo = _createCallInfo(sender, receiver, AMOUNT);

vm.expectRevert(ERC20TransferModule.ERC20TransferModuleNotInitialized.selector);
uninitHarness.onContractPrepare(context, callInfo);
}

// --- Commit Tests ---

function test_onCommit_Success() public {
Expand Down Expand Up @@ -260,6 +312,14 @@ contract ERC20ModuleTest is Test {
harness.onCommit(context);
}

function test_onCommit_RevertWhen_NotInitialized() public {
ERC20TransferModuleHarness uninitHarness = new ERC20TransferModuleHarness();
CrossContext memory context = _createContext(sender);

vm.expectRevert(ERC20TransferModule.ERC20TransferModuleNotInitialized.selector);
uninitHarness.onCommit(context);
}

// --- Abort Tests ---

function test_onAbort_Success() public {
Expand Down Expand Up @@ -305,4 +365,12 @@ contract ERC20ModuleTest is Test {
vm.expectRevert(ERC20TransferModule.ERC20TransferModuleUnauthorized.selector);
harness.onAbort(context);
}

function test_onAbort_RevertWhen_NotInitialized() public {
ERC20TransferModuleHarness uninitHarness = new ERC20TransferModuleHarness();
CrossContext memory context = _createContext(sender);

vm.expectRevert(ERC20TransferModule.ERC20TransferModuleNotInitialized.selector);
uninitHarness.onAbort(context);
}
}
Loading