|
1 | 1 | import builtins |
2 | 2 | import datetime |
3 | 3 | import logging |
| 4 | +from sys import getsizeof |
4 | 5 | from typing import Any, Callable, Optional, Union |
5 | 6 |
|
6 | 7 | from attrs import define as _attrs_define |
7 | 8 | from attrs import field as _attrs_field |
8 | 9 |
|
9 | 10 | from galileo import galileo_context, log |
10 | 11 | from galileo.config import GalileoPythonConfig |
11 | | -from galileo.datasets import Dataset |
| 12 | +from galileo.datasets import Dataset, convert_dataset_row_to_record |
12 | 13 | from galileo.experiment_tags import upsert_experiment_tag |
13 | 14 | from galileo.jobs import Jobs |
14 | 15 | from galileo.projects import Project, Projects |
|
20 | 21 | from galileo.resources.models import ExperimentResponse, HTTPValidationError, PromptRunSettings, ScorerConfig, TaskType |
21 | 22 | from galileo.schema.datasets import DatasetRecord |
22 | 23 | from galileo.schema.metrics import GalileoScorers, LocalMetricConfig, Metric |
23 | | -from galileo.utils.datasets import load_dataset_and_records |
| 24 | +from galileo.utils.datasets import create_rows_from_records, load_dataset |
24 | 25 | from galileo.utils.logging import get_logger |
25 | 26 | from galileo.utils.metrics import create_metric_configs |
26 | 27 |
|
27 | 28 | _logger = get_logger(__name__) |
28 | 29 |
|
29 | 30 | EXPERIMENT_TASK_TYPE: TaskType = 16 |
30 | 31 |
|
| 32 | +MAX_REQUEST_SIZE_BYTES = 10 * 1024 * 1024 # 10 MB |
| 33 | +MAX_INGEST_BATCH_SIZE = 128 |
| 34 | +DATASET_CONTENT_PAGE_SIZE = 1000 |
| 35 | + |
31 | 36 |
|
32 | 37 | @_attrs_define |
33 | 38 | class ExperimentCreateRequest: |
@@ -148,20 +153,54 @@ def run_with_function( |
148 | 153 | self, |
149 | 154 | project_obj: Project, |
150 | 155 | experiment_obj: ExperimentResponse, |
151 | | - records: builtins.list[DatasetRecord], |
| 156 | + dataset_obj: Optional[Dataset], |
| 157 | + records: Optional[builtins.list[DatasetRecord]], |
152 | 158 | func: Callable, |
153 | 159 | local_metrics: builtins.list[LocalMetricConfig], |
154 | 160 | ) -> dict[str, Any]: |
| 161 | + if dataset_obj is None and records is None: |
| 162 | + raise ValueError("Either dataset_obj or records must be provided") |
155 | 163 | results = [] |
156 | 164 | galileo_context.init(project=project_obj.name, experiment_id=experiment_obj.id, local_metrics=local_metrics) |
157 | 165 |
|
158 | 166 | def logged_process_func(row: DatasetRecord) -> Callable: |
159 | 167 | return log(name=experiment_obj.name, dataset_record=row)(func) |
160 | 168 |
|
161 | | - # process each row in the dataset |
162 | | - for row in records: |
163 | | - results.append(process_row(row, logged_process_func(row))) |
164 | | - galileo_context.reset_trace_context() |
| 169 | + # For static records (list), process once |
| 170 | + if records is not None: |
| 171 | + _logger.info(f"Processing {len(records)} rows from dataset") |
| 172 | + for row in records: |
| 173 | + results.append(process_row(row, logged_process_func(row))) |
| 174 | + galileo_context.reset_trace_context() |
| 175 | + if getsizeof(results) > MAX_REQUEST_SIZE_BYTES or len(results) >= MAX_INGEST_BATCH_SIZE: |
| 176 | + _logger.info("Flushing logger due to size limit") |
| 177 | + galileo_context.flush() |
| 178 | + results = [] |
| 179 | + # For dataset object, paginate through content |
| 180 | + elif dataset_obj is not None: |
| 181 | + starting_token = 0 |
| 182 | + has_more_data = True |
| 183 | + |
| 184 | + while has_more_data: |
| 185 | + _logger.info(f"Loading dataset content starting at token {starting_token}") |
| 186 | + content = dataset_obj.get_content(starting_token=starting_token, limit=DATASET_CONTENT_PAGE_SIZE) |
| 187 | + |
| 188 | + if not content or not content.rows: |
| 189 | + _logger.info("No more dataset content to process") |
| 190 | + has_more_data = False |
| 191 | + else: |
| 192 | + batch_records = [convert_dataset_row_to_record(row) for row in content.rows] |
| 193 | + _logger.info(f"Processing {len(batch_records)} rows from dataset") |
| 194 | + |
| 195 | + for row in batch_records: |
| 196 | + results.append(process_row(row, logged_process_func(row))) |
| 197 | + galileo_context.reset_trace_context() |
| 198 | + if getsizeof(results) > MAX_REQUEST_SIZE_BYTES or len(results) >= MAX_INGEST_BATCH_SIZE: |
| 199 | + _logger.info("Flushing logger due to size limit") |
| 200 | + galileo_context.flush() |
| 201 | + results = [] |
| 202 | + |
| 203 | + starting_token += len(batch_records) |
165 | 204 |
|
166 | 205 | # flush the logger |
167 | 206 | galileo_context.flush() |
@@ -247,20 +286,22 @@ def run_experiment( |
247 | 286 | If required parameters are missing or invalid |
248 | 287 | """ |
249 | 288 | # Load dataset and records |
250 | | - dataset_obj, records = load_dataset_and_records(dataset, dataset_id, dataset_name) |
| 289 | + dataset_obj = load_dataset(dataset, dataset_id, dataset_name) |
251 | 290 |
|
252 | 291 | # Validate experiment configuration |
253 | 292 | if prompt_template and not dataset_obj: |
254 | 293 | raise ValueError("A dataset record, id, or name of a dataset must be provided when a prompt_template is used") |
255 | 294 |
|
256 | | - if function and not records: |
257 | | - raise ValueError( |
258 | | - "A dataset record, id or name of a dataset, or list of records must be provided when a function is used" |
259 | | - ) |
260 | | - |
261 | 295 | if function and prompt_template: |
262 | 296 | raise ValueError("A function or prompt_template should be provided, but not both") |
263 | 297 |
|
| 298 | + records = None |
| 299 | + if not dataset_obj and isinstance(dataset, list): |
| 300 | + records = create_rows_from_records(dataset) |
| 301 | + |
| 302 | + if function and not dataset_obj and not records: |
| 303 | + raise ValueError("A dataset record, id, name, or a list of records must be provided when a function is used") |
| 304 | + |
264 | 305 | # Get the project from the name or Id |
265 | 306 | project_obj = Projects().get_with_env_fallbacks(id=project_id, name=project) |
266 | 307 |
|
@@ -303,6 +344,7 @@ def run_experiment( |
303 | 344 | return Experiments().run_with_function( |
304 | 345 | project_obj=project_obj, |
305 | 346 | experiment_obj=experiment_obj, |
| 347 | + dataset_obj=dataset_obj, |
306 | 348 | records=records, |
307 | 349 | func=function, |
308 | 350 | local_metrics=local_metrics, |
|
0 commit comments