Skip to content

Commit 39a73a4

Browse files
committed
added split cache for the generators, and log limit once per data and increase loader cache
Signed-off-by: dafnapension <[email protected]>
1 parent b9c3c4f commit 39a73a4

File tree

7 files changed

+106
-114
lines changed

7 files changed

+106
-114
lines changed

performance/bluebench_profiler.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
logger = get_logger()
1616
settings = get_settings()
17-
settings.allow_unverified_code = True
1817

1918

2019
class BlueBenchProfiler:
@@ -59,26 +58,25 @@ def profiler_instantiate_benchmark_recipe(
5958
def profiler_generate_benchmark_dataset(
6059
self, benchmark_recipe: Benchmark, split: str, **kwargs
6160
) -> List[Dict[str, Any]]:
61+
stream = benchmark_recipe()[split]
62+
63+
# to charge here for the time of generating all instances of the split
64+
return list(stream)
65+
66+
def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs):
6267
with settings.context(
6368
disable_hf_datasets_cache=False,
6469
allow_unverified_code=True,
65-
mock_inference_mode=True,
6670
):
67-
stream = benchmark_recipe()[split]
68-
69-
# to charge here for the time of generating all instances
70-
return list(stream)
71-
72-
def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs):
73-
benchmark_recipe = self.profiler_instantiate_benchmark_recipe(
74-
dataset_query=dataset_query, **kwargs
75-
)
71+
benchmark_recipe = self.profiler_instantiate_benchmark_recipe(
72+
dataset_query=dataset_query, **kwargs
73+
)
7674

77-
dataset = self.profiler_generate_benchmark_dataset(
78-
benchmark_recipe=benchmark_recipe, split=split, **kwargs
79-
)
75+
dataset = self.profiler_generate_benchmark_dataset(
76+
benchmark_recipe=benchmark_recipe, split=split, **kwargs
77+
)
8078

81-
logger.critical(f"length of evaluation_result: {len(dataset)}")
79+
logger.critical(f"length of bluegench generated dataset: {len(dataset)}")
8280

8381

8482
dataset_query = "benchmarks.bluebench[loader_limit=30,max_samples_per_subset=30]"
@@ -140,9 +138,7 @@ def main():
140138
"profile_benchmark_blue_bench", "bluebench_profiler.py", s
141139
)
142140
load_time = find_cummtime_of("load_data", "loaders.py", s)
143-
just_load_no_initial_ms_time = find_cummtime_of(
144-
"load_iterables", "loaders.py", s
145-
)
141+
146142
instantiate_benchmark_time = find_cummtime_of(
147143
"profiler_instantiate_benchmark_recipe", "bluebench_profiler.py", s
148144
)
@@ -155,7 +151,6 @@ def main():
155151
"dataset_query": dataset_query,
156152
"total_time": overall_tot_time,
157153
"load_time": load_time,
158-
"load_time_no_initial_ms": just_load_no_initial_ms_time,
159154
"instantiate_benchmark_time": instantiate_benchmark_time,
160155
"generate_benchmark_dataset_time": generate_benchmark_dataset_time,
161156
"used_eager_mode": settings.use_eager_execution,

performance/compare_benchmark_performance_results.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,9 @@
2424
print(f"used_eager_mode in PR = {pr_perf['used_eager_mode']}")
2525

