|
57 | 57 | IterableDataset,
|
58 | 58 | IterableDatasetDict,
|
59 | 59 | get_dataset_split_names,
|
60 |
| - load_dataset_builder, |
61 | 60 | )
|
62 | 61 | from datasets import load_dataset as hf_load_dataset
|
63 | 62 | from huggingface_hub import HfApi
|
@@ -168,7 +167,7 @@ def load_data(self) -> MultiStream:
|
168 | 167 | self.__class__._loader_cache.max_size = settings.loader_cache_size
|
169 | 168 | self.__class__._loader_cache[str(self)] = iterables
|
170 | 169 | if isoftype(iterables, Dict[str, ReusableGenerator]):
|
171 |
| - return MultiStream.from_generators(iterables) |
| 170 | + return MultiStream.from_generators(iterables, copying=True) |
172 | 171 | return MultiStream.from_iterables(iterables, copying=True)
|
173 | 172 |
|
174 | 173 | def process(self) -> MultiStream:
|
@@ -259,9 +258,6 @@ def stream_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDatas
|
259 | 258 | )
|
260 | 259 | except ValueError as e:
|
261 | 260 | if "trust_remote_code" in str(e):
|
262 |
| - logger.critical( |
263 |
| - f"while raising trust_remote error, settings.allow_unverified_code = {settings.allow_unverified_code}" |
264 |
| - ) |
265 | 261 | raise ValueError(
|
266 | 262 | f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
|
267 | 263 | ) from e
|
@@ -319,30 +315,28 @@ def load_iterables(
|
319 | 315 | if self.get_limit() is not None:
|
320 | 316 | self.log_limited_loading()
|
321 | 317 |
|
322 |
| - if not isinstance(self, LoadFromHFSpace): |
323 |
| - # try the following for LoadHF only |
324 |
| - if self.split is not None: |
325 |
| - return { |
326 |
| - self.split: ReusableGenerator( |
327 |
| - self.split_generator, gen_kwargs={"split": self.split} |
328 |
| - ) |
329 |
| - } |
| 318 | + if self.split is not None: |
| 319 | + return { |
| 320 | + self.split: ReusableGenerator( |
| 321 | + self.split_generator, gen_kwargs={"split": self.split} |
| 322 | + ) |
| 323 | + } |
330 | 324 |
|
331 |
| - try: |
332 |
| - split_names = get_dataset_split_names( |
333 |
| - path=self.path, |
334 |
| - config_name=self.name, |
335 |
| - trust_remote_code=settings.allow_unverified_code, |
| 325 | + try: |
| 326 | + split_names = get_dataset_split_names( |
| 327 | + path=self.path, |
| 328 | + config_name=self.name, |
| 329 | + trust_remote_code=settings.allow_unverified_code, |
| 330 | + ) |
| 331 | + return { |
| 332 | + split_name: ReusableGenerator( |
| 333 | + self.split_generator, gen_kwargs={"split": split_name} |
336 | 334 | )
|
337 |
| - return { |
338 |
| - split_name: ReusableGenerator( |
339 |
| - self.split_generator, gen_kwargs={"split": split_name} |
340 |
| - ) |
341 |
| - for split_name in split_names |
342 |
| - } |
| 335 | + for split_name in split_names |
| 336 | + } |
343 | 337 |
|
344 |
| - except: |
345 |
| - pass # do nothing, and just continue to the usual load dataset |
| 338 | + except: |
| 339 | + pass # do nothing, and just continue to the usual load dataset |
346 | 340 | # self.split is None and
|
347 | 341 | # split names are not known before the splits themselves are loaded, and we need to load them here
|
348 | 342 |
|
@@ -473,14 +467,24 @@ def prepare(self):
|
473 | 467 | self.downloader = getattr(sklearn_datatasets, f"fetch_{self.dataset_name}")
|
474 | 468 |
|
475 | 469 | def load_iterables(self):
|
476 |
| - with TemporaryDirectory() as temp_directory: |
477 |
| - for split in self.splits: |
478 |
| - split_data = self.downloader(subset=split) |
479 |
| - targets = [split_data["target_names"][t] for t in split_data["target"]] |
480 |
| - df = pd.DataFrame([split_data["data"], targets]).T |
481 |
| - df.columns = ["data", "target"] |
482 |
| - df.to_csv(os.path.join(temp_directory, f"{split}.csv"), index=None) |
483 |
| - return hf_load_dataset(temp_directory, streaming=False) |
| 470 | + return { |
| 471 | + split_name: ReusableGenerator( |
| 472 | + self.split_generator, gen_kwargs={"split": split_name} |
| 473 | + ) |
| 474 | + for split_name in self.splits |
| 475 | + } |
| 476 | + |
| 477 | + def split_generator(self, split: str) -> Generator: |
| 478 | + dataset = self.__class__._loader_cache.get(str(self) + "_" + split, None) |
| 479 | + if dataset is None: |
| 480 | + split_data = self.downloader(subset=split) |
| 481 | + targets = [split_data["target_names"][t] for t in split_data["target"]] |
| 482 | + df = pd.DataFrame([split_data["data"], targets]).T |
| 483 | + df.columns = ["data", "target"] |
| 484 | + dataset = df.to_dict("records") |
| 485 | + self.__class__._loader_cache.max_size = settings.loader_cache_size |
| 486 | + self.__class__._loader_cache[str(self) + "_" + split] = dataset |
| 487 | + yield from dataset |
484 | 488 |
|
485 | 489 |
|
486 | 490 | class MissingKaggleCredentialsError(ValueError):
|
|
0 commit comments