-
Notifications
You must be signed in to change notification settings - Fork 175
Expand file tree
/
Copy pathbulk.py
More file actions
112 lines (89 loc) · 3.88 KB
/
bulk.py
File metadata and controls
112 lines (89 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
import uuid
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
from ai.backend.logging.utils import BraceStyleAdapter
from ai.backend.manager.actions.action import (
BaseActionTriggerMeta,
)
from ai.backend.manager.actions.action.bulk import (
BaseBulkAction,
BaseBulkActionResult,
)
from ai.backend.manager.actions.monitors.monitor import ActionMonitor
from ai.backend.manager.actions.validator.bulk import (
BulkActionValidator,
BulkValidationResult,
)
from .base import ActionRunner
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
@dataclass(frozen=True)
class ValidatorDecision:
"""One validator's per-entity verdict observed during bulk processing.
Mirrors the ``SubStepResult`` pattern used by the scheduler history so
callers can trace where in the validator chain each ID was filtered and
*why*. ``results`` carries the validator's classification unchanged.
"""
validator_name: str
results: BulkValidationResult
@dataclass(frozen=True)
class BulkProcessResult[TBulkActionResult: BaseBulkActionResult]:
"""Outcome of a ``BulkActionProcessor`` run.
``result`` is what the service function returned for the permitted subset
of entity IDs. ``validator_decisions`` keeps the per-validator trace in
iteration order; callers assemble the partial-success response by
walking it (each decision carries the denied IDs and their reasons).
"""
result: TBulkActionResult
validator_decisions: list[ValidatorDecision]
class BulkActionProcessor[
TBulkAction: BaseBulkAction[Any],
TBulkActionResult: BaseBulkActionResult,
]:
_validators: Sequence[BulkActionValidator]
_runner: ActionRunner[TBulkAction, TBulkActionResult]
def __init__(
self,
func: Callable[[TBulkAction], Awaitable[TBulkActionResult]],
monitors: Sequence[ActionMonitor] | None = None,
validators: Sequence[BulkActionValidator] | None = None,
) -> None:
self._runner = ActionRunner(func, monitors)
self._validators = validators or []
def _filter_by_validation(
self,
action: TBulkAction,
validation: BulkValidationResult,
) -> TBulkAction:
"""Return a new action narrowed to the IDs this validator permitted.
Returns the incoming action unchanged when the validator denied
nothing; otherwise constructs a fresh instance of the same class
via its ``entity_ids``-only constructor so the original stays
immutable.
"""
if not validation.denied_entities:
return action
allowed_set = set(validation.allowed_entity_ids)
filtered_ids = [eid for eid in action.entity_ids if eid in allowed_set]
return type(action)(entity_ids=filtered_ids)
async def _run(self, action: TBulkAction) -> BulkProcessResult[TBulkActionResult]:
started_at = datetime.now(UTC)
action_id = uuid.uuid4()
action_trigger_meta = BaseActionTriggerMeta(action_id=action_id, started_at=started_at)
filtered_action: TBulkAction = action
decisions: list[ValidatorDecision] = []
for validator in self._validators:
validation = await validator.validate(filtered_action, action_trigger_meta)
decisions.append(
ValidatorDecision(
validator_name=validator.name(),
results=validation,
)
)
filtered_action = self._filter_by_validation(filtered_action, validation)
action_result = await self._runner.run(filtered_action, action_trigger_meta)
return BulkProcessResult(result=action_result, validator_decisions=decisions)
async def wait_for_complete(self, action: TBulkAction) -> BulkProcessResult[TBulkActionResult]:
return await self._run(action)