diff --git a/snapshots/Hub.Operations.json b/snapshots/Hub.Operations.json index 16573542b..d7cf01488 100644 --- a/snapshots/Hub.Operations.json +++ b/snapshots/Hub.Operations.json @@ -4,7 +4,7 @@ "draw": "105931", "eliminateDeficit: full": "59781", "eliminateDeficit: partial": "69429", - "mintFeeShares": "86130", + "mintFeeShares": "86218", "payFee": "72302", "refreshPremium": "71999", "remove: full": "76993", diff --git a/src/access/AccessManagerEnumerable.sol b/src/access/AccessManagerEnumerable.sol index 94d30b2d3..ddff64c75 100644 --- a/src/access/AccessManagerEnumerable.sol +++ b/src/access/AccessManagerEnumerable.sol @@ -12,15 +12,77 @@ import {IAccessManagerEnumerable} from 'src/access/interfaces/IAccessManagerEnum contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { using EnumerableSet for EnumerableSet.AddressSet; using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSet for EnumerableSet.UintSet; + + /// @dev Set of all role identifiers. + EnumerableSet.UintSet private _rolesSet; + + /// @dev Set of all admin role identifiers. + EnumerableSet.UintSet private _adminRolesSet; /// @dev Map of role identifiers to their respective member sets. mapping(uint64 roleId => EnumerableSet.AddressSet) private _roleMembers; + /// @dev Map of admin role identifiers to their respective role identifier sets. + mapping(uint64 roleId => EnumerableSet.UintSet) private _adminOfRoles; + + /// @dev Map of role identifiers to their respective target contract addresses. + mapping(uint64 roleId => EnumerableSet.AddressSet) private _roleTargets; + + /// @dev Map of target contract addresses and function selectors to their assigned role identifier. + mapping(address target => mapping(bytes4 selector => uint64 roleId)) private _targetSelectorRoles; + /// @dev Map of role identifiers and target contract addresses to their respective set of function selectors. mapping(uint64 roleId => mapping(address target => EnumerableSet.Bytes32Set)) - private _roleTargetFunctions; + private _roleTargetSelectors; + + /// @dev Constructor. + /// @param initialAdmin_ The address of the initial admin. + constructor(address initialAdmin_) AccessManager(initialAdmin_) { + // Track the ADMIN_ROLE by default. + // (already tracked as a default role via AccessManager constructor) + _adminRolesSet.add(ADMIN_ROLE); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRole(uint256 index) external view returns (uint64) { + return uint64(_rolesSet.at(index)); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoleCount() external view returns (uint256) { + return _rolesSet.length(); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoles(uint256 start, uint256 end) external view returns (uint64[] memory) { + uint256[] memory listedRoles = _rolesSet.values(start, end); + uint64[] memory roles; + assembly ('memory-safe') { + roles := listedRoles + } + return roles; + } - constructor(address initialAdmin_) AccessManager(initialAdmin_) {} + /// @inheritdoc IAccessManagerEnumerable + function getAdminRole(uint256 index) external view returns (uint64) { + return uint64(_adminRolesSet.at(index)); + } + + /// @inheritdoc IAccessManagerEnumerable + function getAdminRoleCount() external view returns (uint256) { + return _adminRolesSet.length(); + } + + /// @inheritdoc IAccessManagerEnumerable + function getAdminRoles(uint256 start, uint256 end) external view returns (uint64[] memory) { + uint256[] memory listedAdminRoles = _adminRolesSet.values(start, end); + uint64[] memory adminRoles; + assembly ('memory-safe') { + adminRoles := listedAdminRoles + } + return adminRoles; + } /// @inheritdoc IAccessManagerEnumerable function getRoleMember(uint64 roleId, uint256 index) external view returns (address) { @@ -42,30 +104,73 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { } /// @inheritdoc IAccessManagerEnumerable - function getRoleTargetFunction( + function getAdminOfRole(uint64 adminRoleId, uint256 index) external view returns (uint64) { + return uint64(_adminOfRoles[adminRoleId].at(index)); + } + + /// @inheritdoc IAccessManagerEnumerable + function getAdminOfRoleCount(uint64 adminRoleId) external view returns (uint256) { + return _adminOfRoles[adminRoleId].length(); + } + + /// @inheritdoc IAccessManagerEnumerable + function getAdminOfRoles( + uint64 adminRoleId, + uint256 start, + uint256 end + ) external view returns (uint64[] memory) { + uint256[] memory listedRoles = _adminOfRoles[adminRoleId].values(start, end); + uint64[] memory roles; + assembly ('memory-safe') { + roles := listedRoles + } + return roles; + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoleTarget(uint64 roleId, uint256 index) external view returns (address) { + return _roleTargets[roleId].at(index); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoleTargetCount(uint64 roleId) external view returns (uint256) { + return _roleTargets[roleId].length(); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoleTargets( + uint64 roleId, + uint256 start, + uint256 end + ) external view returns (address[] memory) { + return _roleTargets[roleId].values(start, end); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoleTargetSelector( uint64 roleId, address target, uint256 index ) external view returns (bytes4) { - return bytes4(_roleTargetFunctions[roleId][target].at(index)); + return bytes4(_roleTargetSelectors[roleId][target].at(index)); } /// @inheritdoc IAccessManagerEnumerable - function getRoleTargetFunctionCount( + function getRoleTargetSelectorCount( uint64 roleId, address target ) external view returns (uint256) { - return _roleTargetFunctions[roleId][target].length(); + return _roleTargetSelectors[roleId][target].length(); } /// @inheritdoc IAccessManagerEnumerable - function getRoleTargetFunctions( + function getRoleTargetSelectors( uint64 roleId, address target, uint256 start, uint256 end ) external view returns (bytes4[] memory) { - bytes32[] memory targetFunctions = _roleTargetFunctions[roleId][target].values(start, end); + bytes32[] memory targetFunctions = _roleTargetSelectors[roleId][target].values(start, end); bytes4[] memory targetFunctionSelectors; assembly ('memory-safe') { targetFunctionSelectors := targetFunctions @@ -73,7 +178,48 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { return targetFunctionSelectors; } - /// @dev Override AccessManager `_grantRole` function to track role members. + /// @dev Tracks all role identifiers when a new role is created. + function _trackRole(uint64 roleId) internal { + _rolesSet.add(uint256(roleId)); + } + + /// @dev Tracks all admin role identifiers when a new admin role is set. + function _trackAdminRole(uint64 roleId, uint64 admin) internal { + _adminRolesSet.add(uint256(admin)); + uint64 oldAdmin = getRoleAdmin(roleId); + _adminOfRoles[oldAdmin].remove(uint256(roleId)); + _adminOfRoles[admin].add(uint256(roleId)); + } + + /// @dev Tracks all targets where a selector was assigned to a role. + function _trackRoleTarget(uint64 roleId, address target, bytes4 selector) internal { + uint64 oldRole = _targetSelectorRoles[target][selector]; + if (oldRole == roleId) { + return; + } + if (oldRole != ADMIN_ROLE && _roleTargetSelectors[oldRole][target].length() == 0) { + _roleTargets[oldRole].remove(target); + } + if (roleId != ADMIN_ROLE) { + _roleTargets[roleId].add(target); + } + _targetSelectorRoles[target][selector] = roleId; + } + + /// @dev Override AccessManager `_setRoleAdmin` function to track created roles. + function _setRoleAdmin(uint64 roleId, uint64 admin) internal override { + _trackRole(roleId); + _trackAdminRole(roleId, admin); + super._setRoleAdmin(roleId, admin); + } + + /// @dev Override AccessManager `_setRoleGuardian` function to track created roles. + function _setRoleGuardian(uint64 roleId, uint64 guardian) internal override { + _trackRole(roleId); + super._setRoleGuardian(roleId, guardian); + } + + /// @dev Override AccessManager `_grantRole` function to track roles' membership. function _grantRole( uint64 roleId, address account, @@ -82,12 +228,13 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { ) internal override returns (bool) { bool granted = super._grantRole(roleId, account, grantDelay, executionDelay); if (granted) { + _trackRole(roleId); _roleMembers[roleId].add(account); } return granted; } - /// @dev Override AccessManager `_revokeRole` function to remove from tracked role members. + /// @dev Override AccessManager `_revokeRole` function to remove from tracked roles' membership. function _revokeRole(uint64 roleId, address account) internal override returns (bool) { bool revoked = super._revokeRole(roleId, account); if (revoked) { @@ -105,10 +252,12 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { uint64 oldRoleId = getTargetFunctionRole(target, selector); super._setTargetFunctionRole(target, selector, roleId); if (oldRoleId != ADMIN_ROLE) { - _roleTargetFunctions[oldRoleId][target].remove(bytes32(selector)); + _roleTargetSelectors[oldRoleId][target].remove(bytes32(selector)); } if (roleId != ADMIN_ROLE) { - _roleTargetFunctions[roleId][target].add(bytes32(selector)); + _roleTargetSelectors[roleId][target].add(bytes32(selector)); } + // also track the target under the role (will be added if not already present) + _trackRoleTarget(roleId, target, selector); } } diff --git a/src/access/interfaces/IAccessManagerEnumerable.sol b/src/access/interfaces/IAccessManagerEnumerable.sol index c62833460..a11483407 100644 --- a/src/access/interfaces/IAccessManagerEnumerable.sol +++ b/src/access/interfaces/IAccessManagerEnumerable.sol @@ -8,6 +8,37 @@ import {IAccessManager} from 'src/dependencies/openzeppelin/IAccessManager.sol'; /// @author Aave Labs /// @notice Interface for AccessManagerEnumerable extension. interface IAccessManagerEnumerable is IAccessManager { + /// @notice Returns the identifier of the role at a specified index. + /// @param index The index in the role list. + /// @return The identifier of the role. + function getRole(uint256 index) external view returns (uint64); + + /// @notice Returns the total number of existing roles. + /// @dev Does not account for the built-in `PUBLIC_ROLE` role. + /// @return The number of roles. + function getRoleCount() external view returns (uint256); + + /// @notice Returns the list of role identifiers between the specified indexes. + /// @param start The starting index for the role list. + /// @param end The ending index for the role list. + /// @return The list of role identifiers. + function getRoles(uint256 start, uint256 end) external view returns (uint64[] memory); + + /// @notice Returns the identifier of the admin role at a specified index. + /// @param index The index in the admin role list. + /// @return The identifier of the admin role. + function getAdminRole(uint256 index) external view returns (uint64); + + /// @notice Returns the total number of existing admin roles. + /// @return The number of admin roles. + function getAdminRoleCount() external view returns (uint256); + + /// @notice Returns the list of admin role identifiers between the specified indexes. + /// @param start The starting index for the admin role list. + /// @param end The ending index for the admin role list. + /// @return The list of admin role identifiers. + function getAdminRoles(uint256 start, uint256 end) external view returns (uint64[] memory); + /// @notice Returns the address of the role member at a specified index. /// @param roleId The identifier of the role. /// @param index The index in the role member list. @@ -19,7 +50,7 @@ interface IAccessManagerEnumerable is IAccessManager { /// @return The number of members for the role. function getRoleMemberCount(uint64 roleId) external view returns (uint256); - /// @notice Returns the list of members for a specified role. + /// @notice Returns the list of members for a specified role between the specified indexes. /// @param roleId The identifier of the role. /// @param start The starting index for the member list. /// @param end The ending index for the member list. @@ -30,12 +61,56 @@ interface IAccessManagerEnumerable is IAccessManager { uint256 end ) external view returns (address[] memory); + /// @notice Returns the role identifier of the listed roles for a specified admin role at a specified index. + /// @param adminRoleId The identifier of the admin role. + /// @param index The index in the admin controlled role list. + /// @return The indentifier of the role. + function getAdminOfRole(uint64 adminRoleId, uint256 index) external view returns (uint64); + + /// @notice Returns the number of members for a specified role. + /// @param adminRoleId The identifier of the admin role. + /// @return The number of members for the role. + function getAdminOfRoleCount(uint64 adminRoleId) external view returns (uint256); + + /// @notice Returns the list of role identifiers controlled by a specified admin role between the specified indexes. + /// @param adminRoleId The identifier of the admin role. + /// @param start The starting index for the admin controlled role list. + /// @param end The ending index for the admin controlled role list. + /// @return The list of admin controlled role indentifiers for the admin role. + function getAdminOfRoles( + uint64 adminRoleId, + uint256 start, + uint256 end + ) external view returns (uint64[] memory); + + /// @notice Returns the address of the target contract for a specified role and index. + /// @param roleId The identifier of the role. + /// @param index The index in the role target list. + /// @return The address of the target contract. + function getRoleTarget(uint64 roleId, uint256 index) external view returns (address); + + /// @notice Returns the number of target contracts for a specified role. + /// @param roleId The identifier of the role. + /// @return The number of target contracts for the role. + function getRoleTargetCount(uint64 roleId) external view returns (uint256); + + /// @notice Returns the list of target contracts for a specified role between the specified indexes. + /// @param roleId The identifier of the role. + /// @param start The starting index for the role target list. + /// @param end The ending index for the role target list. + /// @return The list of target contracts for the role. + function getRoleTargets( + uint64 roleId, + uint256 start, + uint256 end + ) external view returns (address[] memory); + /// @notice Returns the function selector assigned to a given role at the specified index. /// @param roleId The identifier of the role. /// @param target The address of the target contract. /// @param index The index in the role member list. /// @return The selector at the index. - function getRoleTargetFunction( + function getRoleTargetSelector( uint64 roleId, address target, uint256 index @@ -45,7 +120,7 @@ interface IAccessManagerEnumerable is IAccessManager { /// @param roleId The identifier of the role. /// @param target The address of the target contract. /// @return The number of selectors assigned to the role. - function getRoleTargetFunctionCount( + function getRoleTargetSelectorCount( uint64 roleId, address target ) external view returns (uint256); @@ -56,7 +131,7 @@ interface IAccessManagerEnumerable is IAccessManager { /// @param start The starting index for the selector list. /// @param end The ending index for the selector list. /// @return The list of selectors assigned to the role. - function getRoleTargetFunctions( + function getRoleTargetSelectors( uint64 roleId, address target, uint256 start, diff --git a/tests/unit/AccessManagerEnumerable.t.sol b/tests/unit/AccessManagerEnumerable.t.sol index e84eaee92..9bf8376eb 100644 --- a/tests/unit/AccessManagerEnumerable.t.sol +++ b/tests/unit/AccessManagerEnumerable.t.sol @@ -8,12 +8,20 @@ import {AccessManagerEnumerable} from 'src/access/AccessManagerEnumerable.sol'; contract AccessManagerEnumerableTest is Test { using EnumerableSet for EnumerableSet.AddressSet; + using EnumerableSet for EnumerableSet.UintSet; address internal ADMIN = makeAddr('ADMIN'); + uint64 constant ADMIN_ROLE = 0; + uint64 constant GUARDIAN_ROLE_1 = 111111111; + uint64 constant GUARDIAN_ROLE_2 = 222222222; + AccessManagerEnumerable internal accessManagerEnumerable; EnumerableSet.AddressSet members; + EnumerableSet.UintSet internalRoles; + EnumerableSet.UintSet internalAdminRoles; + mapping(uint64 => EnumerableSet.UintSet) internalAdminOfRoles; function setUp() public virtual { accessManagerEnumerable = new AccessManagerEnumerable(ADMIN); @@ -39,6 +47,12 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleMembers.length, 1); assertEq(roleMembers[0], user1); + assertEq(accessManagerEnumerable.getRole(1), roleId); + assertEq(accessManagerEnumerable.getRoleCount(), 2); + uint64[] memory roles = accessManagerEnumerable.getRoles(0, 2); + assertEq(roles.length, 2); + assertEq(roles[1], roleId); + accessManagerEnumerable.grantRole(roleId, user2, 0); assertEq(accessManagerEnumerable.getRoleMember(roleId, 1), user2); assertEq(accessManagerEnumerable.getRoleMemberCount(roleId), 2); @@ -50,7 +64,12 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleMembers.length, 2); assertEq(roleMembers[0], user1); assertEq(roleMembers[1], user2); - vm.stopPrank(); + + assertEq(accessManagerEnumerable.getRole(1), roleId); + assertEq(accessManagerEnumerable.getRoleCount(), 2); + roles = accessManagerEnumerable.getRoles(0, 2); + assertEq(roles.length, 2); + assertEq(roles[1], roleId); } function test_grantRole_fuzz(uint64 roleId, uint256 membersCount) public { @@ -86,6 +105,351 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleMembers[i], members.at(i)); assertEq(accessManagerEnumerable.getRoleMember(roleId, i), members.at(i)); } + + assertEq(accessManagerEnumerable.getRole(1), roleId); + assertEq(accessManagerEnumerable.getRoleCount(), 2); + uint64[] memory roles = accessManagerEnumerable.getRoles(0, 2); + assertEq(roles.length, 2); + assertEq(roles[1], roleId); + } + + function test_setRoleAdmin_trackRolesAndTrackAdminRoles() public { + assertEq(accessManagerEnumerable.getRoleCount(), 1); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_1, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_2, ADMIN_ROLE); + vm.stopPrank(); + + uint64[] memory roleList = accessManagerEnumerable.getRoles(0, 3); + assertEq(accessManagerEnumerable.getRoleCount(), 3); + assertEq(roleList.length, 3); + assertEq(roleList[0], ADMIN_ROLE); + assertEq(roleList[1], GUARDIAN_ROLE_1); + assertEq(roleList[2], GUARDIAN_ROLE_2); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getRole(1), GUARDIAN_ROLE_1); + assertEq(accessManagerEnumerable.getRole(2), GUARDIAN_ROLE_2); + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles(0, 1); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(adminRoleList.length, 1); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + } + + function test_setRoleAdmin_trackAdminRoles() public { + uint64 adminRole1 = 1; + uint64 adminRole2 = 2; + + uint64 newRole1 = 111; + uint64 newRole2 = 222; + + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_1, adminRole1); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_2, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(newRole1, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole2, adminRole2); + vm.stopPrank(); + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles(0, 3); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 3); + assertEq(adminRoleList.length, 3); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(adminRoleList[1], adminRole1); + assertEq(adminRoleList[2], adminRole2); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(1), adminRole1); + assertEq(accessManagerEnumerable.getAdminRole(2), adminRole2); + } + + function test_setRoleAdmin_trackAdminOfRoles() public { + uint64 adminRole1 = 1; + + uint64 newRole1 = 111; + uint64 newRole2 = 222; + uint64 newRole3 = 333; + + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_1, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_2, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(newRole1, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole2, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole3, adminRole1); + vm.stopPrank(); + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles(0, 2); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 2); + assertEq(adminRoleList.length, 2); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(adminRoleList[1], adminRole1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(1), adminRole1); + + uint64[] memory adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + ADMIN_ROLE, + 0, + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE), 2); + assertEq(adminOfRolesList.length, 2); + assertEq(adminOfRolesList[0], GUARDIAN_ROLE_1); + assertEq(adminOfRolesList[1], GUARDIAN_ROLE_2); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 0), GUARDIAN_ROLE_1); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 1), GUARDIAN_ROLE_2); + + adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + adminRole1, + 0, + accessManagerEnumerable.getAdminOfRoleCount(adminRole1) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(adminRole1), 3); + assertEq(adminOfRolesList.length, 3); + assertEq(adminOfRolesList[0], newRole1); + assertEq(adminOfRolesList[1], newRole2); + assertEq(adminOfRolesList[2], newRole3); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 0), newRole1); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 1), newRole2); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 2), newRole3); + } + + function test_setRoleAdmin_trackAdminOfRoles_changeAdminRole() public { + uint64 adminRole1 = 1; + + uint64 newRole1 = 111; + uint64 newRole2 = 222; + uint64 newRole3 = 333; + + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_1, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_2, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(newRole1, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole2, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole3, adminRole1); + vm.stopPrank(); + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles(0, 2); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 2); + assertEq(adminRoleList.length, 2); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(adminRoleList[1], adminRole1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(1), adminRole1); + + uint64[] memory adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + ADMIN_ROLE, + 0, + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE), 2); + assertEq(adminOfRolesList.length, 2); + assertEq(adminOfRolesList[0], GUARDIAN_ROLE_1); + assertEq(adminOfRolesList[1], GUARDIAN_ROLE_2); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 0), GUARDIAN_ROLE_1); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 1), GUARDIAN_ROLE_2); + + adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + adminRole1, + 0, + accessManagerEnumerable.getAdminOfRoleCount(adminRole1) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(adminRole1), 3); + assertEq(adminOfRolesList.length, 3); + assertEq(adminOfRolesList[0], newRole1); + assertEq(adminOfRolesList[1], newRole2); + assertEq(adminOfRolesList[2], newRole3); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 0), newRole1); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 1), newRole2); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 2), newRole3); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(newRole2, ADMIN_ROLE); + vm.stopPrank(); + + adminRoleList = accessManagerEnumerable.getAdminRoles(0, 2); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 2); + assertEq(adminRoleList.length, 2); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(adminRoleList[1], adminRole1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(1), adminRole1); + + adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + ADMIN_ROLE, + 0, + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE), 3); + assertEq(adminOfRolesList.length, 3); + assertEq(adminOfRolesList[0], GUARDIAN_ROLE_1); + assertEq(adminOfRolesList[1], GUARDIAN_ROLE_2); + assertEq(adminOfRolesList[2], newRole2); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 0), GUARDIAN_ROLE_1); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 1), GUARDIAN_ROLE_2); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 2), newRole2); + + adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + adminRole1, + 0, + accessManagerEnumerable.getAdminOfRoleCount(adminRole1) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(adminRole1), 2); + assertEq(adminOfRolesList.length, 2); + assertEq(adminOfRolesList[0], newRole1); + assertEq(adminOfRolesList[1], newRole3); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 0), newRole1); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 1), newRole3); + } + + function test_setRoleGuardian_trackRoles() public { + uint64 newRole1 = 111; + uint64 newRole2 = 222; + uint64 newRole3 = 333; + assertLe(accessManagerEnumerable.getRoleCount(), 1); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleGuardian(newRole1, GUARDIAN_ROLE_1); + accessManagerEnumerable.setRoleGuardian(newRole2, GUARDIAN_ROLE_2); + accessManagerEnumerable.setRoleGuardian(newRole3, GUARDIAN_ROLE_1); + vm.stopPrank(); + + uint64[] memory roleList = accessManagerEnumerable.getRoles(0, 4); + assertLe(accessManagerEnumerable.getRoleCount(), 4); + assertEq(roleList.length, 4); + assertEq(roleList[0], ADMIN_ROLE); + assertEq(roleList[1], newRole1); + assertEq(roleList[2], newRole2); + assertEq(roleList[3], newRole3); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getRole(1), newRole1); + assertEq(accessManagerEnumerable.getRole(2), newRole2); + assertEq(accessManagerEnumerable.getRole(3), newRole3); + } + + function test_setRoleAdmin_fuzz_trackRolesAndTrackAdminRoles_multipleRoles( + uint256 rolesCount + ) public { + rolesCount = bound(rolesCount, 1, 15); + uint256 expectedTotalRoleCount = rolesCount + 1; // +1 for ADMIN_ROLE + + vm.startPrank(ADMIN); + + for (uint256 i = 0; i < rolesCount; i++) { + uint64 roleId = _getRandomRoleId(); + internalRoles.add(roleId); + internalAdminOfRoles[ADMIN_ROLE].add(uint256(roleId)); + accessManagerEnumerable.setRoleAdmin(roleId, ADMIN_ROLE); + } + vm.stopPrank(); + + uint64[] memory roleList = accessManagerEnumerable.getRoles( + 0, + accessManagerEnumerable.getRoleCount() + ); + assertLe(accessManagerEnumerable.getRoleCount(), expectedTotalRoleCount); + assertEq(roleList.length, expectedTotalRoleCount); + + assertEq(roleList[0], ADMIN_ROLE); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + for (uint256 i = 1; i < rolesCount; i++) { + assertEq(roleList[i], internalRoles.at(i - 1)); + assertEq(accessManagerEnumerable.getRole(i), internalRoles.at(i - 1)); + } + + uint64[] memory adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + ADMIN_ROLE, + 0, + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE) + ); + assertLe( + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE), + internalAdminOfRoles[ADMIN_ROLE].length() + ); + assertEq(adminOfRolesList.length, internalAdminOfRoles[ADMIN_ROLE].length()); + for (uint256 i = 0; i < internalAdminOfRoles[ADMIN_ROLE].length(); i++) { + assertEq(adminOfRolesList[i], uint64(internalAdminOfRoles[ADMIN_ROLE].at(i))); + assertEq( + accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, i), + uint64(internalAdminOfRoles[ADMIN_ROLE].at(i)) + ); + } + } + + function test_setRoleAdmin_fuzz_trackAdminRoles_multipleRoles_multipleAdmins( + uint256 rolesCount + ) public { + rolesCount = bound(rolesCount, 1, 15); + uint256 expectedTotalRoleCount = rolesCount + 1; // +1 for ADMIN_ROLE + internalAdminRoles.add(ADMIN_ROLE); + + vm.startPrank(ADMIN); + + for (uint256 i = 0; i < rolesCount; i++) { + uint64 roleId = _getRandomRoleId(); + uint64 adminRoleId = _getRandomAdminRoleId(); + internalRoles.add(roleId); + internalAdminRoles.add(adminRoleId); + internalAdminOfRoles[adminRoleId].add(uint256(roleId)); + accessManagerEnumerable.setRoleAdmin(roleId, adminRoleId); + } + vm.stopPrank(); + + uint64[] memory roleList = accessManagerEnumerable.getRoles( + 0, + accessManagerEnumerable.getRoleCount() + ); + assertLe(accessManagerEnumerable.getRoleCount(), expectedTotalRoleCount); + assertEq(roleList.length, expectedTotalRoleCount); + + assertEq(roleList[0], ADMIN_ROLE); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + for (uint256 i = 1; i < rolesCount; i++) { + assertEq(roleList[i], internalRoles.at(i - 1)); + assertEq(accessManagerEnumerable.getRole(i), internalRoles.at(i - 1)); + } + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles( + 0, + accessManagerEnumerable.getAdminRoleCount() + ); + assertLe(accessManagerEnumerable.getAdminRoleCount(), internalAdminRoles.length()); + assertEq(adminRoleList.length, internalAdminRoles.length()); + for (uint256 i = 0; i < internalAdminRoles.length(); i++) { + uint64 adminRoleId = uint64(internalAdminRoles.at(i)); + assertEq(adminRoleList[i], adminRoleId); + assertEq(accessManagerEnumerable.getAdminRole(i), adminRoleId); + + uint64[] memory adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + adminRoleId, + 0, + accessManagerEnumerable.getAdminOfRoleCount(adminRoleId) + ); + assertLe( + accessManagerEnumerable.getAdminOfRoleCount(adminRoleId), + internalAdminOfRoles[adminRoleId].length() + ); + assertEq(adminOfRolesList.length, internalAdminOfRoles[adminRoleId].length()); + for (uint256 j = 0; j < internalAdminOfRoles[adminRoleId].length(); j++) { + assertEq(adminOfRolesList[j], uint64(internalAdminOfRoles[adminRoleId].at(j))); + assertEq( + accessManagerEnumerable.getAdminOfRole(adminRoleId, j), + uint64(internalAdminOfRoles[adminRoleId].at(j)) + ); + } + } } function test_revokeRole() public { @@ -137,20 +501,30 @@ contract AccessManagerEnumerableTest is Test { accessManagerEnumerable.setTargetFunctionRole(target, selectors, roleId); vm.stopPrank(); - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target), 3); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 0), selector1); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 1), selector2); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 2), selector3); - bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetFunctions( + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target), 3); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 0), selector1); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 1), selector2); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 2), selector3); + bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetSelectors( roleId, target, 0, - accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target) + accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target) ); assertEq(roleSelectors.length, 3); assertEq(roleSelectors[0], selector1); assertEq(roleSelectors[1], selector2); assertEq(roleSelectors[2], selector3); + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 1); + assertEq(roleTargets[0], target); } function test_setTargetFunctionRole_withReplace() public { @@ -174,11 +548,11 @@ contract AccessManagerEnumerableTest is Test { accessManagerEnumerable.setTargetFunctionRole(target, selectors, roleId); - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target), 3); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 0), selector1); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 1), selector2); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 2), selector3); - bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetFunctions( + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target), 3); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 0), selector1); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 1), selector2); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 2), selector3); + bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetSelectors( roleId, target, 0, @@ -189,31 +563,148 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleSelectors[1], selector2); assertEq(roleSelectors[2], selector3); + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 1); + assertEq(roleTargets[0], target); + accessManagerEnumerable.setTargetFunctionRole(target, updatedSelectors, roleId2); vm.stopPrank(); - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target), 2); - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId2, target), 1); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 0), selector1); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 1), selector3); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId2, target, 0), selector2); - bytes4[] memory roleSelectors1 = accessManagerEnumerable.getRoleTargetFunctions( + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target), 2); + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId2, target), 1); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 0), selector1); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 1), selector3); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId2, target, 0), selector2); + { + bytes4[] memory roleSelectors1 = accessManagerEnumerable.getRoleTargetSelectors( + roleId, + target, + 0, + 3 + ); + bytes4[] memory roleSelectors2 = accessManagerEnumerable.getRoleTargetSelectors( + roleId2, + target, + 0, + 3 + ); + assertEq(roleSelectors1.length, 2); + assertEq(roleSelectors2.length, 1); + assertEq(roleSelectors1[0], selector1); + assertEq(roleSelectors1[1], selector3); + assertEq(roleSelectors2[0], selector2); + } + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target); + roleTargets = accessManagerEnumerable.getRoleTargets( roleId, - target, 0, - 3 + accessManagerEnumerable.getRoleTargetCount(roleId) ); - bytes4[] memory roleSelectors2 = accessManagerEnumerable.getRoleTargetFunctions( - roleId2, - target, + assertEq(roleTargets.length, 1); + assertEq(roleTargets[0], target); + } + + function test_setTargetFunctionRole_multipleTargets() public { + uint64 roleId = 1; + address target1 = makeAddr('target1'); + address target2 = makeAddr('target2'); + address target3 = makeAddr('target3'); + bytes4 selector1 = bytes4(keccak256('functionOne()')); + bytes4 selector2 = bytes4(keccak256('functionTwo()')); + bytes4 selector3 = bytes4(keccak256('functionThree()')); + + address[] memory targets = new address[](3); + targets[0] = target1; + targets[1] = target2; + targets[2] = target3; + + bytes4[] memory selectors = new bytes4[](3); + selectors[0] = selector1; + selectors[1] = selector2; + selectors[2] = selector3; + + vm.startPrank(ADMIN); + accessManagerEnumerable.setTargetFunctionRole(target1, selectors, roleId); + accessManagerEnumerable.setTargetFunctionRole(target2, selectors, roleId); + accessManagerEnumerable.setTargetFunctionRole(target3, selectors, roleId); + vm.stopPrank(); + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 3); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 1), target2); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 2), target3); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, 0, - 3 + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 3); + assertEq(roleTargets[0], target1); + assertEq(roleTargets[1], target2); + assertEq(roleTargets[2], target3); + } + + function test_setTargetFunctionRole_removeTarget() public { + uint64 roleId = 1; + uint64 otherRoleId = 2; + address target1 = makeAddr('target1'); + address target2 = makeAddr('target2'); + address target3 = makeAddr('target3'); + bytes4 selector1 = bytes4(keccak256('functionOne()')); + bytes4 selector2 = bytes4(keccak256('functionTwo()')); + + address[] memory targets = new address[](3); + targets[0] = target1; + targets[1] = target2; + targets[2] = target3; + + bytes4[] memory selectors = new bytes4[](2); + selectors[0] = selector1; + selectors[1] = selector2; + + vm.startPrank(ADMIN); + accessManagerEnumerable.setTargetFunctionRole(target1, selectors, roleId); + accessManagerEnumerable.setTargetFunctionRole(target2, selectors, roleId); + accessManagerEnumerable.setTargetFunctionRole(target3, selectors, roleId); + vm.stopPrank(); + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 3); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 1), target2); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 2), target3); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) ); - assertEq(roleSelectors1.length, 2); - assertEq(roleSelectors2.length, 1); - assertEq(roleSelectors1[0], selector1); - assertEq(roleSelectors1[1], selector3); - assertEq(roleSelectors2[0], selector2); + assertEq(roleTargets.length, 3); + assertEq(roleTargets[0], target1); + assertEq(roleTargets[1], target2); + assertEq(roleTargets[2], target3); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setTargetFunctionRole(target2, selectors, otherRoleId); + vm.stopPrank(); + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 2); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 1), target3); + roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 2); + assertEq(roleTargets[0], target1); + assertEq(roleTargets[1], target3); } function test_setTargetFunctionRole_skipAddToAdminRole() public { @@ -228,7 +719,7 @@ contract AccessManagerEnumerableTest is Test { accessManagerEnumerable.setTargetFunctionRole(target, selectors, roleId); // should not track selectors for ADMIN_ROLE - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target), 0); + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target), 0); } function test_getRoleMembers_fuzz(uint256 startIndex, uint256 endIndex) public { @@ -261,7 +752,7 @@ contract AccessManagerEnumerableTest is Test { } } - function test_getRoleTargetFunctions_fuzz(uint256 startIndex, uint256 endIndex) public { + function test_getRoleTargetSelectors_fuzz(uint256 startIndex, uint256 endIndex) public { startIndex = bound(startIndex, 0, 14); endIndex = bound(endIndex, startIndex + 1, 15); uint64 roleId = 1; @@ -278,7 +769,7 @@ contract AccessManagerEnumerableTest is Test { accessManagerEnumerable.setTargetFunctionRole(target, selectors, roleId); vm.stopPrank(); - bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetFunctions( + bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetSelectors( roleId, target, startIndex, @@ -288,5 +779,25 @@ contract AccessManagerEnumerableTest is Test { for (uint256 i = startIndex; i < endIndex; i++) { assertEq(roleSelectors[i - startIndex], selectors[i]); } + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 1); + assertEq(roleTargets[0], target); + } + + function _getRandomAdminRoleId() internal returns (uint64) { + uint256 adminRoleId = vm.randomUint(0, 4); + return uint64(adminRoleId); + } + + function _getRandomRoleId() internal returns (uint64) { + uint256 roleId = vm.randomUint(5, type(uint64).max - 1); + return uint64(roleId); } } diff --git a/tests/unit/Hub/Hub.Config.t.sol b/tests/unit/Hub/Hub.Config.t.sol index 8428a31c7..998c94ba4 100644 --- a/tests/unit/Hub/Hub.Config.t.sol +++ b/tests/unit/Hub/Hub.Config.t.sol @@ -807,6 +807,7 @@ contract HubConfigTest is HubBase { newConfig.liquidityFee = bound(newConfig.liquidityFee, 0, PercentageMath.PERCENTAGE_FACTOR) .toUint16(); vm.assume(address(newConfig.feeReceiver) != address(0) || newConfig.liquidityFee == 0); + assumeNotZeroAddress(newConfig.feeReceiver); assumeNotPrecompile(newConfig.feeReceiver); assumeNotForgeAddress(newConfig.feeReceiver); assumeNotZeroAddress(newConfig.irStrategy);