1111from contextlib import contextmanager , nullcontext
1212from functools import partial
1313from megatron .core import mpu
14+ from megatron .core .distributed import DistributedDataParallel as DDP
1415from megatron .core .distributed import finalize_model_grads
1516from megatron .core .optimizer import OptimizerConfig , get_megatron_optimizer
1617from megatron .core .pipeline_parallel import get_forward_backward_func
2627from swift .megatron .callbacks import megatron_callbacks_map
2728from swift .megatron .model import get_mcore_model
2829from swift .megatron .tuners import LoraParallelLinear
29- from swift .megatron .utils import (copy_original_module_weight , get_optimizer_param_scheduler , get_padding_to ,
30- init_persistent_async_worker , load_mcore_checkpoint , maybe_finalize_async_save ,
30+ from swift .megatron .utils import (copy_original_module_weight , disable_forward_pre_hook , enable_forward_pre_hook ,
31+ get_optimizer_param_scheduler , get_padding_to , init_persistent_async_worker ,
32+ initialize_tp_communicators , load_mcore_checkpoint ,
33+ logical_and_across_model_parallel_group , maybe_finalize_async_save ,
3134 prepare_mcore_model , reduce_max_stat_across_model_parallel_group ,
32- save_mcore_checkpoint , wrap_model )
35+ save_mcore_checkpoint , should_disable_forward_pre_hook , wrap_model )
3336from swift .template import Template
3437from swift .trainers import dynamic_gradient_checkpointing
3538from swift .trainers .utils import patch_modelscope_hub_timeout
@@ -85,6 +88,9 @@ def __init__(self, args, template: Template):
8588 for callback in args .callbacks :
8689 self .callbacks .append (megatron_callbacks_map [callback ](self ))
8790
91+ if args .tp_comm_overlap :
92+ initialize_tp_communicators (args , self .config )
93+
8894 if args .async_save and args .use_persistent_ckpt_worker :
8995 init_persistent_async_worker ()
9096
@@ -503,7 +509,33 @@ def train(self, train_dataset, val_dataset):
503509 self ._prepare_vit_gradient_checkpointing (m )
504510
505511 config .grad_scale_func = self .optimizer .scale_loss
512+ if isinstance (self .wrapped_models [0 ], DDP ) and args .overlap_grad_reduce :
513+ assert config .no_sync_func is None , ('When overlap_grad_reduce is True, config.no_sync_func must be None; '
514+ 'a custom no_sync_func is not supported when overlapping grad-reduce' )
515+ config .no_sync_func = [model_chunk .no_sync for model_chunk in self .wrapped_models ]
516+ if len (self .wrapped_models ) == 1 :
517+ config .no_sync_func = config .no_sync_func [0 ]
518+ if args .align_grad_reduce :
519+ config .grad_sync_func = [model_chunk .start_grad_sync for model_chunk in self .wrapped_models ]
520+ if len (self .wrapped_models ) == 1 :
521+ config .grad_sync_func = config .grad_sync_func [0 ]
522+ if args .overlap_param_gather and args .align_param_gather :
523+ config .param_sync_func = [model_chunk .start_param_sync for model_chunk in self .wrapped_models ]
524+ if len (self .wrapped_models ) == 1 :
525+ config .param_sync_func = config .param_sync_func [0 ]
506526 config .finalize_model_grads_func = finalize_model_grads
527+ start_iteration = state .iteration
528+ pre_hook_enabled = False
529+ # Disable forward pre-hook to start training to ensure that errors in checkpoint loading
530+ # or random initialization don't propagate to all ranks in first all-gather (which is a
531+ # no-op if things work correctly).
532+ if should_disable_forward_pre_hook (args ):
533+ disable_forward_pre_hook (self .wrapped_models , param_sync = False )
534+ # Also remove param_sync_func temporarily so that sync calls made in
535+ # `forward_backward_func` are no-ops.
536+ param_sync_func = config .param_sync_func
537+ config .param_sync_func = None
538+ pre_hook_enabled = False
507539
508540 self .call_event ('on_train_begin' )
509541 train_metrics = {}
@@ -517,8 +549,20 @@ def train(self, train_dataset, val_dataset):
517549 train_data_iterator , val_data_iterator = self ._prepare_data_iterator (train_dataset , val_dataset )
518550 while state .iteration < args .train_iters :
519551 self .call_event ('on_step_begin' )
520- metrics , grad_norm = self .train_step (train_data_iterator )
521552 maybe_finalize_async_save (args , blocking = False )
553+ metrics , grad_norm , update_successful = self .train_step (train_data_iterator )
554+ if state .iteration == start_iteration :
555+ if update_successful :
556+ # Enable forward pre-hook after training step has successfully run. All subsequent
557+ # forward passes will use the forward pre-hook / `param_sync_func` in
558+ # `forward_backward_func`.
559+ if should_disable_forward_pre_hook (args ):
560+ enable_forward_pre_hook (self .wrapped_models )
561+ config .param_sync_func = param_sync_func
562+ pre_hook_enabled = True
563+ else :
564+ start_iteration = state .iteration + 1
565+
522566 state .iteration += 1
523567 self .call_event ('on_step_end' )
524568 self ._aggregated_metrics (metrics , train_metrics )
@@ -538,16 +582,29 @@ def train(self, train_dataset, val_dataset):
538582 eval_metrics = None
539583 if state .should_eval :
540584 state .should_eval = False
585+ if should_disable_forward_pre_hook (args ):
586+ disable_forward_pre_hook (self .wrapped_models )
587+ pre_hook_enabled = False
541588 eval_metrics = self .evaluate (val_data_iterator )
542589 for m in self .wrapped_models :
543590 m .train ()
591+ if should_disable_forward_pre_hook (args ):
592+ enable_forward_pre_hook (self .wrapped_models )
593+ pre_hook_enabled = True
544594
545595 if state .should_save :
546596 self ._determine_best_metric (eval_metrics )
597+ if should_disable_forward_pre_hook (args ):
598+ disable_forward_pre_hook (self .wrapped_models )
547599 state .should_save = False
548600 self .save_checkpoint ()
601+ if should_disable_forward_pre_hook (args ):
602+ enable_forward_pre_hook (self .wrapped_models )
549603
550604 self .call_event ('on_train_end' )
605+ # Close out pre-hooks if using distributed optimizer and overlapped param gather.
606+ if pre_hook_enabled :
607+ disable_forward_pre_hook (self .wrapped_models )
551608 maybe_finalize_async_save (args , blocking = True , terminate = True )
552609
553610 def _determine_best_metric (self , metrics ) -> bool :
@@ -679,7 +736,7 @@ def evaluate(self, val_data_iterator):
679736 data_iterator = data_iterator ,
680737 model = self .wrapped_models ,
681738 num_microbatches = self .args .num_microbatches ,
682- seq_length = args .max_length ,
739+ seq_length = args .seq_length ,
683740 micro_batch_size = args .micro_batch_size ,
684741 forward_only = True ,
685742 )
@@ -713,16 +770,18 @@ def train_step(self, train_data_iterator):
713770 data_iterator = data_iterator ,
714771 model = self .wrapped_models ,
715772 num_microbatches = args .num_microbatches ,
716- seq_length = args .max_length ,
773+ seq_length = args .seq_length ,
717774 micro_batch_size = args .micro_batch_size ,
718775 forward_only = False ,
719776 )
720777
721- _ , grad_norm , _ = self .optimizer .step ()
778+ update_successful , grad_norm , _ = self .optimizer .step ()
779+ update_successful = logical_and_across_model_parallel_group (update_successful )
722780 grad_norm = reduce_max_stat_across_model_parallel_group (grad_norm )
723- self .opt_param_scheduler .step (increment = args .global_batch_size )
781+ if update_successful :
782+ self .opt_param_scheduler .step (increment = args .global_batch_size )
724783
725- return metrics , grad_norm
784+ return metrics , grad_norm , update_successful
726785
727786 def _aggregated_metrics (self , metrics , total_metrics ):
728787 if 'n_steps' not in total_metrics :
0 commit comments