Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/SsoAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ import { IERC165 } from "@openzeppelin/contracts/utils/introspection/IERC165.sol
import { IERC1271 } from "@openzeppelin/contracts/interfaces/IERC1271.sol";
import { IERC5267 } from "@openzeppelin/contracts/interfaces/IERC5267.sol";

import { ISsoAccount } from "./interfaces/ISsoAccount.sol";
import { IBatchCaller } from "./interfaces/IBatchCaller.sol";
import { INoHooksCaller } from "./interfaces/INoHooksCaller.sol";
import { IHookManager } from "./interfaces/IHookManager.sol";
import { IOwnerManager } from "./interfaces/IOwnerManager.sol";
import { IValidatorManager } from "./interfaces/IValidatorManager.sol";

import { HookManager } from "./managers/HookManager.sol";
import { SsoUtils } from "./helpers/SsoUtils.sol";

Expand Down Expand Up @@ -161,12 +168,18 @@ contract SsoAccount is
_transaction.processPaymasterInput();
}

/// @dev type(ISsoAccount).interfaceId indicates SSO accounts
/// @inheritdoc TokenCallbackHandler
function supportsInterface(bytes4 interfaceId) public view override returns (bool) {
return
interfaceId == type(IERC5267).interfaceId ||
interfaceId == type(IERC1271).interfaceId ||
interfaceId == type(IAccount).interfaceId ||
interfaceId == type(ISsoAccount).interfaceId ||
interfaceId == type(IBatchCaller).interfaceId ||
interfaceId == type(INoHooksCaller).interfaceId ||
interfaceId == type(IHookManager).interfaceId ||
interfaceId == type(IOwnerManager).interfaceId ||
interfaceId == type(IValidatorManager).interfaceId ||
super.supportsInterface(interfaceId);
}

