Skip to content

Commit

Permalink
feat: make the validator ERC-7780 & ERC-1271 compatible (#20)
Browse files Browse the repository at this point in the history
* added gas limit in test and remove gas meter tinkering in verifySignature

* Removed LibBytes lib inclusion

* updated

* updated

* Some fix on solhint
  • Loading branch information
jimmychu0807 authored Jan 16, 2025
1 parent daecd5c commit 4842f2a
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 83 deletions.
146 changes: 97 additions & 49 deletions src/SemaphoreMSAValidator.sol
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
// SPDX-License-Identifier: MIT
pragma solidity >=0.8.23 <=0.8.29;

// Rhinestone module-kit
import { ERC7579ValidatorBase } from "modulekit/Modules.sol";
import { VALIDATION_SUCCESS } from "modulekit/accounts/common/interfaces/IERC7579Module.sol";
import { IStatelessValidator } from "modulekit/module-bases/interfaces/IStatelessValidator.sol";
import { PackedUserOperation } from "modulekit/external/ERC4337.sol";
import { LibSort, LibBytes } from "solady/Milady.sol";

import { LibSort } from "solady/Milady.sol";

import { ISemaphore, ISemaphoreGroups } from "./utils/Semaphore.sol";
import { ValidatorLibBytes } from "./utils/ValidatorLibBytes.sol";
import { Identity } from "./utils/Identity.sol";
// import { console } from "forge-std/console.sol";

contract SemaphoreMSAValidator is ERC7579ValidatorBase {
contract SemaphoreMSAValidator is ERC7579ValidatorBase, IStatelessValidator {
using LibSort for *;
using ValidatorLibBytes for bytes;

// Constants
uint8 public constant MAX_MEMBERS = 32;
uint8 internal constant CMT_BYTELEN = 32;
uint8 public constant CMT_BYTELEN = 32;

// Ensure the following match with the 3 function calls.
bytes4[3] internal ALLOWED_SELECTORS =
bytes4[3] public ALLOWED_SELECTORS =
[this.initiateTx.selector, this.signTx.selector, this.executeTx.selector];

struct ExtCallCount {
Expand All @@ -45,7 +47,6 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
error InvalidSignatureLen(address account, uint256 len);
error InvalidSignature(address account, bytes signature);
error InvalidSemaphoreProof(bytes reason);
error NonAllowedSelector(address account, bytes4 funcSel);
error NonValidatorCallBanned(address targetAddr, address selfAddr);
error InitiateTxWithNullAddress(address account);
error InitiateTxWithNullCallDataAndNullValue(address account, address targetAddr);
Expand Down Expand Up @@ -230,7 +231,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
// 1. targetAddr cannot be 0
// 2. if txCallData is blank, then msg.value must be > 0, else revert
if (targetAddr == address(0)) revert InitiateTxWithNullAddress(account);
if (LibBytes.cmp(txCallData, "") == 0 && msg.value == 0) {
if (txCallData.length == 0 && msg.value == 0) {
revert InitiateTxWithNullCallDataAndNullValue(account, targetAddr);
}

Expand Down Expand Up @@ -322,80 +323,127 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
bytes32 userOpHash
)
external
// view
virtual
override
returns (ValidationData)
{
// you want to exclude initiateTx, signTx, executeTx from needing tx count.
// you just need to ensure they are a valid proof from the semaphore group members
address account = userOp.sender;
uint256 groupId = groupMapping[account];

// The userOp.signature is 160 bytes containing:
// (uint256 pubX (32 bytes), uint256 pubY (32 bytes), bytes[96] signature (96 bytes))
if (userOp.signature.length != 160) {
revert InvalidSignatureLen(account, userOp.signature.length);
bytes calldata targetCallData = userOp.callData[100:];
if (_validateSignatureWithConfig(account, userOpHash, userOp.signature, targetCallData)) {
return VALIDATION_SUCCESS;
}

// Verify signature using the public key
if (!Identity.verifySignature(userOpHash, userOp.signature)) {
revert InvalidSignature(account, userOp.signature);
}

// Verify if the identity commitment is one of the semaphore group members
bytes memory pubKey = LibBytes.slice(userOp.signature, 0, 66);
uint256 cmt = Identity.getCommitment(pubKey);
if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt);

// We don't allow call to other contracts.
address targetAddr = address(bytes20(userOp.callData[100:120]));
if (targetAddr != address(this)) revert NonValidatorCallBanned(targetAddr, address(this));

// For callData, the first 120 bytes are reserved by ERC-7579 use. Then 32 bytes of value,
// then the remaining as the callData passed in getExecOps
bytes memory valAndCallData = userOp.callData[120:];
bytes4 funcSel = bytes4(LibBytes.slice(valAndCallData, 32, 36));

// We only allow calls to `initiateTx()`, `signTx()`, and `executeTx()` to pass,
// and reject the rest.
if (_isAllowedSelector(funcSel)) return VALIDATION_SUCCESS;
revert NonAllowedSelector(account, funcSel);
return VALIDATION_FAILED;
}

/**
* Validates an ERC-1271 signature with the sender
*
* @param hash bytes32 hash of the data
* @param data bytes data containing the signatures, and target calldata
*
* @return bytes4 EIP1271_SUCCESS if the signature is valid, EIP1271_FAILED otherwise
*/
function isValidSignatureWithSender(
address sender,
bytes32 hash,
bytes calldata signature
bytes calldata data
)
external
view
virtual
override
returns (bytes4 sugValidationResult)
returns (bytes4)
{
return EIP1271_SUCCESS;
bytes calldata signature = data[0:160];
bytes calldata targetCallData = data[160:];
if (_validateSignatureWithConfig(sender, hash, signature, targetCallData)) {
return EIP1271_SUCCESS;
}
return EIP1271_FAILED;
}

/**
* Validates a signature given some data
* For [ERC-7780](https://eips.ethereum.org/EIPS/eip-7780) Stateless Validator
*
* @param hash The data that was signed over
* @param signature The signature to verify
* @param data The data to validate the verified signature agains
*
* MUST validate that the signature is a valid signature of the hash
* MUST compare the validated signature against the data provided
* MUST return true if the signature is valid and false otherwise
*/
function validateSignatureWithData(
bytes32,
bytes calldata,
bytes calldata
bytes32 hash,
bytes calldata signature,
bytes calldata data
)
external
view
virtual
returns (bool validSig)
returns (bool)
{
return true;
address account = address(bytes20(data[0:20]));
bytes calldata targetCallData = data[20:];
return _validateSignatureWithConfig(account, hash, signature, targetCallData);
}

/*//////////////////////////////////////////////////////////////////////////
INTERNAL FUNCTIONS
//////////////////////////////////////////////////////////////////////////*/

function _isAllowedSelector(bytes4 sel) internal view returns (bool allowed) {
for (uint256 i = 0; i < ALLOWED_SELECTORS.length; ++i) {
if (sel == ALLOWED_SELECTORS[i]) return true;
}
return false;
}

function _validateSignatureWithConfig(
address account,
bytes32 hash,
bytes calldata signature,
bytes calldata targetCallData
)
internal
view
returns (bool)
{
// you want to exclude initiateTx, signTx, executeTx from needing tx count.
// you just need to ensure they are a valid proof from the semaphore group members
uint256 groupId = groupMapping[account];

// The userOp.signature is 160 bytes containing:
// (uint256 pubX (32 bytes), uint256 pubY (32 bytes), bytes[96] signature (96 bytes))
if (signature.length != 160) {
revert InvalidSignatureLen(account, signature.length);
}

// Verify signature using the public key
if (!Identity.verifySignature(hash, signature)) {
revert InvalidSignature(account, signature);
}

// Verify if the identity commitment is one of the semaphore group members
bytes memory pubKey = signature[0:64];
uint256 cmt = Identity.getCommitment(pubKey);
if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt);

// We don't allow call to other contracts.
address targetAddr = address(bytes20(targetCallData[0:20]));
if (targetAddr != address(this)) revert NonValidatorCallBanned(targetAddr, address(this));

// For callData, the first 120 bytes are reserved by ERC-7579 use. Then 32 bytes of value,
// then the remaining as the callData passed in getExecOps
bytes calldata valAndCallData = targetCallData[20:];
bytes4 funcSel = bytes4(valAndCallData[32:36]);

// We only allow calls to `initiateTx()`, `signTx()`, and `executeTx()` to pass,
// and reject the rest.
return _isAllowedSelector(funcSel);
}

/*//////////////////////////////////////////////////////////////////////////
METADATA
//////////////////////////////////////////////////////////////////////////*/
Expand Down Expand Up @@ -426,6 +474,6 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
* @return true if the module is of the given type, false otherwise
*/
function isModuleType(uint256 typeID) external pure override returns (bool) {
return typeID == TYPE_VALIDATOR;
return typeID == TYPE_VALIDATOR || typeID == TYPE_STATELESS_VALIDATOR;
}
}
3 changes: 2 additions & 1 deletion src/utils/CurveBabyJubJub.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pragma solidity >=0.8.23 <=0.8.29;

// ref: https://github.com/yondonfu/sol-baby-jubjub
// with: https://github.com/yondonfu/sol-baby-jubjub/pull/1
// with PR#1: https://github.com/yondonfu/sol-baby-jubjub/pull/1

library CurveBabyJubJub {
// Curve parameters
Expand Down Expand Up @@ -138,6 +138,7 @@ library CurveBabyJubJub {
* @dev Helper function to call the bigModExp precompile
*/
function expmod(uint256 _b, uint256 _e, uint256 _m) internal view returns (uint256 o) {
// solhint-disable-next-line no-inline-assembly
assembly {
let memPtr := mload(0x40)
mstore(memPtr, 0x20) // Length of base _b
Expand Down
48 changes: 21 additions & 27 deletions src/utils/Identity.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ pragma solidity >=0.8.23 <=0.8.29;

import { PoseidonT3 } from "poseidon-solidity/PoseidonT3.sol";
import { PoseidonT6 } from "poseidon-solidity/PoseidonT6.sol";
import { Vm } from "forge-std/Vm.sol";
// import { console } from "forge-std/console.sol";
import { LibString } from "solady/Milady.sol";
import { CurveBabyJubJub } from "./CurveBabyJubJub.sol";

Vm constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code")))));
// import { LibString } from "solady/Milady.sol";
// import { Vm } from "forge-std/Vm.sol";
// import { console } from "forge-std/console.sol";

library Identity {
uint256 internal constant base8x = CurveBabyJubJub.Base8x;
Expand All @@ -19,48 +17,44 @@ library Identity {
cmt = PoseidonT3.hash([pkX, pkY]);
}

function verifySignatureFFI(bytes32 message, bytes memory signature) public returns (bool) {
(uint256 pkX, uint256 pkY, uint256 s0, uint256 s1, uint256 s2) =
abi.decode(signature, (uint256, uint256, uint256, uint256, uint256));
// function verifySignatureFFI(bytes32 message, bytes memory signature) public returns (bool) {
// (uint256 pkX, uint256 pkY, uint256 s0, uint256 s1, uint256 s2) =
// abi.decode(signature, (uint256, uint256, uint256, uint256, uint256));

string[] memory inputs = new string[](6);
inputs[0] = "pnpm";
inputs[1] = "semaphore-identity";
inputs[2] = "verify";
inputs[3] = vm.toString(abi.encodePacked(pkX, pkY));
inputs[4] = vm.toString(message);
inputs[5] = vm.toString(abi.encodePacked(s0, s1, s2));
// Vm constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code")))));

bytes memory res = vm.ffi(inputs);
string memory resStr = string(res);
return LibString.eq(resStr, "true");
}
// string[] memory inputs = new string[](6);
// inputs[0] = "pnpm";
// inputs[1] = "semaphore-identity";
// inputs[2] = "verify";
// inputs[3] = vm.toString(abi.encodePacked(pkX, pkY));
// inputs[4] = vm.toString(message);
// inputs[5] = vm.toString(abi.encodePacked(s0, s1, s2));

// bytes memory res = vm.ffi(inputs);
// string memory resStr = string(res);
// return LibString.eq(resStr, "true");
// }

function verifySignature(bytes32 message, bytes memory signature) public returns (bool) {
function verifySignature(bytes32 message, bytes memory signature) public view returns (bool) {
// Implement eddsa-poseidon verifySignature() method in solidity.
// https://github.com/privacy-scaling-explorations/zk-kit/blob/388f72b7a029a14bf5c20861d5f54bdaa98b3ac7/packages/eddsa-poseidon/src/eddsa-poseidon-factory.ts#L127-L158

Check warning on line 41 in src/utils/Identity.sol

View workflow job for this annotation

GitHub Actions / lint

Line length must be no more than 120 but current length is 175
(uint256 pkX, uint256 pkY, uint256 s0, uint256 s1, uint256 s2) =
abi.decode(signature, (uint256, uint256, uint256, uint256, uint256));

uint256 hm = PoseidonT6.hash([s0, s1, pkX, pkY, uint256(message)]);

// TODO: remove this after you can increase gas limit in getExecOps()
vm.pauseGasMetering();

(uint256 pLeftx, uint256 pLefty) = CurveBabyJubJub.pointMul(base8x, base8y, s2);

// This is suppose to be: CurveBabyJubJub.pointMul(pkX, pkY, mulmod(8, hm, FM)),
// but I'm not sure the field modulus to use. No, not `CurveBabyJubJub.Q`.
// but I'm not sure the field modulo to use. No, not `CurveBabyJubJub.Q`.
(uint256 pRightx, uint256 pRighty) = CurveBabyJubJub.pointMul(pkX, pkY, hm);
(pRightx, pRighty) = CurveBabyJubJub.pointAdd(pRightx, pRighty, pRightx, pRighty);
(pRightx, pRighty) = CurveBabyJubJub.pointAdd(pRightx, pRighty, pRightx, pRighty);
(pRightx, pRighty) = CurveBabyJubJub.pointAdd(pRightx, pRighty, pRightx, pRighty);

(uint256 pSumx, uint256 pSumy) = CurveBabyJubJub.pointAdd(s0, s1, pRightx, pRighty);

// TODO: remove this after you can increase gas limit in getExecOps()
vm.resumeGasMetering();

return (pLeftx == pSumx && pLefty == pSumy);
}
}
3 changes: 2 additions & 1 deletion test/Identity.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ contract IdentityTest is Test {
assertEq(true, IdentityT.verifySignature(hash, signature));
}

function test_verifySignatureAcceptCorrectSignature2() public {
function test_verifySignatureAcceptCorrectSignature2() public view {
bytes32 hash = hex"00b917632b69261f21d20e0cabdf9f3fa1255c6e500021997a16cf3a46d80297";
bytes memory signature =
// solhint-disable-next-line max-line-length
hex"26c3a847609100b3fd926d3c0a61324a32479d5989f01383aca537869cb23a851d67a417abb29f71e1f7c3d0bcd93cb68f89203b046174f03c3822a9139b512611b5289e52e9f70ff4a30cb9a19d66de49266887d3d17ed35f2dfc30f44573dc0c44756c4e4c5a5e5eeacc68f39b4e2238041e70ca926139ea039e260ea7ca5000b8d0dfc37fc5de7b0f80b722f8966a43caa10c8068cf863e5d06f82ae7c9d8";

assertEq(true, IdentityT.verifySignature(hash, signature));
Expand Down
12 changes: 7 additions & 5 deletions test/SemaphoreMSAValidator.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,11 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
txValidator: address(semaphoreValidator)
});

// TODO: We need to increase the accountGasLimits, default 2e6 is not enough to verify
// signature, for all those elliptic curve computation.
// userOpData.userOp.accountGasLimits = bytes32(uint256(2e7));
// userOpData.userOpHash = smartAcct.aux.entrypoint.getUserOpHash(userOpData.userOp);

// We need to increase the accountGasLimits, default 2e6 is not enough to verify
// signature, for all those elliptic curve computation.
// Encoding two fields here, validation and execution gas
userOpData.userOp.accountGasLimits = bytes32(abi.encodePacked(uint128(2e7), uint128(2e7)));
userOpData.userOpHash = smartAcct.aux.entrypoint.getUserOpHash(userOpData.userOp);
userOpData.userOp.signature = id.signHash(userOpData.userOpHash);
}

Expand Down Expand Up @@ -486,6 +486,8 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
callData: abi.encodeCall(SimpleContract.setVal, (testVal)),
txValidator: address(semaphoreValidator)
});
userOpData.userOp.accountGasLimits = bytes32(abi.encodePacked(uint128(2e7), uint128(2e7)));
userOpData.userOpHash = smartAcct.aux.entrypoint.getUserOpHash(userOpData.userOp);
userOpData.userOp.signature = member.identity.signHash(userOpData.userOpHash);

smartAcct.expect4337Revert(SemaphoreMSAValidator.NonValidatorCallBanned.selector);
Expand Down

0 comments on commit 4842f2a

Please sign in to comment.