Skip to content

Commit fd89934

Browse files
Merge branch 'main' into fix-azure-openai
2 parents 41ef04a + 287a801 commit fd89934

File tree

5 files changed

+38
-35
lines changed

5 files changed

+38
-35
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,5 @@ benchmark_output/*
161161
src.lock
162162
docs/_static/data.js
163163
cache
164+
165+
inference_engine_cache/

pyproject.toml

+1-4
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ docs = [
6060
"datasets",
6161
"evaluate",
6262
"nltk",
63-
"sacrebleu",
64-
"absl-py",
6563
"rouge_score",
6664
"scikit-learn",
6765
"jiwer",
@@ -89,8 +87,7 @@ tests = [
8987
"editdistance",
9088
"rouge-score",
9189
"nltk",
92-
"mecab-python3",
93-
"sacrebleu[ko]",
90+
"sacrebleu[ko,ja]",
9491
"scikit-learn<=1.5.2",
9592
"jiwer",
9693
"conllu",

src/unitxt/error_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Documentation:
1818
BENCHMARKS = "docs/benchmark.html"
1919
DATA_CLASSIFICATION_POLICY = "docs/data_classification_policy.html"
2020
CATALOG = "docs/saving_and_loading_from_catalog.html"
21+
SETTINGS = "docs/settings.html"
2122

2223

2324
def additional_info(path: str) -> str:

src/unitxt/llm_as_judge.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class LLMJudge(BulkInstanceMetric):
5757
# option_selection_strategy: OptionSelectionStrategyEnum = (
5858
# OptionSelectionStrategyEnum.PARSE_OUTPUT_TEXT
5959
# )
60-
evaluator_name: EvaluatorNameEnum = None
60+
evaluator_name: Optional[Union[str,EvaluatorNameEnum]] = None
6161
check_positional_bias: bool = True
6262
context_fields: Union[str, List[str], Dict[str, str]] = ["context"]
6363
generate_summaries: bool = True
@@ -78,8 +78,6 @@ def prepare(self):
7878

7979
if self.evaluator_name is None:
8080
self.evaluator_name = self.inference_engine.get_engine_id()
81-
elif not isinstance(self.evaluator_name, EvaluatorNameEnum):
82-
self.evaluator_name = EvaluatorNameEnum[self.evaluator_name]
8381

8482
def before_process_multi_stream(self):
8583
super().before_process_multi_stream()

src/unitxt/loaders.py

+33-28
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from tqdm import tqdm
6868

6969
from .dataclass import NonPositionalField
70-
from .error_utils import UnitxtError, UnitxtWarning
70+
from .error_utils import Documentation, UnitxtError, UnitxtWarning
7171
from .fusion import FixedFusion
7272
from .logging_utils import get_logger
7373
from .operator import SourceOperator
@@ -80,19 +80,27 @@
8080
logger = get_logger()
8181
settings = get_settings()
8282

83+
class UnitxtUnverifiedCodeError(UnitxtError):
84+
def __init__(self, path):
85+
super().__init__(f"Loader cannot load and run remote code from {path} in huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE.", Documentation.SETTINGS)
86+
8387
def hf_load_dataset(path: str, *args, **kwargs):
8488
if settings.hf_offline_datasets_path is not None:
8589
path = os.path.join(settings.hf_offline_datasets_path, path)
86-
return _hf_load_dataset(
87-
path,
88-
*args, **kwargs,
89-
download_config=DownloadConfig(
90-
max_retries=settings.loaders_max_retries,
91-
),
92-
verification_mode="no_checks",
93-
trust_remote_code=settings.allow_unverified_code,
94-
download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists"
95-
)
90+
try:
91+
return _hf_load_dataset(
92+
path,
93+
*args, **kwargs,
94+
download_config=DownloadConfig(
95+
max_retries=settings.loaders_max_retries,
96+
),
97+
verification_mode="no_checks",
98+
trust_remote_code=settings.allow_unverified_code,
99+
download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists"
100+
)
101+
except ValueError as e:
102+
if "trust_remote_code" in str(e):
103+
raise UnitxtUnverifiedCodeError(path) from e
96104

97105
class Loader(SourceOperator):
98106
"""A base class for all loaders.
@@ -288,22 +296,17 @@ def load_dataset(
288296
if dataset is None:
289297
if streaming is None:
290298
streaming = self.is_streaming()
291-
try:
292-
dataset = hf_load_dataset(
293-
self.path,
294-
name=self.name,
295-
data_dir=self.data_dir,
296-
data_files=self.data_files,
297-
revision=self.revision,
298-
streaming=streaming,
299-
split=split,
300-
num_proc=self.num_proc,
301-
)
302-
except ValueError as e:
303-
if "trust_remote_code" in str(e):
304-
raise ValueError(
305-
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
306-
) from e
299+
300+
dataset = hf_load_dataset(
301+
self.path,
302+
name=self.name,
303+
data_dir=self.data_dir,
304+
data_files=self.data_files,
305+
revision=self.revision,
306+
streaming=streaming,
307+
split=split,
308+
num_proc=self.num_proc,
309+
)
307310
self.__class__._loader_cache.max_size = settings.loader_cache_size
308311
if not disable_memory_caching:
309312
self.__class__._loader_cache[dataset_id] = dataset
@@ -333,7 +336,9 @@ def get_splits(self):
333336
extract_on_the_fly=True,
334337
),
335338
)
336-
except:
339+
except Exception as e:
340+
if "trust_remote_code" in str(e):
341+
raise UnitxtUnverifiedCodeError(self.path) from e
337342
UnitxtWarning(
338343
f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
339344
)

0 commit comments

Comments
 (0)