Skip to content

Commit 13e6672

Browse files
authored
feat(wip): completed removeMember() and test case (#18)
* updated * added removeMember test
1 parent 5e6ae75 commit 13e6672

File tree

5 files changed

+118
-63
lines changed

5 files changed

+118
-63
lines changed

package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
"devDependencies": {
1919
"@rhinestone/modulekit": "~0.5.4",
2020
"@semaphore-protocol/contracts": "github:jimmychu0807/semaphore#identity-cli&path:/packages/contracts/contracts",
21-
"@semaphore-protocol/core": "github:jimmychu0807/semaphore#identity-cli&path:/packages/core",
2221
"@semaphore-protocol/identity": "github:jimmychu0807/semaphore#identity-cli&path:/packages/identity",
2322
"@semaphore-protocol/proof": "github:jimmychu0807/semaphore#identity-cli&path:/packages/proof",
23+
"@semaphore-protocol/group": "github:jimmychu0807/semaphore#identity-cli&path:/packages/group",
2424
"poseidon-solidity": "github:chancehudson/poseidon-solidity#main",
2525
"rimraf": "^5.0.5",
2626
"solady": "^0.0.287"
@@ -47,7 +47,7 @@
4747
"prepack": "pnpm install && bash ./shell/prepare-artifacts.sh",
4848
"prettier:check": "prettier --no-error-on-unmatched-pattern -c \"{src,test,script}/**/*.{json,md,svg,yml}\"",
4949
"prettier:write": "prettier --no-error-on-unmatched-pattern -w \"{src,test,script}/**/*.{json,md,svg,yml}\"",
50-
"test": "COMPLIANCE=true forge test --ffi",
50+
"test": "forge test --ffi",
5151
"test:lite": "FOUNDRY_PROFILE=lite forge test",
5252
"test:optimized": "pnpm run build:optimized && FOUNDRY_PROFILE=test-optimized forge test"
5353
},

pnpm-lock.yaml

Lines changed: 20 additions & 41 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/SemaphoreMSAValidator.sol

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
6868
ISemaphoreGroups public groups;
6969
mapping(address account => uint256 groupId) public groupMapping;
7070
mapping(address account => uint8 threshold) public thresholds;
71+
mapping(address account => uint8 count) public memberCount;
7172

7273
// smart account -> hash(call(params)) -> valid proof count
7374
mapping(address account => mapping(bytes32 txHash => ExtCallCount callDataCount)) public
@@ -133,6 +134,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
133134

134135
// Add members to the group
135136
semaphore.addMembers(groupId, cmts);
137+
memberCount[account] = uint8(cmts.length);
136138

137139
emit ModuleInitialized(account);
138140
}
@@ -143,6 +145,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
143145
delete thresholds[account];
144146
delete groupMapping[account];
145147
delete acctSeqNum[account];
148+
delete memberCount[account];
146149

147150
//TODO: what is a good way to delete entries associated with `acctTxCount[account]`,
148151
// The following line will make the compiler fail.
@@ -151,15 +154,9 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
151154
emit ModuleUninitialized(account);
152155
}
153156

