Skip to content

Commit

Permalink
feat(wip): completed removeMember() and test case (#18)
Browse files Browse the repository at this point in the history
* updated

* added removeMember test
  • Loading branch information
jimmychu0807 authored Jan 13, 2025
1 parent 5e6ae75 commit 13e6672
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 63 deletions.
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
"devDependencies": {
"@rhinestone/modulekit": "~0.5.4",
"@semaphore-protocol/contracts": "github:jimmychu0807/semaphore#identity-cli&path:/packages/contracts/contracts",
"@semaphore-protocol/core": "github:jimmychu0807/semaphore#identity-cli&path:/packages/core",
"@semaphore-protocol/identity": "github:jimmychu0807/semaphore#identity-cli&path:/packages/identity",
"@semaphore-protocol/proof": "github:jimmychu0807/semaphore#identity-cli&path:/packages/proof",
"@semaphore-protocol/group": "github:jimmychu0807/semaphore#identity-cli&path:/packages/group",
"poseidon-solidity": "github:chancehudson/poseidon-solidity#main",
"rimraf": "^5.0.5",
"solady": "^0.0.287"
Expand All @@ -47,7 +47,7 @@
"prepack": "pnpm install && bash ./shell/prepare-artifacts.sh",
"prettier:check": "prettier --no-error-on-unmatched-pattern -c \"{src,test,script}/**/*.{json,md,svg,yml}\"",
"prettier:write": "prettier --no-error-on-unmatched-pattern -w \"{src,test,script}/**/*.{json,md,svg,yml}\"",
"test": "COMPLIANCE=true forge test --ffi",
"test": "forge test --ffi",
"test:lite": "FOUNDRY_PROFILE=lite forge test",
"test:optimized": "pnpm run build:optimized && FOUNDRY_PROFILE=test-optimized forge test"
},
Expand Down
61 changes: 20 additions & 41 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 12 additions & 13 deletions src/SemaphoreMSAValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
ISemaphoreGroups public groups;
mapping(address account => uint256 groupId) public groupMapping;
mapping(address account => uint8 threshold) public thresholds;
mapping(address account => uint8 count) public memberCount;

// smart account -> hash(call(params)) -> valid proof count
mapping(address account => mapping(bytes32 txHash => ExtCallCount callDataCount)) public
Expand Down Expand Up @@ -133,6 +134,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {

// Add members to the group
semaphore.addMembers(groupId, cmts);
memberCount[account] = uint8(cmts.length);

emit ModuleInitialized(account);
}
Expand All @@ -143,6 +145,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
delete thresholds[account];
delete groupMapping[account];
delete acctSeqNum[account];
delete memberCount[account];

//TODO: what is a good way to delete entries associated with `acctTxCount[account]`,
// The following line will make the compiler fail.
Expand All @@ -151,15 +154,9 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
emit ModuleUninitialized(account);
}

function memberCount(address account) public view returns (uint8 cnt) {
// account doesn't belong to a semaphore group. We return 0
if (thresholds[account] == 0) return 0;
cnt = uint8(groups.getMerkleTreeSize(groupMapping[account]));
}

function setThreshold(uint8 newThreshold) external moduleInstalled {
address account = msg.sender;
if (newThreshold == 0 || newThreshold > memberCount(account)) {
if (newThreshold == 0 || newThreshold > memberCount[account]) {
revert InvalidThreshold(account);
}

Expand All @@ -171,14 +168,16 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
address account = msg.sender;
uint256 groupId = groupMapping[account];

if (memberCount(account) + cmts.length > MAX_MEMBERS) revert MaxMemberReached(account);
if (memberCount[account] + cmts.length > MAX_MEMBERS) revert MaxMemberReached(account);

for (uint256 i = 0; i < cmts.length; ++i) {
if (cmts[i] == uint256(0)) revert InvalidCommitment(account);
if (groups.hasMember(groupId, cmts[i])) revert IsMemberAlready(account, cmts[i]);
}

semaphore.addMembers(groupId, cmts);
memberCount[account] += uint8(cmts.length);

emit AddedMembers(account, cmts.length);
}

Expand All @@ -191,12 +190,13 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
{
address account = msg.sender;

if (memberCount(account) == thresholds[account]) revert MemberCntReachesThreshold(account);
if (memberCount[account] == thresholds[account]) revert MemberCntReachesThreshold(account);

uint256 groupId = groupMapping[account];
if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt);

semaphore.removeMember(groupId, cmt, merkleProofSiblings);
memberCount[account] -= 1;

emit RemovedMember(account, cmt);
}
Expand Down Expand Up @@ -347,7 +347,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
uint256 cmt = Identity.getCommitment(pubKey);
if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt);

// We don't allow call to other contract.
// We don't allow call to other contracts.
address targetAddr = address(bytes20(userOp.callData[100:120]));
if (targetAddr != address(this)) revert NonValidatorCallBanned(targetAddr, address(this));

Expand All @@ -356,10 +356,9 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
bytes memory valAndCallData = userOp.callData[120:];
bytes4 funcSel = bytes4(LibBytes.slice(valAndCallData, 32, 36));

// Allow only these few types on function calls to pass, and reject all other on-chain
// calls. They must be executed via `executeTx()` function.
// We only allow calls to `initiateTx()`, `signTx()`, and `executeTx()` to pass,
// and reject the rest.
if (_isAllowedSelector(funcSel)) return VALIDATION_SUCCESS;

