Skip to content

Commit

Permalink
feat: added tests for native token transfer (#14)
Browse files Browse the repository at this point in the history
* to actually execute the tx

* Added test to test token transfer
  • Loading branch information
jimmychu0807 authored Jan 10, 2025
1 parent fec2771 commit e8bc0bc
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 36 deletions.
39 changes: 28 additions & 11 deletions src/SemaphoreMSAValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ pragma solidity >=0.8.23 <=0.8.29;
import { ERC7579ValidatorBase } from "modulekit/Modules.sol";
import { VALIDATION_SUCCESS } from "modulekit/accounts/common/interfaces/IERC7579Module.sol";
import { PackedUserOperation } from "modulekit/external/ERC4337.sol";
import { LibSort } from "solady/utils/LibSort.sol";
import { LibBytes } from "solady/utils/LibBytes.sol";
import { LibSort, LibBytes } from "solady/Milady.sol";

import { ISemaphore, ISemaphoreGroups } from "./utils/Semaphore.sol";
import { ValidatorLibBytes } from "./utils/ValidatorLibBytes.sol";
Expand Down Expand Up @@ -48,6 +47,9 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
error InvalidSemaphoreProof(bytes reason);
error NonAllowedSelector(address account, bytes4 funcSel);
error NonValidatorCallBanned(address targetAddr, address selfAddr);
error InitiateTxWithNullAddress(address account);
error InitiateTxWithNullCallDataAndNullValue(address account, address targetAddr);
error ExecuteTxFailure(address account, address targetAddr, uint256 value, bytes callData);

// Events
event ModuleInitialized(address indexed account);
Expand Down Expand Up @@ -196,11 +198,11 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
emit RemovedMember(account, rmOwner);
}

