|
57 | 57 | IterableDataset,
|
58 | 58 | IterableDatasetDict,
|
59 | 59 | get_dataset_split_names,
|
60 |
| - load_dataset_builder, |
61 | 60 | )
|
62 | 61 | from datasets import load_dataset as hf_load_dataset
|
63 | 62 | from huggingface_hub import HfApi
|
@@ -168,7 +167,7 @@ def load_data(self) -> MultiStream:
|
168 | 167 | self.__class__._loader_cache.max_size = settings.loader_cache_size
|
169 | 168 | self.__class__._loader_cache[str(self)] = iterables
|
170 | 169 | if isoftype(iterables, Dict[str, ReusableGenerator]):
|
171 |
| - return MultiStream.from_generators(iterables) |
| 170 | + return MultiStream.from_generators(iterables, copying=True) |
172 | 171 | return MultiStream.from_iterables(iterables, copying=True)
|
173 | 172 |
|
174 | 173 | def process(self) -> MultiStream:
|
@@ -476,11 +475,15 @@ def load_iterables(self):
|
476 | 475 | }
|
477 | 476 |
|
478 | 477 | def split_generator(self, split: str) -> Generator:
|
479 |
| - split_data = self.downloader(subset=split) |
480 |
| - targets = [split_data["target_names"][t] for t in split_data["target"]] |
481 |
| - df = pd.DataFrame([split_data["data"], targets]).T |
482 |
| - df.columns = ["data", "target"] |
483 |
| - dataset = df.to_dict("records") |
| 478 | + dataset = self.__class__._loader_cache.get(str(self) + "_" + split, None) |
| 479 | + if dataset is None: |
| 480 | + split_data = self.downloader(subset=split) |
| 481 | + targets = [split_data["target_names"][t] for t in split_data["target"]] |
| 482 | + df = pd.DataFrame([split_data["data"], targets]).T |
| 483 | + df.columns = ["data", "target"] |
| 484 | + dataset = df.to_dict("records") |
| 485 | + self.__class__._loader_cache.max_size = settings.loader_cache_size |
| 486 | + self.__class__._loader_cache[str(self) + "_" + split] = dataset |
484 | 487 | yield from dataset
|
485 | 488 |
|
486 | 489 |
|
|
0 commit comments