Skip to content

Commit 86f68d2

Browse files
authored
[megatron] fix megatron overlap_grad_reduce/overlap_param_gather (modelscope#8079)
1 parent 0ea5209 commit 86f68d2

File tree

10 files changed

+136
-32
lines changed

10 files changed

+136
-32
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ Running Environment:
134134
|--------------|--------------|---------------------|-------------------------------------------|
135135
| python | >=3.9 | 3.10/3.11 | |
136136
| cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
137-
| torch | >=2.0 | 2.8.0/2.9.1 | |
137+
| torch | >=2.0 | 2.8.0/2.9.1 | torch2.9 [conv3d slow](https://swift.readthedocs.io/en/latest/BestPractices/Qwen3-VL-Best-Practice.html#environment-setup) |
138138
| transformers | >=4.33 | 4.57.6 | |
139139
| modelscope | >=1.23 | | |
140140
| peft | >=0.11,<0.19 | | |

README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ pip install -e .
129129
|--------------|--------------|---------------------|--------------------|
130130
| python | >=3.9 | 3.10/3.11 | |
131131
| cuda | | cuda12 | 使用cpu、npu、mps则无需安装 |
132-
| torch | >=2.0 | 2.8.0/2.9.1 | |
132+
| torch | >=2.0 | 2.8.0/2.9.1 | torch2.9 [conv3d 缓慢](https://swift.readthedocs.io/zh-cn/latest/BestPractices/Qwen3-VL-Best-Practice.html#id1) |
133133
| transformers | >=4.33 | 4.57.6 | |
134134
| modelscope | >=1.23 | | |
135135
| peft | >=0.11,<0.19 | | |

docs/source/GetStarted/SWIFT-installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2
127127
|--------------|--------------|---------------------|--------------------|
128128
| python | >=3.9 | 3.10/3.11 | |
129129
| cuda | | cuda12 | 使用cpu、npu、mps则无需安装 |
130-
| torch | >=2.0 | 2.8.0/2.9.1 | |
130+
| torch | >=2.0 | 2.8.0/2.9.1 | torch2.9 [conv3d 缓慢](https://swift.readthedocs.io/zh-cn/latest/BestPractices/Qwen3-VL-Best-Practice.html#id1) |
131131
| transformers | >=4.33 | 4.57.6 | |
132132
| modelscope | >=1.23 | | |
133133
| peft | >=0.11,<0.19 | | |

docs/source_en/GetStarted/SWIFT-installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ More images can be found [here](https://modelscope.cn/docs/intro/environment-set
126126
|--------------|--------------|---------------------|-------------------------------------------|
127127
| python | >=3.9 | 3.10/3.11 | |
128128
| cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
129-
| torch | >=2.0 | 2.8.0/2.9.1 | |
129+
| torch | >=2.0 | 2.8.0/2.9.1 | torch2.9 [conv3d slow](https://swift.readthedocs.io/en/latest/BestPractices/Qwen3-VL-Best-Practice.html#environment-setup) |
130130
| transformers | >=4.33 | 4.57.6 | |
131131
| modelscope | >=1.23 | | |
132132
| peft | >=0.11,<0.19 | | |

swift/megatron/arguments/megatron_args.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,11 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
426426

427427
sequence_parallel: bool = False
428428
context_parallel_size: int = 1
429-
tp_comm_overlap: bool = False # TODO
430-
overlap_grad_reduce: bool = False # TODO
431-
overlap_param_gather: bool = False # TODO
429+
tp_comm_overlap: bool = False
430+
overlap_grad_reduce: bool = False
431+
overlap_param_gather: bool = False
432+
overlap_param_gather_with_optimizer_step: bool = False
433+
align_grad_reduce: bool = True
432434
virtual_pipeline_model_parallel_size: Optional[int] = None
433435
microbatch_group_size_per_vp_stage: Optional[int] = None
434436
pipeline_model_parallel_layout: Optional[str] = None

swift/megatron/arguments/megatron_base_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __post_init__(self):
2020
if self.packing:
2121
self.padding_free = True
2222
BaseArguments.__post_init__(self)
23+
self.seq_length = self.packing_length or self.max_length
2324
self._init_megatron_args()
2425
if self.streaming:
2526
if self.dataloader_num_workers > 1:

swift/megatron/model/gpt_bridge.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,6 @@ def _set_moe_state(
674674
hf_prefix: str,
675675
layer_idx: int,
676676
to_mcore: bool,
677-
is_mtp_layer: bool = False,
678677
):
679678
if to_mcore:
680679
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
@@ -727,14 +726,14 @@ def _set_moe_state(
727726
layer_idx,
728727
to_mcore,
729728
ep_rank=ep_rank,
730-
is_mtp_layer=is_mtp_layer))
729+
))
731730
if to_mcore:
732731
hf_state_dict = {}
733732
else:
734733
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
735734
return hf_state_dict
736735

737-
def _get_hf_grouped(self, is_mtp_layer: bool = False):
736+
def _get_hf_grouped(self):
738737
if self.model_type in {
739738
'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe',
740739
'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe',
@@ -758,7 +757,6 @@ def _set_mlp_state(
758757
to_mcore: bool,
759758
ep_rank: Optional[int] = None,
760759
hf_mlp=None,
761-
is_mtp_layer: bool = False,
762760
):
763761
if to_mcore:
764762
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
@@ -786,7 +784,7 @@ def _set_mlp_state(
786784
is_gate_up = hasattr(hf_mlp, 'gate_up_proj')
787785
# transformers 5.0 compatibility
788786
if self.is_transformers_5 and not to_mcore and is_expert:
789-
_hf_grouped, _is_gate_up = self._get_hf_grouped(is_mtp_layer)
787+
_hf_grouped, _is_gate_up = self._get_hf_grouped()
790788
if _hf_grouped is not None:
791789
hf_grouped = _hf_grouped
792790
if _is_gate_up is not None:
@@ -1303,15 +1301,13 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo
13031301
'input_layernorm.weight', to_mcore)
13041302
return hf_state_dict
13051303

1306-
def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp_layer: bool = False):
1304+
def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool):
13071305
hf_mlp_prefix = self.get_hf_mlp_prefix(layer_idx)
13081306
hf_mlp = self._get_hf_mlp(layer_idx)
13091307
is_moe = self._is_moe(hf_mlp.state_dict())
13101308
mg_mlp = None if mg_layer is None else mg_layer.mlp
13111309
if is_moe:
1312-
hf_state_dict.update(
1313-
self._set_moe_state(
1314-
mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp_layer=is_mtp_layer))
1310+
hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore))
13151311
self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight',
13161312
to_mcore)
13171313
else:
@@ -1503,7 +1499,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx:
15031499
self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight',
15041500
to_mcore)
15051501
hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore))
1506-
hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore, is_mtp_layer=True))
1502+
hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore))
15071503
if to_mcore:
15081504
hf_state_dict = {}
15091505
else:

