Skip to content

Commit 850f8f8

Browse files
fregataaclaude
andcommitted
feat(BA-5777): add batch RBAC filtering infrastructure
- BaseBatchAction becomes a generic dataclass with entity_ids: list[str] + typed_entity_ids() abstract for native-type access - BatchActionValidator exposes name() and validate() returning BatchValidationResult (allowed IDs + DeniedEntity with deny_reason) - BatchActionProcessor runs each validator inside _validator_scope, narrows entity_ids in place, and returns BatchProcessResult carrying per-validator decisions for partial-success responses - Remove orphan check_and_transit_status_multi path and 10 unused intermediate batch base classes Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent f127c60 commit 850f8f8

17 files changed

Lines changed: 194 additions & 289 deletions

File tree

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
from abc import abstractmethod
2-
from typing import TypeVar, override
2+
from dataclasses import dataclass
3+
from typing import Any, TypeVar, override
34

45
from .base import BaseAction, BaseActionResult
5-
from .types import BatchFieldData
66

77

8-
class BaseBatchAction(BaseAction):
9-
@override
10-
def entity_id(self) -> str | None:
11-
return None
8+
@dataclass
9+
class BaseBatchAction[T](BaseAction):
10+
"""Base class for actions operating on a batch of entities.
1211
13-
@abstractmethod
14-
def entity_ids(self) -> list[str]:
15-
raise NotImplementedError
12+
``entity_ids`` is stored as ``list[str]`` so ``BatchActionProcessor`` can
13+
narrow it in place after each RBAC validator's verdict arrives without
14+
caring about the concrete ID type. Concrete subclasses implement
15+
``typed_entity_ids()`` to surface the native ``list[T]`` (e.g.
16+
``list[SessionId]``) the service layer expects.
17+
"""
18+
19+
entity_ids: list[str]
1620

1721
@abstractmethod
18-
def field_data(self) -> BatchFieldData | None:
19-
"""
20-
Returns batch field data containing the field type and IDs when the
21-
action's targets exist as fields of another entity.
22-
Returns None if these entities are not fields.
23-
"""
22+
def typed_entity_ids(self) -> list[T]:
23+
"""Return ``entity_ids`` converted back to the native ID type ``T``."""
2424
raise NotImplementedError
2525

2626

@@ -34,5 +34,5 @@ def entity_ids(self) -> list[str]:
3434
raise NotImplementedError
3535

3636

37-
TBatchAction = TypeVar("TBatchAction", bound=BaseBatchAction)
37+
TBatchAction = TypeVar("TBatchAction", bound=BaseBatchAction[Any])
3838
TBatchActionResult = TypeVar("TBatchActionResult", bound=BaseBatchActionResult)

src/ai/backend/manager/actions/processor/batch.py

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import logging
22
import uuid
3-
from collections.abc import Awaitable, Callable, Sequence
3+
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
4+
from contextlib import asynccontextmanager
5+
from dataclasses import dataclass
46
from datetime import UTC, datetime
7+
from typing import Any
58

69
from ai.backend.logging.utils import BraceStyleAdapter
710
from ai.backend.manager.actions.action import (
@@ -12,15 +15,45 @@
1215
BaseBatchActionResult,
1316
)
1417
from ai.backend.manager.actions.monitors.monitor import ActionMonitor
15-
from ai.backend.manager.actions.validator.batch import BatchActionValidator
18+
from ai.backend.manager.actions.validator.batch import (
19+
BatchActionValidator,
20+
BatchValidationResult,
21+
)
1622

1723
from .base import ActionRunner
1824

1925
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
2026

2127

