Skip to content

Commit c30f70c

Browse files
Merge branch 'opendatahub-io:main' into main
2 parents 0bb3b60 + 1c1e7ae commit c30f70c

12 files changed

Lines changed: 649 additions & 160 deletions

File tree

.github/workflows/readme-check.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ on:
1212
- '.github/scripts/**'
1313
- 'pyproject.toml'
1414
- 'uv.lock'
15-
- '!**/OWNERS'
1615

1716
jobs:
1817
check-readme-sync:

components/data_processing/autorag/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ This subcategory contains components in the **Autorag** group:
44

55
- [Documents Discovery](./documents_discovery/README.md): Documents discovery component.
66
- [Documents Indexing](./documents_indexing/README.md): Index extracted text into a vector store with optional batch processing.
7-
- [Test Data Loader](./test_data_loader/README.md): Download test data json file from S3 into a KFP artifact.
7+
- [Test Data Loader](./test_data_loader/README.md): Download test data JSON from S3 and sample it for benchmarking.
88
- [Text Extraction](./text_extraction/README.md): Text Extraction component.

components/data_processing/autorag/test_data_loader/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
55
## Overview 🧾
66

7-
Download test data json file from S3 into a KFP artifact.
7+
Download test data JSON from S3 and sample it for benchmarking.
88

9-
The component reads S3-compatible credentials from environment variables (injected by the pipeline from a Kubernetes secret) and downloads a JSON test data file from the provided bucket and path to the output artifact.
9+
The component reads S3-compatible credentials from environment variables (injected by the pipeline from a Kubernetes secret), downloads a JSON test data file, and randomly samples up to ``benchmark_sample_size`` records to limit evaluation cost in downstream components.
1010

1111
## Inputs 📥
1212

1313
| Parameter | Type | Default | Description |
1414
| --------- | ---- | ------- | ----------- |
1515
| `test_data_bucket_name` | `str` | `None` | S3 (or compatible) bucket that contains the test data file. |
1616
| `test_data_path` | `str` | `None` | S3 object key to the JSON test data file. |
17-
| `test_data` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact that receives the downloaded file. |
17+
| `benchmark_sample_size` | `int` | `25` | Maximum number of records to keep from the test data. When the dataset exceeds this limit, a reproducible random sample is drawn (seed 42). Set to 0 to disable sampling and keep all records. |
18+
| `test_data` | `dsl.Output[dsl.Artifact]` | `None` | Output artifact that receives the (possibly sampled) file. |
1819

1920
## Usage Examples 🧪
2021

components/data_processing/autorag/test_data_loader/component.py

