Skip to content

Commit d2f9dc7

Browse files
refactor: make component_status required, remove None handling
Address Daniel feedback: component_status should never be None in production KFP pipelines. Remove unnecessary None handling to simplify code and enforce proper usage. Signed-off-by: Lukasz Cmielowski <lcmielow@redhat.com> Assisted-by: Cursor
1 parent 1f97d20 commit d2f9dc7

13 files changed

Lines changed: 31 additions & 60 deletions

File tree

components/data_processing/autorag/documents_discovery/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ Lists available documents from S3, performs sampling if applied and writes a JSO
1313
| Parameter | Type | Default | Description |
1414
| --------- | ---- | ------- | ----------- |
1515
| `input_data_bucket_name` | `str` | `None` | S3 (or compatible) bucket containing input data. |
16+
| `component_status` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing stage-level progress tracking. |
1617
| `input_data_path` | `str` | `""` | Path to folder with input documents within the bucket. |
1718
| `test_data` | `dsl.Input[dsl.Artifact]` | `None` | Optional input artifact containing test data for sampling. |
1819
| `sampling_enabled` | `bool` | `True` | Whether to enable sampling or not. |
1920
| `sampling_max_size` | `float` | `1` | Maximum size of sampled documents (in gigabytes). |
2021
| `discovered_documents` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing the documents descriptor JSON file. |
21-
| `component_status` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing stage-level progress tracking. |
2222
| `embedded_artifact` | `dsl.EmbeddedInput[dsl.Dataset]` | `None` | Embedded ``autorag.shared`` helpers injected by KFP at runtime. |
2323

2424
## Usage Examples 🧪

components/data_processing/autorag/documents_discovery/component.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
)
1414
def documents_discovery(
1515
input_data_bucket_name: str,
16+
component_status: dsl.Output[dsl.Artifact],
1617
input_data_path: str = "",
1718
test_data: dsl.Input[dsl.Artifact] = None,
1819
sampling_enabled: bool = True,
1920
sampling_max_size: float = 1,
2021
discovered_documents: dsl.Output[dsl.Artifact] = None,
21-
component_status: dsl.Output[dsl.Artifact] = None,
2222
embedded_artifact: dsl.EmbeddedInput[dsl.Dataset] = None,
2323
):
2424
"""Documents discovery component.
@@ -86,8 +86,7 @@ def get_test_data_docs_names() -> list[str]:
8686
_spec.loader.exec_module(_status_module)
8787
status = _status_module.bootstrap_status_tracker(embedded_artifact, component_status, "documents_discovery")
8888
with status:
89-
if component_status is not None:
90-
component_status.metadata["display_name"] = "Documents Discovery Status"
89+
component_status.metadata["display_name"] = "Documents Discovery Status"
9190
with status.stage("discover_documents"):
9291
s3_creds = {k: os.environ.get(k) for k in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_S3_ENDPOINT"]}
9392
for k, v in s3_creds.items():

components/data_processing/autorag/test_data_loader/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ The component reads S3-compatible credentials from environment variables (inject
1414
| --------- | ---- | ------- | ----------- |
1515
| `test_data_bucket_name` | `str` | `None` | S3 (or compatible) bucket that contains the test data file. |
1616
| `test_data_path` | `str` | `None` | S3 object key to the JSON test data file. |
17+
| `component_status` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing stage-level progress tracking. |
1718
| `benchmark_sample_size` | `int` | `25` | Maximum number of records to keep from the test data. When the dataset exceeds this limit, a reproducible random sample is drawn (seed 42). Set to 0 to disable sampling and keep all records. |
1819
| `test_data` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact that receives the (possibly sampled) file. |
19-
| `component_status` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing stage-level progress tracking. |
2020
| `embedded_artifact` | `dsl.EmbeddedInput[dsl.Dataset]` | `None` | Embedded ``autorag.shared`` helpers injected by KFP at runtime. |
2121

2222
## Usage Examples 🧪

components/data_processing/autorag/test_data_loader/component.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
def test_data_loader(
1616
test_data_bucket_name: str,
1717
test_data_path: str,
18+
component_status: dsl.Output[dsl.Artifact],
1819
benchmark_sample_size: int = 25,
1920
test_data: dsl.Output[dsl.Artifact] = None,
20-
component_status: dsl.Output[dsl.Artifact] = None,
2121
embedded_artifact: dsl.EmbeddedInput[dsl.Dataset] = None,
2222
):
2323
"""Download test data JSON from S3 and sample it for benchmarking.
@@ -78,8 +78,7 @@ class TestDataLoaderException(Exception):
7878
_spec.loader.exec_module(_status_module)
7979
status = _status_module.bootstrap_status_tracker(embedded_artifact, component_status, "test_data_loader")
8080
with status:
81-
if component_status is not None:
82-
component_status.metadata["display_name"] = "Test Data Loader Status"
81+
component_status.metadata["display_name"] = "Test Data Loader Status"
8382
with status.stage("load_benchmark"):
8483
if not test_data_bucket_name:
8584
raise TypeError("test_data_bucket_name must be a non-empty string")

