Skip to content

Commit 586b4ae

Browse files
committed
refresh loaders from just_lazy_loader
Signed-off-by: dafnapension <[email protected]>
1 parent ab37ba6 commit 586b4ae

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

src/unitxt/loaders.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from tqdm import tqdm
6868

6969
from .dataclass import OptionalField
70-
from .error_utils import UnitxtError
70+
from .error_utils import UnitxtError, UnitxtWarning
7171
from .fusion import FixedFusion
7272
from .generator_utils import ReusableGenerator
7373
from .logging_utils import get_logger
@@ -227,7 +227,7 @@ class LoadHF(Loader):
227227
Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
228228
] = None
229229
revision: Optional[str] = None
230-
streaming: bool = None
230+
streaming = None
231231
filtering_lambda: Optional[str] = None
232232
num_proc: Optional[int] = None
233233
requirements_list: List[str] = OptionalField(default_factory=list)
@@ -314,27 +314,25 @@ def load_dataset(
314314
next(iter(dataset[k]))
315315
break
316316

317-
except:
318-
try:
319-
current_streaming = kwargs["streaming"]
320-
logger.info(
321-
f"needed to swap streaming from {current_streaming} to {not current_streaming} for path {self.path}"
322-
)
323-
# try the opposite way of streaming
324-
kwargs["streaming"] = not kwargs["streaming"]
325-
dataset = hf_load_dataset(**kwargs)
326-
if isinstance(dataset, (Dataset, IterableDataset)):
327-
next(iter(dataset))
328-
else:
329-
for k in dataset.keys():
330-
next(iter(dataset[k]))
331-
break
332-
333-
except ValueError as e:
334-
if "trust_remote_code" in str(e):
335-
raise ValueError(
336-
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
337-
) from e
317+
except Exception as e:
318+
if e is ValueError and "trust_remote_code" in str(e):
319+
raise ValueError(
320+
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
321+
) from e
322+
323+
current_streaming = kwargs["streaming"]
324+
logger.info(
325+
f"needed to swap streaming from {current_streaming} to {not current_streaming} for path {self.path}"
326+
)
327+
# try the opposite way of streaming
328+
kwargs["streaming"] = not kwargs["streaming"]
329+
dataset = hf_load_dataset(**kwargs)
330+
if isinstance(dataset, (Dataset, IterableDataset)):
331+
next(iter(dataset))
332+
else:
333+
for k in dataset.keys():
334+
next(iter(dataset[k]))
335+
break
338336

339337
if self.filtering_lambda is not None:
340338
dataset = dataset.filter(eval(self.filtering_lambda))
@@ -373,6 +371,9 @@ def get_splits(self) -> List[str]:
373371
# split names are known before the split themselves are pulled from HF,
374372
# and we can postpone that pulling of the splits until actually demanded
375373
return list(dataset_info.splits.keys())
374+
UnitxtWarning(
375+
f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
376+
)
376377
return None
377378
except:
378379
return None
@@ -915,9 +916,9 @@ class LoadFromHFSpace(LoadHF):
915916
)
916917
"""
917918

919+
path = None
918920
space_name: str
919921
data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
920-
path: Optional[str] = None
921922
revision: Optional[str] = None
922923
use_token: Optional[bool] = None
923924
token_env: Optional[str] = None
@@ -1055,8 +1056,6 @@ def _maybe_set_classification_policy(self):
10551056
def load_data(self):
10561057
self._map_wildcard_path_to_full_paths()
10571058
self.path = self._download_data()
1058-
if self.splits is None and isinstance(self.data_files, dict):
1059-
self.splits = sorted(self.data_files.keys())
10601059

10611060
return super().load_data()
10621061

@@ -1091,7 +1090,7 @@ class LoadFromAPI(Loader):
10911090

10921091
urls: Dict[str, str]
10931092
chunksize: int = 100000
1094-
streaming: bool = False
1093+
streaming = False
10951094
api_key_env_var: str = "SQL_API_KEY"
10961095
headers: Optional[Dict[str, Any]] = None
10971096
data_field: str = "data"

utils/.secrets.baseline

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
"filename": "src/unitxt/loaders.py",
152152
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
153153
"is_verified": false,
154-
"line_number": 629,
154+
"line_number": 630,
155155
"is_secret": false
156156
}
157157
],
@@ -184,5 +184,5 @@
184184
}
185185
]
186186
},
187-
"generated_at": "2025-02-08T13:56:45Z"
187+
"generated_at": "2025-02-09T12:07:07Z"
188188
}

0 commit comments

Comments
 (0)