|
67 | 67 | from tqdm import tqdm
|
68 | 68 |
|
69 | 69 | from .dataclass import OptionalField
|
70 |
| -from .error_utils import UnitxtError |
| 70 | +from .error_utils import UnitxtError, UnitxtWarning |
71 | 71 | from .fusion import FixedFusion
|
72 | 72 | from .generator_utils import ReusableGenerator
|
73 | 73 | from .logging_utils import get_logger
|
@@ -227,7 +227,7 @@ class LoadHF(Loader):
|
227 | 227 | Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
|
228 | 228 | ] = None
|
229 | 229 | revision: Optional[str] = None
|
230 |
| - streaming: bool = None |
| 230 | + streaming = None |
231 | 231 | filtering_lambda: Optional[str] = None
|
232 | 232 | num_proc: Optional[int] = None
|
233 | 233 | requirements_list: List[str] = OptionalField(default_factory=list)
|
@@ -314,27 +314,25 @@ def load_dataset(
|
314 | 314 | next(iter(dataset[k]))
|
315 | 315 | break
|
316 | 316 |
|
317 |
| - except: |
318 |
| - try: |
319 |
| - current_streaming = kwargs["streaming"] |
320 |
| - logger.info( |
321 |
| - f"needed to swap streaming from {current_streaming} to {not current_streaming} for path {self.path}" |
322 |
| - ) |
323 |
| - # try the opposite way of streaming |
324 |
| - kwargs["streaming"] = not kwargs["streaming"] |
325 |
| - dataset = hf_load_dataset(**kwargs) |
326 |
| - if isinstance(dataset, (Dataset, IterableDataset)): |
327 |
| - next(iter(dataset)) |
328 |
| - else: |
329 |
| - for k in dataset.keys(): |
330 |
| - next(iter(dataset[k])) |
331 |
| - break |
332 |
| - |
333 |
| - except ValueError as e: |
334 |
| - if "trust_remote_code" in str(e): |
335 |
| - raise ValueError( |
336 |
| - f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE." |
337 |
| - ) from e |
| 317 | + except Exception as e: |
| 318 | + if e is ValueError and "trust_remote_code" in str(e): |
| 319 | + raise ValueError( |
| 320 | + f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE." |
| 321 | + ) from e |
| 322 | + |
| 323 | + current_streaming = kwargs["streaming"] |
| 324 | + logger.info( |
| 325 | + f"needed to swap streaming from {current_streaming} to {not current_streaming} for path {self.path}" |
| 326 | + ) |
| 327 | + # try the opposite way of streaming |
| 328 | + kwargs["streaming"] = not kwargs["streaming"] |
| 329 | + dataset = hf_load_dataset(**kwargs) |
| 330 | + if isinstance(dataset, (Dataset, IterableDataset)): |
| 331 | + next(iter(dataset)) |
| 332 | + else: |
| 333 | + for k in dataset.keys(): |
| 334 | + next(iter(dataset[k])) |
| 335 | + break |
338 | 336 |
|
339 | 337 | if self.filtering_lambda is not None:
|
340 | 338 | dataset = dataset.filter(eval(self.filtering_lambda))
|
@@ -373,6 +371,9 @@ def get_splits(self) -> List[str]:
|
373 | 371 | # split names are known before the split themselves are pulled from HF,
|
374 | 372 | # and we can postpone that pulling of the splits until actually demanded
|
375 | 373 | return list(dataset_info.splits.keys())
|
| 374 | + UnitxtWarning( |
| 375 | + f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.' |
| 376 | + ) |
376 | 377 | return None
|
377 | 378 | except:
|
378 | 379 | return None
|
@@ -915,9 +916,9 @@ class LoadFromHFSpace(LoadHF):
|
915 | 916 | )
|
916 | 917 | """
|
917 | 918 |
|
| 919 | + path = None |
918 | 920 | space_name: str
|
919 | 921 | data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
|
920 |
| - path: Optional[str] = None |
921 | 922 | revision: Optional[str] = None
|
922 | 923 | use_token: Optional[bool] = None
|
923 | 924 | token_env: Optional[str] = None
|
@@ -1055,8 +1056,6 @@ def _maybe_set_classification_policy(self):
|
1055 | 1056 | def load_data(self):
|
1056 | 1057 | self._map_wildcard_path_to_full_paths()
|
1057 | 1058 | self.path = self._download_data()
|
1058 |
| - if self.splits is None and isinstance(self.data_files, dict): |
1059 |
| - self.splits = sorted(self.data_files.keys()) |
1060 | 1059 |
|
1061 | 1060 | return super().load_data()
|
1062 | 1061 |
|
@@ -1091,7 +1090,7 @@ class LoadFromAPI(Loader):
|
1091 | 1090 |
|
1092 | 1091 | urls: Dict[str, str]
|
1093 | 1092 | chunksize: int = 100000
|
1094 |
| - streaming: bool = False |
| 1093 | + streaming = False |
1095 | 1094 | api_key_env_var: str = "SQL_API_KEY"
|
1096 | 1095 | headers: Optional[Dict[str, Any]] = None
|
1097 | 1096 | data_field: str = "data"
|
|
0 commit comments