-
Notifications
You must be signed in to change notification settings - Fork 152
Expand file tree
/
Copy pathasync_scheduler.py
More file actions
929 lines (806 loc) · 41.1 KB
/
async_scheduler.py
File metadata and controls
929 lines (806 loc) · 41.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import contextlib
import logging
import time
from collections import deque
from collections.abc import Coroutine
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable
import data_designer.lazy_heavy_imports as lazy
from data_designer.config.column_configs import GenerationStrategy
from data_designer.engine.context import current_row_group
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
from data_designer.engine.dataset_builders.utils.async_progress_reporter import (
DEFAULT_REPORT_INTERVAL,
AsyncProgressReporter,
)
from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker
from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker
from data_designer.engine.dataset_builders.utils.skip_evaluator import should_skip_column_for_record
from data_designer.engine.dataset_builders.utils.skip_tracker import (
apply_skip_to_record,
strip_skip_metadata_from_records,
)
from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar
from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task, TaskTrace
from data_designer.engine.models.errors import (
ModelAPIConnectionError,
ModelInternalServerError,
ModelRateLimitError,
ModelTimeoutError,
)
if TYPE_CHECKING:
from data_designer.engine.column_generators.generators.base import ColumnGenerator
from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph
from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager
logger = logging.getLogger(__name__)
DEFAULT_TASK_POOL_SIZE: int = 256
LLM_WAIT_POOL_MULTIPLIER: int = 2
_RETRYABLE_MODEL_ERRORS = (
ModelRateLimitError,
ModelTimeoutError,
ModelInternalServerError,
ModelAPIConnectionError,
)
class TrackingSemaphore(asyncio.Semaphore):
"""``asyncio.Semaphore`` subclass that exposes available permits publicly."""
@property
def available_permits(self) -> int:
return self._value # type: ignore[attr-defined]
def try_acquire(self) -> bool:
"""Non-blocking acquire. Returns ``True`` if a permit was taken."""
if self._value > 0: # type: ignore[attr-defined]
self._value -= 1 # type: ignore[attr-defined]
return True
return False
@dataclass
class _RowGroupState:
"""Lifecycle state for a single admitted row group."""
size: int
seeds_dispatched: bool = False
pre_batch_done: bool = False
in_flight_count: int = 0
class AsyncTaskScheduler:
"""Dependency-aware async task scheduler for the dataset builder.
Replaces sequential column-by-column processing with parallel dispatch
based on the ``ExecutionGraph`` and ``CompletionTracker``.
"""
def __init__(
self,
generators: dict[str, ColumnGenerator],
graph: ExecutionGraph,
tracker: CompletionTracker,
row_groups: list[tuple[int, int]],
buffer_manager: RowGroupBufferManager | None = None,
*,
max_concurrent_row_groups: int = 3,
max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE,
max_llm_wait_tasks: int = DEFAULT_TASK_POOL_SIZE,
salvage_max_rounds: int = 2,
on_finalize_row_group: Callable[[int], None] | None = None,
on_seeds_complete: Callable[[int, int], None] | None = None,
on_before_checkpoint: Callable[[int, int], None] | None = None,
shutdown_error_rate: float = 0.5,
shutdown_error_window: int = 10,
disable_early_shutdown: bool = False,
trace: bool = False,
num_records: int = 0,
buffer_size: int = 0,
progress_interval: float | None = None,
progress_bar: bool = False,
) -> None:
self._generators = generators
self._graph = graph
self._tracker = tracker
self._row_groups = row_groups
self._buffer_manager = buffer_manager
self._rg_semaphore = asyncio.Semaphore(max_concurrent_row_groups)
self._submission_semaphore = TrackingSemaphore(max_submitted_tasks)
self._llm_wait_semaphore = TrackingSemaphore(max_llm_wait_tasks)
self._llm_bound_lookup = build_llm_bound_lookup(generators)
self._dispatched: set[Task] = set()
self._in_flight: set[Task] = set()
self._worker_tasks: set[asyncio.Task] = set()
self._wake_event = asyncio.Event()
self._salvage_max_rounds = salvage_max_rounds
self._on_finalize_row_group = on_finalize_row_group
self._on_seeds_complete = on_seeds_complete
self._on_before_checkpoint = on_before_checkpoint
# Error rate shutdown (caller passes pre-normalized values via RunConfig)
self._shutdown_error_rate = shutdown_error_rate
self._shutdown_error_window = shutdown_error_window
self._disable_early_shutdown = disable_early_shutdown
self._early_shutdown = False
# Multi-column dedup: group output columns by generator identity.
# _gen_instance_to_columns holds only real (graph-registered) columns
# and is used for completion tracking.
# _gen_instance_to_columns_including_side_effects extends that with
# side-effect columns for buffer writes only.
gen_instance_to_columns: dict[int, list[str]] = {}
for col, gen in generators.items():
gen_instance_to_columns.setdefault(id(gen), []).append(col)
self._gen_instance_to_columns = gen_instance_to_columns
seen_cols: set[str] = {col for col in generators}
gen_instance_to_columns_incl_se: dict[int, list[str]] = {k: list(v) for k, v in gen_instance_to_columns.items()}
for col, gen in generators.items():
for side_effect_col in getattr(gen.config, "side_effect_columns", []):
if side_effect_col not in seen_cols:
gen_instance_to_columns_incl_se.setdefault(id(gen), []).append(side_effect_col)
seen_cols.add(side_effect_col)
self._gen_instance_to_columns_including_side_effects = gen_instance_to_columns_incl_se
# Stateful generator tracking: instance_id → asyncio.Lock
self._stateful_locks: dict[int, asyncio.Lock] = {}
for col, gen in generators.items():
if gen.is_order_dependent and id(gen) not in self._stateful_locks:
self._stateful_locks[id(gen)] = asyncio.Lock()
# Per-RG lifecycle state (admitted but not yet checkpointed)
self._rg_states: dict[int, _RowGroupState] = {}
# Deferred retryable failures (retried in salvage rounds)
self._deferred: list[Task] = []
# Tracing
self._trace = trace
self.traces: list[TaskTrace] = []
# Sliding window for error rate shutdown
self._recent_outcomes: deque[bool] = deque(maxlen=shutdown_error_window)
self._all_rgs_admitted = False
# Pre-compute row-group sizes for O(1) lookup
self._rg_size_map: dict[int, int] = dict(row_groups)
# Pre-compute seed columns (graph is static)
self._seed_cols: frozenset[str] = frozenset(c for c in graph.columns if not graph.get_upstream_columns(c))
# Per-column progress tracking (cell-by-cell only; full-column tasks are instant)
self._progress_bar = StickyProgressBar() if progress_bar else None
self._reporter = self._setup_async_progress_reporter(num_records, buffer_size, progress_interval)
def _setup_async_progress_reporter(
self,
num_records: int,
buffer_size: int,
progress_interval: float | None,
) -> AsyncProgressReporter | None:
if num_records <= 0 or buffer_size <= 0:
return None
task_counts = self._graph.compute_task_count(num_records, buffer_size)
trackers: dict[str, ProgressTracker] = {}
for col in self._graph.columns:
if self._graph.get_strategy(col) != GenerationStrategy.CELL_BY_CELL:
continue
trackers[col] = ProgressTracker(
total_records=task_counts[col],
label=f"column '{col}'",
quiet=True,
)
if not trackers:
return None
interval = progress_interval if progress_interval is not None else DEFAULT_REPORT_INTERVAL
return AsyncProgressReporter(
trackers,
report_interval=interval,
progress_bar=self._progress_bar,
)
def _spawn_worker(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task:
"""Create a tracked worker task that auto-removes itself on completion."""
task = asyncio.create_task(coro)
self._worker_tasks.add(task)
task.add_done_callback(self._worker_tasks.discard)
return task
async def _cancel_workers(self) -> None:
"""Cancel all tracked worker tasks and wait for them to finish."""
for t in self._worker_tasks:
t.cancel()
if self._worker_tasks:
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
self._worker_tasks.clear()
async def _admit_row_groups(self) -> None:
"""Admit row groups as semaphore slots become available."""
for rg_id, rg_size in self._row_groups:
await self._rg_semaphore.acquire()
self._rg_states[rg_id] = _RowGroupState(size=rg_size)
if self._buffer_manager is not None:
self._buffer_manager.init_row_group(rg_id, rg_size)
await self._dispatch_seeds(rg_id, rg_size)
self._wake_event.set()
self._all_rgs_admitted = True
self._wake_event.set()
async def run(self) -> None:
"""Main scheduler loop.
On cancellation (``CancelledError``), all tracked worker tasks are
cancelled and awaited so that held semaphore permits are released
before the error propagates.
"""
all_columns = self._graph.columns
seed_cols = self._seed_cols
has_pre_batch = self._on_seeds_complete is not None
num_rgs = len(self._row_groups)
with self._progress_bar or contextlib.nullcontext():
if self._reporter:
self._reporter.log_start(num_row_groups=num_rgs)
# Launch admission as a background task so it interleaves with dispatch.
admission_task = asyncio.create_task(self._admit_row_groups())
try:
# Main dispatch loop
await self._main_dispatch_loop(seed_cols, has_pre_batch, all_columns)
finally:
# Always cancel admission + drain in-flight workers, regardless
# of how the dispatch loop exited (normal, early shutdown,
# CancelledError, or processor failure).
if not admission_task.done():
admission_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await admission_task
await asyncio.shield(self._cancel_workers())
if self._reporter:
self._reporter.log_final()
if self._rg_states:
incomplete = list(self._rg_states)
logger.error(
f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. "
"These row groups were not checkpointed."
)
async def _main_dispatch_loop(
self,
seed_cols: frozenset[str],
has_pre_batch: bool,
all_columns: list[str],
) -> None:
"""Core dispatch loop extracted from ``run()``."""
while True:
if self._early_shutdown:
logger.warning("Early shutdown triggered - error rate exceeded threshold")
if self._deferred:
await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns)
self._checkpoint_completed_row_groups(all_columns)
break
self._wake_event.clear()
if has_pre_batch:
self._run_seeds_complete_check(seed_cols)
admitted_ids = set(self._rg_states)
ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids)
# Gate non-seed tasks on pre-batch completion when a pre-batch callback is configured
if has_pre_batch:
ready = [
t
for t in ready
if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols
]
semaphore_full = False
for task in ready:
if not self._submission_semaphore.try_acquire():
semaphore_full = True
break
self._dispatched.add(task)
self._in_flight.add(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count += 1
self._spawn_worker(self._execute_task(task))
self._checkpoint_completed_row_groups(all_columns)
# Eagerly salvage any row groups that have only deferred tasks,
# even if other row groups are still in-flight. This frees
# semaphore slots so admission doesn't lose capacity.
if self._deferred:
await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns)
# Are we done?
all_done = self._all_rgs_admitted and not self._rg_states and not self._in_flight
if all_done:
break
if not ready and not self._in_flight:
if self._all_rgs_admitted:
break
if not ready or semaphore_full:
await self._wake_event.wait()
async def _salvage_rounds(
self,
seed_cols: frozenset[str],
has_pre_batch: bool,
all_columns: list[str],
) -> None:
"""Phase 3: retry deferred (transient-failure) tasks."""
for round_num in range(self._salvage_max_rounds):
if not self._deferred:
break
logger.debug(f"Salvage round {round_num + 1}/{self._salvage_max_rounds}: {len(self._deferred)} tasks")
to_retry = self._deferred
self._deferred = []
for task in to_retry:
if task.task_type == "from_scratch":
# from_scratch tasks are not in the frontier; re-dispatch directly
gid = id(self._generators[task.column])
self._dispatched.discard(task)
# Also clear the batch alias so completion tracking works
self._dispatched.discard(
Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch")
)
for sibling in self._gen_instance_to_columns.get(gid, []):
if sibling != task.column:
self._dispatched.discard(
Task(column=sibling, row_group=task.row_group, row_index=None, task_type="from_scratch")
)
self._dispatched.discard(
Task(column=sibling, row_group=task.row_group, row_index=None, task_type="batch")
)
# Acquire stateful lock (mirrors _dispatch_seeds) so
# _execute_seed_task can safely release it in finally.
if gid in self._stateful_locks:
await self._stateful_locks[gid].acquire()
await self._submission_semaphore.acquire()
self._dispatched.add(task)
# Re-register batch alias to mirror _dispatch_seeds and prevent
# duplicate dispatch if the frontier contains a stale batch task.
self._dispatched.add(
Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch")
)
# Re-mark sibling columns as dispatched to mirror _dispatch_seeds
# and prevent _drain_frontier from re-dispatching them.
for sibling in self._gen_instance_to_columns.get(gid, []):
if sibling != task.column:
self._dispatched.add(
Task(column=sibling, row_group=task.row_group, row_index=None, task_type="from_scratch")
)
self._dispatched.add(
Task(column=sibling, row_group=task.row_group, row_index=None, task_type="batch")
)
self._in_flight.add(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count += 1
self._spawn_worker(self._execute_seed_task(task, gid))
else:
self._dispatched.discard(task)
# Drain: dispatch frontier tasks and any newly-ready downstream tasks
# until nothing remains in-flight or in the frontier.
await self._drain_frontier(seed_cols, has_pre_batch, all_columns)
self._checkpoint_completed_row_groups(all_columns)
async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, all_columns: list[str]) -> None:
"""Dispatch all frontier tasks and their downstream until quiescent."""
while True:
if has_pre_batch:
self._run_seeds_complete_check(seed_cols)
admitted_ids = set(self._rg_states)
ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids)
if has_pre_batch:
ready = [
t
for t in ready
if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols
]
for task in ready:
if not self._submission_semaphore.try_acquire():
break
self._dispatched.add(task)
self._in_flight.add(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count += 1
self._spawn_worker(self._execute_task(task))
if not ready and not self._in_flight:
break
if not self._in_flight:
continue
self._wake_event.clear()
await self._wake_event.wait()
async def _salvage_stalled_row_groups(
self,
seed_cols: frozenset[str],
has_pre_batch: bool,
all_columns: list[str],
) -> None:
"""Salvage row groups whose tasks are all deferred (0 in-flight).
Retries deferred tasks inline so the row groups can be checkpointed
and their semaphore slots freed, preventing deadlock when admission
is blocked.
"""
stalled_rgs = {
t.row_group
for t in self._deferred
if (s := self._rg_states.get(t.row_group)) is not None and s.in_flight_count == 0
}
if not stalled_rgs:
return
num_rgs = len(self._row_groups)
width = len(str(num_rgs))
for rg_id in sorted(stalled_rgs):
rg_deferred = [t for t in self._deferred if t.row_group == rg_id]
logger.info(f"🔄 ({rg_id + 1:0{width}d}/{num_rgs}) Salvaging {len(rg_deferred)} deferred task(s)")
# Partition deferred into stalled (retry now) and other (keep for later).
stalled_deferred = [t for t in self._deferred if t.row_group in stalled_rgs]
other_deferred = [t for t in self._deferred if t.row_group not in stalled_rgs]
self._deferred = stalled_deferred
await self._salvage_rounds(seed_cols, has_pre_batch, all_columns)
# Separate stalled tasks that exhausted retries from any new failures
# that _drain_frontier may have appended for non-stalled row groups.
exhausted = [t for t in self._deferred if t.row_group in stalled_rgs]
newly_deferred = [t for t in self._deferred if t.row_group not in stalled_rgs]
for task in exhausted:
# If the row was already dropped by an earlier task in this loop,
# the skip was already counted; don't also record a failure.
already_dropped = task.row_index is not None and self._tracker.is_dropped(task.row_group, task.row_index)
if not already_dropped and self._reporter:
self._reporter.record_failure(task.column)
if task.row_index is not None:
self._drop_row(task.row_group, task.row_index, exclude_columns={task.column})
else:
rg_size = self._get_rg_size(task.row_group)
self._drop_row_group(task.row_group, rg_size, exclude_columns={task.column})
self._checkpoint_completed_row_groups(all_columns)
self._deferred = other_deferred + newly_deferred
def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None:
"""Checkpoint any row groups that reached completion."""
completed = [
(rg_id, state.size)
for rg_id, state in self._rg_states.items()
if self._tracker.is_row_group_complete(rg_id, state.size, all_columns)
]
for rg_id, rg_size in completed:
try:
if self._on_before_checkpoint:
try:
self._on_before_checkpoint(rg_id, rg_size)
except DatasetGenerationError:
raise
except Exception as exc:
raise DatasetGenerationError(
f"Post-batch processor failed for row group {rg_id}: {exc}"
) from exc
# Remove from tracking only after the callback succeeds.
del self._rg_states[rg_id]
# If all rows were dropped (e.g. seed failure), free instead of finalizing
if all(self._tracker.is_dropped(rg_id, ri) for ri in range(rg_size)):
if self._buffer_manager:
self._buffer_manager.free_row_group(rg_id)
elif self._on_finalize_row_group is not None:
self._on_finalize_row_group(rg_id)
except DatasetGenerationError:
raise
except Exception:
logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True)
finally:
self._rg_semaphore.release()
# Clean up deferred tasks for checkpointed row groups
if completed:
checkpointed = {rg_id for rg_id, _ in completed}
self._deferred = [t for t in self._deferred if t.row_group not in checkpointed]
def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None:
"""Run pre-batch callbacks for row groups whose seeds just completed."""
for rg_id, state in list(self._rg_states.items()):
if state.seeds_dispatched and not state.pre_batch_done:
all_seeds_done = all(self._tracker.is_column_complete_for_rg(col, rg_id) for col in seed_cols)
if all_seeds_done and state.in_flight_count == 0:
state.pre_batch_done = True
if self._on_seeds_complete:
try:
self._on_seeds_complete(rg_id, state.size)
except DatasetGenerationError:
raise
except Exception as exc:
raise DatasetGenerationError(
f"Pre-batch processor failed for row group {rg_id}: {exc}"
) from exc
# The callback may drop rows (e.g. pre-batch filtering).
# Record skipped tasks for any newly-dropped rows so
# progress reporting stays accurate.
if self._reporter:
for ri in range(state.size):
if self._tracker.is_dropped(rg_id, ri):
self._record_skipped_tasks_for_row(rg_id, ri)
def _drop_row(self, row_group: int, row_index: int, *, exclude_columns: set[str] | None = None) -> None:
if self._tracker.is_dropped(row_group, row_index):
return
self._record_skipped_tasks_for_row(row_group, row_index, exclude_columns=exclude_columns)
self._tracker.drop_row(row_group, row_index)
if self._buffer_manager:
self._buffer_manager.drop_row(row_group, row_index)
def _drop_row_group(self, row_group: int, row_group_size: int, *, exclude_columns: set[str] | None = None) -> None:
for row_index in range(row_group_size):
self._drop_row(row_group, row_index, exclude_columns=exclude_columns)
def _record_skipped_tasks_for_row(
self,
row_group: int,
row_index: int,
*,
exclude_columns: set[str] | None = None,
) -> None:
if self._reporter is None:
return
excluded = exclude_columns or set()
in_flight_columns = {
task.column for task in self._in_flight if task.row_group == row_group and task.row_index == row_index
}
for column in self._graph.columns:
if column in excluded or self._graph.get_strategy(column) != GenerationStrategy.CELL_BY_CELL:
continue
if column in in_flight_columns:
continue
if self._tracker.is_complete(SliceRef(column=column, row_group=row_group, row_index=row_index)):
continue
self._reporter.record_skipped(column)
def _check_error_rate(self, *, success: bool) -> None:
"""Trigger early shutdown if recent error rate exceeds threshold."""
if self._disable_early_shutdown or self._early_shutdown:
return
self._recent_outcomes.append(success)
if len(self._recent_outcomes) < self._shutdown_error_window:
return
errors = sum(1 for ok in self._recent_outcomes if not ok)
if errors / self._shutdown_error_window >= self._shutdown_error_rate:
self._early_shutdown = True
async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None:
"""Dispatch from_scratch tasks for a row group."""
self._rg_states[rg_id].seeds_dispatched = True
seed_cols = self._seed_cols
if not seed_cols:
return
num_rgs = len(self._rg_size_map)
width = len(str(num_rgs))
logger.info(f"🚀 ({rg_id + 1:0{width}d}/{num_rgs}) Dispatching with {rg_size} records")
seen_instances: set[int] = set()
for col in seed_cols:
gen = self._generators[col]
gid = id(gen)
if gid in seen_instances:
continue
seen_instances.add(gid)
task = Task(column=col, row_group=rg_id, row_index=None, task_type="from_scratch")
# Also mark the "batch" variant as dispatched to prevent get_ready_tasks
# from generating a duplicate for this column
batch_alias = Task(column=col, row_group=rg_id, row_index=None, task_type="batch")
if task in self._dispatched or batch_alias in self._dispatched:
continue
# Acquire stateful lock *before* submission semaphore to preserve
# row-group ordering. Held until generation completes (_execute_seed_task).
if gid in self._stateful_locks:
await self._stateful_locks[gid].acquire()
await self._submission_semaphore.acquire()
self._dispatched.add(task)
self._dispatched.add(batch_alias)
# Also mark all sibling output columns as dispatched (multi-column dedup)
for sibling_col in self._gen_instance_to_columns.get(gid, []):
if sibling_col != col:
self._dispatched.add(
Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="from_scratch")
)
self._dispatched.add(Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="batch"))
self._in_flight.add(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count += 1
self._spawn_worker(self._execute_seed_task(task, gid))
async def _execute_seed_task(self, task: Task, generator_id: int) -> None:
"""Execute a from_scratch task and release stateful lock if held."""
try:
await self._execute_task_inner(task)
finally:
if generator_id in self._stateful_locks:
self._stateful_locks[generator_id].release()
async def _execute_task(self, task: Task) -> None:
"""Execute a single task (cell or batch)."""
await self._execute_task_inner(task)
async def _execute_task_inner(self, task: Task) -> None:
"""Core task execution logic.
For LLM-bound tasks, uses a one-way semaphore handoff: acquires the
LLM-wait slot while still holding the submission slot, then releases
the submission slot (never reacquired). This prevents cross-key
starvation while bounding live coroutines.
"""
num_rgs = len(self._row_groups)
token = current_row_group.set((task.row_group, num_rgs))
try:
await self._execute_task_inner_impl(task)
finally:
current_row_group.reset(token)
async def _execute_task_inner_impl(self, task: Task) -> None:
trace: TaskTrace | None = None
if self._trace:
trace = TaskTrace.from_task(task)
trace.dispatched_at = time.perf_counter()
generator = self._generators[task.column]
output_cols = self._gen_instance_to_columns.get(id(generator), [task.column])
retryable = False
# When True, skip removing from _dispatched so the task isn't re-dispatched
# from the frontier (it was never completed, so it stays in the frontier).
skipped = False
is_llm = self._llm_bound_lookup.get(task.column, False)
holds_submission = True
holds_llm_wait = False
try:
# Skip tasks whose row group was already checkpointed (can happen
# when a vacuously-ready downstream is dispatched via create_task
# in the same loop iteration that checkpoints the row group).
if task.row_group not in self._rg_states:
skipped = True
return
if is_llm:
await self._llm_wait_semaphore.acquire()
holds_llm_wait = True
self._submission_semaphore.release()
holds_submission = False
if self._trace and trace:
trace.slot_acquired_at = time.perf_counter()
cell_skipped = False
if task.task_type == "from_scratch":
await self._run_from_scratch(task, generator)
elif task.task_type == "cell":
_result, cell_skipped = await self._run_cell(task, generator)
elif task.task_type == "batch":
await self._run_batch(task, generator)
else:
raise ValueError(f"Unknown task type: {task.task_type}")
# Mark all output columns complete
for col in output_cols:
if task.row_index is None:
rg_size = self._get_rg_size(task.row_group)
self._tracker.mark_row_range_complete(col, task.row_group, rg_size)
else:
self._tracker.mark_cell_complete(col, task.row_group, task.row_index)
self._check_error_rate(success=True)
if self._reporter:
if cell_skipped:
self._reporter.record_skipped(task.column)
else:
self._reporter.record_success(task.column)
if self._trace and trace:
trace.status = "ok"
except Exception as exc:
if not isinstance(exc, ModelRateLimitError):
self._check_error_rate(success=False)
retryable = self._is_retryable(exc)
if not retryable and self._reporter:
self._reporter.record_failure(task.column)
if self._trace and trace:
trace.status = "error"
trace.error = str(exc)
if retryable:
self._deferred.append(task)
else:
# Non-retryable: drop the affected row(s)
if task.row_index is not None:
self._drop_row(task.row_group, task.row_index, exclude_columns={task.column})
else:
# Batch/from_scratch failure: drop all rows in the row group
rg_size = self._get_rg_size(task.row_group)
self._drop_row_group(task.row_group, rg_size, exclude_columns={task.column})
logger.warning(
f"Non-retryable failure on {task.column}[rg={task.row_group}, row={task.row_index}]: {exc}"
)
finally:
if self._trace and trace:
trace.completed_at = time.perf_counter()
self.traces.append(trace)
self._in_flight.discard(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count = max(0, s.in_flight_count - 1)
if not retryable and not skipped:
self._dispatched.discard(task)
if holds_llm_wait:
self._llm_wait_semaphore.release()
if holds_submission:
self._submission_semaphore.release()
self._wake_event.set()
async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any:
"""Execute a from_scratch task."""
rg_size = self._get_rg_size(task.row_group)
# Runtime import: needed for isinstance check; module-level would cause circular import
from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator
if isinstance(generator, FromScratchColumnGenerator):
result_df = await generator.agenerate_from_scratch(rg_size)
else:
result_df = await generator.agenerate(lazy.pd.DataFrame())
# Write results to buffer (include side-effect columns)
if self._buffer_manager is not None:
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
for col in write_cols:
if col in result_df.columns:
values = result_df[col].tolist()
self._buffer_manager.update_batch(task.row_group, col, values)
return result_df
async def _run_cell(self, task: Task, generator: ColumnGenerator) -> tuple[Any, bool]:
"""Execute a cell-by-cell task. Returns ``(result, skipped)``."""
if task.row_index is None:
raise ValueError(f"Cell task requires a row_index, got None for column '{task.column}'")
if self._tracker.is_dropped(task.row_group, task.row_index):
return None, False
# Evaluate skip against the live buffer record (no copy needed —
# there is no `await` between the read and the skip-metadata write).
if self._buffer_manager is not None:
record = self._buffer_manager.get_row(task.row_group, task.row_index)
else:
record = {}
if self._should_skip_record(task.column, record):
self._apply_skip_to_record(task, record)
skip_config = self._graph.get_skip_config(task.column)
return skip_config.value if skip_config is not None else None, True
# Copy for generation: agenerate crosses an await boundary, so the
# generator must not hold a mutable reference to the live record.
result = await generator.agenerate(dict(record))
# Write back to buffer (include side-effect columns)
if self._buffer_manager is not None and not self._tracker.is_dropped(task.row_group, task.row_index):
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
for col in write_cols:
if col in result:
self._buffer_manager.update_cell(task.row_group, task.row_index, col, result[col])
return result, False
def _should_skip_record(self, column: str, record: dict) -> bool:
"""Decide whether a cell should be skipped (propagation first, then expression gate)."""
skip_config = self._graph.get_skip_config(column)
return should_skip_column_for_record(
record,
propagate_skip=self._graph.should_propagate_skip(column),
required_columns=self._graph.get_required_columns(column),
skip_config_when=skip_config.when if skip_config is not None else None,
)
def _apply_skip_to_record(self, task: Task, record: dict) -> None:
"""Write skip metadata directly into *record* (the live buffer row)."""
skip_config = self._graph.get_skip_config(task.column)
skip_value = skip_config.value if skip_config is not None else None
apply_skip_to_record(
record,
column_name=task.column,
cell_value=skip_value,
side_effect_columns=self._graph.get_side_effect_columns(task.column),
)
async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:
"""Execute a full-column/batch task."""
rg_size = self._get_rg_size(task.row_group)
if self._buffer_manager is not None:
pre_dropped: set[int] = {ri for ri in range(rg_size) if self._buffer_manager.is_dropped(task.row_group, ri)}
active_rows_data: list[dict] = []
# Skip evaluation only applies to single-column configs.
# Multi-column configs (sampler/seed) are rejected by the SkipConfig
# model validator, so they never carry skip metadata.
pre_skipped: set[int] = set()
is_multi = isinstance(generator.config, MultiColumnConfig)
for ri in range(rg_size):
if ri in pre_dropped:
continue
record = self._buffer_manager.get_row(task.row_group, ri)
if not is_multi and self._should_skip_record(task.column, record):
self._apply_skip_to_record(task, record)
pre_skipped.add(ri)
continue
active_rows_data.append(record)
batch_df = (
lazy.pd.DataFrame(strip_skip_metadata_from_records(active_rows_data))
if active_rows_data
else lazy.pd.DataFrame()
)
else:
batch_df = lazy.pd.DataFrame()
pre_dropped = set()
pre_skipped = set()
if len(batch_df) == 0:
return batch_df
result_df = await generator.agenerate(batch_df)
# Merge result columns back to buffer (include side-effect columns)
if self._buffer_manager is not None:
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
active_rows = rg_size - len(pre_dropped) - len(pre_skipped)
if len(result_df) != active_rows:
raise ValueError(
f"Batch generator for '{task.column}' returned {len(result_df)} rows "
f"but {active_rows} were expected (rg={task.row_group})."
)
result_idx = 0
for ri in range(rg_size):
if ri in pre_dropped or ri in pre_skipped:
continue
if not self._buffer_manager.is_dropped(task.row_group, ri):
for col in write_cols:
if col in result_df.columns:
self._buffer_manager.update_cell(task.row_group, ri, col, result_df.iloc[result_idx][col])
result_idx += 1
return result_df
def _get_rg_size(self, row_group: int) -> int:
try:
return self._rg_size_map[row_group]
except KeyError:
raise ValueError(f"Unknown row group: {row_group}") from None
def get_semaphore_permits(self) -> tuple[int, int]:
"""Return ``(submission_available, llm_wait_available)`` for diagnostics."""
return (
self._submission_semaphore.available_permits,
self._llm_wait_semaphore.available_permits,
)
@staticmethod
def _is_retryable(exc: Exception) -> bool:
"""Classify whether an exception is retryable."""
return isinstance(exc, _RETRYABLE_MODEL_ERRORS)
def build_llm_bound_lookup(generators: dict[str, ColumnGenerator]) -> dict[str, bool]:
return {col: gen.is_llm_bound for col, gen in generators.items()}