2626
ratio1 = (
27-
(pr_perf["generate_benchmark_dataset_time"] - pr_perf["load_time_no_initial_ms"])
28-
/ (
29-
main_perf["generate_benchmark_dataset_time"]
30-
- main_perf["load_time_no_initial_ms"]
31-
)
32-
if (
33-
main_perf["generate_benchmark_dataset_time"]
34-
- main_perf["load_time_no_initial_ms"]
35-
)
36-
> 0
27+
(pr_perf["generate_benchmark_dataset_time"] - pr_perf["load_time"])
28+
/ (main_perf["generate_benchmark_dataset_time"] - main_perf["load_time"])
29+
if (main_perf["generate_benchmark_dataset_time"] - main_perf["load_time"]) > 0
3730
else 1
3831
)
3932
# Markdown table formatting
@@ -42,9 +35,7 @@
4235
line2 = "--------------------|-------------|-------------|---------------\n"
4336
line3 = f" Total time | {main_perf['total_time']:>11} | {pr_perf['total_time']:>11} | {pr_perf['total_time'] / main_perf['total_time']:.2f}\n"
4437
ratio_line4 = (
45-
pr_perf["load_time_no_initial_ms"] / main_perf["load_time_no_initial_ms"]
46-
if main_perf["load_time_no_initial_ms"] > 0
47-
else 1
38+
pr_perf["load_time"] / main_perf["load_time"] if main_perf["load_time"] > 0 else 1
4839
)
4940
line4 = f" Load time | {main_perf['load_time_no_initial_ms']:>11} | {pr_perf['load_time_no_initial_ms']:>11} | {ratio_line4:.2f}\n"
5041
line5 = f" DS Gen. inc. Load | {main_perf['generate_benchmark_dataset_time']:>11} | {pr_perf['generate_benchmark_dataset_time']:>11} | {pr_perf['generate_benchmark_dataset_time'] / main_perf['generate_benchmark_dataset_time']:.2f}\n"

src/unitxt/fusion.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
from .dataclass import NonPositionalField
55
from .operator import SourceOperator
66
from .random_utils import new_random_generator
7+
from .settings_utils import get_settings
78
from .stream import DynamicStream, MultiStream
89
from .type_utils import isoftype
910

11+
settings = get_settings()
12+
1013