swift/megatron/trainers/base.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from contextlib import contextmanager, nullcontext
1212
from functools import partial
1313
from megatron.core import mpu
14+
from megatron.core.distributed import DistributedDataParallel as DDP
1415
from megatron.core.distributed import finalize_model_grads
1516
from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
1617
from megatron.core.pipeline_parallel import get_forward_backward_func
@@ -26,10 +27,12 @@
2627
from swift.megatron.callbacks import megatron_callbacks_map
2728
from swift.megatron.model import get_mcore_model
2829
from swift.megatron.tuners import LoraParallelLinear
29-
from swift.megatron.utils import (copy_original_module_weight, get_optimizer_param_scheduler, get_padding_to,
30-
init_persistent_async_worker, load_mcore_checkpoint, maybe_finalize_async_save,
30+
from swift.megatron.utils import (copy_original_module_weight, disable_forward_pre_hook, enable_forward_pre_hook,
31+
get_optimizer_param_scheduler, get_padding_to, init_persistent_async_worker,
32+
initialize_tp_communicators, load_mcore_checkpoint,
33+
logical_and_across_model_parallel_group, maybe_finalize_async_save,
3134
prepare_mcore_model, reduce_max_stat_across_model_parallel_group,
32-
save_mcore_checkpoint, wrap_model)
35+
save_mcore_checkpoint, should_disable_forward_pre_hook, wrap_model)
3336
from swift.template import Template
3437
from swift.trainers import dynamic_gradient_checkpointing
3538
from swift.trainers.utils import patch_modelscope_hub_timeout
@@ -85,6 +88,9 @@ def __init__(self, args, template: Template):
8588
for callback in args.callbacks:
8689
self.callbacks.append(megatron_callbacks_map[callback](self))
8790

