diff --git a/changes/11240.feature.md b/changes/11240.feature.md new file mode 100644 index 00000000000..ba37aaf6121 --- /dev/null +++ b/changes/11240.feature.md @@ -0,0 +1 @@ +Wire BulkActionRBACValidator to the bulk permission check so bulk actions filter unauthorized entities and surface them via partial-success responses. diff --git a/src/ai/backend/manager/actions/validators/rbac/bulk.py b/src/ai/backend/manager/actions/validators/rbac/bulk.py index e17e11804e4..2f1115d9aea 100644 --- a/src/ai/backend/manager/actions/validators/rbac/bulk.py +++ b/src/ai/backend/manager/actions/validators/rbac/bulk.py @@ -1,15 +1,21 @@ from typing import Any, override +from ai.backend.common.contexts.user import current_user +from ai.backend.common.exception import UnreachableError from ai.backend.manager.actions.action import BaseActionTriggerMeta from ai.backend.manager.actions.action.bulk import BaseBulkAction from ai.backend.manager.actions.validator.bulk import ( BulkActionValidator, BulkValidationResult, + DeniedEntity, ) +from ai.backend.manager.data.permission.role import BulkPermissionCheckInput from ai.backend.manager.repositories.permission_controller.repository import ( PermissionControllerRepository, ) +_DENY_REASON = "permission_denied" + class BulkActionRBACValidator(BulkActionValidator): def __init__( @@ -27,9 +33,31 @@ def name(cls) -> str: async def validate( self, action: BaseBulkAction[Any], meta: BaseActionTriggerMeta ) -> BulkValidationResult: - # TODO: wire this to PermissionControllerRepository.check_bulk_permission_with_scope_chain(). - # Until then, every entity is treated as allowed so legacy behavior is preserved. + user = current_user() + if user is None: + raise UnreachableError("User context is not available") + entity_ids = list(action.entity_ids) + if user.is_superadmin: + return BulkValidationResult( + allowed_entity_ids=entity_ids, + denied_entities=[], + ) + permission_map = await self._repository.check_bulk_permission_with_scope_chain( + BulkPermissionCheckInput( + user_id=user.user_id, + target_element_type=action.entity_type().to_element(), + target_entity_ids=entity_ids, + operation=action.operation_type().to_permission_operation(), + ) + ) + allowed_entity_ids: list[str] = [] + denied_entities: list[DeniedEntity] = [] + for eid in entity_ids: + if permission_map.get(eid, False): + allowed_entity_ids.append(eid) + else: + denied_entities.append(DeniedEntity(entity_id=eid, deny_reason=_DENY_REASON)) return BulkValidationResult( - allowed_entity_ids=list(action.entity_ids), - denied_entities=[], + allowed_entity_ids=allowed_entity_ids, + denied_entities=denied_entities, ) diff --git a/src/ai/backend/manager/actions/validators/rbac/legacy.py b/src/ai/backend/manager/actions/validators/rbac/legacy.py index 913f6924d43..df1c82adf41 100644 --- a/src/ai/backend/manager/actions/validators/rbac/legacy.py +++ b/src/ai/backend/manager/actions/validators/rbac/legacy.py @@ -9,6 +9,7 @@ from typing import override from ai.backend.common.contexts.user import current_user +from ai.backend.common.exception import UnreachableError from ai.backend.common.metrics.safe import SafeCounter from ai.backend.logging.utils import BraceStyleAdapter from ai.backend.manager.actions.action import BaseActionTriggerMeta @@ -17,7 +18,6 @@ from ai.backend.manager.actions.validator.scope import ScopeActionValidator from ai.backend.manager.actions.validator.single_entity import SingleEntityActionValidator from ai.backend.manager.data.permission.role import ScopeChainPermissionCheckInput -from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.permission_controller.repository import ( PermissionControllerRepository, ) @@ -53,7 +53,7 @@ def __init__( async def validate(self, action: BaseSingleEntityAction, meta: BaseActionTriggerMeta) -> None: user = current_user() if user is None: - raise UserNotFound("User not found in context") + raise UnreachableError("User context is not available") if user.is_superadmin: return @@ -94,7 +94,7 @@ def __init__( async def validate(self, action: BaseScopeAction, meta: BaseActionTriggerMeta) -> None: user = current_user() if user is None: - raise UserNotFound("User not found in context") + raise UnreachableError("User context is not available") if user.is_superadmin: return diff --git a/src/ai/backend/manager/actions/validators/rbac/scope.py b/src/ai/backend/manager/actions/validators/rbac/scope.py index 9842ff1ac90..55ff419f191 100644 --- a/src/ai/backend/manager/actions/validators/rbac/scope.py +++ b/src/ai/backend/manager/actions/validators/rbac/scope.py @@ -1,12 +1,12 @@ from typing import override from ai.backend.common.contexts.user import current_user +from ai.backend.common.exception import UnreachableError from ai.backend.manager.actions.action import BaseActionTriggerMeta from ai.backend.manager.actions.action.scope import BaseScopeAction from ai.backend.manager.actions.validator.scope import ScopeActionValidator from ai.backend.manager.data.permission.role import ScopeChainPermissionCheckInput from ai.backend.manager.errors.permission import NotEnoughPermission -from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.permission_controller.repository import ( PermissionControllerRepository, ) @@ -23,7 +23,7 @@ def __init__( async def validate(self, action: BaseScopeAction, meta: BaseActionTriggerMeta) -> None: user = current_user() if user is None: - raise UserNotFound("User not found in context") + raise UnreachableError("User context is not available") if user.is_superadmin: return diff --git a/src/ai/backend/manager/actions/validators/rbac/single_entity.py b/src/ai/backend/manager/actions/validators/rbac/single_entity.py index b2ad80e373e..7fae8e65d3d 100644 --- a/src/ai/backend/manager/actions/validators/rbac/single_entity.py +++ b/src/ai/backend/manager/actions/validators/rbac/single_entity.py @@ -1,12 +1,12 @@ from typing import override from ai.backend.common.contexts.user import current_user +from ai.backend.common.exception import UnreachableError from ai.backend.manager.actions.action import BaseActionTriggerMeta from ai.backend.manager.actions.action.single_entity import BaseSingleEntityAction from ai.backend.manager.actions.validator.single_entity import SingleEntityActionValidator from ai.backend.manager.data.permission.role import ScopeChainPermissionCheckInput from ai.backend.manager.errors.permission import NotEnoughPermission -from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.permission_controller.repository import ( PermissionControllerRepository, ) @@ -23,7 +23,7 @@ def __init__( async def validate(self, action: BaseSingleEntityAction, meta: BaseActionTriggerMeta) -> None: user = current_user() if user is None: - raise UserNotFound("User not found in context") + raise UnreachableError("User context is not available") if user.is_superadmin: return diff --git a/tests/unit/manager/actions/validators/test_rbac_validators.py b/tests/unit/manager/actions/validators/test_rbac_validators.py index 12a52787a3d..bae9a9586ac 100644 --- a/tests/unit/manager/actions/validators/test_rbac_validators.py +++ b/tests/unit/manager/actions/validators/test_rbac_validators.py @@ -13,6 +13,7 @@ import uuid from collections.abc import AsyncIterator +from dataclasses import dataclass from datetime import UTC, datetime from typing import override @@ -26,11 +27,15 @@ ScopeType, ) from ai.backend.common.data.user.types import UserData, UserRole +from ai.backend.common.exception import UnreachableError from ai.backend.manager.actions.action.base import BaseActionTriggerMeta +from ai.backend.manager.actions.action.bulk import BaseBulkAction from ai.backend.manager.actions.action.scope import BaseScopeAction from ai.backend.manager.actions.action.single_entity import BaseSingleEntityAction from ai.backend.manager.actions.action.types import FieldData from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.actions.validator.bulk import DeniedEntity +from ai.backend.manager.actions.validators.rbac.bulk import BulkActionRBACValidator from ai.backend.manager.actions.validators.rbac.legacy import ( LegacyScopeActionRBACValidator, LegacySingleEntityActionRBACValidator, @@ -42,7 +47,6 @@ from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.user.types import UserStatus from ai.backend.manager.errors.permission import NotEnoughPermission -from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.domain import DomainRow from ai.backend.manager.models.keypair import KeyPairRow from ai.backend.manager.models.rbac_models import UserRoleRow @@ -65,6 +69,8 @@ _TARGET_DOMAIN = "default" _TARGET_VFOLDER = "vf-1" +_BULK_VFOLDER_GRANTED = "bulk-vf-granted" +_BULK_VFOLDER_DENIED = "bulk-vf-denied" class _ProjectCreateAction(BaseScopeAction): @@ -125,6 +131,25 @@ def field_data(self) -> FieldData | None: return None +@dataclass +class _BulkVfolderUpdateAction(BaseBulkAction[str]): + """VFOLDER:UPDATE on multiple vfolders — exercises the bulk validator path.""" + + @override + def typed_entity_ids(self) -> list[str]: + return list(self.entity_ids) + + @classmethod + @override + def entity_type(cls) -> EntityType: + return EntityType.VFOLDER + + @classmethod + @override + def operation_type(cls) -> ActionOperationType: + return ActionOperationType.UPDATE + + def _make_user_data(user_id: uuid.UUID, *, is_superadmin: bool) -> UserData: return UserData( user_id=user_id, @@ -298,6 +323,36 @@ async def regular_user_with_vfolder_update( return _make_user_data(user_id, is_superadmin=False) +@pytest.fixture +def bulk_vfolder_action() -> _BulkVfolderUpdateAction: + return _BulkVfolderUpdateAction( + entity_ids=[_BULK_VFOLDER_GRANTED, _BULK_VFOLDER_DENIED], + ) + + +@pytest.fixture +async def regular_user_with_partial_bulk_vfolder_update( + db_with_rbac_tables: ExtendedAsyncSAEngine, +) -> UserData: + """User granted VFOLDER:UPDATE only on ``_BULK_VFOLDER_GRANTED``. + + Self-scope permission lets the bulk validator return a partial + success — the granted vfolder is allowed, the other denied. + """ + user_id = uuid.uuid4() + role_id = uuid.uuid4() + await _seed_user_with_role(db_with_rbac_tables, user_id=user_id, role_id=role_id) + await _grant_permission( + db_with_rbac_tables, + role_id=role_id, + scope_type=ScopeType.VFOLDER, + scope_id=_BULK_VFOLDER_GRANTED, + entity_type=EntityType.VFOLDER, + operation=OperationType.UPDATE, + ) + return _make_user_data(user_id, is_superadmin=False) + + class TestScopeActionRBACValidator: async def test_superadmin_bypasses_check( self, @@ -318,7 +373,7 @@ async def test_missing_user_raises( trigger_meta: BaseActionTriggerMeta, ) -> None: validator = ScopeActionRBACValidator(repository) - with pytest.raises(UserNotFound): + with pytest.raises(UnreachableError): await validator.validate(scope_action, trigger_meta) async def test_non_superadmin_with_permission_passes( @@ -364,7 +419,7 @@ async def test_missing_user_raises( trigger_meta: BaseActionTriggerMeta, ) -> None: validator = SingleEntityActionRBACValidator(repository) - with pytest.raises(UserNotFound): + with pytest.raises(UnreachableError): await validator.validate(single_entity_action, trigger_meta) async def test_non_superadmin_with_permission_passes( @@ -410,7 +465,7 @@ async def test_missing_user_raises( trigger_meta: BaseActionTriggerMeta, ) -> None: validator = LegacySingleEntityActionRBACValidator(repository) - with pytest.raises(UserNotFound): + with pytest.raises(UnreachableError): await validator.validate(single_entity_action, trigger_meta) async def test_non_superadmin_with_permission_passes( @@ -455,7 +510,7 @@ async def test_missing_user_raises( trigger_meta: BaseActionTriggerMeta, ) -> None: validator = LegacyScopeActionRBACValidator(repository) - with pytest.raises(UserNotFound): + with pytest.raises(UnreachableError): await validator.validate(scope_action, trigger_meta) async def test_non_superadmin_with_permission_passes( @@ -479,3 +534,79 @@ async def test_non_superadmin_without_permission_does_not_raise( validator = LegacyScopeActionRBACValidator(repository) with with_user(regular_user_without_permission): await validator.validate(scope_action, trigger_meta) + + +class TestBulkActionRBACValidator: + async def test_superadmin_bypasses_check( + self, + repository: PermissionControllerRepository, + bulk_vfolder_action: _BulkVfolderUpdateAction, + trigger_meta: BaseActionTriggerMeta, + superadmin_user: UserData, + ) -> None: + # No permission rows seeded; bypass must approve every entity_id. + validator = BulkActionRBACValidator(repository) + with with_user(superadmin_user): + result = await validator.validate(bulk_vfolder_action, trigger_meta) + + assert result.allowed_entity_ids == [_BULK_VFOLDER_GRANTED, _BULK_VFOLDER_DENIED] + assert result.denied_entities == [] + + async def test_missing_user_raises( + self, + repository: PermissionControllerRepository, + bulk_vfolder_action: _BulkVfolderUpdateAction, + trigger_meta: BaseActionTriggerMeta, + ) -> None: + validator = BulkActionRBACValidator(repository) + with pytest.raises(UnreachableError): + await validator.validate(bulk_vfolder_action, trigger_meta) + + async def test_partial_permission_splits_allowed_and_denied( + self, + repository: PermissionControllerRepository, + bulk_vfolder_action: _BulkVfolderUpdateAction, + trigger_meta: BaseActionTriggerMeta, + regular_user_with_partial_bulk_vfolder_update: UserData, + ) -> None: + validator = BulkActionRBACValidator(repository) + with with_user(regular_user_with_partial_bulk_vfolder_update): + result = await validator.validate(bulk_vfolder_action, trigger_meta) + + assert result.allowed_entity_ids == [_BULK_VFOLDER_GRANTED] + assert result.denied_entities == [ + DeniedEntity(entity_id=_BULK_VFOLDER_DENIED, deny_reason="permission_denied"), + ] + + async def test_no_permission_denies_every_entity( + self, + repository: PermissionControllerRepository, + bulk_vfolder_action: _BulkVfolderUpdateAction, + trigger_meta: BaseActionTriggerMeta, + regular_user_without_permission: UserData, + ) -> None: + validator = BulkActionRBACValidator(repository) + with with_user(regular_user_without_permission): + result = await validator.validate(bulk_vfolder_action, trigger_meta) + + assert result.allowed_entity_ids == [] + assert result.denied_entities == [ + DeniedEntity(entity_id=_BULK_VFOLDER_GRANTED, deny_reason="permission_denied"), + DeniedEntity(entity_id=_BULK_VFOLDER_DENIED, deny_reason="permission_denied"), + ] + + async def test_empty_entity_ids_returns_empty_result( + self, + repository: PermissionControllerRepository, + trigger_meta: BaseActionTriggerMeta, + regular_user_without_permission: UserData, + ) -> None: + validator = BulkActionRBACValidator(repository) + with with_user(regular_user_without_permission): + result = await validator.validate( + _BulkVfolderUpdateAction(entity_ids=[]), + trigger_meta, + ) + + assert result.allowed_entity_ids == [] + assert result.denied_entities == []