1114
class BaseFusion(SourceOperator):
1215
"""BaseFusion operator that combines multiple multistreams into one.
@@ -75,26 +78,30 @@ def prepare(self):
7578

7679
# flake8: noqa: C901
7780
def fusion_generator(self, split) -> Generator:
78-
for origin_name, origin in self.named_subsets.items():
79-
multi_stream = origin()
80-
if split not in multi_stream:
81-
continue
82-
emitted_from_this_split = 0
83-
try:
84-
for instance in multi_stream[split]:
85-
if (
86-
self.max_instances_per_subset is not None
87-
and emitted_from_this_split >= self.max_instances_per_subset
88-
):
89-
break
90-
if isinstance(origin_name, str):
91-
if "subset" not in instance:
92-
instance["subset"] = []
93-
instance["subset"].insert(0, origin_name)
94-
emitted_from_this_split += 1
95-
yield instance
96-
except Exception as e:
97-
raise RuntimeError(f"Exception in subset: {origin_name}") from e
81+
with settings.context(
82+
disable_hf_datasets_cache=False,
83+
allow_unverified_code=True,
84+
):
85+
for origin_name, origin in self.named_subsets.items():
86+
multi_stream = origin()
87+
if split not in multi_stream:
88+
continue
89+
emitted_from_this_split = 0
90+
try:
91+
for instance in multi_stream[split]:
92+
if (
93+
self.max_instances_per_subset is not None
94+
and emitted_from_this_split >= self.max_instances_per_subset
95+
):
96+
break
97+
if isinstance(origin_name, str):
98+
if "subset" not in instance:
99+
instance["subset"] = []
100+
instance["subset"].insert(0, origin_name)
101+
emitted_from_this_split += 1
102+
yield instance
103+
except Exception as e:
104+
raise RuntimeError(f"Exception in subset: {origin_name}") from e
98105

99106

100107
class WeightedFusion(BaseFusion):

src/unitxt/loaders.py

+42-46
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@
5353

5454
import pandas as pd
5555
import requests
56-
from datasets import IterableDataset, IterableDatasetDict, load_dataset_builder
56+
from datasets import (
57+
IterableDataset,
58+
IterableDatasetDict,
59+
get_dataset_split_names,
60+
load_dataset_builder,
61+
)
5762
from datasets import load_dataset as hf_load_dataset
5863
from huggingface_hub import HfApi
5964
from tqdm import tqdm
@@ -232,11 +237,6 @@ def filter_load(self, dataset):
232237
logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
233238
return dataset.filter(eval(self.filtering_lambda))
234239

235-
def log_limited_loading(self, split: str):
236-
logger.info(
237-
f"\nLoading of split {split} limited to {self.get_limit()} instances by setting {self.get_limiter()};"
238-
)
239-
240240
# returns Dict when split names are not known in advance, and just the single split dataset - if known
241241
def stream_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDataset]:
242242
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
@@ -259,6 +259,9 @@ def stream_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDatas
259259
)
260260
except ValueError as e:
261261
if "trust_remote_code" in str(e):
262+
logger.critical(
263+
f"while raising trust_remote error, settings.allow_unverified_code = {settings.allow_unverified_code}"
264+
)
262265
raise ValueError(
263266
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
264267
) from e
@@ -312,6 +315,10 @@ def _maybe_set_classification_policy(self):
312315
def load_iterables(
313316
self
314317
) -> Union[Dict[str, ReusableGenerator], IterableDatasetDict]:
318+
# log limit once for the whole data
319+
if self.get_limit() is not None:
320+
self.log_limited_loading()
321+
315322
if not isinstance(self, LoadFromHFSpace):
316323
# try the following for LoadHF only
317324
if self.split is not None:
@@ -322,22 +329,17 @@ def load_iterables(
322329
}
323330

324331
try:
325-
ds_builder = load_dataset_builder(
326-
self.path,
327-
self.name,
332+
split_names = get_dataset_split_names(
333+
path=self.path,
334+
config_name=self.name,
328335
trust_remote_code=settings.allow_unverified_code,
329336
)
330-
dataset_info = ds_builder.info
331-
if dataset_info.splits is not None:
332-
# split names are known before the split themselves are pulled from HF,
333-
# and we can postpone that pulling of the splits until actually demanded
334-
split_names = list(dataset_info.splits.keys())
335-
return {
336-
split_name: ReusableGenerator(
337-
self.split_generator, gen_kwargs={"split": split_name}
338-
)
339-
for split_name in split_names
340-
}
337+
return {
338+
split_name: ReusableGenerator(
339+
self.split_generator, gen_kwargs={"split": split_name}
340+
)
341+
for split_name in split_names
342+
}
341343

342344
except:
343345
pass # do nothing, and just continue to the usual load dataset
@@ -356,8 +358,6 @@ def load_iterables(
356358

357359
limit = self.get_limit()
358360
if limit is not None:
359-
for split_name in dataset:
360-
self.log_limited_loading(split_name)
361361
result = {}
362362
for split_name in dataset:
363363
result[split_name] = dataset[split_name].take(limit)
@@ -366,26 +366,24 @@ def load_iterables(
366366
return dataset
367367

368368
def split_generator(self, split: str) -> Generator:
369-
try:
370-
dataset = self.stream_dataset(split)
371-
except (
372-
NotImplementedError
373-
): # streaming is not supported for zipped files so we load without streaming
374-
dataset = self.load_dataset(split)
369+
dataset = self.__class__._loader_cache.get(str(self) + "_" + split, None)
370+
if dataset is None:
371+
try:
372+
dataset = self.stream_dataset(split)
373+
except NotImplementedError: # streaming is not supported for zipped files so we load without streaming
374+
dataset = self.load_dataset(split)
375375

376-
if self.filtering_lambda is not None:
377-
dataset = self.filter_load(dataset)
376+
if self.filtering_lambda is not None:
377+
dataset = self.filter_load(dataset)
378378

379-
limit = self.get_limit()
380-
if limit is not None:
381-
self.log_limited_loading(split)
382-
dataset = dataset.take(limit)
379+
limit = self.get_limit()
380+
if limit is not None:
381+
dataset = dataset.take(limit)
383382

384-
yield from dataset
383+
self.__class__._loader_cache.max_size = settings.loader_cache_size
384+
self.__class__._loader_cache[str(self) + "_" + split] = dataset
385385

386-
def process(self) -> MultiStream:
387-
self._maybe_set_classification_policy()
388-
return self.add_data_classification(self.load_data())
386+
yield from dataset
389387

390388

391389
class LoadCSV(Loader):
@@ -442,6 +440,9 @@ def get_args(self):
442440
return args
443441

444442
def load_iterables(self):
443+
# log once for the whole data
444+
if self.get_limit() is not None:
445+
self.log_limited_loading()
445446
iterables = {}
446447
for split_name in self.files.keys():
447448
iterables[split_name] = ReusableGenerator(
@@ -451,14 +452,9 @@ def load_iterables(self):
451452
return iterables
452453

453454
def split_generator(self, split: str) -> Generator:
454-
if self.get_limit() is not None:
455-
self.log_limited_loading()
456-
dataset = pd.read_csv(
457-
self.files[split], nrows=self.get_limit(), sep=self.sep
458-
).to_dict("records")
459-
else:
460-
dataset = pd.read_csv(self.files[split], sep=self.sep).to_dict("records")
461-
455+
dataset = pd.read_csv(
456+
self.files[split], nrows=self.get_limit(), sep=self.sep
457+
).to_dict("records")
462458
yield from dataset
463459

464460

src/unitxt/settings_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __getattr__(self, key):
150150
settings.data_classification_policy = None
151151
settings.mock_inference_mode = (bool, False)
152152
settings.disable_hf_datasets_cache = (bool, True)
153-
settings.loader_cache_size = (int, 1)
153+
settings.loader_cache_size = (int, 25)
154154
settings.task_data_as_text = (bool, True)
155155
settings.default_provider = "watsonx"
156156
settings.default_format = None

src/unitxt/test_utils/card.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,21 @@ def test_card(
291291
else:
292292
template_card_indices = range(len(card.templates))
293293

294-
for template_card_index in template_card_indices:
295-
examples = load_examples_from_dataset_recipe(
296-
card, template_card_index=template_card_index, debug=debug, **kwargs
297-
)
298-
if test_exact_match_score_when_predictions_equal_references:
299-
test_correct_predictions(
300-
examples=examples, strict=strict, exact_match_score=exact_match_score
301-
)
302-
if test_full_mismatch_score_with_full_mismatch_prediction_values:
303-
test_wrong_predictions(
304-
examples=examples,
305-
strict=strict,
306-
maximum_full_mismatch_score=maximum_full_mismatch_score,
307-
full_mismatch_prediction_values=full_mismatch_prediction_values,
294+
with settings.context(allow_unverified_code=True):
295+
for template_card_index in template_card_indices:
296+
examples = load_examples_from_dataset_recipe(
297+
card, template_card_index=template_card_index, debug=debug, **kwargs
308298
)
299+
if test_exact_match_score_when_predictions_equal_references:
300+
test_correct_predictions(
301+
examples=examples,
302+
strict=strict,
303+
exact_match_score=exact_match_score,
304+
)
305+
if test_full_mismatch_score_with_full_mismatch_prediction_values:
306+
test_wrong_predictions(
307+
examples=examples,
308+
strict=strict,
309+
maximum_full_mismatch_score=maximum_full_mismatch_score,
310+
full_mismatch_prediction_values=full_mismatch_prediction_values,
311+
)

utils/.secrets.baseline

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
"filename": "src/unitxt/loaders.py",
152152
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
153153
"is_verified": false,
154-
"line_number": 570,
154+
"line_number": 565,
155155
"is_secret": false
156156
}
157157
],

0 commit comments

Comments
 (0)