17
17
CrossProviderInferenceEngine ,
18
18
InferenceEngine ,
19
19
)
20
+ from unitxt .loaders import Loader
20
21
from unitxt .logging_utils import get_logger
21
22
from unitxt .operator import MultiStreamOperator
22
23
from unitxt .settings_utils import get_settings
@@ -120,7 +121,11 @@ def collect_loaded_dataset_iterators(self, recipe: Union[DatasetRecipe, Benchmar
120
121
if recipe .steps [1 ].generators :
121
122
for stream_name in recipe .steps [1 ].generators :
122
123
if recipe .steps [1 ].generators [stream_name ].water_mark > - 1 :
123
- to_ret [stream_name ] = (recipe .steps [1 ].generators [stream_name ].measured_stream .gen_kwargs ["stream" ].gen_kwargs ["stream" ], recipe .steps [1 ].generators [stream_name ].water_mark )
124
+ stream = recipe .steps [1 ].generators [stream_name ].measured_stream
125
+ while not isinstance (stream .generator .__self__ , Loader ):
126
+ assert "stream" in stream .gen_kwargs
127
+ stream = stream .gen_kwargs ["stream" ]
128
+ to_ret [stream_name ] = (stream , recipe .steps [1 ].generators [stream_name ].water_mark )
124
129
else :
125
130
# recipe is a benchmark
126
131
for subset_name in recipe .subsets :
@@ -163,14 +168,23 @@ def profiler_do_the_profiling(self, dataset_query: str, **kwargs):
163
168
t0 = time ()
164
169
recipe = load_recipe (dataset_query , ** kwargs )
165
170
t0_25 = time ()
166
- recipe ()
167
- t0_5 = time ()
168
171
self .equip_with_watermarker (recipe )
172
+ t0_5 = time ()
173
+ ms = recipe ()
169
174
t1 = time ()
175
+ water_marks = self .collect_water_marks (recipe )
176
+ logger .critical (f"water marks for query { dataset_query } following recipe(): { water_marks } " )
177
+ t1_5 = time ()
170
178
dataset = _source_to_dataset (source = recipe )
171
179
t2 = time ()
180
+ water_marks = self .collect_water_marks (recipe )
181
+ logger .critical (f"water marks for query { dataset_query } following _source_to_dataset(recipe): { water_marks } " )
182
+ t2_5 = time ()
172
183
dataset = self .list_from_dataset (dataset )
173
184
t3 = time ()
185
+ water_marks = self .collect_water_marks (recipe )
186
+ logger .critical (f"water marks for query { dataset_query } following list out all from dataset: { water_marks } " )
187
+ t3_5 = time ()
174
188
model = self .profiler_instantiate_model ()
175
189
t4 = time ()
176
190
if isinstance (dataset , dict ):
@@ -181,31 +195,31 @@ def profiler_do_the_profiling(self, dataset_query: str, **kwargs):
181
195
dataset = dataset [split_name ]
182
196
predictions = model .infer (dataset = dataset )
183
197
t5 = time ()
184
- evaluation_result = evaluate (predictions = predictions , data = dataset )
198
+ evaluate (predictions = predictions , data = dataset )
185
199
t6 = time ()
186
200
# now just streaming through recipe, without generating an HF dataset:
187
201
ms = recipe ()
188
202
total_production_length_of_recipe = {k : len (list (ms [k ])) for k in ms }
189
203
t7 = time ()
190
204
# now just loading the specific instances actually loaded above, and listing right after recipe.loader(),
191
205
# to report the loading time from the total processing time.
192
- water_marks = self .collect_water_marks (recipe )
206
+ # water_marks = self.collect_water_marks(recipe)
193
207
pulling_dict = self .collect_loaded_dataset_iterators (recipe )
194
208
t8 = time ()
195
209
self .enumerate_from_loaders (pulling_dict )
196
210
t9 = time ()
197
- logger .critical (f"water marks = { water_marks } " )
198
- logger .critical (f"length of evaluation_result, over the returned dataset from Unitxt.load_dataset: { len (evaluation_result )} " )
211
+ # logger.critical(f"water marks = {water_marks}")
212
+ # logger.critical(f"length of evaluation_result, over the returned dataset from Unitxt.load_dataset: {len(evaluation_result)}")
199
213
logger .critical (f"lengths of total production of recipe: { total_production_length_of_recipe } " )
200
214
201
215
return {
202
216
"load_recipe" : t0_25 - t0 ,
203
- "recipe()" : t0_5 - t0_25 ,
204
- "source_to_dataset" : t2 - t1 ,
205
- "list_out_dataset" : t3 - t2 ,
217
+ "recipe()" : t1 - t0_5 ,
218
+ "source_to_dataset" : t2 - t1_5 ,
219
+ "list_out_dataset" : t3 - t2_5 ,
206
220
"just_load_and_list" : t9 - t8 ,
207
221
"just_stream_through_recipe" : t7 - t6 ,
208
- "instantiate_model" : t4 - t3 ,
222
+ "instantiate_model" : t4 - t3_5 ,
209
223
"inference_time" : t5 - t4 ,
210
224
"evaluation_time" : t6 - t5 ,
211
225
}
@@ -239,7 +253,6 @@ def profile_no_cprofile():
239
253
res [k ] += dsq_time [k ]
240
254
return {k : round (res [k ], 3 ) for k in res }
241
255
242
-
243
256
def find_cummtime_of (func_name : str , file_name : str , pst_printout : str ) -> float :
244
257
relevant_lines = list (
245
258
filter (
@@ -312,13 +325,34 @@ def main():
312
325
action = "store_true" ,
313
326
help = "whether to employ cProfile or just time diffs." ,
314
327
)
328
+ parser .add_argument (
329
+ "--populate_fs_cache" ,
330
+ action = "store_true" ,
331
+ help = "whether to save the downloaded datasets to a file-system cache." ,
332
+ )
315
333
args = parser .parse_args ()
316
334
317
335
# Ensure the directory for the output file exists
318
336
output_dir = os .path .dirname (args .output_file )
319
337
if output_dir :
320
338
os .makedirs (output_dir , exist_ok = True )
321
339
340
+ if args .populate_fs_cache :
341
+ assert os .path .exists (settings .hf_offline_datasets_path )
342
+ assert settings .hf_save_to_offline
343
+ t0 = time ()
344
+ queries = dataset_query if isinstance (dataset_query , list ) else [dataset_query ]
345
+ for dsq in queries :
346
+ recipe = load_recipe (dsq )
347
+ ms = recipe ()
348
+ for split in ms :
349
+ list (ms [split ])
350
+ t1 = time ()
351
+ print (f"Time to fetch the needed datasets from their hubs and save them in the local file-system: { round (t1 - t0 ,3 )} seconds" )
352
+ return
353
+
354
+ if settings .hf_load_from_offline :
355
+ assert os .path .exists (settings .hf_offline_datasets_path )
322
356
323
357
dict_to_print = profile_no_cprofile ()
324
358
0 commit comments