From d396a1015abc056cb286d78b9294af7984a98c49 Mon Sep 17 00:00:00 2001 From: Kogaroshi <25688223+Kogaroshi@users.noreply.github.com> Date: Mon, 15 Dec 2025 12:21:42 +0100 Subject: [PATCH 1/9] feat : more tracking in AccessManagerEnumerable --- src/access/AccessManagerEnumerable.sol | 86 +++++++ .../interfaces/IAccessManagerEnumerable.sol | 37 +++ tests/unit/AccessManagerEnumerable.t.sol | 212 ++++++++++++++++-- 3 files changed, 322 insertions(+), 13 deletions(-) diff --git a/src/access/AccessManagerEnumerable.sol b/src/access/AccessManagerEnumerable.sol index 94d30b2d3..c2d50b6f9 100644 --- a/src/access/AccessManagerEnumerable.sol +++ b/src/access/AccessManagerEnumerable.sol @@ -12,16 +12,46 @@ import {IAccessManagerEnumerable} from 'src/access/interfaces/IAccessManagerEnum contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { using EnumerableSet for EnumerableSet.AddressSet; using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSet for EnumerableSet.UintSet; + + /// @dev Set of all role identifiers. + EnumerableSet.UintSet private _rolesSet; /// @dev Map of role identifiers to their respective member sets. mapping(uint64 roleId => EnumerableSet.AddressSet) private _roleMembers; + /// @dev Map of role identifiers to their respective target contract addresses. + mapping(uint64 roleId => EnumerableSet.AddressSet) private _roleTargets; + + /// @dev Map of target contract addresses to their current role identifiers. + mapping(address target => uint64 roleId) private _targetRoles; + /// @dev Map of role identifiers and target contract addresses to their respective set of function selectors. mapping(uint64 roleId => mapping(address target => EnumerableSet.Bytes32Set)) private _roleTargetFunctions; constructor(address initialAdmin_) AccessManager(initialAdmin_) {} + /// @inheritdoc IAccessManagerEnumerable + function getRole(uint256 index) external view returns (uint64) { + return uint64(_rolesSet.at(index)); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoleCount() external view returns (uint256) { + return _rolesSet.length(); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoles(uint256 start, uint256 end) external view returns (uint64[] memory) { + uint256[] memory listedRoles = _rolesSet.values(start, end); + uint64[] memory roles; + assembly ('memory-safe') { + roles := listedRoles + } + return roles; + } + /// @inheritdoc IAccessManagerEnumerable function getRoleMember(uint64 roleId, uint256 index) external view returns (address) { return _roleMembers[roleId].at(index); @@ -41,6 +71,25 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { return _roleMembers[roleId].values(start, end); } + /// @inheritdoc IAccessManagerEnumerable + function getRoleTarget(uint64 roleId, uint256 index) external view returns (address) { + return _roleTargets[roleId].at(index); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoleTargetCount(uint64 roleId) external view returns (uint256) { + return _roleTargets[roleId].length(); + } + + /// @inheritdoc IAccessManagerEnumerable + function getRoleTargets( + uint64 roleId, + uint256 start, + uint256 end + ) external view returns (address[] memory) { + return _roleTargets[roleId].values(start, end); + } + /// @inheritdoc IAccessManagerEnumerable function getRoleTargetFunction( uint64 roleId, @@ -73,6 +122,40 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { return targetFunctionSelectors; } + /// @dev Tracks all role identifiers when a new role is created. + function _trackRole(uint64 roleId) internal { + if (!_rolesSet.contains(uint256(roleId))) { + _rolesSet.add(uint256(roleId)); + } + } + + /// @dev Tracks all targets where a selector was assigned to a role. + function _trackRoleTarget(uint64 roleId, address target) internal { + uint256 oldRole = _targetRoles[target]; + if (oldRole == roleId) { + return; + } + if (oldRole != ADMIN_ROLE && _roleTargetFunctions[uint64(oldRole)][target].length() == 0) { + _roleTargets[uint64(oldRole)].remove(target); + } + if (roleId != ADMIN_ROLE && !_roleTargets[roleId].contains(target)) { + _roleTargets[roleId].add(target); + } + _targetRoles[target] = roleId; + } + + /// @dev Override AccessManager `_setRoleAdmin` function to track created roles. + function _setRoleAdmin(uint64 roleId, uint64 admin) internal override { + _trackRole(roleId); + super._setRoleAdmin(roleId, admin); + } + + /// @dev Override AccessManager `_setRoleGuardian` function to track created roles. + function _setRoleGuardian(uint64 roleId, uint64 guardian) internal override { + _trackRole(roleId); + super._setRoleGuardian(roleId, guardian); + } + /// @dev Override AccessManager `_grantRole` function to track role members. function _grantRole( uint64 roleId, @@ -80,6 +163,7 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { uint32 grantDelay, uint32 executionDelay ) internal override returns (bool) { + _trackRole(roleId); bool granted = super._grantRole(roleId, account, grantDelay, executionDelay); if (granted) { _roleMembers[roleId].add(account); @@ -110,5 +194,7 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { if (roleId != ADMIN_ROLE) { _roleTargetFunctions[roleId][target].add(bytes32(selector)); } + // also track the target under the role (will be added if not already present) + _trackRoleTarget(roleId, target); } } diff --git a/src/access/interfaces/IAccessManagerEnumerable.sol b/src/access/interfaces/IAccessManagerEnumerable.sol index c62833460..d0a30a3c9 100644 --- a/src/access/interfaces/IAccessManagerEnumerable.sol +++ b/src/access/interfaces/IAccessManagerEnumerable.sol @@ -8,6 +8,21 @@ import {IAccessManager} from 'src/dependencies/openzeppelin/IAccessManager.sol'; /// @author Aave Labs /// @notice Interface for AccessManagerEnumerable extension. interface IAccessManagerEnumerable is IAccessManager { + /// @notice Returns the indentifier of the role at a specified index. + /// @param index The index in the role member list. + /// @return The identifier of the role. + function getRole(uint256 index) external view returns (uint64); + + /// @notice Returns the number of roles tracked by the AccessManager. + /// @return The number of roles. + function getRoleCount() external view returns (uint256); + + /// @notice Returns the list of role identifiers between the specified indexes. + /// @param start The starting index for the role list. + /// @param end The ending index for the role list. + /// @return The list of role identifiers. + function getRoles(uint256 start, uint256 end) external view returns (uint64[] memory); + /// @notice Returns the address of the role member at a specified index. /// @param roleId The identifier of the role. /// @param index The index in the role member list. @@ -30,6 +45,28 @@ interface IAccessManagerEnumerable is IAccessManager { uint256 end ) external view returns (address[] memory); + /// @notice Returns the address of the target contract at a specified index. + /// @param roleId The identifier of the role. + /// @param index The index in the role member list. + /// @return The address of the target contract. + function getRoleTarget(uint64 roleId, uint256 index) external view returns (address); + + /// @notice Returns the number of targets for a specified role. + /// @param roleId The identifier of the role. + /// @return The number of targets for the role. + function getRoleTargetCount(uint64 roleId) external view returns (uint256); + + /// @notice Returns the list of targets for a specified role. + /// @param roleId The identifier of the role. + /// @param start The starting index for the target list. + /// @param end The ending index for the target list. + /// @return The list of targets for the role. + function getRoleTargets( + uint64 roleId, + uint256 start, + uint256 end + ) external view returns (address[] memory); + /// @notice Returns the function selector assigned to a given role at the specified index. /// @param roleId The identifier of the role. /// @param target The address of the target contract. diff --git a/tests/unit/AccessManagerEnumerable.t.sol b/tests/unit/AccessManagerEnumerable.t.sol index e84eaee92..90f77690a 100644 --- a/tests/unit/AccessManagerEnumerable.t.sol +++ b/tests/unit/AccessManagerEnumerable.t.sol @@ -8,12 +8,18 @@ import {AccessManagerEnumerable} from 'src/access/AccessManagerEnumerable.sol'; contract AccessManagerEnumerableTest is Test { using EnumerableSet for EnumerableSet.AddressSet; + using EnumerableSet for EnumerableSet.UintSet; address internal ADMIN = makeAddr('ADMIN'); + uint64 constant ADMIN_ROLE = 0; + uint64 constant GUARDIAN_ROLE_1 = 111111111; + uint64 constant GUARDIAN_ROLE_2 = 222222222; + AccessManagerEnumerable internal accessManagerEnumerable; EnumerableSet.AddressSet members; + EnumerableSet.UintSet internalRoles; function setUp() public virtual { accessManagerEnumerable = new AccessManagerEnumerable(ADMIN); @@ -39,6 +45,12 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleMembers.length, 1); assertEq(roleMembers[0], user1); + assertEq(accessManagerEnumerable.getRole(1), roleId); + assertEq(accessManagerEnumerable.getRoleCount(), 2); + uint64[] memory roles = accessManagerEnumerable.getRoles(0, 2); + assertEq(roles.length, 2); + assertEq(roles[1], roleId); + accessManagerEnumerable.grantRole(roleId, user2, 0); assertEq(accessManagerEnumerable.getRoleMember(roleId, 1), user2); assertEq(accessManagerEnumerable.getRoleMemberCount(roleId), 2); @@ -50,7 +62,12 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleMembers.length, 2); assertEq(roleMembers[0], user1); assertEq(roleMembers[1], user2); - vm.stopPrank(); + + assertEq(accessManagerEnumerable.getRole(1), roleId); + assertEq(accessManagerEnumerable.getRoleCount(), 2); + roles = accessManagerEnumerable.getRoles(0, 2); + assertEq(roles.length, 2); + assertEq(roles[1], roleId); } function test_grantRole_fuzz(uint64 roleId, uint256 membersCount) public { @@ -86,6 +103,88 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleMembers[i], members.at(i)); assertEq(accessManagerEnumerable.getRoleMember(roleId, i), members.at(i)); } + + assertEq(accessManagerEnumerable.getRole(1), roleId); + assertEq(accessManagerEnumerable.getRoleCount(), 2); + uint64[] memory roles = accessManagerEnumerable.getRoles(0, 2); + assertEq(roles.length, 2); + assertEq(roles[1], roleId); + } + + function test_setRoleAdmin_trackRoles() public { + assertLe(accessManagerEnumerable.getRoleCount(), 1); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_1, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_2, ADMIN_ROLE); + vm.stopPrank(); + + uint64[] memory roleList = accessManagerEnumerable.getRoles(0, 3); + assertLe(accessManagerEnumerable.getRoleCount(), 3); + assertEq(roleList.length, 3); + assertEq(roleList[0], ADMIN_ROLE); + assertEq(roleList[1], GUARDIAN_ROLE_1); + assertEq(roleList[2], GUARDIAN_ROLE_2); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getRole(1), GUARDIAN_ROLE_1); + assertEq(accessManagerEnumerable.getRole(2), GUARDIAN_ROLE_2); + } + + function test_setRoleGuardian_trackRoles() public { + uint64 new_role_1 = 111; + uint64 new_role_2 = 222; + uint64 new_role_3 = 333; + assertLe(accessManagerEnumerable.getRoleCount(), 1); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleGuardian(new_role_1, GUARDIAN_ROLE_1); + accessManagerEnumerable.setRoleGuardian(new_role_2, GUARDIAN_ROLE_2); + accessManagerEnumerable.setRoleGuardian(new_role_3, GUARDIAN_ROLE_1); + vm.stopPrank(); + + uint64[] memory roleList = accessManagerEnumerable.getRoles(0, 4); + assertLe(accessManagerEnumerable.getRoleCount(), 4); + assertEq(roleList.length, 4); + assertEq(roleList[0], ADMIN_ROLE); + assertEq(roleList[1], new_role_1); + assertEq(roleList[2], new_role_2); + assertEq(roleList[3], new_role_3); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getRole(1), new_role_1); + assertEq(accessManagerEnumerable.getRole(2), new_role_2); + assertEq(accessManagerEnumerable.getRole(3), new_role_3); + } + + function test_setRoleAdmin_fuzz_trackRoles_multipleRoles(uint256 rolesCount) public { + rolesCount = bound(rolesCount, 1, 15); + uint256 expectedTotalRoleCount = rolesCount + 1; // +1 for ADMIN_ROLE + + vm.startPrank(ADMIN); + + for (uint256 i = 0; i < rolesCount; i++) { + uint64 roleId = _getRandomRoleId(); + if (!internalRoles.contains(roleId)) { + internalRoles.add(roleId); + } + accessManagerEnumerable.setRoleAdmin(roleId, ADMIN_ROLE); + } + vm.stopPrank(); + + uint64[] memory roleList = accessManagerEnumerable.getRoles( + 0, + accessManagerEnumerable.getRoleCount() + ); + assertLe(accessManagerEnumerable.getRoleCount(), expectedTotalRoleCount); + assertEq(roleList.length, expectedTotalRoleCount); + + assertEq(roleList[0], ADMIN_ROLE); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + for (uint256 i = 1; i < rolesCount; i++) { + assertEq(roleList[i], internalRoles.at(i - 1)); + assertEq(accessManagerEnumerable.getRole(i), internalRoles.at(i - 1)); + } } function test_revokeRole() public { @@ -151,6 +250,16 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleSelectors[0], selector1); assertEq(roleSelectors[1], selector2); assertEq(roleSelectors[2], selector3); + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 1); + assertEq(roleTargets[0], target); } function test_setTargetFunctionRole_withReplace() public { @@ -189,6 +298,16 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleSelectors[1], selector2); assertEq(roleSelectors[2], selector3); + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 1); + assertEq(roleTargets[0], target); + accessManagerEnumerable.setTargetFunctionRole(target, updatedSelectors, roleId2); vm.stopPrank(); @@ -197,23 +316,75 @@ contract AccessManagerEnumerableTest is Test { assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 0), selector1); assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 1), selector3); assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId2, target, 0), selector2); - bytes4[] memory roleSelectors1 = accessManagerEnumerable.getRoleTargetFunctions( + { + bytes4[] memory roleSelectors1 = accessManagerEnumerable.getRoleTargetFunctions( + roleId, + target, + 0, + 3 + ); + bytes4[] memory roleSelectors2 = accessManagerEnumerable.getRoleTargetFunctions( + roleId2, + target, + 0, + 3 + ); + assertEq(roleSelectors1.length, 2); + assertEq(roleSelectors2.length, 1); + assertEq(roleSelectors1[0], selector1); + assertEq(roleSelectors1[1], selector3); + assertEq(roleSelectors2[0], selector2); + } + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target); + roleTargets = accessManagerEnumerable.getRoleTargets( roleId, - target, 0, - 3 + accessManagerEnumerable.getRoleTargetCount(roleId) ); - bytes4[] memory roleSelectors2 = accessManagerEnumerable.getRoleTargetFunctions( - roleId2, - target, + assertEq(roleTargets.length, 1); + assertEq(roleTargets[0], target); + } + + function test_setTargetFunctionRole_multipleTargets() public { + uint64 roleId = 1; + address target1 = makeAddr('target1'); + address target2 = makeAddr('target2'); + address target3 = makeAddr('target3'); + bytes4 selector1 = bytes4(keccak256('functionOne()')); + bytes4 selector2 = bytes4(keccak256('functionTwo()')); + bytes4 selector3 = bytes4(keccak256('functionThree()')); + + address[] memory targets = new address[](3); + targets[0] = target1; + targets[1] = target2; + targets[2] = target3; + + bytes4[] memory selectors = new bytes4[](3); + selectors[0] = selector1; + selectors[1] = selector2; + selectors[2] = selector3; + + vm.startPrank(ADMIN); + accessManagerEnumerable.setTargetFunctionRole(target1, selectors, roleId); + accessManagerEnumerable.setTargetFunctionRole(target2, selectors, roleId); + accessManagerEnumerable.setTargetFunctionRole(target3, selectors, roleId); + vm.stopPrank(); + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 3); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 1), target2); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 2), target3); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, 0, - 3 + accessManagerEnumerable.getRoleTargetCount(roleId) ); - assertEq(roleSelectors1.length, 2); - assertEq(roleSelectors2.length, 1); - assertEq(roleSelectors1[0], selector1); - assertEq(roleSelectors1[1], selector3); - assertEq(roleSelectors2[0], selector2); + assertEq(roleTargets.length, 3); + assertEq(roleTargets[0], target1); + assertEq(roleTargets[1], target2); + assertEq(roleTargets[2], target3); } function test_setTargetFunctionRole_skipAddToAdminRole() public { @@ -288,5 +459,20 @@ contract AccessManagerEnumerableTest is Test { for (uint256 i = startIndex; i < endIndex; i++) { assertEq(roleSelectors[i - startIndex], selectors[i]); } + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 1); + assertEq(roleTargets[0], target); + } + + function _getRandomRoleId() internal returns (uint64) { + uint256 roleId = vm.randomUint(1, type(uint64).max - 1); + return uint64(roleId); } } From d7c08850d872cdf39cd934db49f726960b89111f Mon Sep 17 00:00:00 2001 From: Kogaroshi <25688223+Kogaroshi@users.noreply.github.com> Date: Mon, 15 Dec 2025 14:37:09 +0100 Subject: [PATCH 2/9] fix : remove redundant check --- src/access/AccessManagerEnumerable.sol | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/access/AccessManagerEnumerable.sol b/src/access/AccessManagerEnumerable.sol index c2d50b6f9..20a30b1a9 100644 --- a/src/access/AccessManagerEnumerable.sol +++ b/src/access/AccessManagerEnumerable.sol @@ -124,9 +124,7 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { /// @dev Tracks all role identifiers when a new role is created. function _trackRole(uint64 roleId) internal { - if (!_rolesSet.contains(uint256(roleId))) { - _rolesSet.add(uint256(roleId)); - } + _rolesSet.add(uint256(roleId)); } /// @dev Tracks all targets where a selector was assigned to a role. From 59d387583ac8475f30c745fa762f5983ad0015be Mon Sep 17 00:00:00 2001 From: Kogaroshi <25688223+Kogaroshi@users.noreply.github.com> Date: Mon, 15 Dec 2025 15:28:16 +0100 Subject: [PATCH 3/9] fix : address pr comments --- src/access/AccessManagerEnumerable.sol | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/access/AccessManagerEnumerable.sol b/src/access/AccessManagerEnumerable.sol index 20a30b1a9..d0289c13b 100644 --- a/src/access/AccessManagerEnumerable.sol +++ b/src/access/AccessManagerEnumerable.sol @@ -136,7 +136,7 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { if (oldRole != ADMIN_ROLE && _roleTargetFunctions[uint64(oldRole)][target].length() == 0) { _roleTargets[uint64(oldRole)].remove(target); } - if (roleId != ADMIN_ROLE && !_roleTargets[roleId].contains(target)) { + if (roleId != ADMIN_ROLE) { _roleTargets[roleId].add(target); } _targetRoles[target] = roleId; From 9ed9dc8474bf184a04d033b762d28a397a7b6dc8 Mon Sep 17 00:00:00 2001 From: Kogaroshi <25688223+Kogaroshi@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:45:35 +0100 Subject: [PATCH 4/9] fix : fix natspec based on pr comments --- src/access/AccessManagerEnumerable.sol | 2 ++ .../interfaces/IAccessManagerEnumerable.sol | 17 +++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/access/AccessManagerEnumerable.sol b/src/access/AccessManagerEnumerable.sol index d0289c13b..869dd1cd7 100644 --- a/src/access/AccessManagerEnumerable.sol +++ b/src/access/AccessManagerEnumerable.sol @@ -30,6 +30,8 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { mapping(uint64 roleId => mapping(address target => EnumerableSet.Bytes32Set)) private _roleTargetFunctions; + /// @dev Constructor. + /// @param initialAdmin_ The address of the initial admin. constructor(address initialAdmin_) AccessManager(initialAdmin_) {} /// @inheritdoc IAccessManagerEnumerable diff --git a/src/access/interfaces/IAccessManagerEnumerable.sol b/src/access/interfaces/IAccessManagerEnumerable.sol index d0a30a3c9..94c97dc49 100644 --- a/src/access/interfaces/IAccessManagerEnumerable.sol +++ b/src/access/interfaces/IAccessManagerEnumerable.sol @@ -9,11 +9,12 @@ import {IAccessManager} from 'src/dependencies/openzeppelin/IAccessManager.sol'; /// @notice Interface for AccessManagerEnumerable extension. interface IAccessManagerEnumerable is IAccessManager { /// @notice Returns the indentifier of the role at a specified index. - /// @param index The index in the role member list. + /// @param index The index in the role list. /// @return The identifier of the role. function getRole(uint256 index) external view returns (uint64); - /// @notice Returns the number of roles tracked by the AccessManager. + /// @notice Returns the total number of existing roles. + /// @dev Does not account for the built-in `ADMIN_ROLE` & `PUBLIC_ROLE` roles. /// @return The number of roles. function getRoleCount() external view returns (uint256); @@ -34,7 +35,7 @@ interface IAccessManagerEnumerable is IAccessManager { /// @return The number of members for the role. function getRoleMemberCount(uint64 roleId) external view returns (uint256); - /// @notice Returns the list of members for a specified role. + /// @notice Returns the list of members for a specified role between the specified indexes. /// @param roleId The identifier of the role. /// @param start The starting index for the member list. /// @param end The ending index for the member list. @@ -45,9 +46,9 @@ interface IAccessManagerEnumerable is IAccessManager { uint256 end ) external view returns (address[] memory); - /// @notice Returns the address of the target contract at a specified index. + /// @notice Returns the address of the target contract for a specified role and index. /// @param roleId The identifier of the role. - /// @param index The index in the role member list. + /// @param index The index in the role target list. /// @return The address of the target contract. function getRoleTarget(uint64 roleId, uint256 index) external view returns (address); @@ -56,10 +57,10 @@ interface IAccessManagerEnumerable is IAccessManager { /// @return The number of targets for the role. function getRoleTargetCount(uint64 roleId) external view returns (uint256); - /// @notice Returns the list of targets for a specified role. + /// @notice Returns the list of targets for a specified role between the specified indexes. /// @param roleId The identifier of the role. - /// @param start The starting index for the target list. - /// @param end The ending index for the target list. + /// @param start The starting index for the role target list. + /// @param end The ending index for the role target list. /// @return The list of targets for the role. function getRoleTargets( uint64 roleId, From 444b0d2fcfec8b423c75173e3cb4b86c59497592 Mon Sep 17 00:00:00 2001 From: Kogaroshi <25688223+Kogaroshi@users.noreply.github.com> Date: Fri, 19 Dec 2025 09:07:43 +0100 Subject: [PATCH 5/9] fix : address pr comments --- src/access/AccessManagerEnumerable.sol | 12 ++++++------ src/access/interfaces/IAccessManagerEnumerable.sol | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/access/AccessManagerEnumerable.sol b/src/access/AccessManagerEnumerable.sol index 869dd1cd7..dffe09850 100644 --- a/src/access/AccessManagerEnumerable.sol +++ b/src/access/AccessManagerEnumerable.sol @@ -131,12 +131,12 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { /// @dev Tracks all targets where a selector was assigned to a role. function _trackRoleTarget(uint64 roleId, address target) internal { - uint256 oldRole = _targetRoles[target]; + uint64 oldRole = _targetRoles[target]; if (oldRole == roleId) { return; } - if (oldRole != ADMIN_ROLE && _roleTargetFunctions[uint64(oldRole)][target].length() == 0) { - _roleTargets[uint64(oldRole)].remove(target); + if (oldRole != ADMIN_ROLE && _roleTargetFunctions[oldRole][target].length() == 0) { + _roleTargets[oldRole].remove(target); } if (roleId != ADMIN_ROLE) { _roleTargets[roleId].add(target); @@ -156,22 +156,22 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { super._setRoleGuardian(roleId, guardian); } - /// @dev Override AccessManager `_grantRole` function to track role members. + /// @dev Override AccessManager `_grantRole` function to track roles' membership. function _grantRole( uint64 roleId, address account, uint32 grantDelay, uint32 executionDelay ) internal override returns (bool) { - _trackRole(roleId); bool granted = super._grantRole(roleId, account, grantDelay, executionDelay); if (granted) { + _trackRole(roleId); _roleMembers[roleId].add(account); } return granted; } - /// @dev Override AccessManager `_revokeRole` function to remove from tracked role members. + /// @dev Override AccessManager `_revokeRole` function to remove from tracked roles' membership. function _revokeRole(uint64 roleId, address account) internal override returns (bool) { bool revoked = super._revokeRole(roleId, account); if (revoked) { diff --git a/src/access/interfaces/IAccessManagerEnumerable.sol b/src/access/interfaces/IAccessManagerEnumerable.sol index 94c97dc49..bbb3a974e 100644 --- a/src/access/interfaces/IAccessManagerEnumerable.sol +++ b/src/access/interfaces/IAccessManagerEnumerable.sol @@ -14,7 +14,7 @@ interface IAccessManagerEnumerable is IAccessManager { function getRole(uint256 index) external view returns (uint64); /// @notice Returns the total number of existing roles. - /// @dev Does not account for the built-in `ADMIN_ROLE` & `PUBLIC_ROLE` roles. + /// @dev Does not account for the built-in `PUBLIC_ROLE` role. /// @return The number of roles. function getRoleCount() external view returns (uint256); From 865a591f3ddc15070b72a7434f23af97cb3c4d7d Mon Sep 17 00:00:00 2001 From: Kogaroshi <25688223+Kogaroshi@users.noreply.github.com> Date: Fri, 19 Dec 2025 10:38:09 +0100 Subject: [PATCH 6/9] fix : fix Hub config test fail on CI --- tests/unit/Hub/Hub.Config.t.sol | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/Hub/Hub.Config.t.sol b/tests/unit/Hub/Hub.Config.t.sol index 8428a31c7..998c94ba4 100644 --- a/tests/unit/Hub/Hub.Config.t.sol +++ b/tests/unit/Hub/Hub.Config.t.sol @@ -807,6 +807,7 @@ contract HubConfigTest is HubBase { newConfig.liquidityFee = bound(newConfig.liquidityFee, 0, PercentageMath.PERCENTAGE_FACTOR) .toUint16(); vm.assume(address(newConfig.feeReceiver) != address(0) || newConfig.liquidityFee == 0); + assumeNotZeroAddress(newConfig.feeReceiver); assumeNotPrecompile(newConfig.feeReceiver); assumeNotForgeAddress(newConfig.feeReceiver); assumeNotZeroAddress(newConfig.irStrategy); From 2b57bff4c1f9153ab6bfc357cc3e7d4236b345de Mon Sep 17 00:00:00 2001 From: Kogaroshi <25688223+Kogaroshi@users.noreply.github.com> Date: Tue, 6 Jan 2026 11:52:09 +0100 Subject: [PATCH 7/9] New test and fix incorrect tracking & address other comments --- src/access/AccessManagerEnumerable.sol | 12 ++-- .../interfaces/IAccessManagerEnumerable.sol | 10 ++-- tests/unit/AccessManagerEnumerable.t.sol | 55 +++++++++++++++++++ 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/src/access/AccessManagerEnumerable.sol b/src/access/AccessManagerEnumerable.sol index dffe09850..473efd161 100644 --- a/src/access/AccessManagerEnumerable.sol +++ b/src/access/AccessManagerEnumerable.sol @@ -23,8 +23,8 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { /// @dev Map of role identifiers to their respective target contract addresses. mapping(uint64 roleId => EnumerableSet.AddressSet) private _roleTargets; - /// @dev Map of target contract addresses to their current role identifiers. - mapping(address target => uint64 roleId) private _targetRoles; + /// @dev Map of target contract addresses and selectors to their current role identifiers. + mapping(address target => mapping(bytes4 selector => uint64 roleId)) private _targetSelectorRoles; /// @dev Map of role identifiers and target contract addresses to their respective set of function selectors. mapping(uint64 roleId => mapping(address target => EnumerableSet.Bytes32Set)) @@ -130,8 +130,8 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { } /// @dev Tracks all targets where a selector was assigned to a role. - function _trackRoleTarget(uint64 roleId, address target) internal { - uint64 oldRole = _targetRoles[target]; + function _trackRoleTarget(uint64 roleId, address target, bytes4 selector) internal { + uint64 oldRole = _targetSelectorRoles[target][selector]; if (oldRole == roleId) { return; } @@ -141,7 +141,7 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { if (roleId != ADMIN_ROLE) { _roleTargets[roleId].add(target); } - _targetRoles[target] = roleId; + _targetSelectorRoles[target][selector] = roleId; } /// @dev Override AccessManager `_setRoleAdmin` function to track created roles. @@ -195,6 +195,6 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { _roleTargetFunctions[roleId][target].add(bytes32(selector)); } // also track the target under the role (will be added if not already present) - _trackRoleTarget(roleId, target); + _trackRoleTarget(roleId, target, selector); } } diff --git a/src/access/interfaces/IAccessManagerEnumerable.sol b/src/access/interfaces/IAccessManagerEnumerable.sol index bbb3a974e..0a0595d0b 100644 --- a/src/access/interfaces/IAccessManagerEnumerable.sol +++ b/src/access/interfaces/IAccessManagerEnumerable.sol @@ -8,7 +8,7 @@ import {IAccessManager} from 'src/dependencies/openzeppelin/IAccessManager.sol'; /// @author Aave Labs /// @notice Interface for AccessManagerEnumerable extension. interface IAccessManagerEnumerable is IAccessManager { - /// @notice Returns the indentifier of the role at a specified index. + /// @notice Returns the identifier of the role at a specified index. /// @param index The index in the role list. /// @return The identifier of the role. function getRole(uint256 index) external view returns (uint64); @@ -52,16 +52,16 @@ interface IAccessManagerEnumerable is IAccessManager { /// @return The address of the target contract. function getRoleTarget(uint64 roleId, uint256 index) external view returns (address); - /// @notice Returns the number of targets for a specified role. + /// @notice Returns the number of target contracts for a specified role. /// @param roleId The identifier of the role. - /// @return The number of targets for the role. + /// @return The number of target contracts for the role. function getRoleTargetCount(uint64 roleId) external view returns (uint256); - /// @notice Returns the list of targets for a specified role between the specified indexes. + /// @notice Returns the list of target contracts for a specified role between the specified indexes. /// @param roleId The identifier of the role. /// @param start The starting index for the role target list. /// @param end The ending index for the role target list. - /// @return The list of targets for the role. + /// @return The list of target contracts for the role. function getRoleTargets( uint64 roleId, uint256 start, diff --git a/tests/unit/AccessManagerEnumerable.t.sol b/tests/unit/AccessManagerEnumerable.t.sol index 90f77690a..76060423b 100644 --- a/tests/unit/AccessManagerEnumerable.t.sol +++ b/tests/unit/AccessManagerEnumerable.t.sol @@ -387,6 +387,61 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleTargets[2], target3); } + function test_setTargetFunctionRole_removeTarget() public { + uint64 roleId = 1; + uint64 otherRoleId = 2; + address target1 = makeAddr('target1'); + address target2 = makeAddr('target2'); + address target3 = makeAddr('target3'); + bytes4 selector1 = bytes4(keccak256('functionOne()')); + bytes4 selector2 = bytes4(keccak256('functionTwo()')); + + address[] memory targets = new address[](3); + targets[0] = target1; + targets[1] = target2; + targets[2] = target3; + + bytes4[] memory selectors = new bytes4[](2); + selectors[0] = selector1; + selectors[1] = selector2; + + vm.startPrank(ADMIN); + accessManagerEnumerable.setTargetFunctionRole(target1, selectors, roleId); + accessManagerEnumerable.setTargetFunctionRole(target2, selectors, roleId); + accessManagerEnumerable.setTargetFunctionRole(target3, selectors, roleId); + vm.stopPrank(); + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 3); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 1), target2); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 2), target3); + address[] memory roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 3); + assertEq(roleTargets[0], target1); + assertEq(roleTargets[1], target2); + assertEq(roleTargets[2], target3); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setTargetFunctionRole(target2, selectors, otherRoleId); + vm.stopPrank(); + + assertEq(accessManagerEnumerable.getRoleTargetCount(roleId), 2); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 0), target1); + assertEq(accessManagerEnumerable.getRoleTarget(roleId, 1), target3); + roleTargets = accessManagerEnumerable.getRoleTargets( + roleId, + 0, + accessManagerEnumerable.getRoleTargetCount(roleId) + ); + assertEq(roleTargets.length, 2); + assertEq(roleTargets[0], target1); + assertEq(roleTargets[1], target3); + } + function test_setTargetFunctionRole_skipAddToAdminRole() public { uint64 roleId = accessManagerEnumerable.ADMIN_ROLE(); address target = makeAddr('target'); From 6b7d84ea1e24bc5f92c122f0f7505854b46686a5 Mon Sep 17 00:00:00 2001 From: Kogaroshi <25688223+Kogaroshi@users.noreply.github.com> Date: Wed, 7 Jan 2026 10:22:46 +0100 Subject: [PATCH 8/9] fix : rename for consistency --- snapshots/Hub.Operations.json | 2 +- src/access/AccessManagerEnumerable.sol | 22 +++++----- .../interfaces/IAccessManagerEnumerable.sol | 6 +-- tests/unit/AccessManagerEnumerable.t.sol | 42 +++++++++---------- 4 files changed, 36 insertions(+), 36 deletions(-) diff --git a/snapshots/Hub.Operations.json b/snapshots/Hub.Operations.json index 54b1b87c4..1699f8307 100644 --- a/snapshots/Hub.Operations.json +++ b/snapshots/Hub.Operations.json @@ -4,7 +4,7 @@ "draw": "105931", "eliminateDeficit: full": "59781", "eliminateDeficit: partial": "69429", - "mintFeeShares": "84007", + "mintFeeShares": "84095", "payFee": "72302", "refreshPremium": "71999", "remove: full": "76993", diff --git a/src/access/AccessManagerEnumerable.sol b/src/access/AccessManagerEnumerable.sol index 473efd161..b3ad76a57 100644 --- a/src/access/AccessManagerEnumerable.sol +++ b/src/access/AccessManagerEnumerable.sol @@ -23,12 +23,12 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { /// @dev Map of role identifiers to their respective target contract addresses. mapping(uint64 roleId => EnumerableSet.AddressSet) private _roleTargets; - /// @dev Map of target contract addresses and selectors to their current role identifiers. + /// @dev Map of target contract addresses and function selectors to their assigned role identifier. mapping(address target => mapping(bytes4 selector => uint64 roleId)) private _targetSelectorRoles; /// @dev Map of role identifiers and target contract addresses to their respective set of function selectors. mapping(uint64 roleId => mapping(address target => EnumerableSet.Bytes32Set)) - private _roleTargetFunctions; + private _roleTargetSelectors; /// @dev Constructor. /// @param initialAdmin_ The address of the initial admin. @@ -93,30 +93,30 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { } /// @inheritdoc IAccessManagerEnumerable - function getRoleTargetFunction( + function getRoleTargetSelector( uint64 roleId, address target, uint256 index ) external view returns (bytes4) { - return bytes4(_roleTargetFunctions[roleId][target].at(index)); + return bytes4(_roleTargetSelectors[roleId][target].at(index)); } /// @inheritdoc IAccessManagerEnumerable - function getRoleTargetFunctionCount( + function getRoleTargetSelectorCount( uint64 roleId, address target ) external view returns (uint256) { - return _roleTargetFunctions[roleId][target].length(); + return _roleTargetSelectors[roleId][target].length(); } /// @inheritdoc IAccessManagerEnumerable - function getRoleTargetFunctions( + function getRoleTargetSelectors( uint64 roleId, address target, uint256 start, uint256 end ) external view returns (bytes4[] memory) { - bytes32[] memory targetFunctions = _roleTargetFunctions[roleId][target].values(start, end); + bytes32[] memory targetFunctions = _roleTargetSelectors[roleId][target].values(start, end); bytes4[] memory targetFunctionSelectors; assembly ('memory-safe') { targetFunctionSelectors := targetFunctions @@ -135,7 +135,7 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { if (oldRole == roleId) { return; } - if (oldRole != ADMIN_ROLE && _roleTargetFunctions[oldRole][target].length() == 0) { + if (oldRole != ADMIN_ROLE && _roleTargetSelectors[oldRole][target].length() == 0) { _roleTargets[oldRole].remove(target); } if (roleId != ADMIN_ROLE) { @@ -189,10 +189,10 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { uint64 oldRoleId = getTargetFunctionRole(target, selector); super._setTargetFunctionRole(target, selector, roleId); if (oldRoleId != ADMIN_ROLE) { - _roleTargetFunctions[oldRoleId][target].remove(bytes32(selector)); + _roleTargetSelectors[oldRoleId][target].remove(bytes32(selector)); } if (roleId != ADMIN_ROLE) { - _roleTargetFunctions[roleId][target].add(bytes32(selector)); + _roleTargetSelectors[roleId][target].add(bytes32(selector)); } // also track the target under the role (will be added if not already present) _trackRoleTarget(roleId, target, selector); diff --git a/src/access/interfaces/IAccessManagerEnumerable.sol b/src/access/interfaces/IAccessManagerEnumerable.sol index 0a0595d0b..1addcc264 100644 --- a/src/access/interfaces/IAccessManagerEnumerable.sol +++ b/src/access/interfaces/IAccessManagerEnumerable.sol @@ -73,7 +73,7 @@ interface IAccessManagerEnumerable is IAccessManager { /// @param target The address of the target contract. /// @param index The index in the role member list. /// @return The selector at the index. - function getRoleTargetFunction( + function getRoleTargetSelector( uint64 roleId, address target, uint256 index @@ -83,7 +83,7 @@ interface IAccessManagerEnumerable is IAccessManager { /// @param roleId The identifier of the role. /// @param target The address of the target contract. /// @return The number of selectors assigned to the role. - function getRoleTargetFunctionCount( + function getRoleTargetSelectorCount( uint64 roleId, address target ) external view returns (uint256); @@ -94,7 +94,7 @@ interface IAccessManagerEnumerable is IAccessManager { /// @param start The starting index for the selector list. /// @param end The ending index for the selector list. /// @return The list of selectors assigned to the role. - function getRoleTargetFunctions( + function getRoleTargetSelectors( uint64 roleId, address target, uint256 start, diff --git a/tests/unit/AccessManagerEnumerable.t.sol b/tests/unit/AccessManagerEnumerable.t.sol index 76060423b..93a983834 100644 --- a/tests/unit/AccessManagerEnumerable.t.sol +++ b/tests/unit/AccessManagerEnumerable.t.sol @@ -236,15 +236,15 @@ contract AccessManagerEnumerableTest is Test { accessManagerEnumerable.setTargetFunctionRole(target, selectors, roleId); vm.stopPrank(); - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target), 3); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 0), selector1); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 1), selector2); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 2), selector3); - bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetFunctions( + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target), 3); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 0), selector1); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 1), selector2); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 2), selector3); + bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetSelectors( roleId, target, 0, - accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target) + accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target) ); assertEq(roleSelectors.length, 3); assertEq(roleSelectors[0], selector1); @@ -283,11 +283,11 @@ contract AccessManagerEnumerableTest is Test { accessManagerEnumerable.setTargetFunctionRole(target, selectors, roleId); - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target), 3); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 0), selector1); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 1), selector2); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 2), selector3); - bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetFunctions( + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target), 3); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 0), selector1); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 1), selector2); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 2), selector3); + bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetSelectors( roleId, target, 0, @@ -311,19 +311,19 @@ contract AccessManagerEnumerableTest is Test { accessManagerEnumerable.setTargetFunctionRole(target, updatedSelectors, roleId2); vm.stopPrank(); - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target), 2); - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId2, target), 1); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 0), selector1); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId, target, 1), selector3); - assertEq(accessManagerEnumerable.getRoleTargetFunction(roleId2, target, 0), selector2); + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target), 2); + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId2, target), 1); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 0), selector1); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId, target, 1), selector3); + assertEq(accessManagerEnumerable.getRoleTargetSelector(roleId2, target, 0), selector2); { - bytes4[] memory roleSelectors1 = accessManagerEnumerable.getRoleTargetFunctions( + bytes4[] memory roleSelectors1 = accessManagerEnumerable.getRoleTargetSelectors( roleId, target, 0, 3 ); - bytes4[] memory roleSelectors2 = accessManagerEnumerable.getRoleTargetFunctions( + bytes4[] memory roleSelectors2 = accessManagerEnumerable.getRoleTargetSelectors( roleId2, target, 0, @@ -454,7 +454,7 @@ contract AccessManagerEnumerableTest is Test { accessManagerEnumerable.setTargetFunctionRole(target, selectors, roleId); // should not track selectors for ADMIN_ROLE - assertEq(accessManagerEnumerable.getRoleTargetFunctionCount(roleId, target), 0); + assertEq(accessManagerEnumerable.getRoleTargetSelectorCount(roleId, target), 0); } function test_getRoleMembers_fuzz(uint256 startIndex, uint256 endIndex) public { @@ -487,7 +487,7 @@ contract AccessManagerEnumerableTest is Test { } } - function test_getRoleTargetFunctions_fuzz(uint256 startIndex, uint256 endIndex) public { + function test_getRoleTargetSelectors_fuzz(uint256 startIndex, uint256 endIndex) public { startIndex = bound(startIndex, 0, 14); endIndex = bound(endIndex, startIndex + 1, 15); uint64 roleId = 1; @@ -504,7 +504,7 @@ contract AccessManagerEnumerableTest is Test { accessManagerEnumerable.setTargetFunctionRole(target, selectors, roleId); vm.stopPrank(); - bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetFunctions( + bytes4[] memory roleSelectors = accessManagerEnumerable.getRoleTargetSelectors( roleId, target, startIndex, From 79beb08c36fc11ea163e8d499f309ea488314ea9 Mon Sep 17 00:00:00 2001 From: Kogaroshi <25688223+Kogaroshi@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:01:55 +0100 Subject: [PATCH 9/9] feat : track admins and controlled roles --- src/access/AccessManagerEnumerable.sol | 65 +++- .../interfaces/IAccessManagerEnumerable.sol | 37 +++ tests/unit/AccessManagerEnumerable.t.sol | 310 ++++++++++++++++-- 3 files changed, 391 insertions(+), 21 deletions(-) diff --git a/src/access/AccessManagerEnumerable.sol b/src/access/AccessManagerEnumerable.sol index b3ad76a57..ddff64c75 100644 --- a/src/access/AccessManagerEnumerable.sol +++ b/src/access/AccessManagerEnumerable.sol @@ -17,9 +17,15 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { /// @dev Set of all role identifiers. EnumerableSet.UintSet private _rolesSet; + /// @dev Set of all admin role identifiers. + EnumerableSet.UintSet private _adminRolesSet; + /// @dev Map of role identifiers to their respective member sets. mapping(uint64 roleId => EnumerableSet.AddressSet) private _roleMembers; + /// @dev Map of admin role identifiers to their respective role identifier sets. + mapping(uint64 roleId => EnumerableSet.UintSet) private _adminOfRoles; + /// @dev Map of role identifiers to their respective target contract addresses. mapping(uint64 roleId => EnumerableSet.AddressSet) private _roleTargets; @@ -32,7 +38,11 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { /// @dev Constructor. /// @param initialAdmin_ The address of the initial admin. - constructor(address initialAdmin_) AccessManager(initialAdmin_) {} + constructor(address initialAdmin_) AccessManager(initialAdmin_) { + // Track the ADMIN_ROLE by default. + // (already tracked as a default role via AccessManager constructor) + _adminRolesSet.add(ADMIN_ROLE); + } /// @inheritdoc IAccessManagerEnumerable function getRole(uint256 index) external view returns (uint64) { @@ -54,6 +64,26 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { return roles; } + /// @inheritdoc IAccessManagerEnumerable + function getAdminRole(uint256 index) external view returns (uint64) { + return uint64(_adminRolesSet.at(index)); + } + + /// @inheritdoc IAccessManagerEnumerable + function getAdminRoleCount() external view returns (uint256) { + return _adminRolesSet.length(); + } + + /// @inheritdoc IAccessManagerEnumerable + function getAdminRoles(uint256 start, uint256 end) external view returns (uint64[] memory) { + uint256[] memory listedAdminRoles = _adminRolesSet.values(start, end); + uint64[] memory adminRoles; + assembly ('memory-safe') { + adminRoles := listedAdminRoles + } + return adminRoles; + } + /// @inheritdoc IAccessManagerEnumerable function getRoleMember(uint64 roleId, uint256 index) external view returns (address) { return _roleMembers[roleId].at(index); @@ -73,6 +103,30 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { return _roleMembers[roleId].values(start, end); } + /// @inheritdoc IAccessManagerEnumerable + function getAdminOfRole(uint64 adminRoleId, uint256 index) external view returns (uint64) { + return uint64(_adminOfRoles[adminRoleId].at(index)); + } + + /// @inheritdoc IAccessManagerEnumerable + function getAdminOfRoleCount(uint64 adminRoleId) external view returns (uint256) { + return _adminOfRoles[adminRoleId].length(); + } + + /// @inheritdoc IAccessManagerEnumerable + function getAdminOfRoles( + uint64 adminRoleId, + uint256 start, + uint256 end + ) external view returns (uint64[] memory) { + uint256[] memory listedRoles = _adminOfRoles[adminRoleId].values(start, end); + uint64[] memory roles; + assembly ('memory-safe') { + roles := listedRoles + } + return roles; + } + /// @inheritdoc IAccessManagerEnumerable function getRoleTarget(uint64 roleId, uint256 index) external view returns (address) { return _roleTargets[roleId].at(index); @@ -129,6 +183,14 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { _rolesSet.add(uint256(roleId)); } + /// @dev Tracks all admin role identifiers when a new admin role is set. + function _trackAdminRole(uint64 roleId, uint64 admin) internal { + _adminRolesSet.add(uint256(admin)); + uint64 oldAdmin = getRoleAdmin(roleId); + _adminOfRoles[oldAdmin].remove(uint256(roleId)); + _adminOfRoles[admin].add(uint256(roleId)); + } + /// @dev Tracks all targets where a selector was assigned to a role. function _trackRoleTarget(uint64 roleId, address target, bytes4 selector) internal { uint64 oldRole = _targetSelectorRoles[target][selector]; @@ -147,6 +209,7 @@ contract AccessManagerEnumerable is AccessManager, IAccessManagerEnumerable { /// @dev Override AccessManager `_setRoleAdmin` function to track created roles. function _setRoleAdmin(uint64 roleId, uint64 admin) internal override { _trackRole(roleId); + _trackAdminRole(roleId, admin); super._setRoleAdmin(roleId, admin); } diff --git a/src/access/interfaces/IAccessManagerEnumerable.sol b/src/access/interfaces/IAccessManagerEnumerable.sol index 1addcc264..a11483407 100644 --- a/src/access/interfaces/IAccessManagerEnumerable.sol +++ b/src/access/interfaces/IAccessManagerEnumerable.sol @@ -24,6 +24,21 @@ interface IAccessManagerEnumerable is IAccessManager { /// @return The list of role identifiers. function getRoles(uint256 start, uint256 end) external view returns (uint64[] memory); + /// @notice Returns the identifier of the admin role at a specified index. + /// @param index The index in the admin role list. + /// @return The identifier of the admin role. + function getAdminRole(uint256 index) external view returns (uint64); + + /// @notice Returns the total number of existing admin roles. + /// @return The number of admin roles. + function getAdminRoleCount() external view returns (uint256); + + /// @notice Returns the list of admin role identifiers between the specified indexes. + /// @param start The starting index for the admin role list. + /// @param end The ending index for the admin role list. + /// @return The list of admin role identifiers. + function getAdminRoles(uint256 start, uint256 end) external view returns (uint64[] memory); + /// @notice Returns the address of the role member at a specified index. /// @param roleId The identifier of the role. /// @param index The index in the role member list. @@ -46,6 +61,28 @@ interface IAccessManagerEnumerable is IAccessManager { uint256 end ) external view returns (address[] memory); + /// @notice Returns the role identifier of the listed roles for a specified admin role at a specified index. + /// @param adminRoleId The identifier of the admin role. + /// @param index The index in the admin controlled role list. + /// @return The indentifier of the role. + function getAdminOfRole(uint64 adminRoleId, uint256 index) external view returns (uint64); + + /// @notice Returns the number of members for a specified role. + /// @param adminRoleId The identifier of the admin role. + /// @return The number of members for the role. + function getAdminOfRoleCount(uint64 adminRoleId) external view returns (uint256); + + /// @notice Returns the list of role identifiers controlled by a specified admin role between the specified indexes. + /// @param adminRoleId The identifier of the admin role. + /// @param start The starting index for the admin controlled role list. + /// @param end The ending index for the admin controlled role list. + /// @return The list of admin controlled role indentifiers for the admin role. + function getAdminOfRoles( + uint64 adminRoleId, + uint256 start, + uint256 end + ) external view returns (uint64[] memory); + /// @notice Returns the address of the target contract for a specified role and index. /// @param roleId The identifier of the role. /// @param index The index in the role target list. diff --git a/tests/unit/AccessManagerEnumerable.t.sol b/tests/unit/AccessManagerEnumerable.t.sol index 93a983834..9bf8376eb 100644 --- a/tests/unit/AccessManagerEnumerable.t.sol +++ b/tests/unit/AccessManagerEnumerable.t.sol @@ -20,6 +20,8 @@ contract AccessManagerEnumerableTest is Test { EnumerableSet.AddressSet members; EnumerableSet.UintSet internalRoles; + EnumerableSet.UintSet internalAdminRoles; + mapping(uint64 => EnumerableSet.UintSet) internalAdminOfRoles; function setUp() public virtual { accessManagerEnumerable = new AccessManagerEnumerable(ADMIN); @@ -111,9 +113,11 @@ contract AccessManagerEnumerableTest is Test { assertEq(roles[1], roleId); } - function test_setRoleAdmin_trackRoles() public { - assertLe(accessManagerEnumerable.getRoleCount(), 1); + function test_setRoleAdmin_trackRolesAndTrackAdminRoles() public { + assertEq(accessManagerEnumerable.getRoleCount(), 1); assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); vm.startPrank(ADMIN); accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_1, ADMIN_ROLE); @@ -121,7 +125,7 @@ contract AccessManagerEnumerableTest is Test { vm.stopPrank(); uint64[] memory roleList = accessManagerEnumerable.getRoles(0, 3); - assertLe(accessManagerEnumerable.getRoleCount(), 3); + assertEq(accessManagerEnumerable.getRoleCount(), 3); assertEq(roleList.length, 3); assertEq(roleList[0], ADMIN_ROLE); assertEq(roleList[1], GUARDIAN_ROLE_1); @@ -129,35 +133,215 @@ contract AccessManagerEnumerableTest is Test { assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); assertEq(accessManagerEnumerable.getRole(1), GUARDIAN_ROLE_1); assertEq(accessManagerEnumerable.getRole(2), GUARDIAN_ROLE_2); + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles(0, 1); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(adminRoleList.length, 1); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + } + + function test_setRoleAdmin_trackAdminRoles() public { + uint64 adminRole1 = 1; + uint64 adminRole2 = 2; + + uint64 newRole1 = 111; + uint64 newRole2 = 222; + + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_1, adminRole1); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_2, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(newRole1, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole2, adminRole2); + vm.stopPrank(); + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles(0, 3); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 3); + assertEq(adminRoleList.length, 3); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(adminRoleList[1], adminRole1); + assertEq(adminRoleList[2], adminRole2); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(1), adminRole1); + assertEq(accessManagerEnumerable.getAdminRole(2), adminRole2); + } + + function test_setRoleAdmin_trackAdminOfRoles() public { + uint64 adminRole1 = 1; + + uint64 newRole1 = 111; + uint64 newRole2 = 222; + uint64 newRole3 = 333; + + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_1, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_2, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(newRole1, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole2, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole3, adminRole1); + vm.stopPrank(); + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles(0, 2); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 2); + assertEq(adminRoleList.length, 2); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(adminRoleList[1], adminRole1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(1), adminRole1); + + uint64[] memory adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + ADMIN_ROLE, + 0, + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE), 2); + assertEq(adminOfRolesList.length, 2); + assertEq(adminOfRolesList[0], GUARDIAN_ROLE_1); + assertEq(adminOfRolesList[1], GUARDIAN_ROLE_2); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 0), GUARDIAN_ROLE_1); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 1), GUARDIAN_ROLE_2); + + adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + adminRole1, + 0, + accessManagerEnumerable.getAdminOfRoleCount(adminRole1) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(adminRole1), 3); + assertEq(adminOfRolesList.length, 3); + assertEq(adminOfRolesList[0], newRole1); + assertEq(adminOfRolesList[1], newRole2); + assertEq(adminOfRolesList[2], newRole3); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 0), newRole1); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 1), newRole2); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 2), newRole3); + } + + function test_setRoleAdmin_trackAdminOfRoles_changeAdminRole() public { + uint64 adminRole1 = 1; + + uint64 newRole1 = 111; + uint64 newRole2 = 222; + uint64 newRole3 = 333; + + assertEq(accessManagerEnumerable.getAdminRoleCount(), 1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_1, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(GUARDIAN_ROLE_2, ADMIN_ROLE); + accessManagerEnumerable.setRoleAdmin(newRole1, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole2, adminRole1); + accessManagerEnumerable.setRoleAdmin(newRole3, adminRole1); + vm.stopPrank(); + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles(0, 2); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 2); + assertEq(adminRoleList.length, 2); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(adminRoleList[1], adminRole1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(1), adminRole1); + + uint64[] memory adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + ADMIN_ROLE, + 0, + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE), 2); + assertEq(adminOfRolesList.length, 2); + assertEq(adminOfRolesList[0], GUARDIAN_ROLE_1); + assertEq(adminOfRolesList[1], GUARDIAN_ROLE_2); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 0), GUARDIAN_ROLE_1); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 1), GUARDIAN_ROLE_2); + + adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + adminRole1, + 0, + accessManagerEnumerable.getAdminOfRoleCount(adminRole1) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(adminRole1), 3); + assertEq(adminOfRolesList.length, 3); + assertEq(adminOfRolesList[0], newRole1); + assertEq(adminOfRolesList[1], newRole2); + assertEq(adminOfRolesList[2], newRole3); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 0), newRole1); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 1), newRole2); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 2), newRole3); + + vm.startPrank(ADMIN); + accessManagerEnumerable.setRoleAdmin(newRole2, ADMIN_ROLE); + vm.stopPrank(); + + adminRoleList = accessManagerEnumerable.getAdminRoles(0, 2); + assertEq(accessManagerEnumerable.getAdminRoleCount(), 2); + assertEq(adminRoleList.length, 2); + assertEq(adminRoleList[0], ADMIN_ROLE); + assertEq(adminRoleList[1], adminRole1); + assertEq(accessManagerEnumerable.getAdminRole(0), ADMIN_ROLE); + assertEq(accessManagerEnumerable.getAdminRole(1), adminRole1); + + adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + ADMIN_ROLE, + 0, + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE), 3); + assertEq(adminOfRolesList.length, 3); + assertEq(adminOfRolesList[0], GUARDIAN_ROLE_1); + assertEq(adminOfRolesList[1], GUARDIAN_ROLE_2); + assertEq(adminOfRolesList[2], newRole2); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 0), GUARDIAN_ROLE_1); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 1), GUARDIAN_ROLE_2); + assertEq(accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, 2), newRole2); + + adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + adminRole1, + 0, + accessManagerEnumerable.getAdminOfRoleCount(adminRole1) + ); + assertEq(accessManagerEnumerable.getAdminOfRoleCount(adminRole1), 2); + assertEq(adminOfRolesList.length, 2); + assertEq(adminOfRolesList[0], newRole1); + assertEq(adminOfRolesList[1], newRole3); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 0), newRole1); + assertEq(accessManagerEnumerable.getAdminOfRole(adminRole1, 1), newRole3); } function test_setRoleGuardian_trackRoles() public { - uint64 new_role_1 = 111; - uint64 new_role_2 = 222; - uint64 new_role_3 = 333; + uint64 newRole1 = 111; + uint64 newRole2 = 222; + uint64 newRole3 = 333; assertLe(accessManagerEnumerable.getRoleCount(), 1); assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); vm.startPrank(ADMIN); - accessManagerEnumerable.setRoleGuardian(new_role_1, GUARDIAN_ROLE_1); - accessManagerEnumerable.setRoleGuardian(new_role_2, GUARDIAN_ROLE_2); - accessManagerEnumerable.setRoleGuardian(new_role_3, GUARDIAN_ROLE_1); + accessManagerEnumerable.setRoleGuardian(newRole1, GUARDIAN_ROLE_1); + accessManagerEnumerable.setRoleGuardian(newRole2, GUARDIAN_ROLE_2); + accessManagerEnumerable.setRoleGuardian(newRole3, GUARDIAN_ROLE_1); vm.stopPrank(); uint64[] memory roleList = accessManagerEnumerable.getRoles(0, 4); assertLe(accessManagerEnumerable.getRoleCount(), 4); assertEq(roleList.length, 4); assertEq(roleList[0], ADMIN_ROLE); - assertEq(roleList[1], new_role_1); - assertEq(roleList[2], new_role_2); - assertEq(roleList[3], new_role_3); + assertEq(roleList[1], newRole1); + assertEq(roleList[2], newRole2); + assertEq(roleList[3], newRole3); assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); - assertEq(accessManagerEnumerable.getRole(1), new_role_1); - assertEq(accessManagerEnumerable.getRole(2), new_role_2); - assertEq(accessManagerEnumerable.getRole(3), new_role_3); + assertEq(accessManagerEnumerable.getRole(1), newRole1); + assertEq(accessManagerEnumerable.getRole(2), newRole2); + assertEq(accessManagerEnumerable.getRole(3), newRole3); } - function test_setRoleAdmin_fuzz_trackRoles_multipleRoles(uint256 rolesCount) public { + function test_setRoleAdmin_fuzz_trackRolesAndTrackAdminRoles_multipleRoles( + uint256 rolesCount + ) public { rolesCount = bound(rolesCount, 1, 15); uint256 expectedTotalRoleCount = rolesCount + 1; // +1 for ADMIN_ROLE @@ -165,9 +349,8 @@ contract AccessManagerEnumerableTest is Test { for (uint256 i = 0; i < rolesCount; i++) { uint64 roleId = _getRandomRoleId(); - if (!internalRoles.contains(roleId)) { - internalRoles.add(roleId); - } + internalRoles.add(roleId); + internalAdminOfRoles[ADMIN_ROLE].add(uint256(roleId)); accessManagerEnumerable.setRoleAdmin(roleId, ADMIN_ROLE); } vm.stopPrank(); @@ -185,6 +368,88 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleList[i], internalRoles.at(i - 1)); assertEq(accessManagerEnumerable.getRole(i), internalRoles.at(i - 1)); } + + uint64[] memory adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + ADMIN_ROLE, + 0, + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE) + ); + assertLe( + accessManagerEnumerable.getAdminOfRoleCount(ADMIN_ROLE), + internalAdminOfRoles[ADMIN_ROLE].length() + ); + assertEq(adminOfRolesList.length, internalAdminOfRoles[ADMIN_ROLE].length()); + for (uint256 i = 0; i < internalAdminOfRoles[ADMIN_ROLE].length(); i++) { + assertEq(adminOfRolesList[i], uint64(internalAdminOfRoles[ADMIN_ROLE].at(i))); + assertEq( + accessManagerEnumerable.getAdminOfRole(ADMIN_ROLE, i), + uint64(internalAdminOfRoles[ADMIN_ROLE].at(i)) + ); + } + } + + function test_setRoleAdmin_fuzz_trackAdminRoles_multipleRoles_multipleAdmins( + uint256 rolesCount + ) public { + rolesCount = bound(rolesCount, 1, 15); + uint256 expectedTotalRoleCount = rolesCount + 1; // +1 for ADMIN_ROLE + internalAdminRoles.add(ADMIN_ROLE); + + vm.startPrank(ADMIN); + + for (uint256 i = 0; i < rolesCount; i++) { + uint64 roleId = _getRandomRoleId(); + uint64 adminRoleId = _getRandomAdminRoleId(); + internalRoles.add(roleId); + internalAdminRoles.add(adminRoleId); + internalAdminOfRoles[adminRoleId].add(uint256(roleId)); + accessManagerEnumerable.setRoleAdmin(roleId, adminRoleId); + } + vm.stopPrank(); + + uint64[] memory roleList = accessManagerEnumerable.getRoles( + 0, + accessManagerEnumerable.getRoleCount() + ); + assertLe(accessManagerEnumerable.getRoleCount(), expectedTotalRoleCount); + assertEq(roleList.length, expectedTotalRoleCount); + + assertEq(roleList[0], ADMIN_ROLE); + assertEq(accessManagerEnumerable.getRole(0), ADMIN_ROLE); + for (uint256 i = 1; i < rolesCount; i++) { + assertEq(roleList[i], internalRoles.at(i - 1)); + assertEq(accessManagerEnumerable.getRole(i), internalRoles.at(i - 1)); + } + + uint64[] memory adminRoleList = accessManagerEnumerable.getAdminRoles( + 0, + accessManagerEnumerable.getAdminRoleCount() + ); + assertLe(accessManagerEnumerable.getAdminRoleCount(), internalAdminRoles.length()); + assertEq(adminRoleList.length, internalAdminRoles.length()); + for (uint256 i = 0; i < internalAdminRoles.length(); i++) { + uint64 adminRoleId = uint64(internalAdminRoles.at(i)); + assertEq(adminRoleList[i], adminRoleId); + assertEq(accessManagerEnumerable.getAdminRole(i), adminRoleId); + + uint64[] memory adminOfRolesList = accessManagerEnumerable.getAdminOfRoles( + adminRoleId, + 0, + accessManagerEnumerable.getAdminOfRoleCount(adminRoleId) + ); + assertLe( + accessManagerEnumerable.getAdminOfRoleCount(adminRoleId), + internalAdminOfRoles[adminRoleId].length() + ); + assertEq(adminOfRolesList.length, internalAdminOfRoles[adminRoleId].length()); + for (uint256 j = 0; j < internalAdminOfRoles[adminRoleId].length(); j++) { + assertEq(adminOfRolesList[j], uint64(internalAdminOfRoles[adminRoleId].at(j))); + assertEq( + accessManagerEnumerable.getAdminOfRole(adminRoleId, j), + uint64(internalAdminOfRoles[adminRoleId].at(j)) + ); + } + } } function test_revokeRole() public { @@ -526,8 +791,13 @@ contract AccessManagerEnumerableTest is Test { assertEq(roleTargets[0], target); } + function _getRandomAdminRoleId() internal returns (uint64) { + uint256 adminRoleId = vm.randomUint(0, 4); + return uint64(adminRoleId); + } + function _getRandomRoleId() internal returns (uint64) { - uint256 roleId = vm.randomUint(1, type(uint64).max - 1); + uint256 roleId = vm.randomUint(5, type(uint64).max - 1); return uint64(roleId); } }