1414import copy
1515import math
1616import os
17+ import random
1718import sys
1819import time
1920from dataclasses import dataclass , field
2021from typing import Optional
2122
23+ import numpy as np
2224import paddle
25+ import paddlefleet
2326
2427from paddleformers .data .causal_dataset import (
2528 build_train_valid_test_datasets ,
3437 StepFlexToken ,
3538 TrainingArguments ,
3639 get_last_checkpoint ,
37- set_seed ,
3840 speed_metrics ,
3941)
4042from paddleformers .trainer .trainer import Trainer
@@ -350,6 +352,31 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:
350352 )
351353
352354
355+ def _set_random_seed (
356+ seed_ : int ,
357+ data_parallel_random_init : bool = False ,
358+ te_rng_tracker : bool = False ,
359+ inference_rng_tracker : bool = False ,
360+ use_cudagraphable_rng : bool = False ,
361+ ):
362+ """Set random seed for reproducability."""
363+ if seed_ is not None and seed_ > 0 :
364+ # Ensure that different pipeline MP stages get different seeds.
365+ seed = seed_ + (100 * paddlefleet .parallel_state .get_pipeline_model_parallel_rank ())
366+ # Ensure different data parallel ranks get different seeds
367+ if data_parallel_random_init :
368+ seed = seed + (10 * paddlefleet .parallel_state .get_data_parallel_rank ())
369+ random .seed (seed )
370+ np .random .seed (seed )
371+ paddle .manual_seed (seed )
372+ if paddle .cuda .device_count () > 0 :
373+ paddlefleet .tensor_parallel .model_parallel_cuda_manual_seed (
374+ seed , te_rng_tracker , inference_rng_tracker , use_cudagraphable_rng
375+ )
376+ else :
377+ raise ValueError ("Seed ({}) should be a positive integer." .format (seed_ ))
378+
379+
353380def main ():
354381 parser = PdArgumentParser ((ModelArguments , DataArguments , PreTrainingArguments ))
355382 # Support format as "args.json --arg1 value1 --arg2 value2.”
@@ -374,7 +401,7 @@ def main():
374401 os .makedirs (data_args .data_cache , exist_ok = True )
375402
376403 paddle .set_device (training_args .device )
377- set_seed ( seed = training_args .seed )
404+ _set_random_seed ( seed_ = training_args .seed )
378405
379406 training_args .eval_iters = 10
380407 training_args .test_iters = training_args .eval_iters * 10
0 commit comments