Skip to content

Commit 8fdfe21

Browse files
authored
Fix a bug in loading without trust remote code (#1684)
* Fix a bug in loading without trust remote code Signed-off-by: elronbandel <[email protected]> * update Signed-off-by: elronbandel <[email protected]> --------- Signed-off-by: elronbandel <[email protected]>
1 parent ac9bdb3 commit 8fdfe21

File tree

2 files changed

+34
-28
lines changed

2 files changed

+34
-28
lines changed

src/unitxt/error_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Documentation:
1818
BENCHMARKS = "docs/benchmark.html"
1919
DATA_CLASSIFICATION_POLICY = "docs/data_classification_policy.html"
2020
CATALOG = "docs/saving_and_loading_from_catalog.html"
21+
SETTINGS = "docs/settings.html"
2122

2223

2324
def additional_info(path: str) -> str:

src/unitxt/loaders.py

+33-28
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from tqdm import tqdm
6868

6969
from .dataclass import NonPositionalField
70-
from .error_utils import UnitxtError, UnitxtWarning
70+
from .error_utils import Documentation, UnitxtError, UnitxtWarning
7171
from .fusion import FixedFusion
7272
from .logging_utils import get_logger
7373
from .operator import SourceOperator
@@ -80,19 +80,27 @@
8080
logger = get_logger()
8181
settings = get_settings()
8282

83+
class UnitxtUnverifiedCodeError(UnitxtError):
84+
def __init__(self, path):
85+
super().__init__(f"Loader cannot load and run remote code from {path} in huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE.", Documentation.SETTINGS)
86+
8387
def hf_load_dataset(path: str, *args, **kwargs):
8488
if settings.hf_offline_datasets_path is not None:
8589
path = os.path.join(settings.hf_offline_datasets_path, path)
86-
return _hf_load_dataset(
87-
path,
88-
*args, **kwargs,
89-
download_config=DownloadConfig(
90-
max_retries=settings.loaders_max_retries,
91-
),
92-
verification_mode="no_checks",
93-
trust_remote_code=settings.allow_unverified_code,
94-
download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists"
95-
)
90+
try:
91+
return _hf_load_dataset(
92+
path,
93+
*args, **kwargs,
94+
download_config=DownloadConfig(
95+
max_retries=settings.loaders_max_retries,
96+
),
97+
verification_mode="no_checks",
98+
trust_remote_code=settings.allow_unverified_code,
99+
download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists"
100+
)
101+
except ValueError as e:
102+
if "trust_remote_code" in str(e):
103+
raise UnitxtUnverifiedCodeError(path) from e
96104

97105
class Loader(SourceOperator):
98106
"""A base class for all loaders.
@@ -288,22 +296,17 @@ def load_dataset(
288296
if dataset is None:
289297
if streaming is None:
290298
streaming = self.is_streaming()
291-
try:
292-
dataset = hf_load_dataset(
293-
self.path,
294-
name=self.name,
295-
data_dir=self.data_dir,
296-
data_files=self.data_files,
297-
revision=self.revision,
298-
streaming=streaming,
299-
split=split,
300-
num_proc=self.num_proc,
301-
)
302-
except ValueError as e:
303-
if "trust_remote_code" in str(e):
304-
raise ValueError(
305-
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
306-
) from e
299+
300+
dataset = hf_load_dataset(
301+
self.path,
302+
name=self.name,
303+
data_dir=self.data_dir,
304+
data_files=self.data_files,
305+
revision=self.revision,
306+
streaming=streaming,
307+
split=split,
308+
num_proc=self.num_proc,
309+
)
307310
self.__class__._loader_cache.max_size = settings.loader_cache_size
308311
if not disable_memory_caching:
309312
self.__class__._loader_cache[dataset_id] = dataset
@@ -333,7 +336,9 @@ def get_splits(self):
333336
extract_on_the_fly=True,
334337
),
335338
)
336-
except:
339+
except Exception as e:
340+
if "trust_remote_code" in str(e):
341+
raise UnitxtUnverifiedCodeError(self.path) from e
337342
UnitxtWarning(
338343
f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
339344
)

0 commit comments

Comments
 (0)