154-
function memberCount(address account) public view returns (uint8 cnt) {
155-
// account doesn't belong to a semaphore group. We return 0
156-
if (thresholds[account] == 0) return 0;
157-
cnt = uint8(groups.getMerkleTreeSize(groupMapping[account]));
158-
}
159-
160157
function setThreshold(uint8 newThreshold) external moduleInstalled {
161158
address account = msg.sender;
162-
if (newThreshold == 0 || newThreshold > memberCount(account)) {
159+
if (newThreshold == 0 || newThreshold > memberCount[account]) {
163160
revert InvalidThreshold(account);
164161
}
165162

@@ -171,14 +168,16 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
171168
address account = msg.sender;
172169
uint256 groupId = groupMapping[account];
173170

174-
if (memberCount(account) + cmts.length > MAX_MEMBERS) revert MaxMemberReached(account);
171+
if (memberCount[account] + cmts.length > MAX_MEMBERS) revert MaxMemberReached(account);
175172

176173
for (uint256 i = 0; i < cmts.length; ++i) {
177174
if (cmts[i] == uint256(0)) revert InvalidCommitment(account);
178175
if (groups.hasMember(groupId, cmts[i])) revert IsMemberAlready(account, cmts[i]);
179176
}
180177

181178
semaphore.addMembers(groupId, cmts);
179+
memberCount[account] += uint8(cmts.length);
180+
182181
emit AddedMembers(account, cmts.length);
183182
}
184183

@@ -191,12 +190,13 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
191190
{
192191
address account = msg.sender;
193192

194-
if (memberCount(account) == thresholds[account]) revert MemberCntReachesThreshold(account);
193+
if (memberCount[account] == thresholds[account]) revert MemberCntReachesThreshold(account);
195194

196195
uint256 groupId = groupMapping[account];
197196
if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt);
198197

199198
semaphore.removeMember(groupId, cmt, merkleProofSiblings);
199+
memberCount[account] -= 1;
200200

201201
emit RemovedMember(account, cmt);
202202
}
@@ -347,7 +347,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
347347
uint256 cmt = Identity.getCommitment(pubKey);
348348
if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt);
349349

350-
// We don't allow call to other contract.
350+
// We don't allow call to other contracts.
351351
address targetAddr = address(bytes20(userOp.callData[100:120]));
352352
if (targetAddr != address(this)) revert NonValidatorCallBanned(targetAddr, address(this));
353353

@@ -356,10 +356,9 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
356356
bytes memory valAndCallData = userOp.callData[120:];
357357
bytes4 funcSel = bytes4(LibBytes.slice(valAndCallData, 32, 36));
358358

359-
// Allow only these few types on function calls to pass, and reject all other on-chain
360-
// calls. They must be executed via `executeTx()` function.
359+
// We only allow calls to `initiateTx()`, `signTx()`, and `executeTx()` to pass,
360+
// and reject the rest.
361361
if (_isAllowedSelector(funcSel)) return VALIDATION_SUCCESS;
362-
363362
revert NonAllowedSelector(account, funcSel);
364363
}
365364