91+
if args.tp_comm_overlap:
92+
initialize_tp_communicators(args, self.config)
93+
8894
if args.async_save and args.use_persistent_ckpt_worker:
8995
init_persistent_async_worker()
9096

@@ -503,7 +509,33 @@ def train(self, train_dataset, val_dataset):
503509
self._prepare_vit_gradient_checkpointing(m)
504510

505511
config.grad_scale_func = self.optimizer.scale_loss
512+
if isinstance(self.wrapped_models[0], DDP) and args.overlap_grad_reduce:
513+
assert config.no_sync_func is None, ('When overlap_grad_reduce is True, config.no_sync_func must be None; '
514+
'a custom no_sync_func is not supported when overlapping grad-reduce')
515+
config.no_sync_func = [model_chunk.no_sync for model_chunk in self.wrapped_models]
516+
if len(self.wrapped_models) == 1:
517+
config.no_sync_func = config.no_sync_func[0]
518+
if args.align_grad_reduce:
519+
config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.wrapped_models]
520+
if len(self.wrapped_models) == 1:
521+
config.grad_sync_func = config.grad_sync_func[0]
522+
if args.overlap_param_gather and args.align_param_gather:
523+
config.param_sync_func = [model_chunk.start_param_sync for model_chunk in self.wrapped_models]
524+
if len(self.wrapped_models) == 1:
525+
config.param_sync_func = config.param_sync_func[0]
506526
config.finalize_model_grads_func = finalize_model_grads
527+
start_iteration = state.iteration
528+
pre_hook_enabled = False
529+
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
530+
# or random initialization don't propagate to all ranks in first all-gather (which is a
531+
# no-op if things work correctly).
532+
if should_disable_forward_pre_hook(args):
533+
disable_forward_pre_hook(self.wrapped_models, param_sync=False)
534+
# Also remove param_sync_func temporarily so that sync calls made in
535+
# `forward_backward_func` are no-ops.
536+
param_sync_func = config.param_sync_func
537+
config.param_sync_func = None
538+
pre_hook_enabled = False
507539

508540
self.call_event('on_train_begin')
509541
train_metrics = {}
@@ -517,8 +549,20 @@ def train(self, train_dataset, val_dataset):
517549
train_data_iterator, val_data_iterator = self._prepare_data_iterator(train_dataset, val_dataset)
518550
while state.iteration < args.train_iters:
519551
self.call_event('on_step_begin')
520-
metrics, grad_norm = self.train_step(train_data_iterator)
521552
maybe_finalize_async_save(args, blocking=False)
553+
metrics, grad_norm, update_successful = self.train_step(train_data_iterator)
554+
if state.iteration == start_iteration:
555+
if update_successful:
556+
# Enable forward pre-hook after training step has successfully run. All subsequent
557+
# forward passes will use the forward pre-hook / `param_sync_func` in
558+
# `forward_backward_func`.
559+
if should_disable_forward_pre_hook(args):
560+
enable_forward_pre_hook(self.wrapped_models)
561+
config.param_sync_func = param_sync_func
562+
pre_hook_enabled = True
563+
else:
564+
start_iteration = state.iteration + 1
565+
522566
state.iteration += 1
523567
self.call_event('on_step_end')
524568
self._aggregated_metrics(metrics, train_metrics)
@@ -538,16 +582,29 @@ def train(self, train_dataset, val_dataset):
538582
eval_metrics = None
539583
if state.should_eval:
540584
state.should_eval = False
585+
if should_disable_forward_pre_hook(args):
586+
disable_forward_pre_hook(self.wrapped_models)
587+
pre_hook_enabled = False
541588
eval_metrics = self.evaluate(val_data_iterator)
542589
for m in self.wrapped_models:
543590
m.train()
591+
if should_disable_forward_pre_hook(args):
592+
enable_forward_pre_hook(self.wrapped_models)
593+
pre_hook_enabled = True
544594

