11import logging
22import uuid
3- from collections .abc import Awaitable , Callable , Sequence
3+ from collections .abc import AsyncIterator , Awaitable , Callable , Sequence
4+ from contextlib import asynccontextmanager
5+ from dataclasses import dataclass
46from datetime import UTC , datetime
7+ from typing import Any
58
69from ai .backend .logging .utils import BraceStyleAdapter
710from ai .backend .manager .actions .action import (
1215 BaseBatchActionResult ,
1316)
1417from ai .backend .manager .actions .monitors .monitor import ActionMonitor
15- from ai .backend .manager .actions .validator .batch import BatchActionValidator
18+ from ai .backend .manager .actions .validator .batch import (
19+ BatchActionValidator ,
20+ BatchValidationResult ,
21+ )
1622
1723from .base import ActionRunner
1824
1925log = BraceStyleAdapter (logging .getLogger (__spec__ .name ))
2026
2127
28+ @dataclass (frozen = True )
29+ class BatchValidatorDecision :
30+ """One validator's per-entity verdict observed during batch processing.
31+
32+ Mirrors the ``SubStepResult`` pattern used by the scheduler history so
33+ callers can trace where in the validator chain each ID was filtered and
34+ *why*. ``results`` carries the validator's classification unchanged.
35+ """
36+
37+ validator_name : str
38+ results : BatchValidationResult
39+
40+
41+ @dataclass (frozen = True )
42+ class BatchProcessResult [TBatchActionResult : BaseBatchActionResult ]:
43+ """Outcome of a ``BatchActionProcessor`` run.
44+
45+ ``result`` is what the service function returned for the permitted subset
46+ of entity IDs. ``validator_decisions`` keeps the per-validator trace in
47+ iteration order; callers assemble the partial-success response by
48+ walking it (each decision carries the denied IDs and their reasons).
49+ """
50+
51+ result : TBatchActionResult
52+ validator_decisions : list [BatchValidatorDecision ]
53+
54+
2255class BatchActionProcessor [
23- TBatchAction : BaseBatchAction ,
56+ TBatchAction : BaseBatchAction [ Any ] ,
2457 TBatchActionResult : BaseBatchActionResult ,
2558]:
2659 _validators : Sequence [BatchActionValidator ]
@@ -37,14 +70,72 @@ def __init__(
3770
3871 self ._validators = validators or []
3972
40- async def _run (self , action : TBatchAction ) -> TBatchActionResult :
73+ @asynccontextmanager
74+ async def _validator_scope (
75+ self ,
76+ validator : BatchActionValidator ,
77+ action : TBatchAction ,
78+ meta : BaseActionTriggerMeta ,
79+ ) -> AsyncIterator [BatchValidationResult ]:
80+ """Run one validator inside a bookend scope.
81+
82+ Yields the validator's ``BatchValidationResult`` so the caller can
83+ record the decision inside the block. Timing and per-validator
84+ logging live here rather than inside each validator implementation.
85+ """
86+ started_at = datetime .now (UTC )
87+ validation = await validator .validate (action , meta )
88+ try :
89+ yield validation
90+ finally :
91+ duration = (datetime .now (UTC ) - started_at ).total_seconds ()
92+ log .debug (
93+ "batch validator {} saw {} ids, denied {} in {:.3f}s" ,
94+ validator .name (),
95+ len (validation .allowed_entity_ids ) + len (validation .denied_entities ),
96+ len (validation .denied_entities ),
97+ duration ,
98+ )
99+
100+ def _process_action (
101+ self ,
102+ current_action : TBatchAction ,
103+ validation : BatchValidationResult ,
104+ ) -> TBatchAction :
105+ """Narrow ``current_action.entity_ids`` in place using one validator's verdict.
106+
107+ Returns the action unchanged when the validator denied nothing.
108+ """
109+ if not validation .denied_entities :
110+ return current_action
111+ allowed_set = set (validation .allowed_entity_ids )
112+ current_action .entity_ids = [eid for eid in current_action .entity_ids if eid in allowed_set ]
113+ return current_action
114+
115+ async def _run (self , action : TBatchAction ) -> BatchProcessResult [TBatchActionResult ]:
41116 started_at = datetime .now (UTC )
42117 action_id = uuid .uuid4 ()
43118 action_trigger_meta = BaseActionTriggerMeta (action_id = action_id , started_at = started_at )
119+
120+ current_action : TBatchAction = action
121+ decisions : list [BatchValidatorDecision ] = []
122+
44123 for validator in self ._validators :
45- await validator .validate (action , action_trigger_meta )
124+ async with self ._validator_scope (
125+ validator , current_action , action_trigger_meta
126+ ) as validation :
127+ decisions .append (
128+ BatchValidatorDecision (
129+ validator_name = validator .name (),
130+ results = validation ,
131+ )
132+ )
133+ current_action = self ._process_action (current_action , validation )
46134
47- return await self ._runner .run (action , action_trigger_meta )
135+ action_result = await self ._runner .run (current_action , action_trigger_meta )
136+ return BatchProcessResult (result = action_result , validator_decisions = decisions )
48137
49- async def wait_for_complete (self , action : TBatchAction ) -> TBatchActionResult :
138+ async def wait_for_complete (
139+ self , action : TBatchAction
140+ ) -> BatchProcessResult [TBatchActionResult ]:
50141 return await self ._run (action )
0 commit comments