Skip to content

Commit b1b42f9

Browse files
committed
fix tp stuck issue
1 parent 7d2f4ec commit b1b42f9

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

primus/modules/trainer/megatron/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,19 +446,25 @@ def setup(self):
446446
# Data stuff.
447447
self.app_metrics["app_build_dataiters_start_time"] = one_logger_utils.get_timestamp_in_ms()
448448
timers("train/valid/test-data-iterators-setup", log_level=0).start(barrier=True)
449+
450+
def train_valid_test_datasets_provider_func(train_val_test_num_samples):
451+
return self.train_valid_test_datasets_provider(train_val_test_num_samples)
452+
453+
train_valid_test_datasets_provider_func.is_distributed = True
454+
449455
if args.virtual_pipeline_model_parallel_size is not None:
450456
self.train_data_iterator = []
451457
self.valid_data_iterator = []
452458
self.test_data_iterator = []
453459
for i in range(len(self.model)):
454460
mpu.set_virtual_pipeline_model_parallel_rank(i)
455-
iterators = build_train_valid_test_data_iterators(self.train_valid_test_datasets_provider)
461+
iterators = build_train_valid_test_data_iterators(train_valid_test_datasets_provider_func)
456462
self.train_data_iterator.append(iterators[0])
457463
self.valid_data_iterator.append(iterators[1])
458464
self.test_data_iterator.append(iterators[2])
459465
else:
460466
self.train_data_iterator, self.valid_data_iterator, self.test_data_iterator = (
461-
build_train_valid_test_data_iterators(self.train_valid_test_datasets_provider)
467+
build_train_valid_test_data_iterators(train_valid_test_datasets_provider_func)
462468
)
463469
timers("train/valid/test-data-iterators-setup").stop()
464470
print_datetime("after dataloaders are built")

0 commit comments

Comments
 (0)