Skip to content

Commit 4cf6a69

Browse files
committed
make sklearn loader too - a lazy loader
Signed-off-by: dafnapension <[email protected]>
1 parent 5f117b6 commit 4cf6a69

File tree

4 files changed

+62
-69
lines changed

4 files changed

+62
-69
lines changed

performance/bluebench_profiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs):
7676
benchmark_recipe=benchmark_recipe, split=split, **kwargs
7777
)
7878

79-
logger.critical(f"length of bluegench generated dataset: {len(dataset)}")
79+
logger.critical(f"length of bluebench generated dataset: {len(dataset)}")
8080

8181

8282
dataset_query = "benchmarks.bluebench[loader_limit=30,max_samples_per_subset=30]"

src/unitxt/fusion.py

+21-32
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
from .dataclass import NonPositionalField
55
from .operator import SourceOperator
66
from .random_utils import new_random_generator
7-
from .settings_utils import get_settings
87
from .stream import DynamicStream, MultiStream
98
from .type_utils import isoftype
109

11-
settings = get_settings()
12-
1310

1411
class BaseFusion(SourceOperator):
1512
"""BaseFusion operator that combines multiple multistreams into one.
@@ -37,11 +34,7 @@ def prepare_subsets(self):
3734
for i in range(len(self.subsets)):
3835
self.named_subsets[i] = self.subsets[i]
3936
else:
40-
for name, origin in self.subsets.items():
41-
try:
42-
self.named_subsets[name] = origin
43-
except Exception as e:
44-
raise RuntimeError(f"Exception in subset: {name}") from e
37+
self.named_subsets = self.subsets
4538

4639
def splits(self) -> List[str]:
4740
self.prepare_subsets()
@@ -78,30 +71,26 @@ def prepare(self):
7871

7972
# flake8: noqa: C901
8073
def fusion_generator(self, split) -> Generator:
81-
with settings.context(
82-
disable_hf_datasets_cache=False,
83-
allow_unverified_code=True,
84-
):
85-
for origin_name, origin in self.named_subsets.items():
86-
multi_stream = origin()
87-
if split not in multi_stream:
88-
continue
89-
emitted_from_this_split = 0
90-
try:
91-
for instance in multi_stream[split]:
92-
if (
93-
self.max_instances_per_subset is not None
94-
and emitted_from_this_split >= self.max_instances_per_subset
95-
):
96-
break
97-
if isinstance(origin_name, str):
98-
if "subset" not in instance:
99-
instance["subset"] = []
100-
instance["subset"].insert(0, origin_name)
101-
emitted_from_this_split += 1
102-
yield instance
103-
except Exception as e:
104-
raise RuntimeError(f"Exception in subset: {origin_name}") from e
74+
for origin_name, origin in self.named_subsets.items():
75+
multi_stream = origin()
76+
if split not in multi_stream:
77+
continue
78+
emitted_from_this_split = 0
79+
try:
80+
for instance in multi_stream[split]:
81+
if (
82+
self.max_instances_per_subset is not None
83+
and emitted_from_this_split >= self.max_instances_per_subset
84+
):
85+
break
86+
if isinstance(origin_name, str):
87+
if "subset" not in instance:
88+
instance["subset"] = []
89+
instance["subset"].insert(0, origin_name)
90+
emitted_from_this_split += 1
91+
yield instance
92+
except Exception as e:
93+
raise RuntimeError(f"Exception in subset: {origin_name}") from e
10594

10695

10796
class WeightedFusion(BaseFusion):

src/unitxt/loaders.py

+38-34
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
IterableDataset,
5858
IterableDatasetDict,
5959
get_dataset_split_names,
60-
load_dataset_builder,
6160
)
6261
from datasets import load_dataset as hf_load_dataset
6362
from huggingface_hub import HfApi
@@ -168,7 +167,7 @@ def load_data(self) -> MultiStream:
168167
self.__class__._loader_cache.max_size = settings.loader_cache_size
169168
self.__class__._loader_cache[str(self)] = iterables
170169
if isoftype(iterables, Dict[str, ReusableGenerator]):
171-
return MultiStream.from_generators(iterables)
170+
return MultiStream.from_generators(iterables, copying=True)
172171
return MultiStream.from_iterables(iterables, copying=True)
173172

174173
def process(self) -> MultiStream:
@@ -259,9 +258,6 @@ def stream_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDatas
259258
)
260259
except ValueError as e:
261260
if "trust_remote_code" in str(e):
262-
logger.critical(
263-
f"while raising trust_remote error, settings.allow_unverified_code = {settings.allow_unverified_code}"
264-
)
265261
raise ValueError(
266262
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
267263
) from e
@@ -319,30 +315,28 @@ def load_iterables(
319315
if self.get_limit() is not None:
320316
self.log_limited_loading()
321317

322-
if not isinstance(self, LoadFromHFSpace):
323-
# try the following for LoadHF only
324-
if self.split is not None:
325-
return {
326-
self.split: ReusableGenerator(
327-
self.split_generator, gen_kwargs={"split": self.split}
328-
)
329-
}
318+
if self.split is not None:
319+
return {
320+
self.split: ReusableGenerator(
321+
self.split_generator, gen_kwargs={"split": self.split}
322+
)
323+
}
330324

331-
try:
332-
split_names = get_dataset_split_names(
333-
path=self.path,
334-
config_name=self.name,
335-
trust_remote_code=settings.allow_unverified_code,
325+
try:
326+
split_names = get_dataset_split_names(
327+
path=self.path,
328+
config_name=self.name,
329+
trust_remote_code=settings.allow_unverified_code,
330+
)
331+
return {
332+
split_name: ReusableGenerator(
333+
self.split_generator, gen_kwargs={"split": split_name}
336334
)
337-
return {
338-
split_name: ReusableGenerator(
339-
self.split_generator, gen_kwargs={"split": split_name}
340-
)
341-
for split_name in split_names
342-
}
335+
for split_name in split_names
336+
}
343337

344-
except:
345-
pass # do nothing, and just continue to the usual load dataset
338+
except:
339+
pass # do nothing, and just continue to the usual load dataset
346340
# self.split is None and
347341
# split names are not known before the splits themselves are loaded, and we need to load them here
348342

@@ -473,14 +467,24 @@ def prepare(self):
473467
self.downloader = getattr(sklearn_datatasets, f"fetch_{self.dataset_name}")
474468

475469
def load_iterables(self):
476-
with TemporaryDirectory() as temp_directory:
477-
for split in self.splits:
478-
split_data = self.downloader(subset=split)
479-
targets = [split_data["target_names"][t] for t in split_data["target"]]
480-
df = pd.DataFrame([split_data["data"], targets]).T
481-
df.columns = ["data", "target"]
482-
df.to_csv(os.path.join(temp_directory, f"{split}.csv"), index=None)
483-
return hf_load_dataset(temp_directory, streaming=False)
470+
return {
471+
split_name: ReusableGenerator(
472+
self.split_generator, gen_kwargs={"split": split_name}
473+
)
474+
for split_name in self.splits
475+
}
476+
477+
def split_generator(self, split: str) -> Generator:
478+
dataset = self.__class__._loader_cache.get(str(self) + "_" + split, None)
479+
if dataset is None:
480+
split_data = self.downloader(subset=split)
481+
targets = [split_data["target_names"][t] for t in split_data["target"]]
482+
df = pd.DataFrame([split_data["data"], targets]).T
483+
df.columns = ["data", "target"]
484+
dataset = df.to_dict("records")
485+
self.__class__._loader_cache.max_size = settings.loader_cache_size
486+
self.__class__._loader_cache[str(self) + "_" + split] = dataset
487+
yield from dataset
484488

485489

486490
class MissingKaggleCredentialsError(ValueError):

utils/.secrets.baseline

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
"filename": "src/unitxt/loaders.py",
152152
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
153153
"is_verified": false,
154-
"line_number": 565,
154+
"line_number": 572,
155155
"is_secret": false
156156
}
157157
],
@@ -184,5 +184,5 @@
184184
}
185185
]
186186
},
187-
"generated_at": "2025-01-22T11:47:57Z"
187+
"generated_at": "2025-01-23T10:07:40Z"
188188
}

0 commit comments

Comments
 (0)