Skip to content

Commit 6cbef3c

Browse files
authored
feat: Add batch get dataset content and batch flushing for large experiments (#438)
1 parent 2f2d276 commit 6cbef3c

File tree

4 files changed

+269
-58
lines changed

4 files changed

+269
-58
lines changed

src/galileo/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(self, dataset_db: DatasetDB) -> None:
6363
self.dataset = dataset_db
6464
self.config = GalileoPythonConfig.get()
6565

66-
def get_content(self) -> Union[None, DatasetContent]:
66+
def get_content(self, starting_token: int = 0, limit: int = MAX_DATASET_ROWS) -> Union[None, DatasetContent]:
6767
"""
6868
Gets and returns the content of the dataset.
6969
Also refreshes the content of the local dataset instance.
@@ -85,7 +85,7 @@ def get_content(self) -> Union[None, DatasetContent]:
8585
return None
8686

8787
content: DatasetContent = get_dataset_content_datasets_dataset_id_content_get.sync(
88-
client=self.config.api_client, dataset_id=self.dataset.id, limit=MAX_DATASET_ROWS
88+
client=self.config.api_client, dataset_id=self.dataset.id, limit=limit, starting_token=starting_token
8989
)
9090

9191
self.content = content

src/galileo/experiments.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import builtins
22
import datetime
33
import logging
4+
from sys import getsizeof
45
from typing import Any, Callable, Optional, Union
56

67
from attrs import define as _attrs_define
78
from attrs import field as _attrs_field
89

910
from galileo import galileo_context, log
1011
from galileo.config import GalileoPythonConfig
11-
from galileo.datasets import Dataset
12+
from galileo.datasets import Dataset, convert_dataset_row_to_record
1213
from galileo.experiment_tags import upsert_experiment_tag
1314
from galileo.jobs import Jobs
1415
from galileo.projects import Project, Projects
@@ -20,14 +21,18 @@
2021
from galileo.resources.models import ExperimentResponse, HTTPValidationError, PromptRunSettings, ScorerConfig, TaskType
2122
from galileo.schema.datasets import DatasetRecord
2223
from galileo.schema.metrics import GalileoScorers, LocalMetricConfig, Metric
23-
from galileo.utils.datasets import load_dataset_and_records
24+
from galileo.utils.datasets import create_rows_from_records, load_dataset
2425
from galileo.utils.logging import get_logger
2526
from galileo.utils.metrics import create_metric_configs
2627

2728
_logger = get_logger(__name__)
2829

2930
EXPERIMENT_TASK_TYPE: TaskType = 16
3031

32+
MAX_REQUEST_SIZE_BYTES = 10 * 1024 * 1024 # 10 MB
33+
MAX_INGEST_BATCH_SIZE = 128
34+
DATASET_CONTENT_PAGE_SIZE = 1000
35+
3136

3237
@_attrs_define
3338
class ExperimentCreateRequest:
@@ -148,20 +153,54 @@ def run_with_function(
148153
self,
149154
project_obj: Project,
150155
experiment_obj: ExperimentResponse,
151-
records: builtins.list[DatasetRecord],
156+
dataset_obj: Optional[Dataset],
157+
records: Optional[builtins.list[DatasetRecord]],
152158
func: Callable,
153159
local_metrics: builtins.list[LocalMetricConfig],
154160
) -> dict[str, Any]:
161+
if dataset_obj is None and records is None:
162+
raise ValueError("Either dataset_obj or records must be provided")
155163
results = []
156164
galileo_context.init(project=project_obj.name, experiment_id=experiment_obj.id, local_metrics=local_metrics)
157165

158166
def logged_process_func(row: DatasetRecord) -> Callable:
159167
return log(name=experiment_obj.name, dataset_record=row)(func)
160168

161-
# process each row in the dataset
162-
for row in records:
163-
results.append(process_row(row, logged_process_func(row)))
164-
galileo_context.reset_trace_context()
169+
# For static records (list), process once
170+
if records is not None:
171+
_logger.info(f"Processing {len(records)} rows from dataset")
172+
for row in records:
173+
results.append(process_row(row, logged_process_func(row)))
174+
galileo_context.reset_trace_context()
175+
if getsizeof(results) > MAX_REQUEST_SIZE_BYTES or len(results) >= MAX_INGEST_BATCH_SIZE:
176+
_logger.info("Flushing logger due to size limit")
177+
galileo_context.flush()
178+
results = []
179+
# For dataset object, paginate through content
180+
elif dataset_obj is not None:
181+
starting_token = 0
182+
has_more_data = True
183+
184+
while has_more_data:
185+
_logger.info(f"Loading dataset content starting at token {starting_token}")
186+
content = dataset_obj.get_content(starting_token=starting_token, limit=DATASET_CONTENT_PAGE_SIZE)
187+
188+
if not content or not content.rows:
189+
_logger.info("No more dataset content to process")
190+
has_more_data = False
191+
else:
192+
batch_records = [convert_dataset_row_to_record(row) for row in content.rows]
193+
_logger.info(f"Processing {len(batch_records)} rows from dataset")
194+
195+
for row in batch_records:
196+
results.append(process_row(row, logged_process_func(row)))
197+
galileo_context.reset_trace_context()
198+
if getsizeof(results) > MAX_REQUEST_SIZE_BYTES or len(results) >= MAX_INGEST_BATCH_SIZE:
199+
_logger.info("Flushing logger due to size limit")
200+
galileo_context.flush()
201+
results = []
202+
203+
starting_token += len(batch_records)
165204

166205
# flush the logger
167206
galileo_context.flush()
@@ -247,20 +286,22 @@ def run_experiment(
247286
If required parameters are missing or invalid
248287
"""
249288
# Load dataset and records
250-
dataset_obj, records = load_dataset_and_records(dataset, dataset_id, dataset_name)
289+
dataset_obj = load_dataset(dataset, dataset_id, dataset_name)
251290

252291
# Validate experiment configuration
253292
if prompt_template and not dataset_obj:
254293
raise ValueError("A dataset record, id, or name of a dataset must be provided when a prompt_template is used")
255294

256-
if function and not records:
257-
raise ValueError(
258-
"A dataset record, id or name of a dataset, or list of records must be provided when a function is used"
259-
)
260-
261295
if function and prompt_template:
262296
raise ValueError("A function or prompt_template should be provided, but not both")
263297

298+
records = None
299+
if not dataset_obj and isinstance(dataset, list):
300+
records = create_rows_from_records(dataset)
301+
302+
if function and not dataset_obj and not records:
303+
raise ValueError("A dataset record, id, name, or a list of records must be provided when a function is used")
304+
264305
# Get the project from the name or Id
265306
project_obj = Projects().get_with_env_fallbacks(id=project_id, name=project)
266307

@@ -303,6 +344,7 @@ def run_experiment(
303344
return Experiments().run_with_function(
304345
project_obj=project_obj,
305346
experiment_obj=experiment_obj,
347+
dataset_obj=dataset_obj,
306348
records=records,
307349
func=function,
308350
local_metrics=local_metrics,

src/galileo/utils/datasets.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,48 @@ def validate_dataset_in_project(
2424
raise ValueError(f"Dataset '{dataset_identifier}' is not used in project '{project_identifier}'")
2525

2626

27+
def load_dataset(
28+
dataset: Union["Dataset", list[Union[dict[str, Any], str]], str, None],
29+
dataset_id: Optional[str],
30+
dataset_name: Optional[str],
31+
) -> Optional["Dataset"]:
32+
"""
33+
Load dataset based on provided parameters.
34+
35+
Parameters
36+
----------
37+
dataset:
38+
Dataset object, list of records, or dataset name
39+
dataset_id:
40+
ID of the dataset
41+
dataset_name:
42+
Name of the dataset
43+
44+
Returns
45+
-------
46+
Dataset object or None
47+
48+
Raises
49+
------
50+
ValueError
51+
If no dataset information is provided or dataset doesn't exist
52+
"""
53+
from galileo.datasets import get_dataset
54+
55+
if dataset_id:
56+
return get_dataset(id=dataset_id)
57+
if dataset_name:
58+
return get_dataset(name=dataset_name)
59+
if dataset and isinstance(dataset, str):
60+
return get_dataset(name=dataset)
61+
if dataset and not isinstance(dataset, (str, list)):
62+
# Must be a Dataset object
63+
return dataset
64+
if dataset and isinstance(dataset, list):
65+
return None
66+
raise ValueError("To load dataset records, dataset, dataset_name, or dataset_id must be provided")
67+
68+
2769
def load_dataset_and_records(
2870
dataset: Union["Dataset", list[Union[dict[str, Any], str]], str, None],
2971
dataset_id: Optional[str],

0 commit comments

Comments
 (0)