Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11191.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add bulk RBAC filtering infrastructure so `BulkActionProcessor` can narrow actions per-entity and report per-validator decisions.
10 changes: 5 additions & 5 deletions src/ai/backend/manager/actions/action/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
TAction,
TActionResult,
)
from .batch import (
BaseBatchAction,
BaseBatchActionResult,
from .bulk import (
BaseBulkAction,
BaseBulkActionResult,
)
from .rbac import (
BaseRBACAction,
Expand Down Expand Up @@ -123,8 +123,8 @@
"BaseActionResult",
"BaseActionResultMeta",
"BaseActionTriggerMeta",
"BaseBatchAction",
"BaseBatchActionResult",
"BaseBulkAction",
"BaseBulkActionResult",
"BaseRBACAction",
"RBACActionName",
"RBACRequiredPermission",
Expand Down
38 changes: 0 additions & 38 deletions src/ai/backend/manager/actions/action/batch.py

This file was deleted.

44 changes: 44 additions & 0 deletions src/ai/backend/manager/actions/action/bulk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, TypeVar, override

from .base import BaseAction, BaseActionResult


@dataclass
class BaseBulkAction[T](BaseAction):
"""Base class for actions operating on a bulk of entities.

``entity_ids`` is stored as ``list[str]`` so ``BulkActionValidator``
implementations can match against validator verdicts directly. The
original ``T``-typed view is exposed via ``typed_entity_ids()``.

Bulk actions intentionally carry **only** ``entity_ids``. User context
(user id, role) flows through ``current_user()``, not the action, so
``BulkActionProcessor`` can reconstruct a filtered action by calling
``type(action)(entity_ids=...)`` directly — no ``__init__`` override or
factory hook is required. Subclasses that try to add required fields
break that constructor call and will fail fast at runtime, which is
intentional.
"""

entity_ids: list[str]

@abstractmethod
def typed_entity_ids(self) -> list[T]:
"""Return ``entity_ids`` converted back to the native ID type ``T``."""
raise NotImplementedError


class BaseBulkActionResult(BaseActionResult):
@override
def entity_id(self) -> str | None:
return None

@abstractmethod
def entity_ids(self) -> list[str]:
raise NotImplementedError


TBulkAction = TypeVar("TBulkAction", bound=BaseBulkAction[Any])
TBulkActionResult = TypeVar("TBulkActionResult", bound=BaseBulkActionResult)
50 changes: 0 additions & 50 deletions src/ai/backend/manager/actions/processor/batch.py

This file was deleted.

112 changes: 112 additions & 0 deletions src/ai/backend/manager/actions/processor/bulk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import logging
import uuid
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any

from ai.backend.logging.utils import BraceStyleAdapter
from ai.backend.manager.actions.action import (
BaseActionTriggerMeta,
)
from ai.backend.manager.actions.action.bulk import (
BaseBulkAction,
BaseBulkActionResult,
)
from ai.backend.manager.actions.monitors.monitor import ActionMonitor
from ai.backend.manager.actions.validator.bulk import (
BulkActionValidator,
BulkValidationResult,
)

from .base import ActionRunner

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


@dataclass(frozen=True)
class ValidatorDecision:
"""One validator's per-entity verdict observed during bulk processing.

Mirrors the ``SubStepResult`` pattern used by the scheduler history so
callers can trace where in the validator chain each ID was filtered and
*why*. ``results`` carries the validator's classification unchanged.
"""

validator_name: str
results: BulkValidationResult


@dataclass(frozen=True)
class BulkProcessResult[TBulkActionResult: BaseBulkActionResult]:
"""Outcome of a ``BulkActionProcessor`` run.

``result`` is what the service function returned for the permitted subset
of entity IDs. ``validator_decisions`` keeps the per-validator trace in
iteration order; callers assemble the partial-success response by
walking it (each decision carries the denied IDs and their reasons).
"""

result: TBulkActionResult
validator_decisions: list[ValidatorDecision]


class BulkActionProcessor[
TBulkAction: BaseBulkAction[Any],
TBulkActionResult: BaseBulkActionResult,
]:
_validators: Sequence[BulkActionValidator]

_runner: ActionRunner[TBulkAction, TBulkActionResult]

def __init__(
self,
func: Callable[[TBulkAction], Awaitable[TBulkActionResult]],
monitors: Sequence[ActionMonitor] | None = None,
validators: Sequence[BulkActionValidator] | None = None,
) -> None:
self._runner = ActionRunner(func, monitors)

Comment thread
fregataa marked this conversation as resolved.
self._validators = validators or []

def _filter_by_validation(
self,
action: TBulkAction,
validation: BulkValidationResult,
) -> TBulkAction:
"""Return a new action narrowed to the IDs this validator permitted.

Returns the incoming action unchanged when the validator denied
nothing; otherwise constructs a fresh instance of the same class
via its ``entity_ids``-only constructor so the original stays
immutable.
"""
if not validation.denied_entities:
return action
allowed_set = set(validation.allowed_entity_ids)
filtered_ids = [eid for eid in action.entity_ids if eid in allowed_set]
return type(action)(entity_ids=filtered_ids)

async def _run(self, action: TBulkAction) -> BulkProcessResult[TBulkActionResult]:
started_at = datetime.now(UTC)
action_id = uuid.uuid4()
action_trigger_meta = BaseActionTriggerMeta(action_id=action_id, started_at=started_at)

filtered_action: TBulkAction = action
decisions: list[ValidatorDecision] = []

for validator in self._validators:
validation = await validator.validate(filtered_action, action_trigger_meta)
decisions.append(
ValidatorDecision(
validator_name=validator.name(),
results=validation,
)
)
filtered_action = self._filter_by_validation(filtered_action, validation)

action_result = await self._runner.run(filtered_action, action_trigger_meta)
return BulkProcessResult(result=action_result, validator_decisions=decisions)

async def wait_for_complete(self, action: TBulkAction) -> BulkProcessResult[TBulkActionResult]:
return await self._run(action)
10 changes: 0 additions & 10 deletions src/ai/backend/manager/actions/validator/batch.py

This file was deleted.

57 changes: 57 additions & 0 deletions src/ai/backend/manager/actions/validator/bulk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from ai.backend.manager.actions.action import BaseActionTriggerMeta
from ai.backend.manager.actions.action.bulk import BaseBulkAction


@dataclass(frozen=True)
class DeniedEntity:
"""A bulk entity that a validator rejected, paired with its reason."""

entity_id: str
deny_reason: str


@dataclass(frozen=True)
class BulkValidationResult:
"""Per-entity validation outcome for a bulk action.

``BulkActionProcessor`` intersects ``allowed_entity_ids`` across
validators and records each ``DeniedEntity`` — with its reason — on the
corresponding ``ValidatorDecision`` so the final response can
surface *why* each ID was filtered out.
"""

allowed_entity_ids: list[str]
denied_entities: list[DeniedEntity]


class BulkActionValidator(ABC):
@classmethod
@abstractmethod
def name(cls) -> str:
"""Stable identifier used in ``ValidatorDecision.validator_name``.

Chosen by the implementation so logs and partial-success responses can
attribute denials to a specific validator independently of the Python
class name.
"""
raise NotImplementedError

@abstractmethod
async def validate(
self, action: BaseBulkAction[Any], meta: BaseActionTriggerMeta
) -> BulkValidationResult:
"""Validate the bulk action and return per-entity permission results.

Implementations must classify every ID in ``action.entity_ids`` as
either allowed or denied. Validators that cannot make a decision for
an ID should treat it as allowed.

The processor wraps each call in its own async context manager so
cross-cutting concerns (timing, audit) live in one place — validators
do not need to own them.
"""
raise NotImplementedError
21 changes: 0 additions & 21 deletions src/ai/backend/manager/actions/validators/rbac/batch.py

This file was deleted.

Loading
Loading