28+
@dataclass(frozen=True)
29+
class BatchValidatorDecision:
30+
"""One validator's per-entity verdict observed during batch processing.
31+
32+
Mirrors the ``SubStepResult`` pattern used by the scheduler history so
33+
callers can trace where in the validator chain each ID was filtered and
34+
*why*. ``results`` carries the validator's classification unchanged.
35+
"""
36+
37+
validator_name: str
38+
results: BatchValidationResult
39+
40+
41+
@dataclass(frozen=True)
42+
class BatchProcessResult[TBatchActionResult: BaseBatchActionResult]:
43+
"""Outcome of a ``BatchActionProcessor`` run.
44+
45+
``result`` is what the service function returned for the permitted subset
46+
of entity IDs. ``validator_decisions`` keeps the per-validator trace in
47+
iteration order; callers assemble the partial-success response by
48+
walking it (each decision carries the denied IDs and their reasons).
49+
"""
50+
51+
result: TBatchActionResult
52+
validator_decisions: list[BatchValidatorDecision]
53+
54+
2255
class BatchActionProcessor[
23-
TBatchAction: BaseBatchAction,
56+
TBatchAction: BaseBatchAction[Any],
2457
TBatchActionResult: BaseBatchActionResult,
2558
]:
2659
_validators: Sequence[BatchActionValidator]
@@ -37,14 +70,72 @@ def __init__(
3770

3871
self._validators = validators or []
3972

40-
async def _run(self, action: TBatchAction) -> TBatchActionResult:
73+
@asynccontextmanager
74+
async def _validator_scope(
75+
self,
76+
validator: BatchActionValidator,
77+
action: TBatchAction,
78+
meta: BaseActionTriggerMeta,
79+
) -> AsyncIterator[BatchValidationResult]:
80+
"""Run one validator inside a bookend scope.
81+
82+
Yields the validator's ``BatchValidationResult`` so the caller can
83+
record the decision inside the block. Timing and per-validator
84+
logging live here rather than inside each validator implementation.
85+
"""
86+
started_at = datetime.now(UTC)
87+
validation = await validator.validate(action, meta)
88+
try:
89+
yield validation
90+
finally:
91+
duration = (datetime.now(UTC) - started_at).total_seconds()
92+
log.debug(
93+
"batch validator {} saw {} ids, denied {} in {:.3f}s",
94+
validator.name(),
95+
len(validation.allowed_entity_ids) + len(validation.denied_entities),
96+
len(validation.denied_entities),
97+
duration,
98+
)
99+
100+
def _process_action(
101+
self,
102+
current_action: TBatchAction,
103+
validation: BatchValidationResult,
104+
) -> TBatchAction:
105+
"""Narrow ``current_action.entity_ids`` in place using one validator's verdict.
106+
107+
Returns the action unchanged when the validator denied nothing.
108+
"""
109+
if not validation.denied_entities:
110+
return current_action
111+
allowed_set = set(validation.allowed_entity_ids)
112+
current_action.entity_ids = [eid for eid in current_action.entity_ids if eid in allowed_set]
113+
return current_action
114+
115+
async def _run(self, action: TBatchAction) -> BatchProcessResult[TBatchActionResult]:
41116
started_at = datetime.now(UTC)
42117
action_id = uuid.uuid4()
43118
action_trigger_meta = BaseActionTriggerMeta(action_id=action_id, started_at=started_at)
119+
120+
current_action: TBatchAction = action
121+
decisions: list[BatchValidatorDecision] = []
122+
44123
for validator in self._validators:
45-
await validator.validate(action, action_trigger_meta)
124+
async with self._validator_scope(
125+
validator, current_action, action_trigger_meta
126+
) as validation:
127+
decisions.append(
128+
BatchValidatorDecision(
129+
validator_name=validator.name(),
130+
results=validation,
131+
)
132+
)
133+
current_action = self._process_action(current_action, validation)
46134

47-
return await self._runner.run(action, action_trigger_meta)
135+
action_result = await self._runner.run(current_action, action_trigger_meta)
136+
return BatchProcessResult(result=action_result, validator_decisions=decisions)
48137

49-
async def wait_for_complete(self, action: TBatchAction) -> TBatchActionResult:
138+
async def wait_for_complete(
139+
self, action: TBatchAction
140+
) -> BatchProcessResult[TBatchActionResult]:
50141
return await self._run(action)
Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,57 @@
11
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
from typing import Any
24

35
from ai.backend.manager.actions.action import BaseActionTriggerMeta
46
from ai.backend.manager.actions.action.batch import BaseBatchAction
57

68

9+
@dataclass(frozen=True)
10+
class DeniedEntity:
11+
"""A batch entity that a validator rejected, paired with its reason."""
12+
13+
entity_id: str
14+
deny_reason: str
15+
16+
17+
@dataclass(frozen=True)
18+
class BatchValidationResult:
19+
"""Per-entity validation outcome for a batch action.
20+
21+
``BatchActionProcessor`` intersects ``allowed_entity_ids`` across
22+
validators and records each ``DeniedEntity`` — with its reason — on the
23+
corresponding ``BatchValidatorDecision`` so the final response can
24+
surface *why* each ID was filtered out.
25+
"""
26+
27+
allowed_entity_ids: list[str]
28+
denied_entities: list[DeniedEntity]
29+
30+
731
class BatchActionValidator(ABC):
32+
@classmethod
33+
@abstractmethod
34+
def name(cls) -> str:
35+
"""Stable identifier used in ``BatchValidatorDecision.validator_name``.
36+
37+
Chosen by the implementation so logs and partial-success responses can
38+
attribute denials to a specific validator independently of the Python
39+
class name.
40+
"""
41+
raise NotImplementedError
42+
843
@abstractmethod
9-
async def validate(self, action: BaseBatchAction, meta: BaseActionTriggerMeta) -> None:
10-
raise NotImplementedError("Subclasses must implement the validate method")
44+
async def validate(
45+
self, action: BaseBatchAction[Any], meta: BaseActionTriggerMeta
46+
) -> BatchValidationResult:
47+
"""Validate the batch action and return per-entity permission results.
48+
49+
Implementations must classify every ID in ``action.entity_ids`` as
50+
either allowed or denied. Validators that cannot make a decision for
51+
an ID should treat it as allowed.
52+
53+
The processor wraps each call in its own async context manager so
54+
cross-cutting concerns (timing, audit) live in one place — validators
55+
do not need to own them.
56+
"""
57+
raise NotImplementedError
Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import override
1+
from typing import Any, override
22

33
from ai.backend.manager.actions.action import BaseActionTriggerMeta
44
from ai.backend.manager.actions.action.batch import BaseBatchAction
5-
from ai.backend.manager.actions.validator.batch import BatchActionValidator
5+
from ai.backend.manager.actions.validator.batch import (
6+
BatchActionValidator,
7+
BatchValidationResult,
8+
)
69
from ai.backend.manager.repositories.permission_controller.repository import (
710
PermissionControllerRepository,
811
)
@@ -15,7 +18,18 @@ def __init__(
1518
) -> None:
1619
self._repository = repository
1720

21+
@classmethod
22+
@override
23+
def name(cls) -> str:
24+
return "rbac"
25+
1826
@override
19-
async def validate(self, action: BaseBatchAction, meta: BaseActionTriggerMeta) -> None:
20-
# TODO: implement RBAC validation logic
21-
pass
27+
async def validate(
28+
self, action: BaseBatchAction[Any], meta: BaseActionTriggerMeta
29+
) -> BatchValidationResult:
30+
# TODO: wire this to PermissionControllerRepository.check_batch_permission_with_scope_chain().
31+
# Until then, every entity is treated as allowed so legacy behavior is preserved.
32+
return BatchValidationResult(
33+
allowed_entity_ids=list(action.entity_ids),
34+
denied_entities=[],
35+
)

src/ai/backend/manager/services/artifact/actions/base.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import override
33

44
from ai.backend.common.data.permission.types import EntityType
5-
from ai.backend.manager.actions.action import BaseAction, BaseBatchAction
5+
from ai.backend.manager.actions.action import BaseAction
66
from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult
77
from ai.backend.manager.actions.action.single_entity import (
88
BaseSingleEntityAction,
@@ -19,14 +19,6 @@ def entity_type(cls) -> EntityType:
1919
return EntityType.ARTIFACT
2020

2121

22-
@dataclass
23-
class ArtifactBatchAction(BaseBatchAction):
24-
@override
25-
@classmethod
26-
def entity_type(cls) -> EntityType:
27-
return EntityType.ARTIFACT
28-
29-
3022
@dataclass
3123
class ArtifactScopeAction(BaseScopeAction):
3224
@override

src/ai/backend/manager/services/artifact_registry/actions/base.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import override
22

33
from ai.backend.common.data.permission.types import EntityType
4-
from ai.backend.manager.actions.action import BaseAction, BaseBatchAction
4+
from ai.backend.manager.actions.action import BaseAction
55
from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult
66
from ai.backend.manager.actions.action.single_entity import (
77
BaseSingleEntityAction,
@@ -17,13 +17,6 @@ def entity_type(cls) -> EntityType:
1717
return EntityType.ARTIFACT_REGISTRY
1818

1919

20-
class ArtifactBatchRegistryAction(BaseBatchAction):
21-
@override
22-
@classmethod
23-
def entity_type(cls) -> EntityType:
24-
return EntityType.ARTIFACT_REGISTRY
25-
26-
2720
class ArtifactRegistryScopeAction(BaseScopeAction):
2821
@override
2922
@classmethod

src/ai/backend/manager/services/artifact_revision/actions/base.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import override
33

44
from ai.backend.common.data.permission.types import EntityType
5-
from ai.backend.manager.actions.action import BaseAction, BaseBatchAction
5+
from ai.backend.manager.actions.action import BaseAction
66

77

88
@dataclass
@@ -11,11 +11,3 @@ class ArtifactRevisionAction(BaseAction):
1111
@classmethod
1212
def entity_type(cls) -> EntityType:
1313
return EntityType.ARTIFACT_REVISION
14-
15-
16-
@dataclass
17-
class ArtifactRevisionBatchAction(BaseBatchAction):
18-
@override
19-
@classmethod
20-
def entity_type(cls) -> EntityType:
21-
return EntityType.ARTIFACT_REVISION

src/ai/backend/manager/services/container_registry/actions/base.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import override
33

44
from ai.backend.common.data.permission.types import EntityType
5-
from ai.backend.manager.actions.action import BaseAction, BaseBatchAction
5+
from ai.backend.manager.actions.action import BaseAction
66

77

88
@dataclass
@@ -11,11 +11,3 @@ class ContainerRegistryAction(BaseAction):
1111
@classmethod
1212
def entity_type(cls) -> EntityType:
1313
return EntityType.CONTAINER_REGISTRY
14-
15-
16-
@dataclass
17-
class ContainerRegistryBatchAction(BaseBatchAction):
18-
@override
19-
@classmethod
20-
def entity_type(cls) -> EntityType:
21-
return EntityType.CONTAINER_REGISTRY

src/ai/backend/manager/services/image/actions/base.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import override
33

44
from ai.backend.common.data.permission.types import EntityType
5-
from ai.backend.manager.actions.action import BaseAction, BaseBatchAction
5+
from ai.backend.manager.actions.action import BaseAction
66
from ai.backend.manager.actions.action.single_entity import (
77
BaseSingleEntityAction,
88
BaseSingleEntityActionResult,
@@ -18,14 +18,6 @@ def entity_type(cls) -> EntityType:
1818
return EntityType.IMAGE
1919

2020

21-
@dataclass
22-
class ImageBatchAction(BaseBatchAction):
23-
@override
24-
@classmethod
25-
def entity_type(cls) -> EntityType:
26-
return EntityType.IMAGE
27-
28-
2921
@dataclass
3022
class ImageSingleEntityAction(BaseSingleEntityAction):
3123
@override

0 commit comments

Comments
 (0)