Skip to content

Commit bed4d10

Browse files
authored
[bugfix] fix megatron lora TP all-reduce (modelscope#7911)
1 parent e6a6179 commit bed4d10

File tree

6 files changed

+27
-19
lines changed

6 files changed

+27
-19
lines changed

swift/megatron/model/gpt_bridge.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,9 @@ def _set_state_dict(self,
469469
if to_mcore:
470470
assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}'
471471
hf_weight = hf_state_dict[hf_key].load()
472-
if module_key in {'embedding.word_embeddings', 'output_layer'
473-
} and hf_weight.shape[0] < self.args.padded_vocab_size:
472+
if module_key in {
473+
'embedding.word_embeddings', 'output_layer'
474+
} and hf_weight.shape[0] < self.args.padded_vocab_size and self.args.task_type != 'seq_cls':
474475
hf_weight = F.pad(hf_weight, (0, 0, 0, self.args.padded_vocab_size - hf_weight.shape[0]))
475476
hf_scale_inv = None
476477
if f'{hf_key}_scale_inv' in hf_state_dict:
@@ -1295,10 +1296,10 @@ def _convert_post_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcor
12951296
lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model
12961297
if self.args.task_type != 'embedding':
12971298
if self.args.untie_embeddings_and_output_weights:
1298-
if not to_mcore or self.args.task_type in {'causal_lm', 'generative_reranker'}:
1299-
hf_lm_head_key = self.hf_lm_head_key
1300-
if self.args.task_type == 'seq_cls':
1301-
hf_lm_head_key = self.hf_score_key
1299+
hf_lm_head_key = self.hf_lm_head_key
1300+
if self.args.task_type == 'seq_cls':
1301+
hf_lm_head_key = self.hf_score_key
1302+
if not to_mcore or hf_lm_head_key in hf_state_dict:
13021303
self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, to_mcore)
13031304
elif to_mcore and lm_model.output_layer.weight is not None:
13041305
self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, self.hf_embed_key, to_mcore)

swift/megatron/model/gpt_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __init__(
142142
parallel_mode=None,
143143
skip_weight_param_allocation=False,
144144
)
145+
self.output_layer.weight.average_gradients_across_tp_domain = True
145146
elif args.task_type == 'embedding' and self.post_process:
146147
self.output_layer = None
147148

swift/megatron/model/mm_gpt_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from megatron.core import InferenceParams
77
from megatron.core.packed_seq_params import PackedSeqParams
8-
from megatron.core.tensor_parallel import VocabParallelEmbedding, scatter_to_sequence_parallel_region
8+
from megatron.core.tensor_parallel import VocabParallelEmbedding, reduce_scatter_to_sequence_parallel_region
99
from megatron.core.transformer.module import MegatronModule
1010
from megatron.core.transformer.spec_utils import ModuleSpec
1111
from megatron.core.transformer.transformer_config import TransformerConfig
@@ -70,7 +70,7 @@ def forward(_self, input_):
7070
if reduce_scatter_embeddings:
7171
res = res.transpose(0, 1).contiguous()
7272
group_kwargs = {'group': _self.tp_group} if mcore_013 else {}
73-
res = scatter_to_sequence_parallel_region(res, **group_kwargs)
73+
res = reduce_scatter_to_sequence_parallel_region(res, **group_kwargs) / args.tensor_model_parallel_size
7474
return res
7575

7676
VocabParallelEmbedding.forward = forward

