Skip to content

Commit 2cf85db

Browse files
committed
log limit once per data and increase loader cache
Signed-off-by: dafnapension <[email protected]>
1 parent 2d98da8 commit 2cf85db

File tree

6 files changed

+75
-68
lines changed

6 files changed

+75
-68
lines changed

performance/bluebench_profiler.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
logger = get_logger()
1616
settings = get_settings()
17-
settings.allow_unverified_code = True
1817

1918

2019
class BlueBenchProfiler:
@@ -59,26 +58,25 @@ def profiler_instantiate_benchmark_recipe(
5958
def profiler_generate_benchmark_dataset(
6059
self, benchmark_recipe: Benchmark, split: str, **kwargs
6160
) -> List[Dict[str, Any]]:
61+
stream = benchmark_recipe()[split]
62+
63+
# to charge here for the time of generating all instances of the split
64+
return list(stream)
65+
66+
def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs):
6267
with settings.context(
6368
disable_hf_datasets_cache=False,
6469
allow_unverified_code=True,
65-
mock_inference_mode=True,
6670
):
67-
stream = benchmark_recipe()[split]
68-
69-
# to charge here for the time of generating all instances
70-
return list(stream)
71+
benchmark_recipe = self.profiler_instantiate_benchmark_recipe(
72+
dataset_query=dataset_query, **kwargs
73+
)
7174

72-
def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs):
73-
benchmark_recipe = self.profiler_instantiate_benchmark_recipe(
74-
dataset_query=dataset_query, **kwargs
75-
)
76-
77-
dataset = self.profiler_generate_benchmark_dataset(
78-
benchmark_recipe=benchmark_recipe, split=split, **kwargs
79-
)
75+
dataset = self.profiler_generate_benchmark_dataset(
76+
benchmark_recipe=benchmark_recipe, split=split, **kwargs
77+
)
8078

81-
logger.critical(f"length of evaluation_result: {len(dataset)}")
79+
logger.critical(f"length of bluegench generated dataset: {len(dataset)}")
8280

8381

8482
dataset_query = "benchmarks.bluebench[loader_limit=30,max_samples_per_subset=30]"

src/unitxt/fusion.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
from .dataclass import NonPositionalField
55
from .operator import SourceOperator
66
from .random_utils import new_random_generator
7+
from .settings_utils import get_settings
78
from .stream import DynamicStream, MultiStream
89
from .type_utils import isoftype
910

11+
settings = get_settings()
12+
1013

