Skip to content

Commit bfe6d6f

Browse files
Merge pull request opendatahub-io#137 from LukaszCmielowski/autox_clean_up_stages
feat(automl): Aggregate and align progress stages for AutoML and AutoRAG pipelines
2 parents 2040846 + 974f684 commit bfe6d6f

31 files changed

Lines changed: 374 additions & 325 deletions

File tree

components/data_processing/automl/tabular_data_loader/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ load_task = automl_data_loader(
224224

225225
In the tabular training pipeline, this component writes ``component_status.json`` under the
226226
``component_status`` output artifact. The file includes ``component_id`` (``automl_data_loader``),
227-
``started_at``, ``completed_at``, a ``stages`` list (ids such as ``validate_inputs``,
228-
``read_and_sample``, ``cleanse``, ``split``, ``write_outputs``), and optional ``metadata``.
227+
``started_at``, ``completed_at``, a ``stages`` list (ids such as ``prepare_data``,
228+
``split_and_export``), and optional ``metadata``.
229229
Match stage ids to the tabular pipeline entry in ``component_stage_map.json`` from the
230230
``publish-component-stage-map`` task.
231231

components/data_processing/automl/tabular_data_loader/component.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,9 @@ def automl_data_loader( # noqa: D417
132132
# Initialize status tracker
133133
status = ComponentStatusTracker(component_status.path, "automl_data_loader")
134134
with status:
135-
# Stage: validate_inputs
136-
status.record("validate_inputs", "started")
137-
# Validation happens inline below
138-
status.record("validate_inputs", "completed")
135+
status.set_metadata(display_name="Data Loader Status")
136+
component_status.metadata["display_name"] = "Data Loader Status"
137+
status.record("prepare_data", "started")
139138

140139
if sampling_method is None:
141140
if task_type in ("binary", "multiclass"):
@@ -324,10 +323,9 @@ def load_data_in_batches(
324323
return _sample_random(text_stream, PANDAS_CHUNK_SIZE, max_size_bytes)
325324
return _sample_first_n_rows(text_stream, PANDAS_CHUNK_SIZE, max_size_bytes)
326325

327-
# Stage: read_and_sample
328326
status.record(
329-
"read_and_sample",
330-
"started",
327+
"prepare_data",
328+
"running",
331329
sampling_method=sampling_method,
332330
source=f"s3://{bucket_name}/{file_key}",
333331
)
@@ -347,11 +345,6 @@ def load_data_in_batches(
347345
f"Available columns: {list(sampled_dataframe.columns)}"
348346
)
349347

350-
status.record("read_and_sample", "completed", rows=len(sampled_dataframe))
351-
352-
# Stage: cleanse
353-
status.record("cleanse", "started")
354-
355348
sampled_dataframe.replace([math.inf, -math.inf], float("nan"), inplace=True)
356349

357350
n_before_dedup = len(sampled_dataframe)
@@ -400,15 +393,14 @@ def load_data_in_batches(
400393
sampling_method,
401394
)
402395
status.record(
403-
"cleanse",
396+
"prepare_data",
404397
"completed",
405398
rows=n_samples,
406399
duplicates_dropped=n_dup_dropped,
407400
labels_dropped=n_dropped,
408401
)
409402

410-
# Stage: split
411-
status.record("split", "started")
403+
status.record("split_and_export", "started")
412404

413405
# --- Train/test split ---
414406
from pathlib import Path
@@ -465,19 +457,13 @@ def load_data_in_batches(
465457
X_y_test.to_csv(sampled_test_dataset.path, index=False)
466458

467459
status.record(
468-
"split",
460+
"split_and_export",
469461
"completed",
470462
test_size=test_size,
471463
selection_train_size=selection_train_size,
472464
stratify=stratify_effective,
473465
)
474466

475-
# Stage: write_outputs
476-
status.record("write_outputs", "started")
477-
status.record("write_outputs", "completed")
478-
479-
component_status.metadata["display_name"] = "Data Loader Status"
480-
481467
# Sample row for downstream use (JSON string to avoid NaN issues)
482468
sample_row = X_y_test.head(1).to_json(orient="records")
483469

components/data_processing/automl/tabular_data_loader/tests/test_component_unit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ def test_writes_component_status_json(self, tmp_path, monkeypatch):
212212
payload = json.loads(status_path.read_text())
213213
assert payload["component_id"] == "automl_data_loader"
214214
stage_ids = [stage["id"] for stage in payload["stages"]]
215-
assert "read_and_sample" in stage_ids
216-
assert "split" in stage_ids
215+
assert "prepare_data" in stage_ids
216+
assert "split_and_export" in stage_ids
217217

218218
@mock.patch.dict("os.environ", mocked_env_variables)
219219
def test_sets_component_status_display_name(self, tmp_path):

components/data_processing/automl/timeseries_data_loader/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,6 @@ def example_pipeline(
104104

105105
In the time series training pipeline, this component writes ``component_status.json`` under the
106106
``component_status`` output artifact. The file includes ``component_id`` (``timeseries_data_loader``),
107-
timestamps, and per-stage status (e.g. ``validate_inputs``, ``read_and_sample``, ``split``,
108-
``write_outputs``). Dashboards align stage ids with ``component_stage_map.json`` from
107+
timestamps, and per-stage status (e.g. ``prepare_data``, ``split_and_export``).
108+
Dashboards align stage ids with ``component_stage_map.json`` from
109109
``publish-component-stage-map``.

components/data_processing/automl/timeseries_data_loader/component.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,9 @@ def timeseries_data_loader(
9797

9898
status = ComponentStatusTracker(component_status.path, "timeseries_data_loader")
9999
with status:
100-
status.record("validate_inputs", "started")
101-
status.record("validate_inputs", "completed")
100+
status.set_metadata(display_name="Timeseries Data Loader Status")
101+
component_status.metadata["display_name"] = "Timeseries Data Loader Status"
102+
status.record("prepare_data", "started")
102103

103104
def get_s3_client(verify=True):
104105
"""Create and return an S3 client using credentials from environment variables."""
@@ -290,8 +291,8 @@ def _clean_timeseries_dataframe(data, id_col, ts_col, log):
290291
return out.reset_index(drop=True)
291292

292293
status.record(
293-
"read_and_sample",
294-
"started",
294+
"prepare_data",
295+
"running",
295296
source=f"s3://{bucket_name}/{file_key}",
296297
)
297298
df = load_timeseries_data_truncate(bucket_name, file_key, MAX_SIZE_BYTES, PANDAS_CHUNK_SIZE)
@@ -309,9 +310,6 @@ def _clean_timeseries_dataframe(data, id_col, ts_col, log):
309310
f"with columns {sorted(required_columns)}."
310311
)
311312

312-
status.record("read_and_sample", "completed", rows=len(df))
313-
status.record("cleanse", "started")
314-
315313
df = _clean_timeseries_dataframe(df, id_column, timestamp_column, logger)
316314

317315
n_valid = len(df)
@@ -322,8 +320,8 @@ def _clean_timeseries_dataframe(data, id_col, ts_col, log):
322320
"Provide a larger dataset or fix invalid timestamps, null ids, and duplicate keys."
323321
)
324322

325-
status.record("cleanse", "completed", rows=n_valid)
326-
status.record("split", "started")
323+
status.record("prepare_data", "completed", rows=n_valid)
324+
status.record("split_and_export", "started")
327325

328326
# Create workspace datasets directory
329327
datasets_dir = Path(workspace_path) / "datasets"
@@ -391,13 +389,6 @@ def _concat_sorted(parts: list, sort_by: list) -> pd.DataFrame:
391389
"each series has enough train rows for the selection segment."
392390
)
393391

394-
status.record(
395-
"split",
396-
"completed",
397-
test_size=test_size,
398-
selection_train_size=selection_train_size,
399-
)
400-
401392
# Save test dataset to artifact
402393
test_df.to_csv(sampled_test_dataset.path, index=False)
403394

@@ -407,6 +398,13 @@ def _concat_sorted(parts: list, sort_by: list) -> pd.DataFrame:
407398
selection_train_df.to_csv(selection_path, index=False)
408399
extra_train_df.to_csv(extra_path, index=False)
409400

401+
status.record(
402+
"split_and_export",
403+
"completed",
404+
test_size=test_size,
405+
selection_train_size=selection_train_size,
406+
)
407+
410408
logger.info(
411409
"Timeseries loader: %s rows from s3://%s/%s; split selection=%s extra=%s test=%s",
412410
len(df),
@@ -425,10 +423,6 @@ def _concat_sorted(parts: list, sort_by: list) -> pd.DataFrame:
425423
"selection_train_size": selection_train_size,
426424
}
427425

428-
status.record("write_outputs", "started")
429-
status.record("write_outputs", "completed")
430-
component_status.metadata["display_name"] = "Timeseries Data Loader Status"
431-
432426
# Sample rows for downstream use (ISO timestamps when supported; JSON string to avoid NaN issues)
433427
sample_tail = test_df.tail(min(5, len(test_df)))
434428
if hasattr(sample_tail, "to_dict"):

components/data_processing/autorag/documents_discovery/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ Lists available documents from S3, performs sampling if applied and writes a JSO
1313
| Parameter | Type | Default | Description |
1414
| --------- | ---- | ------- | ----------- |
1515
| `input_data_bucket_name` | `str` | `None` | S3 (or compatible) bucket containing input data. |
16+
| `component_status` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing stage-level progress tracking. |
1617
| `input_data_path` | `str` | `""` | Path to folder with input documents within the bucket. |
1718
| `test_data` | `dsl.Input[dsl.Artifact]` | `None` | Optional input artifact containing test data for sampling. |
1819
| `sampling_enabled` | `bool` | `True` | Whether to enable sampling or not. |
1920
| `sampling_max_size` | `float` | `1` | Maximum size of sampled documents (in gigabytes). |
2021
| `discovered_documents` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing the documents descriptor JSON file. |
21-
| `component_status` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing stage-level progress tracking. |
2222
| `embedded_artifact` | `dsl.EmbeddedInput[dsl.Dataset]` | `None` | Embedded ``autorag.shared`` helpers injected by KFP at runtime. |
2323

2424
## Usage Examples 🧪

components/data_processing/autorag/documents_discovery/component.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
)
1414
def documents_discovery(
1515
input_data_bucket_name: str,
16+
component_status: dsl.Output[dsl.Artifact],
1617
input_data_path: str = "",
1718
test_data: dsl.Input[dsl.Artifact] = None,
1819
sampling_enabled: bool = True,
1920
sampling_max_size: float = 1,
2021
discovered_documents: dsl.Output[dsl.Artifact] = None,
21-
component_status: dsl.Output[dsl.Artifact] = None,
2222
embedded_artifact: dsl.EmbeddedInput[dsl.Dataset] = None,
2323
):
2424
"""Documents discovery component.
@@ -86,7 +86,9 @@ def get_test_data_docs_names() -> list[str]:
8686
_spec.loader.exec_module(_status_module)
8787
status = _status_module.bootstrap_status_tracker(embedded_artifact, component_status, "documents_discovery")
8888
with status:
89-
with status.stage("validate_inputs"):
89+
status.set_metadata(display_name="Documents Discovery Status")
90+
component_status.metadata["display_name"] = "Documents Discovery Status"
91+
with status.stage("discover_documents"):
9092
s3_creds = {k: os.environ.get(k) for k in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_S3_ENDPOINT"]}
9193
for k, v in s3_creds.items():
9294
if v is None:
@@ -105,7 +107,6 @@ def _make_s3_client(verify=True):
105107
verify=verify,
106108
)
107109

108-
with status.stage("list_and_sample"):
109110
# Use paginator to handle buckets with >1,000 objects
110111
def _list_all_objects(s3_client):
111112
"""List all objects under prefix using pagination."""
@@ -186,7 +187,6 @@ def _list_all_objects(s3_client):
186187
f"enabled_max={sampling_max_size}GB" if sampling_enabled else "disabled",
187188
)
188189

189-
with status.stage("write_descriptor"):
190190
os.makedirs(discovered_documents.path, exist_ok=True)
191191
descriptor_path = os.path.join(discovered_documents.path, DOCUMENTS_DESCRIPTOR_FILENAME)
192192
with open(descriptor_path, "w") as f:

components/data_processing/autorag/test_data_loader/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ The component reads S3-compatible credentials from environment variables (inject
1414
| --------- | ---- | ------- | ----------- |
1515
| `test_data_bucket_name` | `str` | `None` | S3 (or compatible) bucket that contains the test data file. |
1616
| `test_data_path` | `str` | `None` | S3 object key to the JSON test data file. |
17+
| `component_status` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing stage-level progress tracking. |
1718
| `benchmark_sample_size` | `int` | `25` | Maximum number of records to keep from the test data. When the dataset exceeds this limit, a reproducible random sample is drawn (seed 42). Set to 0 to disable sampling and keep all records. |
1819
| `test_data` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact that receives the (possibly sampled) file. |
19-
| `component_status` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing stage-level progress tracking. |
2020
| `embedded_artifact` | `dsl.EmbeddedInput[dsl.Dataset]` | `None` | Embedded ``autorag.shared`` helpers injected by KFP at runtime. |
2121

2222
## Usage Examples 🧪

components/data_processing/autorag/test_data_loader/component.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
def test_data_loader(
1616
test_data_bucket_name: str,
1717
test_data_path: str,
18+
component_status: dsl.Output[dsl.Artifact],
1819
benchmark_sample_size: int = 25,
1920
test_data: dsl.Output[dsl.Artifact] = None,
20-
component_status: dsl.Output[dsl.Artifact] = None,
2121
embedded_artifact: dsl.EmbeddedInput[dsl.Dataset] = None,
2222
):
2323
"""Download test data JSON from S3 and sample it for benchmarking.
@@ -78,7 +78,9 @@ class TestDataLoaderException(Exception):
7878
_spec.loader.exec_module(_status_module)
7979
status = _status_module.bootstrap_status_tracker(embedded_artifact, component_status, "test_data_loader")
8080
with status:
81-
with status.stage("validate_inputs"):
81+
status.set_metadata(display_name="Test Data Loader Status")
82+
component_status.metadata["display_name"] = "Test Data Loader Status"
83+
with status.stage("load_benchmark"):
8284
if not test_data_bucket_name:
8385
raise TypeError("test_data_bucket_name must be a non-empty string")
8486

@@ -93,17 +95,16 @@ class TestDataLoaderException(Exception):
9395

9496
s3_creds["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION")
9597

96-
def _make_s3_client(verify=True):
97-
return boto3.client(
98-
"s3",
99-
endpoint_url=s3_creds["AWS_S3_ENDPOINT"],
100-
region_name=s3_creds["AWS_DEFAULT_REGION"],
101-
aws_access_key_id=s3_creds["AWS_ACCESS_KEY_ID"],
102-
aws_secret_access_key=s3_creds["AWS_SECRET_ACCESS_KEY"],
103-
verify=verify,
104-
)
105-
106-
with status.stage("download_and_sample"):
98+
def _make_s3_client(verify=True):
99+
return boto3.client(
100+
"s3",
101+
endpoint_url=s3_creds["AWS_S3_ENDPOINT"],
102+
region_name=s3_creds["AWS_DEFAULT_REGION"],
103+
aws_access_key_id=s3_creds["AWS_ACCESS_KEY_ID"],
104+
aws_secret_access_key=s3_creds["AWS_SECRET_ACCESS_KEY"],
105+
verify=verify,
106+
)
107+
107108
s3_client = _make_s3_client()
108109

109110
logger.info("Fetching test data from S3: bucket='%s', path='%s'.", test_data_bucket_name, test_data_path)
@@ -149,7 +150,6 @@ def _make_s3_client(verify=True):
149150
f"Make sure that each test data records contains following keys: {benchmark_record_keys}."
150151
)
151152

152-
with status.stage("write_output"):
153153
if 0 < benchmark_sample_size < len(benchmark_data) and isinstance(benchmark_data, list):
154154
import random
155155

components/data_processing/autorag/text_extraction/component.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
def text_extraction(
1616
documents_descriptor: dsl.Input[dsl.Artifact],
1717
extracted_text: dsl.Output[dsl.Artifact],
18-
component_status: dsl.Output[dsl.Artifact] = None,
18+
component_status: dsl.Output[dsl.Artifact],
1919
embedded_artifact: dsl.EmbeddedInput[dsl.Dataset] = None,
2020
error_tolerance: Optional[float] = None,
2121
max_extraction_workers: Optional[int] = None,
@@ -357,9 +357,11 @@ def raise_if_threshold_exceeded(error_details: list, total_docs: int, tolerance:
357357
_spec.loader.exec_module(_status_module)
358358
status = _status_module.bootstrap_status_tracker(embedded_artifact, component_status, "text_extraction")
359359
with status:
360+
status.set_metadata(display_name="Text Extraction Status")
361+
component_status.metadata["display_name"] = "Text Extraction Status"
360362
descriptor_path = Path(documents_descriptor.path) / DOCUMENTS_DESCRIPTOR_FILENAME
361363

362-
with status.stage("load_descriptor"):
364+
with status.stage("extract_documents"):
363365
if not descriptor_path.exists():
364366
raise FileNotFoundError(f"documents_descriptor.json not found at {descriptor_path}")
365367

@@ -390,7 +392,6 @@ def raise_if_threshold_exceeded(error_details: list, total_docs: int, tolerance:
390392
logger.info("No documents to process.")
391393
return
392394

393-
with status.stage("extract_documents"):
394395
documents = sorted(documents, key=lambda d: d.get("size_bytes", 0), reverse=True)
395396

396397
if max_extraction_workers is not None:

0 commit comments

Comments
 (0)