swift/megatron/tuners/lora.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
TERowParallelGroupedLinear, TERowParallelLinear)
1717
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
1818
from megatron.core.parallel_state import get_expert_tensor_parallel_world_size, get_tensor_model_parallel_world_size
19-
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region
2019
from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
2120
from megatron.core.transformer.mlp import apply_swiglu_sharded_factory
2221
from megatron.core.transformer.module import MegatronModule
@@ -201,8 +200,10 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w
201200
**kwargs,
202201
)
203202
lora_b.parallel_mode = self.base_layer.parallel_mode # fix moe_shared_expert_overlap
204-
lora_a.sequence_parallel = False
205-
lora_b.sequence_parallel = False
203+
for lora in [lora_a, lora_b]:
204+
if getattr(lora, 'parallel_mode', None) is None and hasattr(lora, 'weight'): # TODO: experts
205+
sequence_parallel = True if isinstance(self.base_layer, TopKRouter) else self.sequence_parallel
206+
lora.weight.sequence_parallel = sequence_parallel
206207
self.lora_A[adapter_name] = lora_a
207208
self.lora_B[adapter_name] = lora_b
208209
if hasattr(self, 'lora_bias'):
@@ -341,8 +342,6 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
341342
else:
342343
raise ValueError(f'Unsupported base layer type: {type(self.base_layer)}')
343344
if not isinstance(self.base_layer, TopKRouter) and not self.disable_adapters and not self.merged:
344-
if self.sequence_parallel and self.base_layer.parallel_mode == 'column':
345-
x = gather_from_sequence_parallel_region(x)
346345
for active_adapter in self.active_adapters:
347346
if active_adapter not in self.lora_A.keys():
348347
continue
@@ -362,8 +361,6 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
362361
if isinstance(lora_result, tuple):
363362
lora_result = lora_result[0]
364363
lora_result = lora_result * scaling
365-
if self.sequence_parallel and self.base_layer.parallel_mode == 'row':
366-
lora_result = scatter_to_sequence_parallel_region(lora_result)
367364
result = result + lora_result
368365

369366
result = result.to(previous_dtype)

swift/megatron/utils/convert_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22

33
import math
4-
import os
5-
import shutil
64
from contextlib import contextmanager
75
from typing import Any, Dict
86

@@ -25,7 +23,7 @@ def _test_params_sum(model):
2523
for n, p in model.named_parameters():
2624
n_parameter += 1
2725
sum_ = p.to(device='cuda', dtype=torch.float32).abs().sum().cpu().item()
28-
if sum_ == 0:
26+
if sum_ == 0 and '.lora_B.' not in n:
2927
zero_count += 1
3028
logger.warning(f'n: {n}, sum: {sum_}')
3129
elif math.isnan(sum_) or math.isinf(sum_) or sum_ > 1e10:
@@ -200,7 +198,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float
200198
with torch.inference_mode(), _model_cpu_forward_context(
201199
mg_modules, torch_dtype, 'cuda', share_embedding=share_embedding, target_device=mg_device):
202200
mg_logits = forward_step_helper(mg_model, mg_inputs, dtype=torch_dtype)
203-
if args.tensor_model_parallel_size > 1:
201+
if args.tensor_model_parallel_size > 1 and args.task_type != 'seq_cls':
204202
from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region
205203
if mg_logits is not None:
206204
mg_logits = gather_from_tensor_model_parallel_region(mg_logits)

swift/megatron/utils/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default
1616
from megatron.training import checkpointing, get_args
1717
from packaging import version
18+
from peft.tuners.lora import Linear as LoraLinear
1819
from peft.utils.other import ModulesToSaveWrapper
1920
from torch import nn
2021

@@ -156,6 +157,7 @@ def new_deepcopy(x, *args, **kwargs):
156157

157158

158159
def prepare_adapter(model):
160+
from swift.megatron.tuners import LoraParallelLinear
159161
args = get_args()
160162
set_linear_is_expert(model)
161163
target_modules = get_target_modules(args, model)
@@ -179,6 +181,15 @@ def prepare_adapter(model):
179181
for n, p in model.named_parameters():
180182
if '.ref_adapter.' in n:
181183
p.requires_grad = False
184+
# setting average_gradients_across_tp_domain
185+
for m in model.modules():
186+
if isinstance(m, LoraLinear):
187+
# just check
188+
assert args.is_multimodal or args.hf_model_type == 'qwen3_next'
189+
assert not isinstance(m, LoraParallelLinear)
190+
for p in m.parameters():
191+
if p.requires_grad:
192+
p.average_gradients_across_tp_domain = True
182193
return model
183194

184195

0 commit comments

Comments
 (0)