|
67 | 67 | from tqdm import tqdm
|
68 | 68 |
|
69 | 69 | from .dataclass import NonPositionalField
|
70 |
| -from .error_utils import UnitxtError, UnitxtWarning |
| 70 | +from .error_utils import Documentation, UnitxtError, UnitxtWarning |
71 | 71 | from .fusion import FixedFusion
|
72 | 72 | from .logging_utils import get_logger
|
73 | 73 | from .operator import SourceOperator
|
|
80 | 80 | logger = get_logger()
|
81 | 81 | settings = get_settings()
|
82 | 82 |
|
| 83 | +class UnitxtUnverifiedCodeError(UnitxtError): |
| 84 | + def __init__(self, path): |
| 85 | + super().__init__(f"Loader cannot load and run remote code from {path} in huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE.", Documentation.SETTINGS) |
| 86 | + |
83 | 87 | def hf_load_dataset(path: str, *args, **kwargs):
|
84 | 88 | if settings.hf_offline_datasets_path is not None:
|
85 | 89 | path = os.path.join(settings.hf_offline_datasets_path, path)
|
86 |
| - return _hf_load_dataset( |
87 |
| - path, |
88 |
| - *args, **kwargs, |
89 |
| - download_config=DownloadConfig( |
90 |
| - max_retries=settings.loaders_max_retries, |
91 |
| - ), |
92 |
| - verification_mode="no_checks", |
93 |
| - trust_remote_code=settings.allow_unverified_code, |
94 |
| - download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists" |
95 |
| - ) |
| 90 | + try: |
| 91 | + return _hf_load_dataset( |
| 92 | + path, |
| 93 | + *args, **kwargs, |
| 94 | + download_config=DownloadConfig( |
| 95 | + max_retries=settings.loaders_max_retries, |
| 96 | + ), |
| 97 | + verification_mode="no_checks", |
| 98 | + trust_remote_code=settings.allow_unverified_code, |
| 99 | + download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists" |
| 100 | + ) |
| 101 | + except ValueError as e: |
| 102 | + if "trust_remote_code" in str(e): |
| 103 | + raise UnitxtUnverifiedCodeError(path) from e |
96 | 104 |
|
97 | 105 | class Loader(SourceOperator):
|
98 | 106 | """A base class for all loaders.
|
@@ -288,22 +296,17 @@ def load_dataset(
|
288 | 296 | if dataset is None:
|
289 | 297 | if streaming is None:
|
290 | 298 | streaming = self.is_streaming()
|
291 |
| - try: |
292 |
| - dataset = hf_load_dataset( |
293 |
| - self.path, |
294 |
| - name=self.name, |
295 |
| - data_dir=self.data_dir, |
296 |
| - data_files=self.data_files, |
297 |
| - revision=self.revision, |
298 |
| - streaming=streaming, |
299 |
| - split=split, |
300 |
| - num_proc=self.num_proc, |
301 |
| - ) |
302 |
| - except ValueError as e: |
303 |
| - if "trust_remote_code" in str(e): |
304 |
| - raise ValueError( |
305 |
| - f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE." |
306 |
| - ) from e |
| 299 | + |
| 300 | + dataset = hf_load_dataset( |
| 301 | + self.path, |
| 302 | + name=self.name, |
| 303 | + data_dir=self.data_dir, |
| 304 | + data_files=self.data_files, |
| 305 | + revision=self.revision, |
| 306 | + streaming=streaming, |
| 307 | + split=split, |
| 308 | + num_proc=self.num_proc, |
| 309 | + ) |
307 | 310 | self.__class__._loader_cache.max_size = settings.loader_cache_size
|
308 | 311 | if not disable_memory_caching:
|
309 | 312 | self.__class__._loader_cache[dataset_id] = dataset
|
@@ -333,7 +336,9 @@ def get_splits(self):
|
333 | 336 | extract_on_the_fly=True,
|
334 | 337 | ),
|
335 | 338 | )
|
336 |
| - except: |
| 339 | + except Exception as e: |
| 340 | + if "trust_remote_code" in str(e): |
| 341 | + raise UnitxtUnverifiedCodeError(self.path) from e |
337 | 342 | UnitxtWarning(
|
338 | 343 | f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
|
339 | 344 | )
|
|
0 commit comments