components/data_processing/autorag/text_extraction/component.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
def text_extraction(
1616
documents_descriptor: dsl.Input[dsl.Artifact],
1717
extracted_text: dsl.Output[dsl.Artifact],
18-
component_status: dsl.Output[dsl.Artifact] = None,
18+
component_status: dsl.Output[dsl.Artifact],
1919
embedded_artifact: dsl.EmbeddedInput[dsl.Dataset] = None,
2020
error_tolerance: Optional[float] = None,
2121
max_extraction_workers: Optional[int] = None,
@@ -357,8 +357,7 @@ def raise_if_threshold_exceeded(error_details: list, total_docs: int, tolerance:
357357
_spec.loader.exec_module(_status_module)
358358
status = _status_module.bootstrap_status_tracker(embedded_artifact, component_status, "text_extraction")
359359
with status:
360-
if component_status is not None:
361-
component_status.metadata["display_name"] = "Text Extraction Status"
360+
component_status.metadata["display_name"] = "Text Extraction Status"
362361
descriptor_path = Path(documents_descriptor.path) / DOCUMENTS_DESCRIPTOR_FILENAME
363362

364363
with status.stage("extract_documents"):

components/training/automl/shared/component_status.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,14 @@ class ComponentStatusTracker:
6161
Each component independently tracks its stages and metadata.
6262
"""
6363

64-
def __init__(self, artifact_path: str | None, component_id: str) -> None:
64+
def __init__(self, artifact_path: str, component_id: str) -> None:
6565
"""Initialize the status tracker.
6666
6767
Args:
68-
artifact_path: Path to the KFP artifact directory where status.json will be written.
69-
When ``None``, tracking is disabled (e.g. unit tests without a mock artifact).
68+
artifact_path: Path to the KFP artifact directory where component_status.json will be written.
7069
component_id: Unique component identifier (e.g., "autogluon_models_training").
7170
"""
72-
self._enabled = artifact_path is not None
73-
self.artifact_path = Path(artifact_path) if self._enabled else Path(".")
71+
self.artifact_path = Path(artifact_path)
7472
self.component_id = component_id
7573
self.stages: list[dict[str, Any]] = []
7674
self.started_at = utc_now_z()
@@ -131,9 +129,6 @@ def save(self) -> None:
131129
Creates the artifact directory if needed and writes component_status.json
132130
with all recorded stages and metadata.
133131
"""
134-
if not self._enabled:
135-
return
136-
137132
self.artifact_path.mkdir(parents=True, exist_ok=True)
138133

139134
data = {

components/training/automl/shared/tests/test_component_status.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,6 @@ def test_stage_context_manager_records_failed(self, tmp_path: Path) -> None:
9595
assert data["stages"][-1]["status"] == "failed"
9696
assert "bad split" in data["stages"][-1]["error"]
9797

98-
def test_disabled_tracker_skips_save(self, tmp_path: Path) -> None:
99-
"""When artifact_path is None, save() is a no-op."""
100-
tracker = ComponentStatusTracker(None, "automl_data_loader")
101-
tracker.record("prepare_data", "completed")
102-
tracker.save()
103-
assert not (tmp_path / COMPONENT_STATUS_FILENAME).exists()
104-
10598
def test_stage_skips_auto_complete_when_completed_inside_block(self, tmp_path: Path) -> None:
10699
"""stage() does not overwrite a completed record written inside the block."""
107100
tracker = ComponentStatusTracker(str(tmp_path), "autogluon_models_training")

components/training/autorag/leaderboard_evaluation/component.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
def leaderboard_evaluation(
1515
rag_patterns: dsl.InputPath(dsl.Artifact),
1616
html_artifact: dsl.Output[dsl.HTML],
17-
component_status: dsl.Output[dsl.Artifact] = None,
17+
component_status: dsl.Output[dsl.Artifact],
1818
embedded_artifact: dsl.EmbeddedInput[dsl.Dataset] = None,
1919
optimization_metric: str = "faithfulness",
2020
):
@@ -338,8 +338,7 @@ def _build_leaderboard_html(
338338
_spec.loader.exec_module(_status_module)
339339
status = _status_module.bootstrap_status_tracker(embedded_artifact, component_status, "leaderboard_evaluation")
340340
with status:
341-
if component_status is not None:
342-
component_status.metadata["display_name"] = "Leaderboard Evaluation Status"
341+
component_status.metadata["display_name"] = "Leaderboard Evaluation Status"
343342
with status.stage("build_leaderboard"):
344343
if not rag_patterns_dir.is_dir():
345344
raise FileNotFoundError("rag_patterns path is not a directory: %s" % rag_patterns_dir)

components/training/autorag/rag_templates_optimization/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ Carries out the iterative RAG optimization process.
1818
| `rag_patterns` | `dsl.Output[dsl.Artifact]` | `None` | kfp-enforced argument specifying an output artifact. Provided by kfp backend automatically. |
1919
| `test_data_key` | `Optional[str]` | `None` | Path to the benchmark JSON file in object storage used by generated notebooks. |
2020
| `vector_io_provider_id` | `str` | `None` | Vector I/O provider identifier as registered in OGX. |
21-
| `embedded_artifact` | `dsl.EmbeddedInput[dsl.Dataset]` | `None` | Embedded ``autorag.shared`` helpers injected by KFP at runtime. |
2221
| `component_status` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact containing stage-level progress tracking. |
22+
| `embedded_artifact` | `dsl.EmbeddedInput[dsl.Dataset]` | `None` | Embedded ``autorag.shared`` helpers injected by KFP at runtime. |
2323
| `optimization_settings` | `Optional[dict]` | `None` | Additional settings customising the experiment. |
2424
| `input_data_key` | `Optional[str]` | `""` | A path to documents dir within a bucket used as an input to AI4RAG experiment. |
2525

components/training/autorag/rag_templates_optimization/component.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def rag_templates_optimization(
1919
rag_patterns: dsl.Output[dsl.Artifact],
2020
test_data_key: Optional[str],
2121
vector_io_provider_id: str,
22+
component_status: dsl.Output[dsl.Artifact],
2223
embedded_artifact: dsl.EmbeddedInput[dsl.Dataset] = None,
23-
component_status: dsl.Output[dsl.Artifact] = None,
2424
optimization_settings: Optional[dict] = None,
2525
input_data_key: Optional[str] = "",
2626
):
@@ -551,8 +551,7 @@ def on_pattern_creation(self, payload: dict, evaluation_results: list, **kwargs)
551551
pass
552552

553553
with status:
554-
if component_status is not None:
555-
component_status.metadata["display_name"] = "RAG Templates Optimization Status"
554+
component_status.metadata["display_name"] = "RAG Templates Optimization Status"
556555
with status.stage("optimize_templates", steps=optimize_templates_steps):
557556
if not ogx_client_base_url or not ogx_client_api_key:
558557
raise ValueError(

0 commit comments

Comments
 (0)