test/SemaphoreMSAValidator.t.sol

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import { SemaphoreMSAValidator, ERC7579ValidatorBase } from "../src/SemaphoreMSA
3131
import {
3232
getEmptyUserOperation,
3333
getEmptySemaphoreProof,
34+
getGroupRmMerkleProof,
3435
getTestUserOpCallData,
3536
Identity,
3637
IdentityLib
@@ -236,8 +237,14 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
236237
uint256[] memory newMembers = new uint256[](1);
237238
newMembers[0] = newCommitment;
238239

239-
vm.prank(smartAcct.account);
240+
// Test: addMembers() is successfully executed
241+
vm.startPrank(smartAcct.account);
242+
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
243+
emit SemaphoreMSAValidator.AddedMembers(smartAcct.account, uint256(1));
240244
semaphoreValidator.addMembers(newMembers);
245+
vm.stopPrank();
246+
247+
assertEq(semaphoreValidator.memberCount(smartAcct.account), 2);
241248

242249
// Test: the userOp should pass
243250
uint256 validationData = ERC7579ValidatorBase.ValidationData.unwrap(
@@ -246,8 +253,40 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
246253
assertEq(validationData, VALIDATION_SUCCESS);
247254
}
248255

249-
function test_removeMember() public setupSmartAcctWithMembersThreshold(2, 1) {
250-
revert("to be implemented");
256+
function test_removeMember() public setupSmartAcctWithMembersThreshold(MEMBER_NUM, 1) {
257+
uint256[] memory cmts = _getMemberCmts(MEMBER_NUM);
258+
User storage rmUser = $users[0];
259+
uint256 rmCmt = rmUser.identity.commitment();
260+
261+
(uint256[] memory merkleProof,) = getGroupRmMerkleProof(cmts, rmCmt);
262+
263+
// Test: remove member
264+
vm.startPrank(smartAcct.account);
265+
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
266+
emit SemaphoreMSAValidator.RemovedMember(smartAcct.account, rmCmt);
267+
semaphoreValidator.removeMember(rmCmt, merkleProof);
268+
vm.stopPrank();
269+
270+
assertEq(semaphoreValidator.memberCount(smartAcct.account), MEMBER_NUM - 1);
271+
272+
// Compose a UserOp
273+
PackedUserOperation memory userOp = getEmptyUserOperation();
274+
userOp.sender = smartAcct.account;
275+
userOp.callData = getTestUserOpCallData(
276+
0,
277+
address(semaphoreValidator),
278+
abi.encodeWithSelector(SemaphoreMSAValidator.initiateTx.selector)
279+
);
280+
bytes32 userOpHash = bytes32(keccak256("userOpHash"));
281+
userOp.signature = rmUser.identity.signHash(userOpHash);
282+
283+
// Test: the userOp should fail and revert
284+
vm.expectRevert(
285+
abi.encodeWithSelector(
286+
SemaphoreMSAValidator.MemberNotExists.selector, smartAcct.account, rmCmt
287+
)
288+
);
289+
semaphoreValidator.validateUserOp(userOp, userOpHash);
251290
}
252291

253292
function _getSemaphoreValidatorUserOpData(

test/utils/TestUtils.sol

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ import { ISemaphore } from "../../src/utils/Semaphore.sol";
77
// import { console } from "forge-std/console.sol";
88
import { LibString } from "solady/Milady.sol";
99

10+
// https://github.com/foundry-rs/forge-std/blob/master/src/Base.sol#L9
11+
address constant VM_ADDRESS = 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D;
12+
Vm constant vm = Vm(VM_ADDRESS);
13+
1014
struct ValidationData {
1115
address aggregator;
1216
uint48 validAfter;
@@ -49,13 +53,47 @@ function getTestUserOpCallData(
4953
callData = bytes.concat(new bytes(100), bytes20(targetAddr), bytes32(value), txCallData);
5054
}
5155

56+
function getGroupRmMerkleProof(
57+
uint256[] memory members,
58+
uint256 removal
59+
)
60+
returns (uint256[] memory merkleProof, uint256 root)
61+
{
62+
string[] memory cmd = new string[](5);
63+
cmd[0] = "pnpm";
64+
cmd[1] = "semaphore-group";
65+
cmd[2] = "remove-member";
66+
cmd[3] = _join(members);
67+
cmd[4] = LibString.toString(removal);
68+
69+
bytes memory outBytes = vm.ffi(cmd);
70+
string memory outStr = string(outBytes);
71+
string[] memory retStr = LibString.split(outStr, " ");
72+
73+
merkleProof = _splitToUint(retStr[0]);
74+
root = vm.parseUint(retStr[1]);
75+
}
76+
77+
function _splitToUint(string memory str) pure returns (uint256[] memory retArr) {
78+
string[] memory arr = LibString.split(str, ",");
79+
retArr = new uint256[](arr.length);
80+
for (uint256 i = 0; i < arr.length; i++) {
81+
retArr[i] = vm.parseUint(arr[i]);
82+
}
83+
}
84+
85+
function _join(uint256[] memory members) pure returns (string memory retStr) {
86+
for (uint256 i = 0; i < members.length; i++) {
87+
retStr = string.concat(retStr, LibString.toString(members[i]));
88+
if (i < members.length - 1) {
89+
retStr = string.concat(retStr, ",");
90+
}
91+
}
92+
}
93+
5294
type Identity is bytes32;
5395

5496
library IdentityLib {
55-
// https://github.com/foundry-rs/forge-std/blob/master/src/Base.sol#L9
56-
address internal constant VM_ADDRESS = 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D;
57-
Vm internal constant vm = Vm(VM_ADDRESS);
58-
5997
function genIdentity(uint256 seed) public view returns (Identity) {
6098
return Identity.wrap(keccak256(abi.encodePacked(seed, address(this))));
6199
}

0 commit comments

Comments
 (0)