Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
from data_designer.config.config_builder import DataDesignerConfigBuilder
Expand All @@ -23,6 +23,7 @@ def __init__(
dataset: pd.DataFrame | None = None,
analysis: DatasetProfilerResults | None = None,
processor_artifacts: dict[str, list[dict]] | None = None,
task_traces: list[Any] | None = None,
):
"""Creates a new instance with results from a Data Designer preview run.

Expand All @@ -32,9 +33,11 @@ def __init__(
dataset: Dataset of the preview run.
analysis: Analysis of the preview run.
processor_artifacts: Artifacts generated by the processors.
task_traces: Async scheduler task traces (when DATA_DESIGNER_ASYNC_TRACE=1).
"""
self.dataset: pd.DataFrame | None = dataset
self.analysis: DatasetProfilerResults | None = analysis
self.processor_artifacts: dict[str, list[dict]] | None = processor_artifacts
self.dataset_metadata: DatasetMetadata | None = dataset_metadata
self.task_traces: list[Any] | None = task_traces
self._config_builder = config_builder
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import data_designer.lazy_heavy_imports as lazy
from data_designer.config.column_configs import GenerationStrategy
from data_designer.engine.context import current_row_group
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
from data_designer.engine.dataset_builders.utils.async_progress_reporter import (
DEFAULT_REPORT_INTERVAL,
Expand Down Expand Up @@ -268,30 +269,25 @@ async def run(self) -> None:
try:
# Main dispatch loop
await self._main_dispatch_loop(seed_cols, has_pre_batch, all_columns)

# Cancel admission if still running
finally:
# Always cancel admission + drain in-flight workers, regardless
# of how the dispatch loop exited (normal, early shutdown,
# CancelledError, or processor failure).
if not admission_task.done():
admission_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await admission_task
await asyncio.shield(self._cancel_workers())

if self._reporter:
self._reporter.log_final()

if self._rg_states:
incomplete = list(self._rg_states)
logger.error(
f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. "
"These row groups were not checkpointed."
)
if self._reporter:
self._reporter.log_final()

except asyncio.CancelledError:
if not admission_task.done():
admission_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await admission_task
await asyncio.shield(self._cancel_workers())
raise
if self._rg_states:
incomplete = list(self._rg_states)
logger.error(
f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. "
"These row groups were not checkpointed."
)

async def _main_dispatch_loop(
self,
Expand Down Expand Up @@ -500,29 +496,26 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None:
if self._tracker.is_row_group_complete(rg_id, state.size, all_columns)
]
for rg_id, rg_size in completed:
dropped = False
try:
del self._rg_states[rg_id]
if self._on_before_checkpoint:
try:
self._on_before_checkpoint(rg_id, rg_size)
except Exception:
# Post-batch is mandatory; drop rather than checkpoint unprocessed data.
logger.error(
f"on_before_checkpoint failed for row group {rg_id}, dropping row group.",
exc_info=True,
)
self._drop_row_group(rg_id, rg_size)
if self._buffer_manager:
self._buffer_manager.free_row_group(rg_id)
dropped = True
except DatasetGenerationError:
raise
except Exception as exc:
raise DatasetGenerationError(
f"Post-batch processor failed for row group {rg_id}: {exc}"
) from exc
# Remove from tracking only after the callback succeeds.
del self._rg_states[rg_id]
# If all rows were dropped (e.g. seed failure), free instead of finalizing
if not dropped and all(self._tracker.is_dropped(rg_id, ri) for ri in range(rg_size)):
if all(self._tracker.is_dropped(rg_id, ri) for ri in range(rg_size)):
if self._buffer_manager:
self._buffer_manager.free_row_group(rg_id)
dropped = True
if not dropped and self._on_finalize_row_group is not None:
elif self._on_finalize_row_group is not None:
self._on_finalize_row_group(rg_id)
except DatasetGenerationError:
raise
except Exception:
logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True)
finally:
Expand All @@ -543,19 +536,19 @@ def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None:
if self._on_seeds_complete:
try:
self._on_seeds_complete(rg_id, state.size)
# The callback may drop rows (e.g. pre-batch filtering).
# Record skipped tasks for any newly-dropped rows so
# progress reporting stays accurate.
if self._reporter:
for ri in range(state.size):
if self._tracker.is_dropped(rg_id, ri):
self._record_skipped_tasks_for_row(rg_id, ri)
except Exception:
logger.warning(
f"Pre-batch processor failed for row group {rg_id}, skipping.",
exc_info=True,
)
self._drop_row_group(rg_id, state.size)
except DatasetGenerationError:
raise
except Exception as exc:
raise DatasetGenerationError(
f"Pre-batch processor failed for row group {rg_id}: {exc}"
) from exc
# The callback may drop rows (e.g. pre-batch filtering).
# Record skipped tasks for any newly-dropped rows so
# progress reporting stays accurate.
if self._reporter:
for ri in range(state.size):
if self._tracker.is_dropped(rg_id, ri):
self._record_skipped_tasks_for_row(rg_id, ri)

