Skip to content

Commit c8dbded

Browse files
committed
add cache to sklearn
Signed-off-by: dafnapension <[email protected]>
1 parent 9c78791 commit c8dbded

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/unitxt/loaders.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
IterableDataset,
5858
IterableDatasetDict,
5959
get_dataset_split_names,
60-
load_dataset_builder,
6160
)
6261
from datasets import load_dataset as hf_load_dataset
6362
from huggingface_hub import HfApi
@@ -168,7 +167,7 @@ def load_data(self) -> MultiStream:
168167
self.__class__._loader_cache.max_size = settings.loader_cache_size
169168
self.__class__._loader_cache[str(self)] = iterables
170169
if isoftype(iterables, Dict[str, ReusableGenerator]):
171-
return MultiStream.from_generators(iterables)
170+
return MultiStream.from_generators(iterables, copying=True)
172171
return MultiStream.from_iterables(iterables, copying=True)
173172

174173
def process(self) -> MultiStream:
@@ -476,11 +475,15 @@ def load_iterables(self):
476475
}
477476

478477
def split_generator(self, split: str) -> Generator:
479-
split_data = self.downloader(subset=split)
480-
targets = [split_data["target_names"][t] for t in split_data["target"]]
481-
df = pd.DataFrame([split_data["data"], targets]).T
482-
df.columns = ["data", "target"]
483-
dataset = df.to_dict("records")
478+
dataset = self.__class__._loader_cache.get(str(self) + "_" + split, None)
479+
if dataset is None:
480+
split_data = self.downloader(subset=split)
481+
targets = [split_data["target_names"][t] for t in split_data["target"]]
482+
df = pd.DataFrame([split_data["data"], targets]).T
483+
df.columns = ["data", "target"]
484+
dataset = df.to_dict("records")
485+
self.__class__._loader_cache.max_size = settings.loader_cache_size
486+
self.__class__._loader_cache[str(self) + "_" + split] = dataset
484487
yield from dataset
485488

486489

utils/.secrets.baseline

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
"filename": "src/unitxt/loaders.py",
152152
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
153153
"is_verified": false,
154-
"line_number": 566,
154+
"line_number": 572,
155155
"is_secret": false
156156
}
157157
],
@@ -184,5 +184,5 @@
184184
}
185185
]
186186
},
187-
"generated_at": "2025-01-22T20:27:31Z"
187+
"generated_at": "2025-01-23T10:07:40Z"
188188
}

0 commit comments

Comments
 (0)