diff --git a/changes/11191.feature.md b/changes/11191.feature.md new file mode 100644 index 00000000000..1f69be05e1a --- /dev/null +++ b/changes/11191.feature.md @@ -0,0 +1 @@ +Add bulk RBAC filtering infrastructure so `BulkActionProcessor` can narrow actions per-entity and report per-validator decisions. diff --git a/src/ai/backend/manager/actions/action/__init__.py b/src/ai/backend/manager/actions/action/__init__.py index dca1571092f..3f2628ecabf 100644 --- a/src/ai/backend/manager/actions/action/__init__.py +++ b/src/ai/backend/manager/actions/action/__init__.py @@ -8,9 +8,9 @@ TAction, TActionResult, ) -from .batch import ( - BaseBatchAction, - BaseBatchActionResult, +from .bulk import ( + BaseBulkAction, + BaseBulkActionResult, ) from .rbac import ( BaseRBACAction, @@ -123,8 +123,8 @@ "BaseActionResult", "BaseActionResultMeta", "BaseActionTriggerMeta", - "BaseBatchAction", - "BaseBatchActionResult", + "BaseBulkAction", + "BaseBulkActionResult", "BaseRBACAction", "RBACActionName", "RBACRequiredPermission", diff --git a/src/ai/backend/manager/actions/action/batch.py b/src/ai/backend/manager/actions/action/batch.py deleted file mode 100644 index 33068438578..00000000000 --- a/src/ai/backend/manager/actions/action/batch.py +++ /dev/null @@ -1,38 +0,0 @@ -from abc import abstractmethod -from typing import TypeVar, override - -from .base import BaseAction, BaseActionResult -from .types import BatchFieldData - - -class BaseBatchAction(BaseAction): - @override - def entity_id(self) -> str | None: - return None - - @abstractmethod - def entity_ids(self) -> list[str]: - raise NotImplementedError - - @abstractmethod - def field_data(self) -> BatchFieldData | None: - """ - Returns batch field data containing the field type and IDs when the - action's targets exist as fields of another entity. - Returns None if these entities are not fields. - """ - raise NotImplementedError - - -class BaseBatchActionResult(BaseActionResult): - @override - def entity_id(self) -> str | None: - return None - - @abstractmethod - def entity_ids(self) -> list[str]: - raise NotImplementedError - - -TBatchAction = TypeVar("TBatchAction", bound=BaseBatchAction) -TBatchActionResult = TypeVar("TBatchActionResult", bound=BaseBatchActionResult) diff --git a/src/ai/backend/manager/actions/action/bulk.py b/src/ai/backend/manager/actions/action/bulk.py new file mode 100644 index 00000000000..71f6364b48f --- /dev/null +++ b/src/ai/backend/manager/actions/action/bulk.py @@ -0,0 +1,44 @@ +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, TypeVar, override + +from .base import BaseAction, BaseActionResult + + +@dataclass +class BaseBulkAction[T](BaseAction): + """Base class for actions operating on a bulk of entities. + + ``entity_ids`` is stored as ``list[str]`` so ``BulkActionValidator`` + implementations can match against validator verdicts directly. The + original ``T``-typed view is exposed via ``typed_entity_ids()``. + + Bulk actions intentionally carry **only** ``entity_ids``. User context + (user id, role) flows through ``current_user()``, not the action, so + ``BulkActionProcessor`` can reconstruct a filtered action by calling + ``type(action)(entity_ids=...)`` directly — no ``__init__`` override or + factory hook is required. Subclasses that try to add required fields + break that constructor call and will fail fast at runtime, which is + intentional. + """ + + entity_ids: list[str] + + @abstractmethod + def typed_entity_ids(self) -> list[T]: + """Return ``entity_ids`` converted back to the native ID type ``T``.""" + raise NotImplementedError + + +class BaseBulkActionResult(BaseActionResult): + @override + def entity_id(self) -> str | None: + return None + + @abstractmethod + def entity_ids(self) -> list[str]: + raise NotImplementedError + + +TBulkAction = TypeVar("TBulkAction", bound=BaseBulkAction[Any]) +TBulkActionResult = TypeVar("TBulkActionResult", bound=BaseBulkActionResult) diff --git a/src/ai/backend/manager/actions/processor/batch.py b/src/ai/backend/manager/actions/processor/batch.py deleted file mode 100644 index cf27dd6c76e..00000000000 --- a/src/ai/backend/manager/actions/processor/batch.py +++ /dev/null @@ -1,50 +0,0 @@ -import logging -import uuid -from collections.abc import Awaitable, Callable, Sequence -from datetime import UTC, datetime - -from ai.backend.logging.utils import BraceStyleAdapter -from ai.backend.manager.actions.action import ( - BaseActionTriggerMeta, -) -from ai.backend.manager.actions.action.batch import ( - BaseBatchAction, - BaseBatchActionResult, -) -from ai.backend.manager.actions.monitors.monitor import ActionMonitor -from ai.backend.manager.actions.validator.batch import BatchActionValidator - -from .base import ActionRunner - -log = BraceStyleAdapter(logging.getLogger(__spec__.name)) - - -class BatchActionProcessor[ - TBatchAction: BaseBatchAction, - TBatchActionResult: BaseBatchActionResult, -]: - _validators: Sequence[BatchActionValidator] - - _runner: ActionRunner[TBatchAction, TBatchActionResult] - - def __init__( - self, - func: Callable[[TBatchAction], Awaitable[TBatchActionResult]], - monitors: Sequence[ActionMonitor] | None = None, - validators: Sequence[BatchActionValidator] | None = None, - ) -> None: - self._runner = ActionRunner(func, monitors) - - self._validators = validators or [] - - async def _run(self, action: TBatchAction) -> TBatchActionResult: - started_at = datetime.now(UTC) - action_id = uuid.uuid4() - action_trigger_meta = BaseActionTriggerMeta(action_id=action_id, started_at=started_at) - for validator in self._validators: - await validator.validate(action, action_trigger_meta) - - return await self._runner.run(action, action_trigger_meta) - - async def wait_for_complete(self, action: TBatchAction) -> TBatchActionResult: - return await self._run(action) diff --git a/src/ai/backend/manager/actions/processor/bulk.py b/src/ai/backend/manager/actions/processor/bulk.py new file mode 100644 index 00000000000..7dd6fd09c70 --- /dev/null +++ b/src/ai/backend/manager/actions/processor/bulk.py @@ -0,0 +1,112 @@ +import logging +import uuid +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any + +from ai.backend.logging.utils import BraceStyleAdapter +from ai.backend.manager.actions.action import ( + BaseActionTriggerMeta, +) +from ai.backend.manager.actions.action.bulk import ( + BaseBulkAction, + BaseBulkActionResult, +) +from ai.backend.manager.actions.monitors.monitor import ActionMonitor +from ai.backend.manager.actions.validator.bulk import ( + BulkActionValidator, + BulkValidationResult, +) + +from .base import ActionRunner + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +@dataclass(frozen=True) +class ValidatorDecision: + """One validator's per-entity verdict observed during bulk processing. + + Mirrors the ``SubStepResult`` pattern used by the scheduler history so + callers can trace where in the validator chain each ID was filtered and + *why*. ``results`` carries the validator's classification unchanged. + """ + + validator_name: str + results: BulkValidationResult + + +@dataclass(frozen=True) +class BulkProcessResult[TBulkActionResult: BaseBulkActionResult]: + """Outcome of a ``BulkActionProcessor`` run. + + ``result`` is what the service function returned for the permitted subset + of entity IDs. ``validator_decisions`` keeps the per-validator trace in + iteration order; callers assemble the partial-success response by + walking it (each decision carries the denied IDs and their reasons). + """ + + result: TBulkActionResult + validator_decisions: list[ValidatorDecision] + + +class BulkActionProcessor[ + TBulkAction: BaseBulkAction[Any], + TBulkActionResult: BaseBulkActionResult, +]: + _validators: Sequence[BulkActionValidator] + + _runner: ActionRunner[TBulkAction, TBulkActionResult] + + def __init__( + self, + func: Callable[[TBulkAction], Awaitable[TBulkActionResult]], + monitors: Sequence[ActionMonitor] | None = None, + validators: Sequence[BulkActionValidator] | None = None, + ) -> None: + self._runner = ActionRunner(func, monitors) + + self._validators = validators or [] + + def _filter_by_validation( + self, + action: TBulkAction, + validation: BulkValidationResult, + ) -> TBulkAction: + """Return a new action narrowed to the IDs this validator permitted. + + Returns the incoming action unchanged when the validator denied + nothing; otherwise constructs a fresh instance of the same class + via its ``entity_ids``-only constructor so the original stays + immutable. + """ + if not validation.denied_entities: + return action + allowed_set = set(validation.allowed_entity_ids) + filtered_ids = [eid for eid in action.entity_ids if eid in allowed_set] + return type(action)(entity_ids=filtered_ids) + + async def _run(self, action: TBulkAction) -> BulkProcessResult[TBulkActionResult]: + started_at = datetime.now(UTC) + action_id = uuid.uuid4() + action_trigger_meta = BaseActionTriggerMeta(action_id=action_id, started_at=started_at) + + filtered_action: TBulkAction = action + decisions: list[ValidatorDecision] = [] + + for validator in self._validators: + validation = await validator.validate(filtered_action, action_trigger_meta) + decisions.append( + ValidatorDecision( + validator_name=validator.name(), + results=validation, + ) + ) + filtered_action = self._filter_by_validation(filtered_action, validation) + + action_result = await self._runner.run(filtered_action, action_trigger_meta) + return BulkProcessResult(result=action_result, validator_decisions=decisions) + + async def wait_for_complete(self, action: TBulkAction) -> BulkProcessResult[TBulkActionResult]: + return await self._run(action) diff --git a/src/ai/backend/manager/actions/validator/batch.py b/src/ai/backend/manager/actions/validator/batch.py deleted file mode 100644 index 2fd7648d130..00000000000 --- a/src/ai/backend/manager/actions/validator/batch.py +++ /dev/null @@ -1,10 +0,0 @@ -from abc import ABC, abstractmethod - -from ai.backend.manager.actions.action import BaseActionTriggerMeta -from ai.backend.manager.actions.action.batch import BaseBatchAction - - -class BatchActionValidator(ABC): - @abstractmethod - async def validate(self, action: BaseBatchAction, meta: BaseActionTriggerMeta) -> None: - raise NotImplementedError("Subclasses must implement the validate method") diff --git a/src/ai/backend/manager/actions/validator/bulk.py b/src/ai/backend/manager/actions/validator/bulk.py new file mode 100644 index 00000000000..bcf1f6afdcd --- /dev/null +++ b/src/ai/backend/manager/actions/validator/bulk.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +from ai.backend.manager.actions.action import BaseActionTriggerMeta +from ai.backend.manager.actions.action.bulk import BaseBulkAction + + +@dataclass(frozen=True) +class DeniedEntity: + """A bulk entity that a validator rejected, paired with its reason.""" + + entity_id: str + deny_reason: str + + +@dataclass(frozen=True) +class BulkValidationResult: + """Per-entity validation outcome for a bulk action. + + ``BulkActionProcessor`` intersects ``allowed_entity_ids`` across + validators and records each ``DeniedEntity`` — with its reason — on the + corresponding ``ValidatorDecision`` so the final response can + surface *why* each ID was filtered out. + """ + + allowed_entity_ids: list[str] + denied_entities: list[DeniedEntity] + + +class BulkActionValidator(ABC): + @classmethod + @abstractmethod + def name(cls) -> str: + """Stable identifier used in ``ValidatorDecision.validator_name``. + + Chosen by the implementation so logs and partial-success responses can + attribute denials to a specific validator independently of the Python + class name. + """ + raise NotImplementedError + + @abstractmethod + async def validate( + self, action: BaseBulkAction[Any], meta: BaseActionTriggerMeta + ) -> BulkValidationResult: + """Validate the bulk action and return per-entity permission results. + + Implementations must classify every ID in ``action.entity_ids`` as + either allowed or denied. Validators that cannot make a decision for + an ID should treat it as allowed. + + The processor wraps each call in its own async context manager so + cross-cutting concerns (timing, audit) live in one place — validators + do not need to own them. + """ + raise NotImplementedError diff --git a/src/ai/backend/manager/actions/validators/rbac/batch.py b/src/ai/backend/manager/actions/validators/rbac/batch.py deleted file mode 100644 index 1a53dd55232..00000000000 --- a/src/ai/backend/manager/actions/validators/rbac/batch.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import override - -from ai.backend.manager.actions.action import BaseActionTriggerMeta -from ai.backend.manager.actions.action.batch import BaseBatchAction -from ai.backend.manager.actions.validator.batch import BatchActionValidator -from ai.backend.manager.repositories.permission_controller.repository import ( - PermissionControllerRepository, -) - - -class BatchActionRBACValidator(BatchActionValidator): - def __init__( - self, - repository: PermissionControllerRepository, - ) -> None: - self._repository = repository - - @override - async def validate(self, action: BaseBatchAction, meta: BaseActionTriggerMeta) -> None: - # TODO: implement RBAC validation logic - pass diff --git a/src/ai/backend/manager/actions/validators/rbac/bulk.py b/src/ai/backend/manager/actions/validators/rbac/bulk.py new file mode 100644 index 00000000000..e17e11804e4 --- /dev/null +++ b/src/ai/backend/manager/actions/validators/rbac/bulk.py @@ -0,0 +1,35 @@ +from typing import Any, override + +from ai.backend.manager.actions.action import BaseActionTriggerMeta +from ai.backend.manager.actions.action.bulk import BaseBulkAction +from ai.backend.manager.actions.validator.bulk import ( + BulkActionValidator, + BulkValidationResult, +) +from ai.backend.manager.repositories.permission_controller.repository import ( + PermissionControllerRepository, +) + + +class BulkActionRBACValidator(BulkActionValidator): + def __init__( + self, + repository: PermissionControllerRepository, + ) -> None: + self._repository = repository + + @classmethod + @override + def name(cls) -> str: + return "rbac" + + @override + async def validate( + self, action: BaseBulkAction[Any], meta: BaseActionTriggerMeta + ) -> BulkValidationResult: + # TODO: wire this to PermissionControllerRepository.check_bulk_permission_with_scope_chain(). + # Until then, every entity is treated as allowed so legacy behavior is preserved. + return BulkValidationResult( + allowed_entity_ids=list(action.entity_ids), + denied_entities=[], + ) diff --git a/src/ai/backend/manager/services/artifact/actions/base.py b/src/ai/backend/manager/services/artifact/actions/base.py index bb9cb6f9375..2c61901ff4e 100644 --- a/src/ai/backend/manager/services/artifact/actions/base.py +++ b/src/ai/backend/manager/services/artifact/actions/base.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult from ai.backend.manager.actions.action.single_entity import ( BaseSingleEntityAction, @@ -19,14 +19,6 @@ def entity_type(cls) -> EntityType: return EntityType.ARTIFACT -@dataclass -class ArtifactBatchAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.ARTIFACT - - @dataclass class ArtifactScopeAction(BaseScopeAction): @override diff --git a/src/ai/backend/manager/services/artifact_registry/actions/base.py b/src/ai/backend/manager/services/artifact_registry/actions/base.py index fbe9b450946..45b14598956 100644 --- a/src/ai/backend/manager/services/artifact_registry/actions/base.py +++ b/src/ai/backend/manager/services/artifact_registry/actions/base.py @@ -1,7 +1,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult from ai.backend.manager.actions.action.single_entity import ( BaseSingleEntityAction, @@ -17,13 +17,6 @@ def entity_type(cls) -> EntityType: return EntityType.ARTIFACT_REGISTRY -class ArtifactBatchRegistryAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.ARTIFACT_REGISTRY - - class ArtifactRegistryScopeAction(BaseScopeAction): @override @classmethod diff --git a/src/ai/backend/manager/services/artifact_revision/actions/base.py b/src/ai/backend/manager/services/artifact_revision/actions/base.py index 1190de0677b..3d550ebeaee 100644 --- a/src/ai/backend/manager/services/artifact_revision/actions/base.py +++ b/src/ai/backend/manager/services/artifact_revision/actions/base.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction @dataclass @@ -11,11 +11,3 @@ class ArtifactRevisionAction(BaseAction): @classmethod def entity_type(cls) -> EntityType: return EntityType.ARTIFACT_REVISION - - -@dataclass -class ArtifactRevisionBatchAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.ARTIFACT_REVISION diff --git a/src/ai/backend/manager/services/container_registry/actions/base.py b/src/ai/backend/manager/services/container_registry/actions/base.py index 439c15246fa..85802e9c063 100644 --- a/src/ai/backend/manager/services/container_registry/actions/base.py +++ b/src/ai/backend/manager/services/container_registry/actions/base.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction @dataclass @@ -11,11 +11,3 @@ class ContainerRegistryAction(BaseAction): @classmethod def entity_type(cls) -> EntityType: return EntityType.CONTAINER_REGISTRY - - -@dataclass -class ContainerRegistryBatchAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.CONTAINER_REGISTRY diff --git a/src/ai/backend/manager/services/image/actions/base.py b/src/ai/backend/manager/services/image/actions/base.py index b2cd8966667..650e920798a 100644 --- a/src/ai/backend/manager/services/image/actions/base.py +++ b/src/ai/backend/manager/services/image/actions/base.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction from ai.backend.manager.actions.action.single_entity import ( BaseSingleEntityAction, BaseSingleEntityActionResult, @@ -18,14 +18,6 @@ def entity_type(cls) -> EntityType: return EntityType.IMAGE -@dataclass -class ImageBatchAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.IMAGE - - @dataclass class ImageSingleEntityAction(BaseSingleEntityAction): @override diff --git a/src/ai/backend/manager/services/keypair_resource_policy/actions/base.py b/src/ai/backend/manager/services/keypair_resource_policy/actions/base.py index cee39341ed9..51cf586003d 100644 --- a/src/ai/backend/manager/services/keypair_resource_policy/actions/base.py +++ b/src/ai/backend/manager/services/keypair_resource_policy/actions/base.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction @dataclass @@ -11,11 +11,3 @@ class KeypairResourcePolicyAction(BaseAction): @classmethod def entity_type(cls) -> EntityType: return EntityType.KEYPAIR_RESOURCE_POLICY - - -@dataclass -class KeypairResourcePolicyBatchAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.KEYPAIR_RESOURCE_POLICY diff --git a/src/ai/backend/manager/services/project_resource_policy/actions/base.py b/src/ai/backend/manager/services/project_resource_policy/actions/base.py index 66b2f4dcca9..2b194457dd2 100644 --- a/src/ai/backend/manager/services/project_resource_policy/actions/base.py +++ b/src/ai/backend/manager/services/project_resource_policy/actions/base.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction @dataclass @@ -11,11 +11,3 @@ class ProjectResourcePolicyAction(BaseAction): @classmethod def entity_type(cls) -> EntityType: return EntityType.PROJECT_RESOURCE_POLICY - - -@dataclass -class ProjectResourcePolicyBatchAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.PROJECT_RESOURCE_POLICY diff --git a/src/ai/backend/manager/services/resource_preset/actions/base.py b/src/ai/backend/manager/services/resource_preset/actions/base.py index e0b6364ed0e..815a7a4c3b8 100644 --- a/src/ai/backend/manager/services/resource_preset/actions/base.py +++ b/src/ai/backend/manager/services/resource_preset/actions/base.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction @dataclass @@ -11,11 +11,3 @@ class ResourcePresetAction(BaseAction): @classmethod def entity_type(cls) -> EntityType: return EntityType.RESOURCE_PRESET - - -@dataclass -class ResourcePresetBatchAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.RESOURCE_PRESET diff --git a/src/ai/backend/manager/services/session/actions/check_and_transit_status.py b/src/ai/backend/manager/services/session/actions/check_and_transit_status.py index dcfc156ad28..38e1cad87ad 100644 --- a/src/ai/backend/manager/services/session/actions/check_and_transit_status.py +++ b/src/ai/backend/manager/services/session/actions/check_and_transit_status.py @@ -5,10 +5,10 @@ from typing import TYPE_CHECKING, override from ai.backend.common.types import SessionId -from ai.backend.manager.actions.action import BaseActionResult, BaseBatchActionResult +from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.session.types import SessionData -from ai.backend.manager.services.session.base import SessionAction, SessionBatchAction +from ai.backend.manager.services.session.base import SessionAction if TYPE_CHECKING: from ai.backend.manager.models.user import UserRole @@ -39,30 +39,3 @@ class CheckAndTransitStatusActionResult(BaseActionResult): @override def entity_id(self) -> str | None: return str(self.session_data.id) - - -# TODO: Change this to BatchAction -@dataclass -class CheckAndTransitStatusBatchAction(SessionBatchAction): - user_id: uuid.UUID - user_role: UserRole - session_ids: list[SessionId] - - @override - def entity_ids(self) -> list[str]: - return [str(session_id) for session_id in self.session_ids] - - @override - @classmethod - def operation_type(cls) -> ActionOperationType: - return ActionOperationType.UPDATE - - -@dataclass -class CheckAndTransitStatusBatchActionResult(BaseBatchActionResult): - # TODO: Add proper type - session_status_map: dict[SessionId, str] - - @override - def entity_ids(self) -> list[str]: - return [str(session_id) for session_id in self.session_status_map.keys()] diff --git a/src/ai/backend/manager/services/session/base.py b/src/ai/backend/manager/services/session/base.py index 7062cba4380..732b07da03c 100644 --- a/src/ai/backend/manager/services/session/base.py +++ b/src/ai/backend/manager/services/session/base.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult from ai.backend.manager.actions.action.single_entity import ( BaseSingleEntityAction, @@ -19,14 +19,6 @@ def entity_type(cls) -> EntityType: return EntityType.SESSION -@dataclass -class SessionBatchAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.SESSION - - @dataclass class SessionScopeAction(BaseScopeAction): """Base class for session actions that operate within a scope. diff --git a/src/ai/backend/manager/services/session/service.py b/src/ai/backend/manager/services/session/service.py index 5d03412d374..27e0e6fdb8c 100644 --- a/src/ai/backend/manager/services/session/service.py +++ b/src/ai/backend/manager/services/session/service.py @@ -80,8 +80,6 @@ from ai.backend.manager.services.session.actions.check_and_transit_status import ( CheckAndTransitStatusAction, CheckAndTransitStatusActionResult, - CheckAndTransitStatusBatchAction, - CheckAndTransitStatusBatchActionResult, ) from ai.backend.manager.services.session.actions.commit_session import ( CommitSessionAction, @@ -1531,45 +1529,6 @@ async def check_and_transit_status( result=result, session_data=session_row.to_dataclass(owner=session_owner_data) ) - async def check_and_transit_status_multi( - self, action: CheckAndTransitStatusBatchAction - ) -> CheckAndTransitStatusBatchActionResult: - user_id = action.user_id - user_role = action.user_role - session_ids = action.session_ids - accessible_session_ids: list[SessionId] = [] - - for sid in session_ids: - if user_role in (UserRole.ADMIN, UserRole.SUPERADMIN): - accessible_session_ids.append(sid) - else: - try: - session_row = await self._session_repository.get_session_to_determine_status( - sid - ) - if session_row.user_uuid == user_id: - accessible_session_ids.append(sid) - else: - log.warning( - f"You are not allowed to transit others's sessions status, skip (s:{sid})" - ) - except Exception: - log.warning(f"Session not found or access denied, skip (s:{sid})") - - now = datetime.now(tzutc()) - if accessible_session_ids: - session_rows = await self._agent_registry.session_lifecycle_mgr.transit_session_status( - accessible_session_ids, now - ) - await self._agent_registry.session_lifecycle_mgr.deregister_status_updatable_session([ - row.id for row, is_transited in session_rows if is_transited - ]) - result = {row.id: row.status.name for row, _ in session_rows} - else: - result = {} - - return CheckAndTransitStatusBatchActionResult(session_status_map=result) - async def search(self, action: SearchSessionsAction) -> SearchSessionsActionResult: """Search sessions with querier pattern.""" result = await self._session_repository.search(action.querier) diff --git a/src/ai/backend/manager/services/user_resource_policy/actions/base.py b/src/ai/backend/manager/services/user_resource_policy/actions/base.py index 155525f70c7..39738cdef60 100644 --- a/src/ai/backend/manager/services/user_resource_policy/actions/base.py +++ b/src/ai/backend/manager/services/user_resource_policy/actions/base.py @@ -2,7 +2,7 @@ from typing import override from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action import BaseAction @dataclass @@ -11,11 +11,3 @@ class UserResourcePolicyAction(BaseAction): @classmethod def entity_type(cls) -> EntityType: return EntityType.USER_RESOURCE_POLICY - - -@dataclass -class UserResourcePolicyBatchAction(BaseBatchAction): - @override - @classmethod - def entity_type(cls) -> EntityType: - return EntityType.USER_RESOURCE_POLICY diff --git a/tests/unit/manager/actions/test_bulk_processor.py b/tests/unit/manager/actions/test_bulk_processor.py new file mode 100644 index 00000000000..ac462de13a4 --- /dev/null +++ b/tests/unit/manager/actions/test_bulk_processor.py @@ -0,0 +1,241 @@ +"""Tests for ``BulkActionProcessor`` filtering infrastructure (BA-5777). + +Verifies that the processor narrows ``entity_ids`` exactly to what each +validator allowed — no more, no less — and that later validators only +see IDs that survived earlier ones. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any, override + +import pytest + +from ai.backend.common.data.permission.types import EntityType +from ai.backend.manager.actions.action import BaseActionTriggerMeta +from ai.backend.manager.actions.action.bulk import BaseBulkAction, BaseBulkActionResult +from ai.backend.manager.actions.processor.bulk import ( + BulkActionProcessor, +) +from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.actions.validator.bulk import ( + BulkActionValidator, + BulkValidationResult, + DeniedEntity, +) + + +@dataclass +class _MockBulkAction(BaseBulkAction[str]): + @override + def typed_entity_ids(self) -> list[str]: + return list(self.entity_ids) + + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.SESSION + + @override + @classmethod + def operation_type(cls) -> ActionOperationType: + return ActionOperationType.UPDATE + + +@dataclass +class _MockBulkActionResult(BaseBulkActionResult): + processed_ids: list[str] = field(default_factory=list) + + @override + def entity_ids(self) -> list[str]: + return list(self.processed_ids) + + +class _AllowSetValidator(BulkActionValidator): + """Approves any ID in ``allowed``; anything else visible is denied.""" + + def __init__(self, allowed: set[str]) -> None: + self._allowed = set(allowed) + + @classmethod + @override + def name(cls) -> str: + return "allow-set" + + @override + async def validate( + self, action: BaseBulkAction[Any], meta: BaseActionTriggerMeta + ) -> BulkValidationResult: + current = list(action.entity_ids) + allowed = [eid for eid in current if eid in self._allowed] + denied = [ + DeniedEntity(entity_id=eid, deny_reason="not in allow-set") + for eid in current + if eid not in self._allowed + ] + return BulkValidationResult(allowed_entity_ids=allowed, denied_entities=denied) + + +class _RecordingValidator(BulkActionValidator): + """Captures the entity IDs each ``validate()`` call received.""" + + def __init__(self, allowed: set[str]) -> None: + self._allowed = set(allowed) + self.observed_batches: list[list[str]] = [] + + @classmethod + @override + def name(cls) -> str: + return "recording" + + @override + async def validate( + self, action: BaseBulkAction[Any], meta: BaseActionTriggerMeta + ) -> BulkValidationResult: + current = list(action.entity_ids) + self.observed_batches.append(current) + allowed = [eid for eid in current if eid in self._allowed] + denied = [ + DeniedEntity(entity_id=eid, deny_reason="blocked") + for eid in current + if eid not in self._allowed + ] + return BulkValidationResult(allowed_entity_ids=allowed, denied_entities=denied) + + +def _echo_func() -> Callable[[_MockBulkAction], Awaitable[_MockBulkActionResult]]: + async def _run(action: _MockBulkAction) -> _MockBulkActionResult: + return _MockBulkActionResult(processed_ids=list(action.entity_ids)) + + return _run + + +class TestBulkActionProcessorFiltering: + async def test_no_validators_passes_all_ids_through(self) -> None: + processor = BulkActionProcessor[_MockBulkAction, _MockBulkActionResult]( + func=_echo_func(), + ) + action = _MockBulkAction(entity_ids=["a", "b", "c"]) + + outcome = await processor.wait_for_complete(action) + + assert outcome.result.processed_ids == ["a", "b", "c"] + assert outcome.validator_decisions == [] + + async def test_validator_denies_subset_reports_denied_ids(self) -> None: + processor = BulkActionProcessor[_MockBulkAction, _MockBulkActionResult]( + func=_echo_func(), + validators=[_AllowSetValidator(allowed={"a", "c"})], + ) + action = _MockBulkAction(entity_ids=["a", "b", "c"]) + + outcome = await processor.wait_for_complete(action) + + assert outcome.result.processed_ids == ["a", "c"] + assert len(outcome.validator_decisions) == 1 + decision = outcome.validator_decisions[0] + assert decision.validator_name == "allow-set" + assert decision.results.allowed_entity_ids == ["a", "c"] + assert decision.results.denied_entities == [ + DeniedEntity(entity_id="b", deny_reason="not in allow-set"), + ] + + async def test_validator_denies_all_still_runs_service_with_empty_batch(self) -> None: + processor = BulkActionProcessor[_MockBulkAction, _MockBulkActionResult]( + func=_echo_func(), + validators=[_AllowSetValidator(allowed=set())], + ) + action = _MockBulkAction(entity_ids=["a", "b"]) + + outcome = await processor.wait_for_complete(action) + + assert outcome.result.processed_ids == [] + decision = outcome.validator_decisions[0] + assert decision.results.allowed_entity_ids == [] + assert [d.entity_id for d in decision.results.denied_entities] == ["a", "b"] + + async def test_later_validator_only_sees_surviving_ids(self) -> None: + first = _RecordingValidator(allowed={"a", "b"}) + second = _RecordingValidator(allowed={"a"}) + processor = BulkActionProcessor[_MockBulkAction, _MockBulkActionResult]( + func=_echo_func(), + validators=[first, second], + ) + action = _MockBulkAction(entity_ids=["a", "b", "c"]) + + outcome = await processor.wait_for_complete(action) + + # First validator sees the full batch; second only sees IDs that + # survived the first. + assert first.observed_batches == [["a", "b", "c"]] + assert second.observed_batches == [["a", "b"]] + + assert outcome.result.processed_ids == ["a"] + assert [ + ( + d.validator_name, + d.results.allowed_entity_ids, + [de.entity_id for de in d.results.denied_entities], + ) + for d in outcome.validator_decisions + ] == [ + ("recording", ["a", "b"], ["c"]), + ("recording", ["a"], ["b"]), + ] + + async def test_original_action_is_not_mutated(self) -> None: + processor = BulkActionProcessor[_MockBulkAction, _MockBulkActionResult]( + func=_echo_func(), + validators=[_AllowSetValidator(allowed={"a"})], + ) + original = _MockBulkAction(entity_ids=["a", "b"]) + + outcome = await processor.wait_for_complete(original) + + assert outcome.result.processed_ids == ["a"] + # The processor constructs a fresh action; it must not mutate the caller's. + assert original.entity_ids == ["a", "b"] + + async def test_pass_through_reuses_same_action_instance(self) -> None: + seen: list[_MockBulkAction] = [] + + async def _capture(action: _MockBulkAction) -> _MockBulkActionResult: + seen.append(action) + return _MockBulkActionResult(processed_ids=list(action.entity_ids)) + + processor = BulkActionProcessor[_MockBulkAction, _MockBulkActionResult]( + func=_capture, + validators=[_AllowSetValidator(allowed={"a", "b"})], + ) + original = _MockBulkAction(entity_ids=["a", "b"]) + + await processor.wait_for_complete(original) + + # No denials → no filtering copy was created. + assert seen[0] is original + + +@pytest.mark.parametrize( + ("allowed", "batch", "expected_processed"), + [ + ({"a", "b"}, ["a", "b"], ["a", "b"]), + ({"a"}, ["a", "b"], ["a"]), + (set(), ["a", "b"], []), + ], +) +async def test_single_validator_scenarios( + allowed: set[str], + batch: list[str], + expected_processed: list[str], +) -> None: + processor = BulkActionProcessor[_MockBulkAction, _MockBulkActionResult]( + func=_echo_func(), + validators=[_AllowSetValidator(allowed=allowed)], + ) + action = _MockBulkAction(entity_ids=batch) + + outcome = await processor.wait_for_complete(action) + + assert outcome.result.processed_ids == expected_processed diff --git a/tests/unit/manager/services/session/test_session_lifecycle_service.py b/tests/unit/manager/services/session/test_session_lifecycle_service.py index 2d4a1c70b1f..e1d048dce88 100644 --- a/tests/unit/manager/services/session/test_session_lifecycle_service.py +++ b/tests/unit/manager/services/session/test_session_lifecycle_service.py @@ -44,9 +44,6 @@ from ai.backend.manager.models.user import UserRole from ai.backend.manager.registry import AgentRegistry from ai.backend.manager.repositories.session.repository import SessionRepository -from ai.backend.manager.services.session.actions.check_and_transit_status import ( - CheckAndTransitStatusBatchAction, -) from ai.backend.manager.services.session.actions.commit_session import ( CommitSessionAction, ) @@ -1898,100 +1895,3 @@ async def test_session_not_found( with pytest.raises(SessionNotFound): await session_service.shutdown_service(action) - - -# ==================== CheckAndTransitStatusBatch Tests ==================== - - -class _TestCheckAndTransitStatusBatchAction(CheckAndTransitStatusBatchAction): - """Concrete subclass for testing (field_data is abstract in BaseBatchAction).""" - - def field_data(self) -> None: - return None - - -class TestCheckAndTransitStatusBatch: - async def test_admin_processes_all_sessions( - self, - session_service: SessionService, - mock_session_repository: MagicMock, - mock_agent_registry: MagicMock, - sample_user_id: UUID, - ) -> None: - sid1 = SessionId(uuid4()) - sid2 = SessionId(uuid4()) - - mock_row1 = MagicMock() - mock_row1.id = sid1 - mock_row1.status = SessionStatus.RUNNING - mock_row2 = MagicMock() - mock_row2.id = sid2 - mock_row2.status = SessionStatus.RUNNING - - mock_agent_registry.session_lifecycle_mgr.transit_session_status = AsyncMock( - return_value=[(mock_row1, True), (mock_row2, True)] - ) - - action = _TestCheckAndTransitStatusBatchAction( - user_id=sample_user_id, - user_role=UserRole.ADMIN, - session_ids=[sid1, sid2], - ) - result = await session_service.check_and_transit_status_multi(action) - - assert sid1 in result.session_status_map - assert sid2 in result.session_status_map - - async def test_user_role_only_processes_owned_sessions( - self, - session_service: SessionService, - mock_session_repository: MagicMock, - mock_agent_registry: MagicMock, - sample_user_id: UUID, - ) -> None: - owned_sid = SessionId(uuid4()) - other_sid = SessionId(uuid4()) - other_user_id = uuid4() - - owned_session_row = MagicMock() - owned_session_row.user_uuid = sample_user_id - other_session_row = MagicMock() - other_session_row.user_uuid = other_user_id - - mock_session_repository.get_session_to_determine_status = AsyncMock( - side_effect=lambda sid: owned_session_row if sid == owned_sid else other_session_row - ) - - mock_row = MagicMock() - mock_row.id = owned_sid - mock_row.status = SessionStatus.RUNNING - - mock_agent_registry.session_lifecycle_mgr.transit_session_status = AsyncMock( - return_value=[(mock_row, True)] - ) - - action = _TestCheckAndTransitStatusBatchAction( - user_id=sample_user_id, - user_role=UserRole.USER, - session_ids=[owned_sid, other_sid], - ) - result = await session_service.check_and_transit_status_multi(action) - - assert owned_sid in result.session_status_map - assert other_sid not in result.session_status_map - - async def test_empty_session_ids_returns_empty( - self, - session_service: SessionService, - mock_session_repository: MagicMock, - mock_agent_registry: MagicMock, - sample_user_id: UUID, - ) -> None: - action = _TestCheckAndTransitStatusBatchAction( - user_id=sample_user_id, - user_role=UserRole.USER, - session_ids=[], - ) - result = await session_service.check_and_transit_status_multi(action) - - assert result.session_status_map == {}