53
53
54
54
import pandas as pd
55
55
import requests
56
- from datasets import IterableDataset , IterableDatasetDict , load_dataset_builder
56
+ from datasets import (
57
+ IterableDataset ,
58
+ IterableDatasetDict ,
59
+ get_dataset_split_names ,
60
+ load_dataset_builder ,
61
+ )
57
62
from datasets import load_dataset as hf_load_dataset
58
63
from huggingface_hub import HfApi
59
64
from tqdm import tqdm
@@ -232,11 +237,6 @@ def filter_load(self, dataset):
232
237
logger .info (f"\n Loading filtered by: { self .filtering_lambda } ;" )
233
238
return dataset .filter (eval (self .filtering_lambda ))
234
239
235
- def log_limited_loading (self , split : str ):
236
- logger .info (
237
- f"\n Loading of split { split } limited to { self .get_limit ()} instances by setting { self .get_limiter ()} ;"
238
- )
239
-
240
240
# returns Dict when split names are not known in advance, and just the single split dataset - if known
241
241
def stream_dataset (self , split : str ) -> Union [IterableDatasetDict , IterableDataset ]:
242
242
with tempfile .TemporaryDirectory () as dir_to_be_deleted :
@@ -259,6 +259,9 @@ def stream_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDatas
259
259
)
260
260
except ValueError as e :
261
261
if "trust_remote_code" in str (e ):
262
+ logger .critical (
263
+ f"while raising trust_remote error, settings.allow_unverified_code = { settings .allow_unverified_code } "
264
+ )
262
265
raise ValueError (
263
266
f"{ self .__class__ .__name__ } cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
264
267
) from e
@@ -312,6 +315,10 @@ def _maybe_set_classification_policy(self):
312
315
def load_iterables (
313
316
self
314
317
) -> Union [Dict [str , ReusableGenerator ], IterableDatasetDict ]:
318
+ # log limit once for the whole data
319
+ if self .get_limit () is not None :
320
+ self .log_limited_loading ()
321
+
315
322
if not isinstance (self , LoadFromHFSpace ):
316
323
# try the following for LoadHF only
317
324
if self .split is not None :
@@ -322,22 +329,17 @@ def load_iterables(
322
329
}
323
330
324
331
try :
325
- ds_builder = load_dataset_builder (
326
- self .path ,
327
- self .name ,
332
+ split_names = get_dataset_split_names (
333
+ path = self .path ,
334
+ config_name = self .name ,
328
335
trust_remote_code = settings .allow_unverified_code ,
329
336
)
330
- dataset_info = ds_builder .info
331
- if dataset_info .splits is not None :
332
- # split names are known before the split themselves are pulled from HF,
333
- # and we can postpone that pulling of the splits until actually demanded
334
- split_names = list (dataset_info .splits .keys ())
335
- return {
336
- split_name : ReusableGenerator (
337
- self .split_generator , gen_kwargs = {"split" : split_name }
338
- )
339
- for split_name in split_names
340
- }
337
+ return {
338
+ split_name : ReusableGenerator (
339
+ self .split_generator , gen_kwargs = {"split" : split_name }
340
+ )
341
+ for split_name in split_names
342
+ }
341
343
342
344
except :
343
345
pass # do nothing, and just continue to the usual load dataset
@@ -356,8 +358,6 @@ def load_iterables(
356
358
357
359
limit = self .get_limit ()
358
360
if limit is not None :
359
- for split_name in dataset :
360
- self .log_limited_loading (split_name )
361
361
result = {}
362
362
for split_name in dataset :
363
363
result [split_name ] = dataset [split_name ].take (limit )
@@ -366,26 +366,24 @@ def load_iterables(
366
366
return dataset
367
367
368
368
def split_generator (self , split : str ) -> Generator :
369
- try :
370
- dataset = self . stream_dataset ( split )
371
- except (
372
- NotImplementedError
373
- ) : # streaming is not supported for zipped files so we load without streaming
374
- dataset = self .load_dataset (split )
369
+ dataset = self . __class__ . _loader_cache . get ( str ( self ) + "_" + split , None )
370
+ if dataset is None :
371
+ try :
372
+ dataset = self . stream_dataset ( split )
373
+ except NotImplementedError : # streaming is not supported for zipped files so we load without streaming
374
+ dataset = self .load_dataset (split )
375
375
376
- if self .filtering_lambda is not None :
377
- dataset = self .filter_load (dataset )
376
+ if self .filtering_lambda is not None :
377
+ dataset = self .filter_load (dataset )
378
378
379
- limit = self .get_limit ()
380
- if limit is not None :
381
- self .log_limited_loading (split )
382
- dataset = dataset .take (limit )
379
+ limit = self .get_limit ()
380
+ if limit is not None :
381
+ dataset = dataset .take (limit )
383
382
384
- yield from dataset
383
+ self .__class__ ._loader_cache .max_size = settings .loader_cache_size
384
+ self .__class__ ._loader_cache [str (self ) + "_" + split ] = dataset
385
385
386
- def process (self ) -> MultiStream :
387
- self ._maybe_set_classification_policy ()
388
- return self .add_data_classification (self .load_data ())
386
+ yield from dataset
389
387
390
388
391
389
class LoadCSV (Loader ):
@@ -442,6 +440,9 @@ def get_args(self):
442
440
return args
443
441
444
442
def load_iterables (self ):
443
+ # log once for the whole data
444
+ if self .get_limit () is not None :
445
+ self .log_limited_loading ()
445
446
iterables = {}
446
447
for split_name in self .files .keys ():
447
448
iterables [split_name ] = ReusableGenerator (
@@ -451,14 +452,9 @@ def load_iterables(self):
451
452
return iterables
452
453
453
454
def split_generator (self , split : str ) -> Generator :
454
- if self .get_limit () is not None :
455
- self .log_limited_loading ()
456
- dataset = pd .read_csv (
457
- self .files [split ], nrows = self .get_limit (), sep = self .sep
458
- ).to_dict ("records" )
459
- else :
460
- dataset = pd .read_csv (self .files [split ], sep = self .sep ).to_dict ("records" )
461
-
455
+ dataset = pd .read_csv (
456
+ self .files [split ], nrows = self .get_limit (), sep = self .sep
457
+ ).to_dict ("records" )
462
458
yield from dataset
463
459
464
460
0 commit comments