File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
primus/modules/trainer/megatron Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -446,19 +446,25 @@ def setup(self):
446446 # Data stuff.
447447 self .app_metrics ["app_build_dataiters_start_time" ] = one_logger_utils .get_timestamp_in_ms ()
448448 timers ("train/valid/test-data-iterators-setup" , log_level = 0 ).start (barrier = True )
449+
450+ def train_valid_test_datasets_provider_func (train_val_test_num_samples ):
451+ return self .train_valid_test_datasets_provider (train_val_test_num_samples )
452+
453+ train_valid_test_datasets_provider_func .is_distributed = True
454+
449455 if args .virtual_pipeline_model_parallel_size is not None :
450456 self .train_data_iterator = []
451457 self .valid_data_iterator = []
452458 self .test_data_iterator = []
453459 for i in range (len (self .model )):
454460 mpu .set_virtual_pipeline_model_parallel_rank (i )
455- iterators = build_train_valid_test_data_iterators (self . train_valid_test_datasets_provider )
461+ iterators = build_train_valid_test_data_iterators (train_valid_test_datasets_provider_func )
456462 self .train_data_iterator .append (iterators [0 ])
457463 self .valid_data_iterator .append (iterators [1 ])
458464 self .test_data_iterator .append (iterators [2 ])
459465 else :
460466 self .train_data_iterator , self .valid_data_iterator , self .test_data_iterator = (
461- build_train_valid_test_data_iterators (self . train_valid_test_datasets_provider )
467+ build_train_valid_test_data_iterators (train_valid_test_datasets_provider_func )
462468 )
463469 timers ("train/valid/test-data-iterators-setup" ).stop ()
464470 print_datetime ("after dataloaders are built" )
You can’t perform that action at this time.
0 commit comments