revert NonAllowedSelector(account, funcSel);
}

Expand Down
45 changes: 42 additions & 3 deletions test/SemaphoreMSAValidator.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { SemaphoreMSAValidator, ERC7579ValidatorBase } from "../src/SemaphoreMSA
import {
getEmptyUserOperation,
getEmptySemaphoreProof,
getGroupRmMerkleProof,
getTestUserOpCallData,
Identity,
IdentityLib
Expand Down Expand Up @@ -236,8 +237,14 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
uint256[] memory newMembers = new uint256[](1);
newMembers[0] = newCommitment;

vm.prank(smartAcct.account);
// Test: addMembers() is successfully executed
vm.startPrank(smartAcct.account);
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
emit SemaphoreMSAValidator.AddedMembers(smartAcct.account, uint256(1));
semaphoreValidator.addMembers(newMembers);
vm.stopPrank();

assertEq(semaphoreValidator.memberCount(smartAcct.account), 2);

// Test: the userOp should pass
uint256 validationData = ERC7579ValidatorBase.ValidationData.unwrap(
Expand All @@ -246,8 +253,40 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
assertEq(validationData, VALIDATION_SUCCESS);
}

function test_removeMember() public setupSmartAcctWithMembersThreshold(2, 1) {
revert("to be implemented");
function test_removeMember() public setupSmartAcctWithMembersThreshold(MEMBER_NUM, 1) {
uint256[] memory cmts = _getMemberCmts(MEMBER_NUM);
User storage rmUser = $users[0];
uint256 rmCmt = rmUser.identity.commitment();

(uint256[] memory merkleProof,) = getGroupRmMerkleProof(cmts, rmCmt);

// Test: remove member
vm.startPrank(smartAcct.account);
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
emit SemaphoreMSAValidator.RemovedMember(smartAcct.account, rmCmt);
semaphoreValidator.removeMember(rmCmt, merkleProof);
vm.stopPrank();

assertEq(semaphoreValidator.memberCount(smartAcct.account), MEMBER_NUM - 1);

// Compose a UserOp
PackedUserOperation memory userOp = getEmptyUserOperation();
userOp.sender = smartAcct.account;
userOp.callData = getTestUserOpCallData(
0,
address(semaphoreValidator),
abi.encodeWithSelector(SemaphoreMSAValidator.initiateTx.selector)
);
bytes32 userOpHash = bytes32(keccak256("userOpHash"));
userOp.signature = rmUser.identity.signHash(userOpHash);

// Test: the userOp should fail and revert
vm.expectRevert(
abi.encodeWithSelector(
SemaphoreMSAValidator.MemberNotExists.selector, smartAcct.account, rmCmt
)
);
semaphoreValidator.validateUserOp(userOp, userOpHash);
}

function _getSemaphoreValidatorUserOpData(
Expand Down
46 changes: 42 additions & 4 deletions test/utils/TestUtils.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import { ISemaphore } from "../../src/utils/Semaphore.sol";
// import { console } from "forge-std/console.sol";
import { LibString } from "solady/Milady.sol";

// https://github.com/foundry-rs/forge-std/blob/master/src/Base.sol#L9
address constant VM_ADDRESS = 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D;
Vm constant vm = Vm(VM_ADDRESS);

struct ValidationData {
address aggregator;
uint48 validAfter;
Expand Down Expand Up @@ -49,13 +53,47 @@ function getTestUserOpCallData(
callData = bytes.concat(new bytes(100), bytes20(targetAddr), bytes32(value), txCallData);
}

function getGroupRmMerkleProof(
uint256[] memory members,
uint256 removal
)
returns (uint256[] memory merkleProof, uint256 root)
{
string[] memory cmd = new string[](5);
cmd[0] = "pnpm";
cmd[1] = "semaphore-group";
cmd[2] = "remove-member";
cmd[3] = _join(members);
cmd[4] = LibString.toString(removal);

bytes memory outBytes = vm.ffi(cmd);
string memory outStr = string(outBytes);
string[] memory retStr = LibString.split(outStr, " ");

merkleProof = _splitToUint(retStr[0]);
root = vm.parseUint(retStr[1]);
}

function _splitToUint(string memory str) pure returns (uint256[] memory retArr) {
string[] memory arr = LibString.split(str, ",");
retArr = new uint256[](arr.length);
for (uint256 i = 0; i < arr.length; i++) {
retArr[i] = vm.parseUint(arr[i]);
}
}

function _join(uint256[] memory members) pure returns (string memory retStr) {
for (uint256 i = 0; i < members.length; i++) {
retStr = string.concat(retStr, LibString.toString(members[i]));
if (i < members.length - 1) {
retStr = string.concat(retStr, ",");
}
}
}

type Identity is bytes32;

library IdentityLib {
// https://github.com/foundry-rs/forge-std/blob/master/src/Base.sol#L9
address internal constant VM_ADDRESS = 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D;
Vm internal constant vm = Vm(VM_ADDRESS);

function genIdentity(uint256 seed) public view returns (Identity) {
return Identity.wrap(keccak256(abi.encodePacked(seed, address(this))));
}
Expand Down

0 comments on commit 13e6672

Please sign in to comment.