Skip to content

Commit ef29521

Browse files
committed
make sklearn loader too - a lazy loader
Signed-off-by: dafnapension <[email protected]>
1 parent 2cf85db commit ef29521

File tree

4 files changed

+62
-68
lines changed

4 files changed

+62
-68
lines changed

performance/bluebench_profiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs):
7676
benchmark_recipe=benchmark_recipe, split=split, **kwargs
7777
)
7878

79-
logger.critical(f"length of bluegench generated dataset: {len(dataset)}")
79+
logger.critical(f"length of bluebench generated dataset: {len(dataset)}")
8080

8181

8282
dataset_query = "benchmarks.bluebench[loader_limit=30,max_samples_per_subset=30]"

src/unitxt/fusion.py

+21-32
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
from .dataclass import NonPositionalField
55
from .operator import SourceOperator
66
from .random_utils import new_random_generator
7-
from .settings_utils import get_settings
87
from .stream import DynamicStream, MultiStream
98
from .type_utils import isoftype
109

11-
settings = get_settings()
12-
1310

1411
class BaseFusion(SourceOperator):
1512
"""BaseFusion operator that combines multiple multistreams into one.
@@ -37,11 +34,7 @@ def prepare_subsets(self):
3734
for i in range(len(self.subsets)):
3835
self.named_subsets[i] = self.subsets[i]
3936
else:
40-
for name, origin in self.subsets.items():
41-
try:
42-
self.named_subsets[name] = origin
43-
except Exception as e:
44-
raise RuntimeError(f"Exception in subset: {name}") from e
37+
self.named_subsets = self.subsets
4538

4639
def splits(self) -> List[str]:
4740
self.prepare_subsets()
@@ -78,30 +71,26 @@ def prepare(self):
7871

7972
# flake8: noqa: C901
8073
def fusion_generator(self, split) -> Generator:
81-
with settings.context(
82-
disable_hf_datasets_cache=False,
83-
allow_unverified_code=True,
84-
):
85-
for origin_name, origin in self.named_subsets.items():
86-
multi_stream = origin()
87-
if split not in multi_stream:
88-
continue
89-
emitted_from_this_split = 0
90-
try:
91-
for instance in multi_stream[split]:
92-
if (
93-
self.max_instances_per_subset is not None
94-
and emitted_from_this_split >= self.max_instances_per_subset
95-
):
96-
break
97-
if isinstance(origin_name, str):
98-
if "subset" not in instance:
99-
instance["subset"] = []
100-
instance["subset"].insert(0, origin_name)
101-
emitted_from_this_split += 1
102-
yield instance
103-
except Exception as e:
104-
raise RuntimeError(f"Exception in subset: {origin_name}") from e
74+
for origin_name, origin in self.named_subsets.items():
75+
multi_stream = origin()
76+
if split not in multi_stream:
77+
continue
78+
emitted_from_this_split = 0
79+
try:
80+
for instance in multi_stream[split]:
81+
if (
82+
self.max_instances_per_subset is not None
83+
and emitted_from_this_split >= self.max_instances_per_subset
84+
):
85+
break
86+
if isinstance(origin_name, str):
87+
if "subset" not in instance:
88+
instance["subset"] = []
89+
instance["subset"].insert(0, origin_name)
90+
emitted_from_this_split += 1
91+
yield instance
92+
except Exception as e:
93+
raise RuntimeError(f"Exception in subset: {origin_name}") from e
10594

10695

10796
class WeightedFusion(BaseFusion):

src/unitxt/loaders.py

+39-34
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
Generator,
4646
Iterable,
4747
List,
48+
Literal,
4849
Mapping,
4950
Optional,
5051
Sequence,
@@ -57,7 +58,6 @@
5758
IterableDataset,
5859
IterableDatasetDict,
5960
get_dataset_split_names,
60-
load_dataset_builder,
6161
)
6262
from datasets import load_dataset as hf_load_dataset
6363
from huggingface_hub import HfApi
@@ -168,7 +168,7 @@ def load_data(self) -> MultiStream:
168168
self.__class__._loader_cache.max_size = settings.loader_cache_size
169169
self.__class__._loader_cache[str(self)] = iterables
170170
if isoftype(iterables, Dict[str, ReusableGenerator]):
171-
return MultiStream.from_generators(iterables)
171+
return MultiStream.from_generators(iterables, copying=True)
172172
return MultiStream.from_iterables(iterables, copying=True)
173173

174174
def process(self) -> MultiStream:
@@ -259,9 +259,6 @@ def stream_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDatas
259259
)
260260
except ValueError as e:
261261
if "trust_remote_code" in str(e):
262-
logger.critical(
263-
f"while raising trust_remote error, settings.allow_unverified_code = {settings.allow_unverified_code}"
264-
)
265262
raise ValueError(
266263
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
267264
) from e
@@ -319,30 +316,28 @@ def load_iterables(
319316
if self.get_limit() is not None:
320317
self.log_limited_loading()
321318

322-
if not isinstance(self, LoadFromHFSpace):
323-
# try the following for LoadHF only
324-
if self.split is not None:
325-
return {
326-
self.split: ReusableGenerator(
327-
self.split_generator, gen_kwargs={"split": self.split}
328-
)
329-
}
319+
if self.split is not None:
320+
return {
321+
self.split: ReusableGenerator(
322+
self.split_generator, gen_kwargs={"split": self.split}
323+
)
324+
}
330325

331-
try:
332-
split_names = get_dataset_split_names(
333-
path=self.path,
334-
config_name=self.name,
335-
trust_remote_code=settings.allow_unverified_code,
326+
try:
327+
split_names = get_dataset_split_names(
328+
path=self.path,
329+
config_name=self.name,
330+
trust_remote_code=settings.allow_unverified_code,
331+
)
332+
return {
333+
split_name: ReusableGenerator(
334+
self.split_generator, gen_kwargs={"split": split_name}
336335
)
337-
return {
338-
split_name: ReusableGenerator(
339-
self.split_generator, gen_kwargs={"split": split_name}
340-
)
341-
for split_name in split_names
342-
}
336+
for split_name in split_names
337+
}
343338

344-
except:
345-
pass # do nothing, and just continue to the usual load dataset
339+
except:
340+
pass # do nothing, and just continue to the usual load dataset
346341
# self.split is None and
347342
# split names are not known before the splits themselves are loaded, and we need to load them here
348343

@@ -495,14 +490,24 @@ def prepare(self):
495490
self.downloader = getattr(sklearn_datatasets, f"fetch_{self.dataset_name}")
496491

497492
def load_iterables(self):
498-
with TemporaryDirectory() as temp_directory:
499-
for split in self.splits:
500-
split_data = self.downloader(subset=split)
501-
targets = [split_data["target_names"][t] for t in split_data["target"]]
502-
df = pd.DataFrame([split_data["data"], targets]).T
503-
df.columns = ["data", "target"]
504-
df.to_csv(os.path.join(temp_directory, f"{split}.csv"), index=None)
505-
return hf_load_dataset(temp_directory, streaming=False)
493+
return {
494+
split_name: ReusableGenerator(
495+
self.split_generator, gen_kwargs={"split": split_name}
496+
)
497+
for split_name in self.splits
498+
}
499+
500+
def split_generator(self, split: str) -> Generator:
501+
dataset = self.__class__._loader_cache.get(str(self) + "_" + split, None)
502+
if dataset is None:
503+
split_data = self.downloader(subset=split)
504+
targets = [split_data["target_names"][t] for t in split_data["target"]]
505+
df = pd.DataFrame([split_data["data"], targets]).T
506+
df.columns = ["data", "target"]
507+
dataset = df.to_dict("records")
508+
self.__class__._loader_cache.max_size = settings.loader_cache_size
509+
self.__class__._loader_cache[str(self) + "_" + split] = dataset
510+
yield from dataset
506511

507512

508513
class MissingKaggleCredentialsError(ValueError):

utils/.secrets.baseline

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
"filename": "src/unitxt/loaders.py",
152152
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
153153
"is_verified": false,
154-
"line_number": 565,
154+
"line_number": 595,
155155
"is_secret": false
156156
}
157157
],

0 commit comments

Comments
 (0)