1114
class BaseFusion(SourceOperator):
1215
"""BaseFusion operator that combines multiple multistreams into one.
@@ -75,26 +78,30 @@ def prepare(self):
7578

7679
# flake8: noqa: C901
7780
def fusion_generator(self, split) -> Generator:
78-
for origin_name, origin in self.named_subsets.items():
79-
multi_stream = origin()
80-
if split not in multi_stream:
81-
continue
82-
emitted_from_this_split = 0
83-
try:
84-
for instance in multi_stream[split]:
85-
if (
86-
self.max_instances_per_subset is not None
87-
and emitted_from_this_split >= self.max_instances_per_subset
88-
):
89-
break
90-
if isinstance(origin_name, str):
91-
if "subset" not in instance:
92-
instance["subset"] = []
93-
instance["subset"].insert(0, origin_name)
94-
emitted_from_this_split += 1
95-
yield instance
96-
except Exception as e:
97-
raise RuntimeError(f"Exception in subset: {origin_name}") from e
81+
with settings.context(
82+
disable_hf_datasets_cache=False,
83+
allow_unverified_code=True,
84+
):
85+
for origin_name, origin in self.named_subsets.items():
86+
multi_stream = origin()
87+
if split not in multi_stream:
88+
continue
89+
emitted_from_this_split = 0
90+
try:
91+
for instance in multi_stream[split]:
92+
if (
93+
self.max_instances_per_subset is not None
94+
and emitted_from_this_split >= self.max_instances_per_subset
95+
):
96+
break
97+
if isinstance(origin_name, str):
98+
if "subset" not in instance:
99+
instance["subset"] = []
100+
instance["subset"].insert(0, origin_name)
101+
emitted_from_this_split += 1
102+
yield instance
103+
except Exception as e:
104+
raise RuntimeError(f"Exception in subset: {origin_name}") from e
98105

99106

100107
class WeightedFusion(BaseFusion):

src/unitxt/loaders.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,6 @@ def filter_load(self, dataset):
237237
logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
238238
return dataset.filter(eval(self.filtering_lambda))
239239

240-
def log_limited_loading(self, split: str):
241-
logger.info(
242-
f"\nLoading of split {split} limited to {self.get_limit()} instances by setting {self.get_limiter()};"
243-
)
244-
245240
# returns Dict when split names are not known in advance, and just the single split dataset - if known
246241
def stream_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDataset]:
247242
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
@@ -264,6 +259,9 @@ def stream_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDatas
264259
)
265260
except ValueError as e:
266261
if "trust_remote_code" in str(e):
262+
logger.critical(
263+
f"while raising trust_remote error, settings.allow_unverified_code = {settings.allow_unverified_code}"
264+
)
267265
raise ValueError(
268266
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
269267
) from e
@@ -317,6 +315,10 @@ def _maybe_set_classification_policy(self):
317315
def load_iterables(
318316
self
319317
) -> Union[Dict[str, ReusableGenerator], IterableDatasetDict]:
318+
# log limit once for the whole data
319+
if self.get_limit() is not None:
320+
self.log_limited_loading()
321+
320322
if not isinstance(self, LoadFromHFSpace):
321323
# try the following for LoadHF only
322324
if self.split is not None:
@@ -328,7 +330,9 @@ def load_iterables(
328330

329331
try:
330332
split_names = get_dataset_split_names(
331-
path=self.path, config_name=self.name
333+
path=self.path,
334+
config_name=self.name,
335+
trust_remote_code=settings.allow_unverified_code,
332336
)
333337
return {
334338
split_name: ReusableGenerator(
@@ -354,8 +358,6 @@ def load_iterables(
354358

355359
limit = self.get_limit()
356360
if limit is not None:
357-
for split_name in dataset:
358-
self.log_limited_loading(split_name)
359361
result = {}
360362
for split_name in dataset:
361363
result[split_name] = dataset[split_name].take(limit)
@@ -376,7 +378,6 @@ def split_generator(self, split: str) -> Generator:
376378

377379
limit = self.get_limit()
378380
if limit is not None:
379-
self.log_limited_loading(split)
380381
dataset = dataset.take(limit)
381382

382383
self.__class__._loader_cache.max_size = settings.loader_cache_size
@@ -439,6 +440,9 @@ def get_args(self):
439440
return args
440441

441442
def load_iterables(self):
443+
# log once for the whole data
444+
if self.get_limit() is not None:
445+
self.log_limited_loading()
442446
iterables = {}
443447
for split_name in self.files.keys():
444448
iterables[split_name] = ReusableGenerator(
@@ -448,14 +452,9 @@ def load_iterables(self):
448452
return iterables
449453

450454
def split_generator(self, split: str) -> Generator:
451-
if self.get_limit() is not None:
452-
self.log_limited_loading()
453-
dataset = pd.read_csv(
454-
self.files[split], nrows=self.get_limit(), sep=self.sep
455-
).to_dict("records")
456-
else:
457-
dataset = pd.read_csv(self.files[split], sep=self.sep).to_dict("records")
458-
455+
dataset = pd.read_csv(
456+
self.files[split], nrows=self.get_limit(), sep=self.sep
457+
).to_dict("records")
459458
yield from dataset
460459

461460

src/unitxt/settings_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __getattr__(self, key):
150150
settings.data_classification_policy = None
151151
settings.mock_inference_mode = (bool, False)
152152
settings.disable_hf_datasets_cache = (bool, True)
153-
settings.loader_cache_size = (int, 1)
153+
settings.loader_cache_size = (int, 25)
154154
settings.task_data_as_text = (bool, True)
155155
settings.default_provider = "watsonx"
156156
settings.default_format = None

src/unitxt/test_utils/card.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,21 @@ def test_card(
291291
else:
292292
template_card_indices = range(len(card.templates))
293293

294-
for template_card_index in template_card_indices:
295-
examples = load_examples_from_dataset_recipe(
296-
card, template_card_index=template_card_index, debug=debug, **kwargs
297-
)
298-
if test_exact_match_score_when_predictions_equal_references:
299-
test_correct_predictions(
300-
examples=examples, strict=strict, exact_match_score=exact_match_score
301-
)
302-
if test_full_mismatch_score_with_full_mismatch_prediction_values:
303-
test_wrong_predictions(
304-
examples=examples,
305-
strict=strict,
306-
maximum_full_mismatch_score=maximum_full_mismatch_score,
307-
full_mismatch_prediction_values=full_mismatch_prediction_values,
294+
with settings.context(allow_unverified_code=True):
295+
for template_card_index in template_card_indices:
296+
examples = load_examples_from_dataset_recipe(
297+
card, template_card_index=template_card_index, debug=debug, **kwargs
308298
)
299+
if test_exact_match_score_when_predictions_equal_references:
300+
test_correct_predictions(
301+
examples=examples,
302+
strict=strict,
303+
exact_match_score=exact_match_score,
304+
)
305+
if test_full_mismatch_score_with_full_mismatch_prediction_values:
306+
test_wrong_predictions(
307+
examples=examples,
308+
strict=strict,
309+
maximum_full_mismatch_score=maximum_full_mismatch_score,
310+
full_mismatch_prediction_values=full_mismatch_prediction_values,
311+
)

utils/.secrets.baseline

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
"filename": "src/unitxt/loaders.py",
152152
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
153153
"is_verified": false,
154-
"line_number": 566,
154+
"line_number": 565,
155155
"is_secret": false
156156
}
157157
],

0 commit comments

Comments
 (0)