Skip to content

Commit ac7c198

Browse files
committed
increase loader cache
Signed-off-by: dafnapension <[email protected]>
1 parent 6c77678 commit ac7c198

File tree

6 files changed

+66
-52
lines changed

6 files changed

+66
-52
lines changed

performance/bluebench_profiler.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -59,26 +59,25 @@ def profiler_instantiate_benchmark_recipe(
5959
def profiler_generate_benchmark_dataset(
6060
self, benchmark_recipe: Benchmark, split: str, **kwargs
6161
) -> List[Dict[str, Any]]:
62+
stream = benchmark_recipe()[split]
63+
64+
# to charge here for the time of generating all instances of the split
65+
return list(stream)
66+
67+
def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs):
6268
with settings.context(
6369
disable_hf_datasets_cache=False,
6470
allow_unverified_code=True,
65-
mock_inference_mode=True,
6671
):
67-
stream = benchmark_recipe()[split]
68-
69-
# to charge here for the time of generating all instances
70-
return list(stream)
72+
benchmark_recipe = self.profiler_instantiate_benchmark_recipe(
73+
dataset_query=dataset_query, **kwargs
74+
)
7175

72-
def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs):
73-
benchmark_recipe = self.profiler_instantiate_benchmark_recipe(
74-
dataset_query=dataset_query, **kwargs
75-
)
76-
77-
dataset = self.profiler_generate_benchmark_dataset(
78-
benchmark_recipe=benchmark_recipe, split=split, **kwargs
79-
)
76+
dataset = self.profiler_generate_benchmark_dataset(
77+
benchmark_recipe=benchmark_recipe, split=split, **kwargs
78+
)
8079

81-
logger.critical(f"length of evaluation_result: {len(dataset)}")
80+
logger.critical(f"length of bluegench generated dataset: {len(dataset)}")
8281

8382

8483
dataset_query = "benchmarks.bluebench[loader_limit=30,max_samples_per_subset=30]"

src/unitxt/fusion.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
from .dataclass import NonPositionalField
55
from .operator import SourceOperator
66
from .random_utils import new_random_generator
7+
from .settings_utils import get_settings
78
from .stream import DynamicStream, MultiStream
89
from .type_utils import isoftype
910

11+
settings = get_settings()
12+
1013

1114
class BaseFusion(SourceOperator):
1215
"""BaseFusion operator that combines multiple multistreams into one.
@@ -75,26 +78,30 @@ def prepare(self):
7578

7679
# flake8: noqa: C901
7780
def fusion_generator(self, split) -> Generator:
78-
for origin_name, origin in self.named_subsets.items():
79-
multi_stream = origin()
80-
if split not in multi_stream:
81-
continue
82-
emitted_from_this_split = 0
83-
try:
84-
for instance in multi_stream[split]:
85-
if (
86-
self.max_instances_per_subset is not None
87-
and emitted_from_this_split >= self.max_instances_per_subset
88-
):
89-
break
90-
if isinstance(origin_name, str):
91-
if "subset" not in instance:
92-
instance["subset"] = []
93-
instance["subset"].insert(0, origin_name)
94-
emitted_from_this_split += 1
95-
yield instance
96-
except Exception as e:
97-
raise RuntimeError(f"Exception in subset: {origin_name}") from e
81+
with settings.context(
82+
disable_hf_datasets_cache=False,
83+
allow_unverified_code=True,
84+
):
85+
for origin_name, origin in self.named_subsets.items():
86+
multi_stream = origin()
87+
if split not in multi_stream:
88+
continue
89+
emitted_from_this_split = 0
90+
try:
91+
for instance in multi_stream[split]:
92+
if (
93+
self.max_instances_per_subset is not None
94+
and emitted_from_this_split >= self.max_instances_per_subset
95+
):
96+
break
97+
if isinstance(origin_name, str):
98+
if "subset" not in instance:
99+
instance["subset"] = []
100+
instance["subset"].insert(0, origin_name)
101+
emitted_from_this_split += 1
102+
yield instance
103+
except Exception as e:
104+
raise RuntimeError(f"Exception in subset: {origin_name}") from e
98105

99106

100107
class WeightedFusion(BaseFusion):

src/unitxt/loaders.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ def stream_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDatas
259259
)
260260
except ValueError as e:
261261
if "trust_remote_code" in str(e):
262+
logger.critical(
263+
f"while raising trust_remote error, settings.allow_unverified_code = {settings.allow_unverified_code}"
264+
)
262265
raise ValueError(
263266
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
264267
) from e
@@ -327,7 +330,9 @@ def load_iterables(
327330

328331
try:
329332
split_names = get_dataset_split_names(
330-
path=self.path, config_name=self.name
333+
path=self.path,
334+
config_name=self.name,
335+
trust_remote_code=settings.allow_unverified_code,
331336
)
332337
return {
333338
split_name: ReusableGenerator(

src/unitxt/settings_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __getattr__(self, key):
150150
settings.data_classification_policy = None
151151
settings.mock_inference_mode = (bool, False)
152152
settings.disable_hf_datasets_cache = (bool, True)
153-
settings.loader_cache_size = (int, 1)
153+
settings.loader_cache_size = (int, 25)
154154
settings.task_data_as_text = (bool, True)
155155
settings.default_provider = "watsonx"
156156
settings.default_format = None

src/unitxt/test_utils/card.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,21 @@ def test_card(
291291
else:
292292
template_card_indices = range(len(card.templates))
293293

294-
for template_card_index in template_card_indices:
295-
examples = load_examples_from_dataset_recipe(
296-
card, template_card_index=template_card_index, debug=debug, **kwargs
297-
)
298-
if test_exact_match_score_when_predictions_equal_references:
299-
test_correct_predictions(
300-
examples=examples, strict=strict, exact_match_score=exact_match_score
301-
)
302-
if test_full_mismatch_score_with_full_mismatch_prediction_values:
303-
test_wrong_predictions(
304-
examples=examples,
305-
strict=strict,
306-
maximum_full_mismatch_score=maximum_full_mismatch_score,
307-
full_mismatch_prediction_values=full_mismatch_prediction_values,
294+
with settings.context(allow_unverified_code=True):
295+
for template_card_index in template_card_indices:
296+
examples = load_examples_from_dataset_recipe(
297+
card, template_card_index=template_card_index, debug=debug, **kwargs
308298
)
299+
if test_exact_match_score_when_predictions_equal_references:
300+
test_correct_predictions(
301+
examples=examples,
302+
strict=strict,
303+
exact_match_score=exact_match_score,
304+
)
305+
if test_full_mismatch_score_with_full_mismatch_prediction_values:
306+
test_wrong_predictions(
307+
examples=examples,
308+
strict=strict,
309+
maximum_full_mismatch_score=maximum_full_mismatch_score,
310+
full_mismatch_prediction_values=full_mismatch_prediction_values,
311+
)

utils/.secrets.baseline

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
"filename": "src/unitxt/loaders.py",
152152
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
153153
"is_verified": false,
154-
"line_number": 560,
154+
"line_number": 565,
155155
"is_secret": false
156156
}
157157
],
@@ -184,5 +184,5 @@
184184
}
185185
]
186186
},
187-
"generated_at": "2025-01-22T08:29:46Z"
187+
"generated_at": "2025-01-22T11:47:57Z"
188188
}

0 commit comments

Comments
 (0)