Expand Down
2 changes: 1 addition & 1 deletion src/interfaces/IGuardianRecoveryValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ interface IGuardianRecoveryValidator is IModuleValidator {
function initRecovery(
address accountToRecover,
bytes32 hashedCredentialId,
bytes32[2] memory rawPublicKey,
bytes32[2] calldata rawPublicKey,
bytes32 hashedOriginDomain
) external;

Expand Down
2 changes: 1 addition & 1 deletion src/interfaces/IOidcRecoveryValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ interface IOidcRecoveryValidator is IModuleValidator {
/// @param iss The OIDC issuer.
/// @param readyToRecover Indicating if recovery is active (true after `startRecovery` and false once recovery is completed).
/// @param pendingPasskeyHash The hash of the pending passkey.
/// @param recoveryStartedAt The timestamp when the recovery process was started.
/// @param recoverNonce The value is used to build the jwt nonce, and gets incremented each time a zk proof is successfully verified to prevent replay attacks.
/// @param addedOn The timestamp when the OIDC account was added.
struct OidcData {
Expand Down Expand Up @@ -62,7 +63,6 @@ interface IOidcRecoveryValidator is IModuleValidator {

/// @notice The data for starting a recovery process.
/// @param zkProof The zk proof.
/// @param issHash The hash of the OIDC issuer.
/// @param kid The key id (kid) of the OIDC key.
/// @param pendingPasskeyHash The hash of the pending passkey to be added.
/// @param timeLimit If the recovery process is started after this moment it will fail.
Expand Down
2 changes: 1 addition & 1 deletion src/libraries/SessionLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ library SessionLib {
}

// shrink array to actual size
assembly {
assembly ("memory-safe") {
mstore(callParams, paramLimitIndex)
}

Expand Down
13 changes: 5 additions & 8 deletions src/validators/AllowedSessionsValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@ pragma solidity ^0.8.24;
import { Transaction } from "@matterlabs/zksync-contracts/l2/system-contracts/libraries/TransactionHelper.sol";

import { IERC165 } from "@openzeppelin/contracts/utils/introspection/IERC165.sol";
import { ECDSA } from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol";
import { AccessControl } from "@openzeppelin/contracts/access/AccessControl.sol";
import { IAccessControl } from "@openzeppelin/contracts/access/IAccessControl.sol";

import { IAllowedSessionsValidator } from "../interfaces/IAllowedSessionsValidator.sol";
import { ISessionKeyValidator } from "../interfaces/ISessionKeyValidator.sol";
import { IModuleValidator } from "../interfaces/IModuleValidator.sol";
import { IModule } from "../interfaces/IModule.sol";
import { IValidatorManager } from "../interfaces/IValidatorManager.sol";
import { SessionLib } from "../libraries/SessionLib.sol";
import { Errors } from "../libraries/Errors.sol";
import { SsoUtils } from "../helpers/SsoUtils.sol";
Expand All @@ -33,7 +31,7 @@ contract AllowedSessionsValidator is SessionKeyValidator, AccessControl, IAllowe

/// @notice Mapping to track whether a session actions is allowed.
/// @dev The key is the hash of session actions, and the value indicates if the actions are allowed.
mapping(bytes32 sessionActionsHash => bool active) public areSessionActionsAllowed;
mapping(bytes32 sessionActionsHash => bool allowed) public areSessionActionsAllowed;

constructor() {
_grantRole(SESSION_REGISTRY_MANAGER_ROLE, msg.sender);
Expand Down Expand Up @@ -97,10 +95,9 @@ contract AllowedSessionsValidator is SessionKeyValidator, AccessControl, IAllowe
bytes4 interfaceId
) public pure override(SessionKeyValidator, AccessControl, IERC165) returns (bool) {
return
interfaceId == type(IERC165).interfaceId ||
interfaceId == type(IModuleValidator).interfaceId ||
interfaceId == type(IModule).interfaceId ||
interfaceId == type(IAccessControl).interfaceId;
interfaceId == type(IAllowedSessionsValidator).interfaceId ||
interfaceId == type(IAccessControl).interfaceId ||
SessionKeyValidator.supportsInterface(interfaceId);
}

/// @notice Validate a session transaction for an account.
Expand All @@ -122,6 +119,6 @@ contract AllowedSessionsValidator is SessionKeyValidator, AccessControl, IAllowe
if (!areSessionActionsAllowed[sessionActionsHash]) {
revert Errors.SESSION_ACTIONS_NOT_ALLOWED(sessionActionsHash);
}
return super.validateTransaction(signedHash, transaction);
return SessionKeyValidator.validateTransaction(signedHash, transaction);
}
}
4 changes: 4 additions & 0 deletions src/validators/GuardianRecoveryValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ contract GuardianRecoveryValidator is Initializable, IGuardianRecoveryValidator
if (!accountsRemovalSuccessful) {
revert Errors.ACCOUNT_NOT_GUARDED_BY_ADDRESS(msg.sender, guardianToRemove);
}

// In case an ongoing recovery was started by this guardian, discard it to prevent a potential
// account overtake by a second malicious guardian.
discardRecovery(hashedOriginDomain);
}

if (accountGuardians[hashedOriginDomain][msg.sender].length() == 0) {
Expand Down
1 change: 1 addition & 0 deletions src/validators/OidcRecoveryValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ contract OidcRecoveryValidator is IOidcRecoveryValidator, Initializable {
/// @inheritdoc IERC165
function supportsInterface(bytes4 interfaceId) external pure override returns (bool) {
return
interfaceId == type(IOidcRecoveryValidator).interfaceId ||
interfaceId == type(IERC165).interfaceId ||
interfaceId == type(IModuleValidator).interfaceId ||
interfaceId == type(IModule).interfaceId;
Expand Down
32 changes: 18 additions & 14 deletions src/validators/SessionKeyValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import { IValidatorManager } from "../interfaces/IValidatorManager.sol";
import { SessionLib } from "../libraries/SessionLib.sol";
import { Errors } from "../libraries/Errors.sol";
import { SsoUtils } from "../helpers/SsoUtils.sol";
import { TimestampAsserterLocator } from "../helpers/TimestampAsserterLocator.sol";
import { ISsoAccount } from "../interfaces/ISsoAccount.sol";

/// @title SessionKeyValidator
Expand All @@ -22,8 +21,9 @@ import { ISsoAccount } from "../interfaces/ISsoAccount.sol";
contract SessionKeyValidator is ISessionKeyValidator {
using SessionLib for SessionLib.SessionStorage;

mapping(address signer => bytes32 sessionHash) public sessionSigner;
mapping(address account => uint256 openSessions) private __DEPRECATED__sessionCounter;
mapping(bytes32 sessionHash => SessionLib.SessionStorage sessionState) internal sessions;
mapping(address signer => bytes32 sessionHash) public sessionSigner;

/// @notice Get the session state for an account
/// @param account The account to fetch the session state for
Expand Down Expand Up @@ -72,7 +72,7 @@ contract SessionKeyValidator is ISessionKeyValidator {
/// @notice This module should not be used to validate signatures (including EIP-1271),
/// as a signature by itself does not have enough information to validate it against a session.
/// @return false
function validateSignature(bytes32, bytes memory) external pure returns (bool) {
function validateSignature(bytes32, bytes calldata) external pure returns (bool) {
return false;
}

Expand Down Expand Up @@ -100,11 +100,24 @@ contract SessionKeyValidator is ISessionKeyValidator {

/// @notice Create a new session for an account
/// @param sessionSpec The session specification to create a session with
/// @dev In the sessionSpec, callPolicies should not have duplicated instances of
/// (target, selector) pairs. Only the first one is considered when validating transactions.
function createSession(SessionLib.SessionSpec memory sessionSpec) public virtual {
bytes32 sessionHash = keccak256(abi.encode(sessionSpec));
if (!isInitialized(msg.sender)) {
revert Errors.NOT_FROM_INITIALIZED_ACCOUNT(msg.sender);
}

uint256 totalCallPolicies = sessionSpec.callPolicies.length;
for (uint256 i = 0; i < totalCallPolicies; i++) {
if (isBannedCall(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector)) {
revert Errors.SESSION_CALL_POLICY_BANNED(
sessionSpec.callPolicies[i].target,
sessionSpec.callPolicies[i].selector
);
}
}

if (sessionSpec.signer == address(0)) {
revert Errors.SESSION_ZERO_SIGNER();
}
Expand All @@ -123,16 +136,6 @@ contract SessionKeyValidator is ISessionKeyValidator {
revert Errors.SESSION_EXPIRES_TOO_SOON(sessionSpec.expiresAt);
}

uint256 totalCallPolicies = sessionSpec.callPolicies.length;
for (uint256 i = 0; i < totalCallPolicies; i++) {
if (isBannedCall(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector)) {
revert Errors.SESSION_CALL_POLICY_BANNED(
sessionSpec.callPolicies[i].target,
sessionSpec.callPolicies[i].selector
);
}
}

sessions[sessionHash].status[msg.sender] = SessionLib.Status.Active;
sessionSigner[sessionSpec.signer] = sessionHash;
emit SessionCreated(msg.sender, sessionHash, sessionSpec);
Expand All @@ -147,8 +150,9 @@ contract SessionKeyValidator is ISessionKeyValidator {
}

/// @inheritdoc IERC165
function supportsInterface(bytes4 interfaceId) external pure virtual returns (bool) {
function supportsInterface(bytes4 interfaceId) public pure virtual returns (bool) {
return
interfaceId == type(ISessionKeyValidator).interfaceId ||
interfaceId == type(IERC165).interfaceId ||
interfaceId == type(IModuleValidator).interfaceId ||
interfaceId == type(IModule).interfaceId;
Expand Down
1 change: 1 addition & 0 deletions src/validators/WebAuthValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ contract WebAuthValidator is IWebAuthValidator {
/// @inheritdoc IERC165
function supportsInterface(bytes4 interfaceId) external pure override returns (bool) {
return
interfaceId == type(IWebAuthValidator).interfaceId ||
interfaceId == type(IERC165).interfaceId ||
interfaceId == type(IModuleValidator).interfaceId ||
interfaceId == type(IModule).interfaceId;
Expand Down
Loading