diff --git a/foundry.toml b/foundry.toml index 52ab18c..0cb6ba9 100644 --- a/foundry.toml +++ b/foundry.toml @@ -5,7 +5,7 @@ libs = ["lib", "dependencies"] fs_permissions = [{ access = "read", path = "out-optimized" }] allow_paths = ["*", "/"] optimizer = true -optimizer_runs = 20_000 +optimizer_runs = 2_000 via_ir = true [fmt] diff --git a/script/Deploy.s.sol b/script/Deploy.s.sol new file mode 100644 index 0000000..436e7da --- /dev/null +++ b/script/Deploy.s.sol @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import { TransparentUpgradeableProxy } from "@openzeppelin/contracts/proxy/transparent/TransparentUpgradeableProxy.sol"; +import { UpgradeableBeacon } from "@openzeppelin/contracts/proxy/beacon/UpgradeableBeacon.sol"; +import { Script } from "forge-std/Script.sol"; + +import { MSAFactory } from "src/MSAFactory.sol"; +import { EOAKeyValidator } from "src/modules/EOAKeyValidator.sol"; +import { SessionKeyValidator } from "src/modules/SessionKeyValidator.sol"; +import { WebAuthnValidator } from "src/modules/WebAuthnValidator.sol"; +import { ModularSmartAccount } from "src/ModularSmartAccount.sol"; + +contract Deploy is Script { + function run() public { + // TODO: use correct owner address. + address owner = msg.sender; + + address[] memory defaultModules = new address[](3); + defaultModules[0] = address(new TransparentUpgradeableProxy(address(new EOAKeyValidator()), owner, "")); + defaultModules[1] = address(new TransparentUpgradeableProxy(address(new SessionKeyValidator()), owner, "")); + defaultModules[2] = address(new TransparentUpgradeableProxy(address(new WebAuthnValidator()), owner, "")); + + address accountImpl = address(new ModularSmartAccount()); + address beacon = address(new UpgradeableBeacon(accountImpl, owner)); + address factory = address(new TransparentUpgradeableProxy(address(new MSAFactory(beacon)), owner, "")); + } +} diff --git a/src/MSAFactory.sol b/src/MSAFactory.sol new file mode 100644 index 0000000..ca3fb54 --- /dev/null +++ b/src/MSAFactory.sol @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import { BeaconProxy } from "@openzeppelin/contracts/proxy/beacon/BeaconProxy.sol"; + +import { IMSA } from "./interfaces/IMSA.sol"; + +/// @title MSAFactory +/// @author Matter Labs +/// @custom:security-contact security@matterlabs.dev +/// @dev This contract is used to deploy SSO accounts as beacon proxies. +contract MSAFactory { + /// @dev The address of the beacon contract used for the accounts' beacon proxies. + address public immutable beacon; + + /// @notice A mapping from unique account IDs to their corresponding deployed account addresses. + /// TODO: add versioning for upgradeability + mapping(bytes32 accountId => address deployedAccount) public accountRegistry; + + /// TODO: have this contract be a module registry too? + // address[] public moduleRegistry; + + /// @notice Emitted when a new account is successfully created. + /// @param accountAddress The address of the newly created account. + /// @param accountId A unique identifier for the account. + event AccountCreated(address indexed accountAddress, bytes32 accountId); + + error AccountAlreadyExists(bytes32 accountId); + + constructor(address _beacon) { + beacon = _beacon; + } + + function deployAccount(bytes32 accountId, bytes calldata initData) external returns (address account) { + require(accountRegistry[accountId] == address(0), AccountAlreadyExists(accountId)); + + accountRegistry[accountId] = address(account); + account = address(new BeaconProxy{ salt: accountId }(beacon, initData)); + + emit AccountCreated(account, accountId); + } +} diff --git a/src/ModularSmartAccount.sol b/src/ModularSmartAccount.sol index b2a703a..86b90b2 100644 --- a/src/ModularSmartAccount.sol +++ b/src/ModularSmartAccount.sol @@ -1,18 +1,16 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.23; +import { PackedUserOperation } from "account-abstraction/interfaces/PackedUserOperation.sol"; +import { Initializable } from "@openzeppelin/contracts/proxy/utils/Initializable.sol"; +import { ERC1271 } from "solady/accounts/ERC1271.sol"; + import { ExecutionLib } from "./libraries/ExecutionLib.sol"; import { ExecutionHelper } from "./core/ExecutionHelper.sol"; -import { PackedUserOperation } from "account-abstraction/interfaces/PackedUserOperation.sol"; import { IERC7579Account, Execution } from "./interfaces/IERC7579Account.sol"; import { IMSA } from "./interfaces/IMSA.sol"; -import { ModuleManager } from "./core/ModuleManager.sol"; -// import { HookManager } from "./core/HookManager.sol"; +import { ERC1271Handler } from "./core/ERC1271Handler.sol"; import { RegistryAdapter } from "./core/RegistryAdapter.sol"; -import { ECDSA } from "solady/utils/ECDSA.sol"; -import { Initializable } from "./libraries/Initializable.sol"; -// import { ERC7779Adapter } from "./core/ERC7779Adapter.sol"; -// import { PreValidationHookManager } from "./core/PreValidationHookManager.sol"; import { IModule, @@ -37,8 +35,6 @@ import { CALLTYPE_DELEGATECALL, ModeLib } from "./libraries/ModeLib.sol"; -import { AccountBase } from "./core/AccountBase.sol"; -import { console } from "forge-std/console.sol"; /** * @author zeroknots.eth | rhinestone.wtf @@ -47,19 +43,13 @@ import { console } from "forge-std/console.sol"; * This account implements ExecType: DEFAULT and TRY. * Hook support is implemented */ -contract ModularSmartAccount is - IMSA, - ExecutionHelper, - ModuleManager, - // HookManager, - // PreValidationHookManager, - RegistryAdapter -{ - // ERC7779Adapter - +contract ModularSmartAccount is IMSA, ExecutionHelper, ERC1271Handler, RegistryAdapter, Initializable { using ExecutionLib for bytes; using ModeLib for ModeCode; - using ECDSA for bytes32; + + constructor() { + _disableInitializers(); + } /** * @inheritdoc IERC7579Account @@ -200,7 +190,6 @@ contract ModularSmartAccount is external payable onlyEntryPointOrSelf - // withHook withRegistry(module, moduleTypeId) { if (!IModule(module).isModuleType(moduleTypeId)) revert MismatchModuleTypeId(moduleTypeId); @@ -211,17 +200,7 @@ contract ModularSmartAccount is _installExecutor(module, initData); } else if (moduleTypeId == MODULE_TYPE_FALLBACK) { _installFallbackHandler(module, initData); - } - // TODO - // else if (moduleTypeId == MODULE_TYPE_HOOK) { - // _installHook(module, initData); - // } else if ( - // moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271 - // || moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337 - // ) { - // _installPreValidationHook(module, moduleTypeId, initData); - // } - else { + } else { revert UnsupportedModuleType(moduleTypeId); } emit ModuleInstalled(moduleTypeId, module); @@ -238,7 +217,6 @@ contract ModularSmartAccount is external payable onlyEntryPointOrSelf - // withHook { if (moduleTypeId == MODULE_TYPE_VALIDATOR) { _uninstallValidator(module, deInitData); @@ -246,17 +224,7 @@ contract ModularSmartAccount is _uninstallExecutor(module, deInitData); } else if (moduleTypeId == MODULE_TYPE_FALLBACK) { _uninstallFallbackHandler(module, deInitData); - } - // TODO - // else if (moduleTypeId == MODULE_TYPE_HOOK) { - // _uninstallHook(module, deInitData); - // } else if ( - // moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271 - // || moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337 - // ) { - // _uninstallPreValidationHook(module, moduleTypeId, deInitData); - // } - else { + } else { revert UnsupportedModuleType(moduleTypeId); } emit ModuleUninstalled(moduleTypeId, module); @@ -289,30 +257,21 @@ contract ModularSmartAccount is if (!_isValidatorInstalled(validator)) { return VALIDATION_FAILED; } else { - // TODO - // (userOpHash, userOp.signature) = _withPreValidationHook(userOpHash, userOp, missingAccountFunds); // bubble up the return value of the validator module validSignature = IValidator(validator).validateUserOp(userOp, userOpHash); } } - /** - * @dev ERC-1271 isValidSignature - * This function is intended to be used to validate a smart account signature - * and may forward the call to a validator module - * - * @param hash The hash of the data that is signed - * @param data The data that is signed - */ - function isValidSignature(bytes32 hash, bytes calldata data) external view virtual override returns (bytes4) { - address validator = address(bytes20(data[:20])); - if (!_isValidatorInstalled(validator)) { - revert InvalidModule(validator); - } - // TODO - // bytes memory signature_; - // (hash, signature_) = _withPreValidationHook(hash, data[20:]); - return IValidator(validator).isValidSignatureWithSender(msg.sender, hash, data[20:]); + function isValidSignature( + bytes32 hash, + bytes calldata data + ) + public + view + override(ERC1271, IERC7579Account) + returns (bytes4) + { + return super.isValidSignature(hash, data); } /** @@ -334,17 +293,7 @@ contract ModularSmartAccount is return _isExecutorInstalled(module); } else if (moduleTypeId == MODULE_TYPE_FALLBACK) { return _isFallbackHandlerInstalled(abi.decode(additionalContext, (bytes4)), module); - } - // TODO - // else if (moduleTypeId == MODULE_TYPE_HOOK) { - // return _isHookInstalled(module); - // } else if ( - // moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271 - // || moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337 - // ) { - // return _isPreValidationHookInstalled(module, moduleTypeId); - // } - else { + } else { return false; } } @@ -394,15 +343,17 @@ contract ModularSmartAccount is * @dev Initializes the account. Function might be called directly, or by a Factory * @param data. encoded data that can be used during the initialization phase */ - function initializeAccount(address entryPoint, address validator, bytes calldata data) public payable virtual { - // protect this function to only be callable when used with the proxy factory or when - // account calls itself - if (msg.sender != address(this)) { - Initializable.checkInitializable(); + function initializeAccount( + address[] calldata validators, + bytes[] calldata data + ) + external + payable + virtual + initializer + { + for (uint256 i = 0; i < validators.length; i++) { + _installValidator(address(validators[i]), data[i]); } - - ENTRY_POINT = entryPoint; - - _installValidator(address(validator), data); } } diff --git a/src/core/AccountBase.sol b/src/core/AccountBase.sol index c668992..2afcedc 100644 --- a/src/core/AccountBase.sol +++ b/src/core/AccountBase.sol @@ -8,12 +8,7 @@ pragma solidity ^0.8.21; contract AccountBase { error AccountAccessUnauthorized(); - // TODO: custom slot for this? - address public ENTRY_POINT; - - ///////////////////////////////////////////////////// - // Access Control - //////////////////////////////////////////////////// + address public constant ENTRY_POINT = 0x4337084D9E255Ff0702461CF8895CE9E3b5Ff108; modifier onlyEntryPointOrSelf() virtual { if (!(msg.sender == ENTRY_POINT || msg.sender == address(this))) { diff --git a/src/core/ERC1271Handler.sol b/src/core/ERC1271Handler.sol new file mode 100644 index 0000000..723620a --- /dev/null +++ b/src/core/ERC1271Handler.sol @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import { ERC1271 } from "solady/accounts/ERC1271.sol"; +import { IERC1271 } from "@openzeppelin/contracts/interfaces/IERC1271.sol"; +import { ModuleManager } from "./ModuleManager.sol"; +import { IValidator } from "../interfaces/IERC7579Module.sol"; + +/// @title ERC1271Handler +/// @author Matter Labs +/// @notice Contract which provides ERC1271 signature validation +/// @notice Uses ERC7739 for signature replay protection +abstract contract ERC1271Handler is ERC1271, ModuleManager { + /// @notice Returns the domain name and version for the EIP-712 signature. + /// @return name string - The name of the domain + /// @return version string - The version of the domain + function _domainNameAndVersion() internal pure override returns (string memory name, string memory version) { + return ("zksync-sso-1271", "1.0.0"); + } + + /// @notice Indicates whether or not the contract may cache the domain name and version. + /// @return bool - Whether the domain name and version may change. + function _domainNameAndVersionMayChange() internal pure override returns (bool) { + return true; + } + + // @notice Returns whether the signature provided is valid for the provided hash. + // @dev Does not run validation hooks. Is used internally after ERC7739 unwrapping. + // @param hash bytes32 - Hash of the data that is signed + // @param signature bytes calldata - K1 owner signature OR validator address concatenated to signature + // @return bool - Whether the signature is valid + function _erc1271IsValidSignatureNowCalldata( + bytes32 hash, + bytes calldata data + ) + internal + view + virtual + override + returns (bool) + { + address validator = address(bytes20(data[:20])); + if (!_isValidatorInstalled(validator)) { + revert InvalidModule(validator); + } + return IValidator(validator).isValidSignatureWithSender(msg.sender, hash, data[20:]) + == IERC1271.isValidSignature.selector; + } + + /// @notice This function is not used anywhere in the contract, but is required to be implemented. + function _erc1271Signer() internal pure override returns (address) { + revert(); + } + + /// @dev Returns whether the `msg.sender` is considered safe, such + /// that we don't need to use the nested EIP-712 workflow. + /// @return bool - currently, always returns false + function _erc1271CallerIsSafe() internal pure override returns (bool) { + return false; + } + + function domainSeparator() external view returns (bytes32) { + return _domainSeparator(); + } +} diff --git a/src/core/ModuleManager.sol b/src/core/ModuleManager.sol index 53182f9..a32f8a9 100644 --- a/src/core/ModuleManager.sol +++ b/src/core/ModuleManager.sol @@ -4,7 +4,6 @@ pragma solidity ^0.8.21; import { EnumerableSet } from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import { CallType, CALLTYPE_SINGLE, CALLTYPE_DELEGATECALL, CALLTYPE_STATIC } from "../libraries/ModeLib.sol"; import "../interfaces/IERC7579Module.sol"; -import "forge-std/interfaces/IERC165.sol"; /** * @title ModuleManager @@ -18,6 +17,9 @@ abstract contract ModuleManager { error InvalidModule(address module); error NoFallbackHandler(bytes4 selector); error CannotRemoveLastValidator(); + error SelectorAlreadyUsed(bytes4 selector); + error AlreadyInstalled(address module); + error NotInstalled(address module); event ValidatorUninstallFailed(address validator, bytes data); event ExecutorUninstallFailed(address executor, bytes data); @@ -64,13 +66,13 @@ abstract contract ModuleManager { // Manage Validators //////////////////////////////////////////////////// function _installValidator(address validator, bytes calldata data) internal virtual { - require($moduleManager().$valdiators.add(validator), "already installed"); + require($moduleManager().$valdiators.add(validator), AlreadyInstalled(validator)); IValidator(validator).onInstall(data); } function _uninstallValidator(address validator, bytes calldata data) internal { - // TODO: check if its the last validator. this might brick the account - require($moduleManager().$valdiators.remove(validator), "not installed"); + require($moduleManager().$valdiators.remove(validator), NotInstalled(validator)); + require($moduleManager().$valdiators.length() > 1, CannotRemoveLastValidator()); IValidator(validator).onUninstall(data); } // TODO: unlink validator @@ -88,12 +90,12 @@ abstract contract ModuleManager { //////////////////////////////////////////////////// function _installExecutor(address executor, bytes calldata data) internal { - require($moduleManager().$executors.add(executor), "already installed"); + require($moduleManager().$executors.add(executor), AlreadyInstalled(executor)); IExecutor(executor).onInstall(data); } function _uninstallExecutor(address executor, bytes calldata data) internal { - require($moduleManager().$executors.remove(executor), "not installed"); + require($moduleManager().$executors.remove(executor), NotInstalled(executor)); IExecutor(executor).onUninstall(data); } @@ -114,10 +116,7 @@ abstract contract ModuleManager { CallType calltype = CallType.wrap(bytes1(params[4])); bytes memory initData = params[5:]; - if (_isFallbackHandlerInstalled(selector)) { - // TODO: convert all errors to custom errors - revert("Function selector already used"); - } + require(!_isFallbackHandlerInstalled(selector), SelectorAlreadyUsed(selector)); $moduleManager().$fallbacks[selector] = FallbackHandler(handler, calltype); IFallback(handler).onInstall(initData); } @@ -126,18 +125,12 @@ abstract contract ModuleManager { bytes4 selector = bytes4(deInitData[0:4]); bytes memory _deInitData = deInitData[4:]; - if (!_isFallbackHandlerInstalled(selector)) { - revert("Function selector not used"); - } + require(_isFallbackHandlerInstalled(selector), NoFallbackHandler(selector)); FallbackHandler memory activeFallback = $moduleManager().$fallbacks[selector]; - if (activeFallback.handler != handler) { - revert("Function selector not used by this handler"); - } - + require(activeFallback.handler == handler, NotInstalled(handler)); $moduleManager().$fallbacks[selector] = FallbackHandler(address(0), CallType.wrap(0x00)); - IFallback(handler).onUninstall(_deInitData); } @@ -199,12 +192,12 @@ abstract contract ModuleManager { switch calltype case 0xFE { // CALLTYPE_STATIC - // Add 20 bytes for the address appended add the end + // Add 20 bytes for the address appended at the end success := staticcall(gas(), handler, calldataPtr, add(calldatasize(), 20), 0, 0) } case 0x00 { // CALLTYPE_SINGLE - // Add 20 bytes for the address appended add the end + // Add 20 bytes for the address appended at the end success := call(gas(), handler, 0, calldataPtr, add(calldatasize(), 20), 0, 0) } default { return(0, 0) } // Unsupported calltype diff --git a/src/interfaces/IMSA.sol b/src/interfaces/IMSA.sol index a4dca38..235586a 100644 --- a/src/interfaces/IMSA.sol +++ b/src/interfaces/IMSA.sol @@ -18,9 +18,6 @@ interface IMSA is IERC7579Account, IERC4337Account { // Error thrown when account installs/unistalls module with mismatched input `moduleTypeId` error MismatchModuleTypeId(uint256 moduleTypeId); - /** - * @dev Initializes the account. Function might be called directly, or by a Factory - * @param data. encoded data that can be used during the initialization phase - */ - function initializeAccount(address entryPoint, address validator, bytes calldata data) external payable; + /// @dev Initializes the account. Function might be called directly, or by a Factory + function initializeAccount(address[] calldata validators, bytes[] calldata data) external payable; } diff --git a/src/libraries/Initializable.sol b/src/libraries/Initializable.sol deleted file mode 100644 index 20f1b8f..0000000 --- a/src/libraries/Initializable.sol +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.23; - -bytes32 constant INIT_SLOT = - keccak256(abi.encode(uint256(keccak256("initializable.transient.msa")) - 1)) & ~bytes32(uint256(0xff)); - -library Initializable { - error NotInitializable(); - - function checkInitializable() internal view { - bytes32 slot = INIT_SLOT; - // Load the current value from the slot, revert if 0 - assembly { - let isInitializable := tload(slot) - if iszero(isInitializable) { - mstore(0x0, 0xaed59595) // NotInitializable() - revert(0x1c, 0x04) - } - } - } - - function setInitializable() internal { - bytes32 slot = INIT_SLOT; - assembly { - tstore(slot, 0x01) - } - } -} diff --git a/src/libraries/SessionLib.sol b/src/libraries/SessionLib.sol index 2c3f58e..fb4e1e5 100644 --- a/src/libraries/SessionLib.sol +++ b/src/libraries/SessionLib.sol @@ -6,9 +6,9 @@ import { UserOperationLib } from "account-abstraction/core/UserOperationLib.sol" import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import { LibBytes } from "solady/utils/LibBytes.sol"; +import { IERC7579Account } from "../interfaces/IERC7579Account.sol"; import { ExecutionLib } from "../libraries/ExecutionLib.sol"; import { CallType, ModeCode, ExecType, CALLTYPE_SINGLE, ModeLib } from "../libraries/ModeLib.sol"; -import { console } from "forge-std/console.sol"; /// @title Session Library /// @author Matter Labs @@ -24,6 +24,7 @@ library SessionLib { error ZeroSigner(); error InvalidSigner(address recovered, address expected); error InvalidCallType(CallType callType, CallType expected); + error InvalidTopLevelSelector(bytes4 selector, bytes4 expected); error SessionAlreadyExists(bytes32 sessionHash); error UnlimitedFees(); error SessionExpiresTooSoon(uint256 expiresAt); @@ -38,6 +39,8 @@ library SessionLib { error SignerAlreadyUsed(address signer); error CallPolicyBanned(address target, bytes4 selector); error SessionActionsNotAllowed(bytes32 sessionActionsHash); + error InvalidNonceKey(uint192 nonceKey, uint192 expectedNonceKey); + error ActionsNotAllowed(bytes32 actionsHash); /// @notice We do not permit opening multiple identical sessions (even after one is closed, /// e.g.). @@ -263,6 +266,11 @@ library SessionLib { } } + function shrinkRange(uint48[2] memory range, uint48 newAfter, uint48 newUntil) internal pure { + range[0] = newAfter > range[0] ? newAfter : range[0]; + range[1] = newUntil < range[1] ? newUntil : range[1]; + } + /// @notice Validates the transaction against the session spec and updates the usage trackers. /// @param state The session storage to update. /// @param userOp The user operation to validate. @@ -272,10 +280,7 @@ library SessionLib { /// otherwise (which will be ignored). /// periodIds[0] is for fee limit (not used in this function), /// periodIds[1] is for value limit, - /// peroidIds[2:2+n] are for `ERC20.approve()` constraints, where `n` is the number of - /// constraints in the `ERC20.approve()` policy - /// if an approval-based paymaster is used, 0 otherwise. - /// periodIds[2+n:] are for call constraints, if there are any. + /// periodIds[2:] are for call constraints, if there are any. /// It is required to pass them in (instead of computing via block.timestamp) since during /// validation /// we can only assert the range of the timestamp, but not access its value. @@ -286,20 +291,26 @@ library SessionLib { uint48[] memory periodIds ) internal - returns (uint48 validAfter, uint48 validUntil) + returns (uint48, uint48) { require(state.status[msg.sender] == Status.Active, SessionNotActive()); + bytes4 topLevelSelector = bytes4(userOp.callData[:4]); CallType callType = CallType.wrap(userOp.callData[4]); + require(callType == CALLTYPE_SINGLE, InvalidCallType(callType, CALLTYPE_SINGLE)); - // require topLevelSelector == IMSA.execute.selector TODO + require( + topLevelSelector == IERC7579Account.execute.selector, + InvalidTopLevelSelector(topLevelSelector, IERC7579Account.execute.selector) + ); + + // TODO: put a comment about why this exact slice uint256 length = uint256(bytes32(userOp.callData[68:100])); (address target, uint256 value, bytes calldata callData) = ExecutionLib.decodeSingle(userOp.callData[100:100 + length]); - // TODO - validAfter = 0; - validUntil = spec.expiresAt; + // Time range whithin which the transaction is valid. + uint48[2] memory timeRange = [0, spec.expiresAt]; if (callData.length >= 4) { bytes4 selector = bytes4(callData[:4]); @@ -318,15 +329,13 @@ library SessionLib { require(value <= callPolicy.maxValuePerUse, MaxValueExceeded(value, callPolicy.maxValuePerUse)); (uint48 newValidAfter, uint48 newValidUntil) = callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], value, periodIds[1]); - validAfter = newValidAfter > validAfter ? newValidAfter : validAfter; - validUntil = newValidUntil < validUntil ? newValidUntil : validUntil; + shrinkRange(timeRange, newValidAfter, newValidUntil); for (uint256 i = 0; i < callPolicy.constraints.length; i++) { (newValidAfter, newValidUntil) = callPolicy.constraints[i].checkAndUpdate( state.params[target][selector][i], callData, periodIds[2 + i] ); - validAfter = newValidAfter > validAfter ? newValidAfter : validAfter; - validUntil = newValidUntil < validUntil ? newValidUntil : validUntil; + shrinkRange(timeRange, newValidAfter, newValidUntil); } } else { TransferSpec memory transferPolicy; @@ -344,9 +353,10 @@ library SessionLib { require(value <= transferPolicy.maxValuePerUse, MaxValueExceeded(value, transferPolicy.maxValuePerUse)); (uint48 newValidAfter, uint48 newValidUntil) = transferPolicy.valueLimit.checkAndUpdate(state.transferValue[target], value, periodIds[1]); - validAfter = newValidAfter > validAfter ? newValidAfter : validAfter; - validUntil = newValidUntil < validUntil ? newValidUntil : validUntil; + shrinkRange(timeRange, newValidAfter, newValidUntil); } + + return (timeRange[0], timeRange[1]); } /// @notice Getter for the remainder of a usage limit. @@ -375,6 +385,8 @@ library SessionLib { uint64 period = uint64(block.timestamp / limit.period); return limit.limit - tracker.allowanceUsage[period][account]; } + // Unreachable, but silences warning + return 0; } /// @notice Getter for the session state. diff --git a/src/modules/EOAKeyValidator.sol b/src/modules/EOAKeyValidator.sol index cc5827b..4dc5605 100644 --- a/src/modules/EOAKeyValidator.sol +++ b/src/modules/EOAKeyValidator.sol @@ -4,10 +4,9 @@ pragma solidity ^0.8.24; import { IValidator, MODULE_TYPE_VALIDATOR } from "../interfaces/IERC7579Module.sol"; import { PackedUserOperation } from "account-abstraction/interfaces/PackedUserOperation.sol"; import { SIG_VALIDATION_FAILED, SIG_VALIDATION_SUCCESS } from "account-abstraction/core/Helpers.sol"; - -import { EnumerableSet } from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import { ECDSA } from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; -import { console } from "forge-std/console.sol"; +import { IERC1271 } from "@openzeppelin/contracts/interfaces/IERC1271.sol"; +import { EnumerableSet } from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; contract EOAKeyValidator is IValidator { using EnumerableSet for EnumerableSet.AddressSet; @@ -16,6 +15,12 @@ contract EOAKeyValidator is IValidator { mapping(address => bool) internal _initialized; mapping(address => EnumerableSet.AddressSet) owners; + event OwnerAdded(address indexed smartAccount, address indexed owner); + event OwnerRemoved(address indexed smartAccount, address indexed owner); + + error OwnerAlreadyExists(address smartAccount, address owner); + error OwnerDoesNotExist(address smartAccount, address owner); + function onInstall(bytes calldata data) external override { if (isInitialized(msg.sender)) revert AlreadyInitialized(msg.sender); _initialized[msg.sender] = true; @@ -25,12 +30,7 @@ contract EOAKeyValidator is IValidator { } } - function onUninstall( - bytes calldata // data - ) - external - override - { + function onUninstall(bytes calldata) external override { if (!isInitialized(msg.sender)) revert NotInitialized(msg.sender); _initialized[msg.sender] = false; // TODO: clear owners? @@ -55,21 +55,19 @@ contract EOAKeyValidator is IValidator { } function addOwner(address owner) public { - if (!owners[msg.sender].add(owner)) { - revert("Owner already exists"); - } - // TODO emit event? + require(isInitialized(msg.sender), NotInitialized(msg.sender)); + require(owners[msg.sender].add(owner), OwnerAlreadyExists(msg.sender, owner)); + emit OwnerAdded(msg.sender, owner); } function removeOwner(address owner) public { - if (!owners[msg.sender].remove(owner)) { - revert("Owner does not exist"); - } - // TODO emit event? + require(isInitialized(msg.sender), NotInitialized(msg.sender)); + require(owners[msg.sender].remove(owner), OwnerDoesNotExist(msg.sender, owner)); + emit OwnerRemoved(msg.sender, owner); } function isValidSignatureWithSender( - address sender, + address, // sender bytes32 hash, bytes calldata data ) @@ -78,6 +76,10 @@ contract EOAKeyValidator is IValidator { override returns (bytes4) { - // TODO + // slither-disable-next-line unused-return + (address signer, ECDSA.RecoverError err,) = ECDSA.tryRecover(hash, data); + return err == ECDSA.RecoverError.NoError && owners[msg.sender].contains(signer) + ? IERC1271.isValidSignature.selector + : bytes4(0xffffffff); } } diff --git a/src/modules/SessionKeyValidator.sol b/src/modules/SessionKeyValidator.sol index 7dab19e..bc77633 100644 --- a/src/modules/SessionKeyValidator.sol +++ b/src/modules/SessionKeyValidator.sol @@ -76,7 +76,7 @@ contract SessionKeyValidator is IValidator { /// @notice This module should not be used to validate signatures (including EIP-1271), /// as a signature by itself does not have enough information to validate it against a session. - function isValidSignatureWithSender(address, bytes32, bytes memory) external pure returns (bytes4) { + function isValidSignatureWithSender(address, bytes32, bytes calldata) external pure returns (bytes4) { return 0x00000000; } @@ -92,21 +92,29 @@ contract SessionKeyValidator is IValidator { /// + batchCall /// @dev can be extended by derived contracts. /// @param target The target address of the call - /// @param _selector The function selector of the call; currently unused /// @return true if the call is banned, false otherwise - function isBannedCall(address target, bytes4 _selector) internal view virtual returns (bool) { + function isBannedCall(address target, bytes4 /* selector */ ) internal view virtual returns (bool) { return target == address(this) // this line is technically unnecessary - || target == address(msg.sender) || IMSA(msg.sender).isModuleInstalled(MODULE_TYPE_VALIDATOR, target, ""); // TODO: - // make one - // call to check any module type + || target == address(msg.sender) || IMSA(msg.sender).isModuleInstalled(MODULE_TYPE_VALIDATOR, target, ""); + // TODO: make one call to check any module type } /// @notice Create a new session for an account /// @param sessionSpec The session specification to create a session with + /// @dev In the sessionSpec, callPolicies should not have duplicated instances of + /// (target, selector) pairs. Only the first one is considered when validating transactions. function createSession(SessionLib.SessionSpec memory sessionSpec) public virtual { bytes32 sessionHash = keccak256(abi.encode(sessionSpec)); - // TODO error - require(isInitialized(msg.sender), "not initialized"); + require(isInitialized(msg.sender), NotInitialized(msg.sender)); + + uint256 totalCallPolicies = sessionSpec.callPolicies.length; + for (uint256 i = 0; i < totalCallPolicies; i++) { + require( + !isBannedCall(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector), + SessionLib.CallPolicyBanned(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector) + ); + } + require(sessionSpec.signer != address(0), SessionLib.ZeroSigner()); // Avoid using same session key for multiple sessions, contract-wide require(sessionSigner[sessionSpec.signer] == bytes32(0), SessionLib.SignerAlreadyUsed(sessionSpec.signer)); @@ -118,14 +126,6 @@ contract SessionKeyValidator is IValidator { // Sessions should expire in no less than 60 seconds. require(sessionSpec.expiresAt >= block.timestamp + 60, SessionLib.SessionExpiresTooSoon(sessionSpec.expiresAt)); - uint256 totalCallPolicies = sessionSpec.callPolicies.length; - for (uint256 i = 0; i < totalCallPolicies; i++) { - require( - !isBannedCall(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector), - SessionLib.CallPolicyBanned(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector) - ); - } - sessions[sessionHash].status[msg.sender] = SessionLib.Status.Active; sessionSigner[sessionSpec.signer] = sessionHash; emit SessionCreated(msg.sender, sessionHash, sessionSpec); @@ -171,7 +171,7 @@ contract SessionKeyValidator is IValidator { /// @notice Validate a session transaction for an account /// @param userOp User operation to validate /// @param userOpHash The hash of the userOp - /// TODO @return + /// @return uint256 Validation data, according to ERC-4337 (EntryPoint v0.8) /// @dev Session spec and period IDs must be provided as validator data function validateUserOp(PackedUserOperation calldata userOp, bytes32 userOpHash) public virtual returns (uint256) { (, bytes memory transactionSignature, bytes memory validatorData) = @@ -183,7 +183,8 @@ contract SessionKeyValidator is IValidator { require(spec.signer != address(0), SessionLib.ZeroSigner()); bytes32 sessionHash = keccak256(abi.encode(spec)); uint192 nonceKey = uint192(userOp.nonce >> 64); - require(nonceKey == uint192(uint160(spec.signer)), "invalid nonce key"); + uint192 expectedNonceKey = uint192(uint160(spec.signer)); + require(nonceKey == expectedNonceKey, SessionLib.InvalidNonceKey(nonceKey, expectedNonceKey)); // this will revert if session spec is violated (uint48 validAfter, uint48 validUntil) = sessions[sessionHash].validate(userOp, spec, periodIds); diff --git a/src/modules/WebAuthnValidator.sol b/src/modules/WebAuthnValidator.sol index 3d30c1b..ec5a1c2 100644 --- a/src/modules/WebAuthnValidator.sol +++ b/src/modules/WebAuthnValidator.sol @@ -1,4 +1,4 @@ -// SPDX-License-Identifier: GPL-3.0 +// SPDX-License-Identifier: MIT pragma solidity ^0.8.24; import { IERC165 } from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; @@ -6,6 +6,7 @@ import { Strings } from "@openzeppelin/contracts/utils/Strings.sol"; import { Base64 } from "solady/utils/Base64.sol"; import { JSONParserLib } from "solady/utils/JSONParserLib.sol"; import { PackedUserOperation } from "account-abstraction/interfaces/PackedUserOperation.sol"; +import { IERC1271 } from "@openzeppelin/contracts/interfaces/IERC1271.sol"; import { IMSA } from "../interfaces/IMSA.sol"; import { IValidator, IModule, MODULE_TYPE_VALIDATOR } from "../interfaces/IERC7579Module.sol"; @@ -164,7 +165,7 @@ contract WebAuthnValidator is IValidator { /// @param signature The signature to validate // TODO return function isValidSignatureWithSender( - address sender, + address, // sender bytes32 signedHash, bytes calldata signature ) @@ -172,7 +173,7 @@ contract WebAuthnValidator is IValidator { view returns (bytes4) { - return webAuthVerify(signedHash, signature) ? bytes4(0x1626ba7e) : bytes4(0x00000000); + return webAuthVerify(signedHash, signature) ? IERC1271.isValidSignature.selector : bytes4(0xffffffff); } /// @notice Validates a transaction signed with a passkey @@ -180,7 +181,7 @@ contract WebAuthnValidator is IValidator { /// due to the modular format /// @param signedHash The hash of the signed transaction /// @param userOp The user operation to validate - // TODO return + /// @return 0 if the signature is valid, 1 if invalid, otherwise reverts function validateUserOp(PackedUserOperation calldata userOp, bytes32 signedHash) external view returns (uint256) { (, bytes memory signature,) = abi.decode(userOp.signature, (address, bytes, bytes)); return webAuthVerify(signedHash, signature) ? 0 : 1; @@ -198,6 +199,8 @@ contract WebAuthnValidator is IValidator { (bytes memory authenticatorData, string memory clientDataJSON, bytes32[2] memory rs, bytes memory credentialId) = _decodeFatSignature(fatSignature); + // TODO: this call should revert in all cases except invalid signature. Format should be correct regardless. + // prevent signature replay https://yondon.blog/2019/01/01/how-not-to-use-ecdsa/ if (uint256(rs[0]) == 0 || rs[0] > HIGH_R_MAX || uint256(rs[1]) == 0 || rs[1] > LOW_S_MAX) { return false; diff --git a/src/modules/contrib/AllowedSessionsValidator.sol b/src/modules/contrib/AllowedSessionsValidator.sol new file mode 100644 index 0000000..6df58ce --- /dev/null +++ b/src/modules/contrib/AllowedSessionsValidator.sol @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import { IERC165 } from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; +import { ECDSA } from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import { AccessControl } from "@openzeppelin/contracts/access/AccessControl.sol"; +import { IAccessControl } from "@openzeppelin/contracts/access/IAccessControl.sol"; +import { PackedUserOperation } from "account-abstraction/interfaces/PackedUserOperation.sol"; + +import { SessionLib } from "src/libraries/SessionLib.sol"; +import { SessionKeyValidator } from "../SessionKeyValidator.sol"; +import { IValidator, IModule, MODULE_TYPE_VALIDATOR } from "src/interfaces/IERC7579Module.sol"; + +/// @title AllowedSessionsValidator +/// @author Oleg Bedrin - - Xsolla Special Initiatives +/// @custom:security-contact security@matterlabs.dev and o.bedrin@xsolla.com +/// @notice This contract is used to manage allowed sessions for a smart account. +/// @notice This module is controlled by a single entity, which has the power +/// to close all current sessions and disallow any future sessions on this module. +contract AllowedSessionsValidator is SessionKeyValidator, AccessControl { + using SessionLib for SessionLib.SessionStorage; + + /// @notice Emitted when session actions are allowed or disallowed. + /// @param sessionActionsHash The hash of the session actions. + /// @param allowed Boolean indicating if the session actions are allowed. + event SessionActionsAllowed(bytes32 indexed sessionActionsHash, bool indexed allowed); + + /// @notice Role identifier for session registry managers. + bytes32 public constant SESSION_REGISTRY_MANAGER_ROLE = keccak256("SESSION_REGISTRY_MANAGER_ROLE"); + + /// @notice Mapping to track whether a session actions is allowed. + /// @dev The key is the hash of session actions, and the value indicates if the actions are allowed. + mapping(bytes32 sessionActionsHash => bool allowed) public areSessionActionsAllowed; + + constructor() { + _grantRole(SESSION_REGISTRY_MANAGER_ROLE, msg.sender); + _grantRole(DEFAULT_ADMIN_ROLE, msg.sender); + } + + /// @notice Set whether a session actions hash is allowed or not. + /// @param sessionActionsHash The hash of the session actions. + /// @param allowed Boolean indicating if the session actions are allowed. + /// @dev Session actions represent the set of operations, such as fee limits, call policies, and transfer policies, + /// that define the behavior and constraints of a session. + function setSessionActionsAllowed( + bytes32 sessionActionsHash, + bool allowed + ) + external + virtual + onlyRole(SESSION_REGISTRY_MANAGER_ROLE) + { + if (areSessionActionsAllowed[sessionActionsHash] != allowed) { + areSessionActionsAllowed[sessionActionsHash] = allowed; + emit SessionActionsAllowed(sessionActionsHash, allowed); + } + } + + /// @notice Get the hash of session actions from a session specification. + /// @param sessionSpec The session specification. + /// @return The hash of the session actions. + /// @dev The session actions hash is derived from the session's fee limits, call policies, and transfer policies. + function getSessionActionsHash(SessionLib.SessionSpec memory sessionSpec) public view virtual returns (bytes32) { + uint256 callPoliciesLength = sessionSpec.callPolicies.length; + bytes memory callPoliciesEncoded; + + for (uint256 i = 0; i < callPoliciesLength; ++i) { + SessionLib.CallSpec memory policy = sessionSpec.callPolicies[i]; + callPoliciesEncoded = abi.encodePacked( + callPoliciesEncoded, + bytes20(policy.target), // Address cast to bytes20 + policy.selector, // Selector + policy.maxValuePerUse, // Max value per use + uint256(policy.valueLimit.limitType), // Limit type + policy.valueLimit.limit, // Limit + policy.valueLimit.period // Period + ); + } + + return keccak256(abi.encode(sessionSpec.feeLimit, sessionSpec.transferPolicies, callPoliciesEncoded)); + } + + /// @notice Create a new session for an account. + /// @param sessionSpec The session specification to create a session with. + /// @dev A session is a temporary authorization for an account to perform specific actions, defined by the session + /// specification. + function createSession(SessionLib.SessionSpec memory sessionSpec) public virtual override(SessionKeyValidator) { + bytes32 sessionActionsHash = getSessionActionsHash(sessionSpec); + require(areSessionActionsAllowed[sessionActionsHash], SessionLib.ActionsNotAllowed(sessionActionsHash)); + SessionKeyValidator.createSession(sessionSpec); + } + + /// @inheritdoc SessionKeyValidator + function supportsInterface(bytes4 interfaceId) + public + pure + override(SessionKeyValidator, AccessControl) + returns (bool) + { + return interfaceId == type(IERC165).interfaceId || interfaceId == type(IValidator).interfaceId + || interfaceId == type(IAccessControl).interfaceId; + } + + /// @notice Validate a session transaction for an account. + /// @param userOp The user operation to validate. + /// @param userOpHash The hash of the operation. + /// @return true if the transaction is valid. + /// @dev Session spec and period IDs must be provided as validator data. + function validateUserOp( + PackedUserOperation calldata userOp, + bytes32 userOpHash + ) + public + override + returns (uint256) + { + // slither-disable-next-line unused-return + (,, bytes memory validatorData) = abi.decode(userOp.signature, (address, bytes, bytes)); + // slither-disable-next-line unused-return + (SessionLib.SessionSpec memory spec,) = abi.decode( + validatorData, // this is passed by the signature builder + (SessionLib.SessionSpec, uint48[]) + ); + bytes32 sessionActionsHash = getSessionActionsHash(spec); + require(areSessionActionsAllowed[sessionActionsHash], SessionLib.ActionsNotAllowed(sessionActionsHash)); + return SessionKeyValidator.validateUserOp(userOp, userOpHash); + } +} diff --git a/src/utils/MSAProxy.sol b/src/utils/MSAProxy.sol deleted file mode 100644 index f380488..0000000 --- a/src/utils/MSAProxy.sol +++ /dev/null @@ -1,17 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.23; - -import { Proxy } from "@openzeppelin/contracts/proxy/Proxy.sol"; -import { ERC1967Utils } from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Utils.sol"; -import { Initializable } from "src/libraries/Initializable.sol"; - -contract MSAProxy is Proxy { - constructor(address implementation, bytes memory _data) payable { - Initializable.setInitializable(); - ERC1967Utils.upgradeToAndCall(implementation, _data); - } - - function _implementation() internal view virtual override returns (address) { - return ERC1967Utils.getImplementation(); - } -} diff --git a/test/Basic.t.sol b/test/Basic.t.sol index 73ba32c..adab24c 100644 --- a/test/Basic.t.sol +++ b/test/Basic.t.sol @@ -1,81 +1,193 @@ // SPDX-License-Identifier: MIT + pragma solidity ^0.8.24; -import { Test } from "forge-std/Test.sol"; -import { EntryPoint } from "account-abstraction/core/EntryPoint.sol"; -import { ModularSmartAccount } from "../src/ModularSmartAccount.sol"; -import { MSAProxy } from "../src/utils/MSAProxy.sol"; -import { EOAKeyValidator } from "../src/modules/EOAKeyValidator.sol"; import { PackedUserOperation } from "account-abstraction/interfaces/PackedUserOperation.sol"; -import { IMSA } from "../src/interfaces/IMSA.sol"; -import { ExecutionLib } from "../src/libraries/ExecutionLib.sol"; -import { ModeLib } from "../src/libraries/ModeLib.sol"; - -contract Basic is Test { - EntryPoint public entryPoint; - ModularSmartAccount public account; - IMSA public accountProxy; - EOAKeyValidator public eoaValidator; - Account public owner; - - function setUp() public { - owner = makeAccount("owner"); - address[] memory owners = new address[](1); - owners[0] = owner.addr; - - entryPoint = new EntryPoint(); - account = new ModularSmartAccount(); - eoaValidator = new EOAKeyValidator(); - accountProxy = IMSA( - address( - new MSAProxy( - address(account), - abi.encodeCall( - ModularSmartAccount.initializeAccount, - (address(entryPoint), address(eoaValidator), abi.encode(owners)) - ) - ) +import { LibString } from "solady/utils/LibString.sol"; + +import { IERC7579Account } from "src/interfaces/IERC7579Account.sol"; +import { ExecutionLib } from "src/libraries/ExecutionLib.sol"; +import { Execution } from "src/interfaces/IERC7579Account.sol"; +import "src/libraries/ModeLib.sol"; + +import { MockTarget } from "./mocks/MockTarget.sol"; +import { MockDelegateTarget } from "./mocks/MockDelegateTarget.sol"; +import { MockERC1271Caller, MockMessage } from "./mocks/MockERC1271Caller.sol"; +import { MSATest } from "./MSATest.sol"; + +contract BasicTest is MSATest { + MockTarget public target; + MockDelegateTarget public delegateTarget; + MockERC1271Caller public erc1271Caller; + + function setUp() public override { + super.setUp(); + + target = new MockTarget(); + delegateTarget = new MockDelegateTarget(); + erc1271Caller = new MockERC1271Caller(); + } + + function test_transfer() public { + address recipient = makeAddr("recipient"); + bytes memory execution = ExecutionLib.encodeSingle(recipient, 1 ether, ""); + bytes memory callData = abi.encodeCall(IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), execution)); + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = makeSignedUserOp(callData, owner.key, address(eoaValidator)); + + entryPoint.handleOps(userOps, bundler); + vm.assertEq(recipient.balance, 1 ether); + } + + function test_execSingle() public { + bytes memory execution = + ExecutionLib.encodeSingle(address(target), 0, abi.encodeCall(MockTarget.setValue, 1337)); + bytes memory callData = abi.encodeCall(IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), execution)); + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = makeSignedUserOp(callData, owner.key, address(eoaValidator)); + + entryPoint.handleOps(userOps, bundler); + vm.assertEq(target.value(), 1337); + } + + function test_execBatch() public { + bytes memory setValueOnTarget = abi.encodeCall(MockTarget.setValue, 1337); + address target2 = makeAddr("target2"); + uint256 target2Amount = 1 wei; + + Execution[] memory executions = new Execution[](2); + executions[0] = Execution({ target: address(target), value: 0, callData: setValueOnTarget }); + executions[1] = Execution({ target: target2, value: target2Amount, callData: "" }); + + bytes memory callData = + abi.encodeCall(IERC7579Account.execute, (ModeLib.encodeSimpleBatch(), ExecutionLib.encodeBatch(executions))); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = makeSignedUserOp(callData, owner.key, address(eoaValidator)); + + entryPoint.handleOps(userOps, bundler); + vm.assertEq(target.value(), 1337); + vm.assertEq(target2.balance, target2Amount); + } + + function test_delegateCall() public { + address valueTarget = makeAddr("valueTarget"); + uint256 value = 1 ether; + bytes memory sendValue = abi.encodeWithSelector(MockDelegateTarget.sendValue.selector, valueTarget, value); + + bytes memory callData = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encode(CALLTYPE_DELEGATECALL, EXECTYPE_DEFAULT, MODE_DEFAULT, ModePayload.wrap(0x00)), + abi.encodePacked(address(delegateTarget), sendValue) ) ); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = makeSignedUserOp(callData, owner.key, address(eoaValidator)); + + entryPoint.handleOps(userOps, bundler); + vm.assertEq(valueTarget.balance, value); } - function makeUserOp( - address target, - uint256 value, - bytes memory data - ) - public - view - returns (PackedUserOperation memory userOp) - { - bytes memory callData = ExecutionLib.encodeSingle(target, value, data); - - userOp = PackedUserOperation({ - sender: address(accountProxy), - nonce: 0, - initCode: "", - callData: abi.encodeCall(ModularSmartAccount.execute, (ModeLib.encodeSimpleSingle(), callData)), - accountGasLimits: bytes32(uint256((100_000 << 128) | 100_000)), - preVerificationGas: 0, - gasFees: bytes32(0), - paymasterAndData: "", - signature: "" - }); - - bytes32 userOpHash = entryPoint.getUserOpHash(userOp); - (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner.key, userOpHash); - userOp.signature = abi.encode(address(eoaValidator), abi.encodePacked(r, s, v), ""); + function test_signatureTypedData() public view { + MockMessage memory mockMessage = MockMessage({ message: "Hello, world!", value: 42 }); + bytes memory contentsDescription = "MockMessage(string message,uint256 value)"; + + bytes32 structHash = keccak256( + abi.encode(keccak256(contentsDescription), keccak256(bytes(mockMessage.message)), mockMessage.value) + ); + + (, string memory name, string memory version, uint256 chainId, address verifyingContract, bytes32 salt,) = + account.eip712Domain(); + + bytes32 typedDataSignTypehash = keccak256( + abi.encodePacked( + "TypedDataSign(", + "MockMessage contents,", + "string name,", + "string version,", + "uint256 chainId,", + "address verifyingContract,", + "bytes32 salt)", + contentsDescription + ) + ); + + bytes32 wrapperStructHash = keccak256( + abi.encode( + typedDataSignTypehash, + structHash, + keccak256(bytes(name)), + keccak256(bytes(version)), + uint256(chainId), + uint256(uint160(verifyingContract)), + bytes32(salt) + ) + ); + + bytes32 finalHash = keccak256(abi.encodePacked(hex"1901", erc1271Caller.domainSeparator(), wrapperStructHash)); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner.key, finalHash); + bytes memory originalSignature = abi.encodePacked(r, s, v); + + bytes memory signature = abi.encodePacked( + address(eoaValidator), + originalSignature, + erc1271Caller.domainSeparator(), + structHash, + contentsDescription, + uint16(contentsDescription.length) + ); + + bool success = erc1271Caller.validateStruct(mockMessage, address(account), signature); + vm.assertTrue(success, "Signature validation failed"); } - function test_Transfer() public { - vm.deal(address(accountProxy), 10 ether); + function test_signaturePersonalSign() public view { + bytes memory message = "Hello, world!"; + bytes32 messageHash = + keccak256(abi.encodePacked("\x19Ethereum Signed Message:\n", LibString.toString(message.length), message)); - address recipient = makeAddr("recipient"); - PackedUserOperation[] memory userOps = new PackedUserOperation[](1); - userOps[0] = makeUserOp(recipient, 1 ether, ""); + bytes32 finalHash = keccak256( + abi.encodePacked( + hex"1901", + account.domainSeparator(), + keccak256(abi.encode(keccak256("PersonalSign(bytes prefixed)"), messageHash)) + ) + ); - address bundler = makeAddr("bundler"); - entryPoint.handleOps(userOps, payable(bundler)); - vm.assertEq(recipient.balance, 1 ether); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner.key, finalHash); + bytes memory originalSignature = abi.encodePacked(r, s, v); + bytes memory signature = abi.encodePacked(address(eoaValidator), originalSignature); + + bytes4 magic = account.isValidSignature(messageHash, signature); + vm.assertEq(magic, account.isValidSignature.selector); + } + + function test_signatureTypedDataUnnested() public { + // This test is skipped because solady's implementation of ERC1271 + // does some weird checks with gas limit for RPC calls + vm.skip(true); + + MockMessage memory mockMessage = MockMessage({ message: "Hello, world!", value: 42 }); + + bytes32 structHash = keccak256( + abi.encode( + keccak256("MockMessage(string message,uint256 value)"), + keccak256(bytes(mockMessage.message)), + mockMessage.value + ) + ); + + bytes32 finalHash = keccak256(abi.encodePacked(hex"1901", erc1271Caller.domainSeparator(), structHash)); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner.key, finalHash); + bytes memory originalSignature = abi.encodePacked(r, s, v); + + bytes memory signature = abi.encodePacked(address(eoaValidator), originalSignature); + + bool success = erc1271Caller.validateStruct(mockMessage, address(account), signature); + vm.assertTrue(success, "Signature validation failed"); } } diff --git a/test/MSATest.sol b/test/MSATest.sol new file mode 100644 index 0000000..00c5055 --- /dev/null +++ b/test/MSATest.sol @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.24; + +import { EntryPoint } from "account-abstraction/core/EntryPoint.sol"; +import { PackedUserOperation } from "account-abstraction/interfaces/PackedUserOperation.sol"; +import { UpgradeableBeacon } from "@openzeppelin/contracts/proxy/beacon/UpgradeableBeacon.sol"; +import { Test } from "forge-std/Test.sol"; + +import { ModularSmartAccount } from "src/ModularSmartAccount.sol"; +import { MSAFactory } from "src/MSAFactory.sol"; +import { EOAKeyValidator } from "src/modules/EOAKeyValidator.sol"; +import { IMSA } from "src/interfaces/IMSA.sol"; + +contract MSATest is Test { + EntryPoint public entryPoint; + ModularSmartAccount public account; + MSAFactory public factory; + EOAKeyValidator public eoaValidator; + Account public owner; + address payable bundler; + + function setUp() public virtual { + bundler = payable(makeAddr("bundler")); + owner = makeAccount("owner"); + + ModularSmartAccount accountImplementation = new ModularSmartAccount(); + + address entryPointAddress = accountImplementation.ENTRY_POINT(); + vm.etch(entryPointAddress, address(new EntryPoint()).code); + entryPoint = EntryPoint(payable(entryPointAddress)); + + eoaValidator = new EOAKeyValidator(); + address[] memory modules = new address[](1); + modules[0] = address(eoaValidator); + + address[] memory owners = new address[](1); + owners[0] = owner.addr; + + bytes[] memory initData = new bytes[](1); + initData[0] = abi.encode(owners); + + UpgradeableBeacon beacon = new UpgradeableBeacon(address(accountImplementation), address(this)); + factory = new MSAFactory(address(beacon)); + + bytes memory data = abi.encodeCall(IMSA.initializeAccount, (modules, initData)); + account = ModularSmartAccount(payable(factory.deployAccount(keccak256("my-account-id"), data))); + vm.deal(address(account), 2 ether); + } + + function makeUserOp(bytes memory callData) public view returns (PackedUserOperation memory userOp) { + userOp = PackedUserOperation({ + sender: address(account), + nonce: entryPoint.getNonce(address(account), 0), + initCode: "", + callData: callData, + accountGasLimits: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + preVerificationGas: 2e6, + gasFees: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + paymasterAndData: "", + signature: "" + }); + } + + function makeSignedUserOp( + bytes memory callData, + uint256 key, + address validator, + bytes memory validatorData + ) + public + view + returns (PackedUserOperation memory userOp) + { + userOp = makeUserOp(callData); + signUserOp(userOp, key, validator, validatorData); + } + + function makeSignedUserOp( + bytes memory callData, + uint256 key, + address validator + ) + public + view + returns (PackedUserOperation memory userOp) + { + userOp = makeUserOp(callData); + signUserOp(userOp, key, validator); + } + + function signUserOp( + PackedUserOperation memory userOp, + uint256 key, + address validator, + bytes memory validatorData + ) + public + view + { + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(key, userOpHash); + userOp.signature = abi.encode(validator, abi.encodePacked(r, s, v), validatorData); + } + + function signUserOp(PackedUserOperation memory userOp, uint256 key, address validator) public view { + signUserOp(userOp, key, validator, ""); + } +} diff --git a/test/Sessions.t.sol b/test/Sessions.t.sol index 4175795..7309b39 100644 --- a/test/Sessions.t.sol +++ b/test/Sessions.t.sol @@ -1,168 +1,110 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.24; -import { Test } from "forge-std/Test.sol"; -import { EntryPoint } from "account-abstraction/core/EntryPoint.sol"; -import { ModularSmartAccount } from "../src/ModularSmartAccount.sol"; -import { MSAProxy } from "../src/utils/MSAProxy.sol"; -import { EOAKeyValidator } from "../src/modules/EOAKeyValidator.sol"; -import { SessionKeyValidator } from "../src/modules/SessionKeyValidator.sol"; import { PackedUserOperation } from "account-abstraction/interfaces/PackedUserOperation.sol"; -import { IMSA } from "../src/interfaces/IMSA.sol"; -import { ExecutionLib } from "../src/libraries/ExecutionLib.sol"; -import { ModeLib } from "../src/libraries/ModeLib.sol"; -import { MODULE_TYPE_VALIDATOR } from "../src/interfaces/IERC7579Module.sol"; -import { IERC7579Account } from "../src/interfaces/IERC7579Account.sol"; -import { SessionLib } from "../src/libraries/SessionLib.sol"; -import { console } from "forge-std/console.sol"; - -contract Basic is Test { - EntryPoint public entryPoint; - ModularSmartAccount public account; - IMSA public accountProxy; - uint256 accountNonce = 0; - EOAKeyValidator public eoaValidator; +import { ModularSmartAccount } from "src/ModularSmartAccount.sol"; +import { MSAFactory } from "src/MSAFactory.sol"; +import { EOAKeyValidator } from "src/modules/EOAKeyValidator.sol"; +import { SessionKeyValidator } from "src/modules/SessionKeyValidator.sol"; +import { IMSA } from "src/interfaces/IMSA.sol"; +import { ExecutionLib } from "src/libraries/ExecutionLib.sol"; +import { ModeLib } from "src/libraries/ModeLib.sol"; +import { MODULE_TYPE_VALIDATOR } from "src/interfaces/IERC7579Module.sol"; +import { IERC7579Account } from "src/interfaces/IERC7579Account.sol"; +import { SessionLib } from "src/libraries/SessionLib.sol"; + +import { MSATest } from "./MSATest.sol"; + +contract SessionsTest is MSATest { SessionKeyValidator public sessionKeyValidator; - Account public owner; + uint256 accountNonce = 0; Account public sessionOwner; address recipient; - address bundler; SessionLib.SessionSpec public spec; - function setUp() public { - owner = makeAccount("owner"); - sessionOwner = makeAccount("sessionOwner"); - recipient = makeAddr("sessionRecipient"); - bundler = makeAddr("bundler"); + function setUp() public override { + super.setUp(); - address[] memory owners = new address[](1); - owners[0] = owner.addr; - - entryPoint = new EntryPoint(); - account = new ModularSmartAccount(); - eoaValidator = new EOAKeyValidator(); + recipient = makeAddr("sessionRecipient"); + sessionOwner = makeAccount("sessionOwner"); sessionKeyValidator = new SessionKeyValidator(); - accountProxy = IMSA( - address( - new MSAProxy( - address(account), - abi.encodeCall( - ModularSmartAccount.initializeAccount, - (address(entryPoint), address(eoaValidator), abi.encode(owners)) - ) - ) - ) - ); - } - - function makeUserOp( - bytes memory data, - uint256 signerKey, - address validator, - bytes memory validatorData - ) - public - returns (PackedUserOperation memory userOp) - { - userOp = PackedUserOperation({ - sender: address(accountProxy), - nonce: accountNonce++, - initCode: "", - callData: data, - accountGasLimits: bytes32(uint256((100_000 << 128) | 100_000)), - preVerificationGas: 0, - gasFees: bytes32(0), - paymasterAndData: "", - signature: "" - }); - - bytes32 userOpHash = entryPoint.getUserOpHash(userOp); - (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerKey, userOpHash); - userOp.signature = abi.encode(address(validator), abi.encodePacked(r, s, v), validatorData); - } - - function makeUserOp( - address target, - uint256 value, - bytes memory data, - uint256 signerKey, - address validator, - bytes memory validatorData - ) - public - returns (PackedUserOperation memory) - { - bytes memory callData = ExecutionLib.encodeSingle(target, value, data); - bytes memory executeData = abi.encodeCall(ModularSmartAccount.execute, (ModeLib.encodeSimpleSingle(), callData)); - return makeUserOp(executeData, signerKey, validator, validatorData); } - function test_InstallValidator() public { + function test_installValidator() public { bytes memory data = abi.encodeCall(ModularSmartAccount.installModule, (MODULE_TYPE_VALIDATOR, address(sessionKeyValidator), "")); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); - userOps[0] = makeUserOp(data, owner.key, address(eoaValidator), ""); + userOps[0] = makeSignedUserOp(data, owner.key, address(eoaValidator)); vm.expectEmit(true, false, false, false); emit IERC7579Account.ModuleInstalled(MODULE_TYPE_VALIDATOR, address(sessionKeyValidator)); - entryPoint.handleOps(userOps, payable(bundler)); + entryPoint.handleOps(userOps, bundler); } - function test_CreateSession() public { - test_InstallValidator(); - - SessionLib.UsageLimit memory feeLimit = - SessionLib.UsageLimit({ limitType: SessionLib.LimitType.Lifetime, limit: 0.15 ether, period: 0 }); - - SessionLib.UsageLimit memory transferLimit = - SessionLib.UsageLimit({ limitType: SessionLib.LimitType.Unlimited, limit: 0, period: 0 }); + function test_createSession() public { + test_installValidator(); SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); - transferPolicies[0] = - SessionLib.TransferSpec({ target: recipient, maxValuePerUse: 0.1 ether, valueLimit: transferLimit }); + transferPolicies[0] = SessionLib.TransferSpec({ + target: recipient, + maxValuePerUse: 0.1 ether, + valueLimit: SessionLib.UsageLimit({ limitType: SessionLib.LimitType.Unlimited, limit: 0, period: 0 }) + }); spec = SessionLib.SessionSpec({ signer: sessionOwner.addr, expiresAt: uint48(block.timestamp + 1000), transferPolicies: transferPolicies, callPolicies: new SessionLib.CallSpec[](0), - feeLimit: feeLimit + feeLimit: SessionLib.UsageLimit({ limitType: SessionLib.LimitType.Lifetime, limit: 0.15 ether, period: 0 }) }); - PackedUserOperation[] memory userOps = new PackedUserOperation[](1); - userOps[0] = makeUserOp( - address(sessionKeyValidator), - 0, - abi.encodeCall(SessionKeyValidator.createSession, (spec)), - owner.key, - address(eoaValidator), - "" + bytes memory call = ExecutionLib.encodeSingle( + address(sessionKeyValidator), 0, abi.encodeCall(SessionKeyValidator.createSession, (spec)) ); + bytes memory callData = abi.encodeCall(IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), call)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = makeSignedUserOp(callData, owner.key, address(eoaValidator)); bytes32 sessionHash = keccak256(abi.encode(spec)); vm.expectEmit(true, true, true, true); - emit SessionKeyValidator.SessionCreated(address(accountProxy), sessionHash, spec); - entryPoint.handleOps(userOps, payable(bundler)); + emit SessionKeyValidator.SessionCreated(address(account), sessionHash, spec); + entryPoint.handleOps(userOps, bundler); - SessionLib.Status status = sessionKeyValidator.sessionStatus(address(accountProxy), sessionHash); + SessionLib.Status status = sessionKeyValidator.sessionStatus(address(account), sessionHash); vm.assertTrue(status == SessionLib.Status.Active); } - function test_UseSession() public { - test_CreateSession(); - - vm.deal(address(accountProxy), 0.2 ether); + function test_useSession() public { + test_createSession(); - accountNonce = uint256(uint160(sessionOwner.addr)) << 64; + bytes memory call = ExecutionLib.encodeSingle(recipient, 0.05 ether, ""); + bytes memory callData = abi.encodeCall(IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), call)); PackedUserOperation[] memory userOps = new PackedUserOperation[](1); - userOps[0] = makeUserOp( - recipient, 0.05 ether, "", sessionOwner.key, address(sessionKeyValidator), abi.encode(spec, new uint48[](2)) - ); + userOps[0] = makeUserOp(callData); + userOps[0].nonce = uint256(uint160(sessionOwner.addr)) << 64; + signUserOp(userOps[0], sessionOwner.key, address(sessionKeyValidator), abi.encode(spec, new uint48[](2))); - entryPoint.handleOps(userOps, payable(bundler)); + entryPoint.handleOps(userOps, bundler); vm.assertEq(recipient.balance, 0.05 ether); } + + function testRevert_useSession() public { + test_createSession(); + + bytes memory call = ExecutionLib.encodeSingle(recipient, 0.11 ether, ""); // more than maxValuePerUse + bytes memory callData = abi.encodeCall(IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), call)); + + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = makeUserOp(callData); + userOps[0].nonce = uint256(uint160(sessionOwner.addr)) << 64; + signUserOp(userOps[0], sessionOwner.key, address(sessionKeyValidator), abi.encode(spec, new uint48[](2))); + + vm.expectRevert(); + entryPoint.handleOps(userOps, bundler); + } } diff --git a/test/Utils.t.sol b/test/Utils.t.sol new file mode 100644 index 0000000..75601ec --- /dev/null +++ b/test/Utils.t.sol @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +import "forge-std/Test.sol"; +import "src/libraries/ExecutionLib.sol"; +import "src/libraries/ModeLib.sol"; + +contract UtilsTest is Test { + function setUp() public { } + + function decode(bytes calldata encoded) public pure returns (address, uint256, bytes calldata) { + return ExecutionLib.decodeSingle(encoded); + } + + function test_encodeDecodeExecution(address target, uint256 value, bytes memory callData) public view { + bytes memory encoded = ExecutionLib.encodeSingle(target, value, callData); + (address _target, uint256 _value, bytes memory _callData) = this.decode(encoded); + + vm.assertTrue(_target == target); + vm.assertTrue(_value == value); + vm.assertTrue(keccak256(_callData) == keccak256(callData)); + } + + function test_encodeDecodeMode() public pure { + CallType callType = CALLTYPE_SINGLE; + ExecType execType = EXECTYPE_DEFAULT; + ModeSelector modeSelector = MODE_DEFAULT; + ModePayload payload = ModePayload.wrap(bytes22(hex"01")); + ModeCode enc = ModeLib.encode(callType, execType, modeSelector, payload); + + (CallType _calltype, ExecType _execType, ModeSelector _mode, ModePayload _payload) = ModeLib.decode(enc); + vm.assertTrue(_calltype == callType); + vm.assertTrue(_execType == execType); + vm.assertTrue(_mode == modeSelector); + vm.assertTrue(ModePayload.unwrap(_payload) == ModePayload.unwrap(payload)); + } +} diff --git a/test/mocks/MockDelegateTarget.sol b/test/mocks/MockDelegateTarget.sol new file mode 100644 index 0000000..8162887 --- /dev/null +++ b/test/mocks/MockDelegateTarget.sol @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +contract MockDelegateTarget { + function sendValue(address target, uint256 _value) public { + (bool success,) = target.call{ value: _value }(""); + require(success, "Call failed"); + } +} diff --git a/test/mocks/MockERC1271Caller.sol b/test/mocks/MockERC1271Caller.sol new file mode 100644 index 0000000..b07a987 --- /dev/null +++ b/test/mocks/MockERC1271Caller.sol @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts/utils/cryptography/EIP712.sol"; +import "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import "@openzeppelin/contracts/interfaces/IERC1271.sol"; +import "@openzeppelin/contracts/utils/Address.sol"; + +struct MockMessage { + string message; + uint256 value; +} + +contract MockERC1271Caller is EIP712 { + constructor() EIP712("ERC1271Caller", "1.0.0") { } + + function validateStruct( + MockMessage calldata mockMessage, + address signer, + bytes calldata signature + ) + external + view + returns (bool) + { + require(signer != address(0), "Invalid signer address"); + + bytes32 structHash = keccak256( + abi.encode( + keccak256("MockMessage(string message,uint256 value)"), + keccak256(bytes(mockMessage.message)), + mockMessage.value + ) + ); + + bytes32 digest = _hashTypedDataV4(structHash); + + if (signer.code.length > 0) { + // Call the ERC1271 contract + bytes4 magic = IERC1271(signer).isValidSignature(digest, signature); + return magic == IERC1271.isValidSignature.selector; + } else { + return ECDSA.recover(digest, signature) == signer; + } + } + + function domainSeparator() external view returns (bytes32) { + return _domainSeparatorV4(); + } +} diff --git a/test/mocks/MockERC20.sol b/test/mocks/MockERC20.sol new file mode 100644 index 0000000..2f051a2 --- /dev/null +++ b/test/mocks/MockERC20.sol @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract MockERC20 is ERC20 { + constructor(address mintTo) ERC20("Mock ERC20", "MOCK") { + _mint(mintTo, 10 ** 18); + } +} diff --git a/test/mocks/MockTarget.sol b/test/mocks/MockTarget.sol new file mode 100644 index 0000000..6e8f4b7 --- /dev/null +++ b/test/mocks/MockTarget.sol @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +contract MockTarget { + uint256 public value; + + function setValue(uint256 _value) public returns (uint256) { + value = _value; + return _value; + } +}