Skip to content

Commit d2f4e89

Browse files
committed
fix the dp bug
1 parent f91faca commit d2f4e89

2 files changed

Lines changed: 9 additions & 1 deletion

File tree

β€Žpaddlenlp/trainer/trainer.pyβ€Ž

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2816,7 +2816,12 @@ def _save_checkpoint(self, model, metrics=None):
28162816
or "remove_master_weight" not in self.args.unified_checkpoint_config
28172817
):
28182818
paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}"))
2819-
if self.args.should_save or self.args.use_expert_parallel:
2819+
2820+
if (
2821+
self.args.should_save
2822+
or self.args.use_expert_parallel
2823+
or (self.args.data_parallel_degree > 1 and not self.args.use_hybrid_parallel)
2824+
):
28202825
if not self.args.use_hybrid_parallel:
28212826
logger.info("Saving optimizer files.")
28222827
if self.args.unified_checkpoint:

β€Žpaddlenlp/trainer/training_args.pyβ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,6 +1807,9 @@ def is_segment_parallel_supported():
18071807
# DP use hybrid group
18081808
strategy = fleet.DistributedStrategy()
18091809
fleet.init(is_collective=True, strategy=strategy)
1810+
elif self.using_flex_checkpoint:
1811+
strategy = fleet.DistributedStrategy()
1812+
fleet.init(is_collective=True, strategy=strategy)
18101813
else:
18111814
paddle.distributed.init_parallel_env()
18121815

0 commit comments

Comments
Β (0)