545595
if state.should_save:
546596
self._determine_best_metric(eval_metrics)
597+
if should_disable_forward_pre_hook(args):
598+
disable_forward_pre_hook(self.wrapped_models)
547599
state.should_save = False
548600
self.save_checkpoint()
601+
if should_disable_forward_pre_hook(args):
602+
enable_forward_pre_hook(self.wrapped_models)
549603

550604
self.call_event('on_train_end')
605+
# Close out pre-hooks if using distributed optimizer and overlapped param gather.
606+
if pre_hook_enabled:
607+
disable_forward_pre_hook(self.wrapped_models)
551608
maybe_finalize_async_save(args, blocking=True, terminate=True)
552609

553610
def _determine_best_metric(self, metrics) -> bool:
@@ -679,7 +736,7 @@ def evaluate(self, val_data_iterator):
679736
data_iterator=data_iterator,
680737
model=self.wrapped_models,
681738
num_microbatches=self.args.num_microbatches,
682-
seq_length=args.max_length,
739+
seq_length=args.seq_length,
683740
micro_batch_size=args.micro_batch_size,
684741
forward_only=True,
685742
)
@@ -713,16 +770,18 @@ def train_step(self, train_data_iterator):
713770
data_iterator=data_iterator,
714771
model=self.wrapped_models,
715772
num_microbatches=args.num_microbatches,
716-
seq_length=args.max_length,
773+
seq_length=args.seq_length,
717774
micro_batch_size=args.micro_batch_size,
718775
forward_only=False,
719776
)
720777

721-
_, grad_norm, _ = self.optimizer.step()
778+
update_successful, grad_norm, _ = self.optimizer.step()
779+
update_successful = logical_and_across_model_parallel_group(update_successful)
722780
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
723-
self.opt_param_scheduler.step(increment=args.global_batch_size)
781+
if update_successful:
782+
self.opt_param_scheduler.step(increment=args.global_batch_size)
724783

725-
return metrics, grad_norm
784+
return metrics, grad_norm, update_successful
726785

727786
def _aggregated_metrics(self, metrics, total_metrics):
728787
if 'n_steps' not in total_metrics:

swift/megatron/utils/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22

33
from .convert_utils import test_convert_precision
4-
from .megatron_lm_utils import (get_optimizer_param_scheduler, init_persistent_async_worker, initialize_megatron,
4+
from .megatron_lm_utils import (disable_forward_pre_hook, enable_forward_pre_hook, get_optimizer_param_scheduler,
5+
init_persistent_async_worker, initialize_megatron, initialize_tp_communicators,
56
load_mcore_checkpoint, maybe_finalize_async_save, save_mcore_checkpoint,
6-
set_random_seed, unwrap_model, wrap_model)
7+
set_random_seed, should_disable_forward_pre_hook, unwrap_model, wrap_model)
78
from .parallel_utils import (logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group,
89
split_cp_inputs)
910
from .patcher import patch_merge_fn, patch_torch_dist_shard

0 commit comments

Comments
 (0)