function getNextSeqNum(address account) external returns (uint256) {
function getNextSeqNum(address account) external view returns (uint256) {
return acctSeqNum[account];
}

function getGroupId(address account) external returns (bool, uint256) {
function getGroupId(address account) external view returns (bool, uint256) {
uint256 groupId = groupMapping[account];
if (thresholds[account] == 0) return (false, 0);
return (true, groupId);
Expand All @@ -221,6 +223,14 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
address account = msg.sender;
uint256 groupId = groupMapping[account];

// Check:
// 1. targetAddr cannot be 0
// 2. if txCallData is blank, then msg.value must be > 0, else revert
if (targetAddr == address(0)) revert InitiateTxWithNullAddress(account);
if (LibBytes.cmp(txCallData, "") == 0 && msg.value == 0) {
revert InitiateTxWithNullCallDataAndNullValue(account, targetAddr);
}

// By this point, txParams should be validated.
// combine the txParams with the account nonce and compute its hash
uint256 seq = acctSeqNum[account];
Expand Down Expand Up @@ -273,22 +283,29 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
if (execute && cdc.count >= thresholds[account]) executeTx(txHash);
}

function executeTx(bytes32 txHash) public moduleInstalled {
function executeTx(bytes32 txHash) public moduleInstalled returns (bytes memory) {
// retrieve the group ID
address account = msg.sender;
uint256 groupId = groupMapping[account];
uint8 threshold = thresholds[account];
ExtCallCount storage cdc = acctTxCount[account][txHash];
ExtCallCount storage ecc = acctTxCount[account][txHash];

if (cdc.count == 0) revert TxHashNotFound(account, txHash);
if (cdc.count < threshold) revert ThresholdNotReach(account, threshold, cdc.count);
// console.log("executeTx");
if (ecc.count == 0) revert TxHashNotFound(account, txHash);
if (ecc.count < threshold) revert ThresholdNotReach(account, threshold, ecc.count);

//TODO: make the actual contract call here
// console.log("executeTx - check pass");
// REVIEW: Is there a better way to make external contract call given the target address,
// value, and call data.
address payable targetAddr = payable(ecc.targetAddr);
(bool success, bytes memory returnData) = targetAddr.call{ value: ecc.value }(ecc.callData);
if (!success) revert ExecuteTxFailure(account, targetAddr, ecc.value, ecc.callData);

emit ExecutedTx(account, txHash);

// Clean up the storage
delete acctTxCount[account][txHash];

return returnData;
}

/**
Expand Down Expand Up @@ -339,7 +356,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
// For callData, the first 120 bytes are reserved by ERC-7579 use. Then 32 bytes of value,
// then the remaining as the callData passed in getExecOps
bytes memory valAndCallData = userOp.callData[120:];
(uint256 val) = abi.decode(LibBytes.slice(valAndCallData, 0, 32), (uint256));
// (uint256 val) = abi.decode(LibBytes.slice(valAndCallData, 0, 32), (uint256));
bytes4 funcSel = bytes4(LibBytes.slice(valAndCallData, 32, 36));

// console.log("val: %s", val);
Expand Down
156 changes: 132 additions & 24 deletions test/SemaphoreMSAValidator.t.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.23;

// forge
// forge-std
import { Test } from "forge-std/Test.sol";
// import { console } from "forge-std/console.sol";

Expand Down Expand Up @@ -155,7 +155,7 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
assertEq(semaphoreValidator.memberCount(smartAcct.account), 0);
assertEq(semaphoreValidator.isInitialized(smartAcct.account), false);

(bool bExist, uint256 groupId) = semaphoreValidator.getGroupId(smartAcct.account);
(bool bExist,) = semaphoreValidator.getGroupId(smartAcct.account);
assertEq(bExist, false);
}

Expand Down Expand Up @@ -250,47 +250,155 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
userOpData.execUserOps();
}

function test_initiateTokensTransferMemberValid() public setupSmartAcctOneMember {
function test_initiateTokensTransferMemberValid()
public
setupSmartAcctOneMember
returns (bytes32 txHash)
{
User storage member = $users[0];

uint256 value = 1 ether;
address targetAddr = $users[1].addr;
uint256 seq = semaphoreValidator.getNextSeqNum(smartAcct.account);
txHash = keccak256(abi.encodePacked(seq, targetAddr, value, ""));

{
// Using scope to limit the number of local variables, work around the `stack too deep`
// error.
// Generate the semaphore proof
(bool bExist, uint256 groupId) = semaphoreValidator.getGroupId(smartAcct.account);
assert(bExist);
uint256[] memory members = new uint256[](1);
members[0] = member.identity.commitment();
ISemaphore.SemaphoreProof memory smProof =
member.identity.generateSempahoreProof(groupId, members, txHash);

// Composing the UserOpData
UserOpData memory userOpData = smartAcct.getExecOps({
target: address(semaphoreValidator),
value: value,
callData: abi.encodeCall(
SemaphoreMSAValidator.initiateTx, (targetAddr, "", smProof, false)
),
txValidator: address(semaphoreValidator)
});
userOpData.userOp.signature = member.identity.signHash(userOpData.userOpHash);

// Expecting `InitiatedTx` event to be emitted
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
emit SemaphoreMSAValidator.InitiatedTx(smartAcct.account, seq, txHash);
userOpData.execUserOps();
}

bytes32 txHash = keccak256(abi.encodePacked(seq, targetAddr, value, ""));
// Test the states are changed accordingly
assertEq(semaphoreValidator.acctSeqNum(smartAcct.account), seq + 1);

// Generate the semaphore proof
(bool bExist, uint256 groupId) = semaphoreValidator.getGroupId(smartAcct.account);
assert(bExist);
uint256[] memory members = new uint256[](1);
members[0] = member.identity.commitment();
ISemaphore.SemaphoreProof memory smProof =
member.identity.generateSempahoreProof(groupId, members, txHash);
(address eccTargetAddr, bytes memory eccCallData, uint256 eccValue, uint8 eccCount) =
semaphoreValidator.acctTxCount(smartAcct.account, txHash);

assertEq(eccTargetAddr, targetAddr);
assertEq(eccValue, value);
assertEq(eccCallData, "");
assertEq(eccCount, 1);
}

function test_initiateTokensTransferMemberValidAndExecuteInvalidTxHash() public {
bytes32 forgedHash = test_initiateTokensTransferMemberValid();
// Changed the last 2 bytes to 0xffff
forgedHash = forgedHash | bytes32(uint256(65_535));

User storage member = $users[0];

// Now execute the token transfer.
// Composing the UserOpData.
UserOpData memory userOpData = smartAcct.getExecOps({
target: address(semaphoreValidator),
value: 0,
callData: abi.encodeCall(SemaphoreMSAValidator.executeTx, (forgedHash)),
txValidator: address(semaphoreValidator)
});
userOpData.userOp.signature = member.identity.signHash(userOpData.userOpHash);

smartAcct.expect4337Revert(
abi.encodeWithSelector(
SemaphoreMSAValidator.TxHashNotFound.selector, smartAcct.account, forgedHash
)
);
userOpData.execUserOps();
}

function test_ExecuteTxFailure() public pure {
revert("to be implemented");
}

// Composing the UserOpData
function test_initiateTokensTransferMemberValidAndExecuteValid() public {
bytes32 txHash = test_initiateTokensTransferMemberValid();

User storage member = $users[0];
address targetAddr = $users[1].addr;
uint256 value = 1 ether;
uint256 beforeBalance = targetAddr.balance;

// Now execute the token transfer.
// Composing the UserOpData.
UserOpData memory userOpData = smartAcct.getExecOps({
target: address(semaphoreValidator),
value: value,
callData: abi.encodeCall(SemaphoreMSAValidator.initiateTx, (targetAddr, "", smProof, false)),
value: 0,
callData: abi.encodeCall(SemaphoreMSAValidator.executeTx, (txHash)),
txValidator: address(semaphoreValidator)
});
userOpData.userOp.signature = member.identity.signHash(userOpData.userOpHash);

// Expecting `InitiatedTx` event to be emitted
// Test event emission
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
emit SemaphoreMSAValidator.InitiatedTx(smartAcct.account, seq, txHash);
emit SemaphoreMSAValidator.ExecutedTx(smartAcct.account, txHash);
userOpData.execUserOps();

// Test the states are changed accordingly
assertEq(semaphoreValidator.acctSeqNum(smartAcct.account), seq + 1);
uint256 afterBalance = targetAddr.balance;
assertEq(afterBalance - beforeBalance, value);
}

(address eccTargetAddr, bytes memory eccCallData, uint256 eccValue, uint8 eccCount) =
semaphoreValidator.acctTxCount(smartAcct.account, txHash);
function test_initiateTokensTransferMemberAndExecuteValid() public setupSmartAcctOneMember {
User storage member = $users[0];
uint256 value = 1 ether;
address targetAddr = $users[1].addr;
uint256 beforeBalance = targetAddr.balance;
uint256 seq = semaphoreValidator.getNextSeqNum(smartAcct.account);
bytes32 txHash = keccak256(abi.encodePacked(seq, targetAddr, value, ""));

assertEq(eccTargetAddr, targetAddr);
assertEq(eccValue, value);
assertEq(eccCallData, "");
assertEq(eccCount, 1);
{
// Using scope to limit the number of local variables, work around the `stack too deep`
// error.
// Generate the semaphore proof
(bool bExist, uint256 groupId) = semaphoreValidator.getGroupId(smartAcct.account);
assert(bExist);
uint256[] memory members = new uint256[](1);
members[0] = member.identity.commitment();
ISemaphore.SemaphoreProof memory smProof =
member.identity.generateSempahoreProof(groupId, members, txHash);

// Composing the UserOpData
UserOpData memory userOpData = smartAcct.getExecOps({
target: address(semaphoreValidator),
value: value,
callData: abi.encodeCall(
SemaphoreMSAValidator.initiateTx, (targetAddr, "", smProof, true)
),
txValidator: address(semaphoreValidator)
});
userOpData.userOp.signature = member.identity.signHash(userOpData.userOpHash);

// Expecting `InitiatedTx` event to be emitted
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
emit SemaphoreMSAValidator.InitiatedTx(smartAcct.account, seq, txHash);
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
emit SemaphoreMSAValidator.ExecutedTx(smartAcct.account, txHash);
userOpData.execUserOps();
}

// Confirm user balance has changed
uint256 afterBalance = targetAddr.balance;
assertEq(afterBalance - beforeBalance, value);
}

function test_initiateTxOneMemberNonValidatorCall()
Expand Down
6 changes: 5 additions & 1 deletion test/utils/TestUtils.sol
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ library IdentityLib {
});
}

function _uint256ArrToString(uint256[] memory arr) internal returns (string memory retStr) {
function _uint256ArrToString(uint256[] memory arr)
internal
pure
returns (string memory retStr)
{
for (uint256 i = 0; i < arr.length; i++) {
if (i == arr.length - 1) {
retStr = string.concat(retStr, LibString.toString(arr[i]));
Expand Down

0 comments on commit e8bc0bc

Please sign in to comment.