Lines changed: 97 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,28 @@
77
@dsl.component(
88
base_image=AUTORAG_IMAGE, # noqa: E501
99
)
10-
def test_data_loader(test_data_bucket_name: str, test_data_path: str, test_data: dsl.Output[dsl.Artifact] = None):
11-
"""Download test data json file from S3 into a KFP artifact.
10+
def test_data_loader(
11+
test_data_bucket_name: str,
12+
test_data_path: str,
13+
benchmark_sample_size: int = 25,
14+
test_data: dsl.Output[dsl.Artifact] = None,
15+
):
16+
"""Download test data JSON from S3 and sample it for benchmarking.
1217
1318
The component reads S3-compatible credentials from environment variables
14-
(injected by the pipeline from a Kubernetes secret) and downloads a JSON
15-
test data file from the provided bucket and path to the output artifact.
19+
(injected by the pipeline from a Kubernetes secret), downloads a JSON
20+
test data file, and randomly samples up to ``benchmark_sample_size``
21+
records to limit evaluation cost in downstream components.
1622
1723
Args:
1824
test_data_bucket_name: S3 (or compatible) bucket that contains the test
1925
data file.
2026
test_data_path: S3 object key to the JSON test data file.
21-
test_data: Output artifact that receives the downloaded file.
27+
benchmark_sample_size: Maximum number of records to keep from the test
28+
data. When the dataset exceeds this limit, a reproducible random
29+
sample is drawn (seed 42). Set to 0 to disable sampling and keep
30+
all records.
31+
test_data: Output artifact that receives the (possibly sampled) file.
2232
2333
Environment variables (required when run with pipeline secret injection):
2434
AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_S3_ENDPOINT.
@@ -45,64 +55,90 @@ def test_data_loader(test_data_bucket_name: str, test_data_path: str, test_data:
4555
if not test_data_bucket_name:
4656
raise TypeError("test_data_bucket_name must be a non-empty string")
4757

48-
def get_test_data_s3():
49-
"""Validate S3 credentials and download the JSON test data file."""
50-
51-
class TestDataLoaderException(Exception):
52-
pass
53-
54-
s3_creds = {k: os.environ.get(k) for k in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_S3_ENDPOINT"]}
55-
for k, v in s3_creds.items():
56-
if v is None:
57-
raise ValueError(
58-
"%s environment variable not set. Check if kubernetes secret was configured properly" % k
59-
)
60-
s3_creds["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION")
61-
62-
def _make_s3_client(verify=True):
63-
return boto3.client(
64-
"s3",
65-
endpoint_url=s3_creds["AWS_S3_ENDPOINT"],
66-
region_name=s3_creds["AWS_DEFAULT_REGION"],
67-
aws_access_key_id=s3_creds["AWS_ACCESS_KEY_ID"],
68-
aws_secret_access_key=s3_creds["AWS_SECRET_ACCESS_KEY"],
69-
verify=verify,
58+
benchmark_record_keys = {"question", "correct_answers", "correct_answer_document_ids"}
59+
60+
class TestDataLoaderException(Exception):
61+
pass
62+
63+
s3_creds = {k: os.environ.get(k) for k in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_S3_ENDPOINT"]}
64+
missing_creds = [k for k, v in s3_creds.items() if v is None]
65+
66+
if missing_creds:
67+
raise ValueError(
68+
f"Missing environment variable(s): {missing_creds}. Check if kubernetes secret was configured properly."
69+
)
70+
71+
s3_creds["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION")
72+
73+
def _make_s3_client(verify=True):
74+
return boto3.client(
75+
"s3",
76+
endpoint_url=s3_creds["AWS_S3_ENDPOINT"],
77+
region_name=s3_creds["AWS_DEFAULT_REGION"],
78+
aws_access_key_id=s3_creds["AWS_ACCESS_KEY_ID"],
79+
aws_secret_access_key=s3_creds["AWS_SECRET_ACCESS_KEY"],
80+
verify=verify,
81+
)
82+
83+
s3_client = _make_s3_client()
84+
85+
logger.info("Fetching test data from S3: bucket='%s', path='%s'.", test_data_bucket_name, test_data_path)
86+
try:
87+
logger.info("Downloading test data...")
88+
test_data_response = s3_client.get_object(Bucket=test_data_bucket_name, Key=test_data_path)
89+
logger.info("Download completed successfully.")
90+
except SSLError:
91+
logger.warning("SSL error when downloading %s, retrying with verify=False.", test_data_path)
92+
s3_client = _make_s3_client(verify=False)
93+
test_data_response = s3_client.get_object(Bucket=test_data_bucket_name, Key=test_data_path)
94+
logger.info("Download completed successfully with verify=False.")
95+
except ClientError as e:
96+
if e.response.get("Error", {}).get("Code") in ("404", "NoSuchKey"):
97+
raise FileNotFoundError(
98+
"Test data object not found in S3. bucket=%r, key=%r. "
99+
"Check that test_data_key (pipeline parameter) is the full object key to an existing JSON file."
100+
% (test_data_bucket_name, test_data_path)
101+
) from e
102+
else:
103+
raise TestDataLoaderException(f"Failed to fetch {test_data_path}: {e}") from e
104+
except Exception as e:
105+
raise TestDataLoaderException(f"Failed to fetch {test_data_path}: {e}") from e
106+
107+
test_data_raw = test_data_response["Body"].read().decode("utf-8")
108+
109+
try:
110+
benchmark_data = json.loads(test_data_raw)
111+
except JSONDecodeError as e:
112+
raise TestDataLoaderException("test_data_path must point to a valid JSON file.") from e
113+
114+
if not isinstance(benchmark_data, list):
115+
raise TestDataLoaderException("Test data file content must be a list with benchmark records.")
116+
117+
for idx, benchmark_record in enumerate(benchmark_data):
118+
if not isinstance(benchmark_record, dict):
119+
raise TestDataLoaderException(
120+
f"Expected a dict at index {idx}, got {type(benchmark_record).__name__}: {benchmark_record!r}"
70121
)
71-
72-
s3_client = _make_s3_client()
73-
74-
logger.info(f"Fetching test data from S3: bucket={test_data_bucket_name}, path={test_data_path}")
75-
try:
76-
logger.info(f"Starting download to {test_data.path}")
77-
s3_client.download_file(test_data_bucket_name, test_data_path, test_data.path)
78-
logger.info("Download completed successfully")
79-
except SSLError:
80-
logger.warning(
81-
"SSL error when downloading %s, retrying with verify=False",
82-
test_data_path,
122+
if set(benchmark_record.keys()) != benchmark_record_keys:
123+
raise TestDataLoaderException(
124+
f"Incorrect or incomplete keys in test data record. "
125+
f"Make sure that each test data records contains following keys: {benchmark_record_keys}."
83126
)
84-
s3_client = _make_s3_client(verify=False)
85-
s3_client.download_file(test_data_bucket_name, test_data_path, test_data.path)
86-
logger.info("Download completed successfully with verify=False")
87-
except ClientError as e:
88-
if e.response.get("Error", {}).get("Code") in ("404", "NoSuchKey"):
89-
raise FileNotFoundError(
90-
"Test data object not found in S3. bucket=%r, key=%r. "
91-
"Check that test_data_key (pipeline parameter) is the full object key to an existing JSON file."
92-
% (test_data_bucket_name, test_data_path)
93-
) from e
94-
else:
95-
raise TestDataLoaderException("Failed to fetch %s: %s", test_data_path, e) from e
96-
except Exception as e:
97-
raise TestDataLoaderException("Failed to fetch %s: %s", test_data_path, e) from e
98-
99-
try:
100-
with open(test_data.path, "r") as f:
101-
json.load(f)
102-
except JSONDecodeError as e:
103-
raise TestDataLoaderException("test_data_path must point to a valid JSON file.") from e
104-
105-
get_test_data_s3()
127+
128+
if 0 < benchmark_sample_size < len(benchmark_data) and isinstance(benchmark_data, list):
129+
import random
130+
131+
original_count = len(benchmark_data)
132+
rng = random.Random(42)
133+
data = rng.sample(benchmark_data, benchmark_sample_size)
134+
with open(test_data.path, "w", encoding="utf-8") as f:
135+
json.dump(data, f, ensure_ascii=False, indent=2)
136+
logger.info("Sampled %d records from %d total.", benchmark_sample_size, original_count)
137+
else:
138+
with open(test_data.path, "w", encoding="utf-8") as f:
139+
json.dump(benchmark_data, f, ensure_ascii=False, indent=2)
140+
record_count = len(benchmark_data)
141+
logger.info("No sampling applied; record count: %s.", record_count)
106142

107143

108144
if __name__ == "__main__":

0 commit comments

Comments
 (0)