def _drop_row(self, row_group: int, row_index: int, *, exclude_columns: set[str] | None = None) -> None:
if self._tracker.is_dropped(row_group, row_index):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import time
import uuid
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(
self._task_traces: list[TaskTrace] = []
self._registry = registry or DataDesignerRegistry()
self._graph: ExecutionGraph | None = None
self._use_async: bool = DATA_DESIGNER_ASYNC_ENGINE

self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider)
self._column_configs = compile_dataset_builder_column_configs(self._data_designer_config)
Expand Down Expand Up @@ -185,8 +187,8 @@ def build(
start_time = time.perf_counter()
buffer_size = self._resource_provider.run_config.buffer_size

if DATA_DESIGNER_ASYNC_ENGINE:
self._validate_async_compatibility()
self._use_async = DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
if self._use_async:
self._build_async(generators, num_records, buffer_size, on_batch_complete)
else:
group_id = uuid.uuid4().hex
Expand Down Expand Up @@ -218,8 +220,8 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame:
generators, self._graph = self._initialize_generators_and_graph()
start_time = time.perf_counter()

if DATA_DESIGNER_ASYNC_ENGINE:
self._validate_async_compatibility()
self._use_async = DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
if self._use_async:
dataset = self._build_async_preview(generators, num_records)
else:
group_id = uuid.uuid4().hex
Expand All @@ -236,11 +238,15 @@ def _build_async_preview(self, generators: list[ColumnGenerator], num_records: i
"""Async preview path - single row group, no disk writes, returns in-memory DataFrame."""
logger.info("⚑ DATA_DESIGNER_ASYNC_ENGINE is enabled - using async task-queue preview")

settings = self._resource_provider.run_config
trace_enabled = settings.async_trace or os.environ.get("DATA_DESIGNER_ASYNC_TRACE", "0") == "1"

scheduler, buffer_manager = self._prepare_async_run(
generators,
num_records,
buffer_size=num_records,
run_post_batch_in_scheduler=False,
trace=trace_enabled,
)

loop = ensure_async_engine_loop()
Expand All @@ -256,15 +262,23 @@ def _build_async_preview(self, generators: list[ColumnGenerator], num_records: i
buffer_manager.free_row_group(0)
return dataset

def _validate_async_compatibility(self) -> None:
"""Raise if any column uses allow_resize=True with the async scheduler."""
def _resolve_async_compatibility(self) -> bool:
"""Check if the async engine can be used; auto-fallback to sync if not.

Returns True if async is usable, False if allow_resize forces sync fallback.
"""
offending = [config.name for config in self.single_column_configs if getattr(config, "allow_resize", False)]
if offending:
raise DatasetGenerationError(
f"allow_resize=True is not supported with DATA_DESIGNER_ASYNC_ENGINE=1. "
f"Offending column(s): {offending}. Either remove allow_resize=True or "
f"disable the async scheduler."
msg = (
f"allow_resize=True detected on column(s) {offending}. "
"Falling back to sync engine for this run. "
"allow_resize is deprecated and will be removed in a future release; "
"use workflow chaining instead (see issue #552)."
)
logger.warning(f"⚠️ {msg}")
warnings.warn(msg, DeprecationWarning, stacklevel=4)
return False
return True

def _build_async(
self,
Expand Down Expand Up @@ -318,6 +332,15 @@ def on_complete(final_path: Path | str | None) -> None:
# Write metadata
buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size)

# Surface partial completion
actual = buffer_manager.actual_num_records
if actual < num_records:
pct = actual / num_records * 100 if num_records > 0 else 0
logger.warning(
f"⚠️ Generated {actual} of {num_records} requested records ({pct:.0f}%). "
"The dataset may be incomplete due to errors or early shutdown."
)

def _prepare_async_run(
self,
generators: list[ColumnGenerator],
Expand Down Expand Up @@ -366,10 +389,10 @@ def _prepare_async_run(
buffer_manager = RowGroupBufferManager(self.artifact_storage)

# Pre-batch processor callback: runs after seed tasks complete for a row group.
# If it raises, the scheduler drops all rows in the row group (skips it).
# If it raises, the scheduler propagates the error as DatasetGenerationError (fail-fast).
def on_seeds_complete(rg_id: int, rg_size: int) -> None:
df = buffer_manager.get_dataframe(rg_id)
df = self._processor_runner.run_pre_batch_on_df(df)
df = self._processor_runner.run_pre_batch_on_df(df, strict_row_count=True)
buffer_manager.replace_dataframe(rg_id, df)
for ri in range(rg_size):
if buffer_manager.is_dropped(rg_id, ri) and not tracker.is_dropped(rg_id, ri):
Expand All @@ -378,7 +401,13 @@ def on_seeds_complete(rg_id: int, rg_size: int) -> None:
# Post-batch processor callback: runs after all columns, before finalization.
def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
df = buffer_manager.get_dataframe(rg_id)
original_len = len(df)
df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id)
if len(df) != original_len:
raise DatasetGenerationError(
f"Post-batch processor changed row count from {original_len} to {len(df)}. "
"Row-count changes in post-batch processors are not supported with the async engine."
)
buffer_manager.replace_dataframe(rg_id, df)

# Coarse upper bound: sums all registered aliases, not just those used
Expand Down Expand Up @@ -505,7 +534,7 @@ def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
max_workers = self._resource_provider.run_config.non_inference_max_parallel_workers
if isinstance(generator, ColumnGeneratorWithModel):
max_workers = generator.inference_parameters.max_parallel_requests
if DATA_DESIGNER_ASYNC_ENGINE:
if self._use_async:
logger.info("⚑ Using async engine for concurrent execution")
self._fan_out_with_async(generator, max_workers=max_workers)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,23 @@ def run_pre_batch(self, batch_manager: DatasetBatchManager) -> None:
df = self._run_stage(df, ProcessorStage.PRE_BATCH)
batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=True)

def run_pre_batch_on_df(self, df: pd.DataFrame) -> pd.DataFrame:
"""Run PRE_BATCH processors on a DataFrame and return the result."""
return self._run_stage(df, ProcessorStage.PRE_BATCH)
def run_pre_batch_on_df(self, df: pd.DataFrame, *, strict_row_count: bool = False) -> pd.DataFrame:
"""Run PRE_BATCH processors on a DataFrame and return the result.

Args:
df: Input DataFrame.
strict_row_count: If True, raise ``DatasetProcessingError`` when a
processor changes the row count. Used by the async engine where
row-count changes are not supported.
"""
original_len = len(df)
df = self._run_stage(df, ProcessorStage.PRE_BATCH)
if strict_row_count and len(df) != original_len:
raise DatasetProcessingError(
f"Pre-batch processor changed row count from {original_len} to {len(df)}. "
"Row-count changes in pre-batch processors are not supported with the async engine."
)
return df

def run_post_batch(self, df: pd.DataFrame, current_batch_number: int | None) -> pd.DataFrame:
"""Run process_after_batch() on processors that implement it."""
Expand Down
Loading
Loading