Skip to content

Commit d00d246

Browse files
committed
chore: Update Megatron patch with MTP loss div-by-zero guard
Guard against division by zero in MTP loss computation when num_tokens is 0, which can happen with context parallelism when one CP rank has no response tokens after label rolling.
1 parent 7ba4af8 commit d00d246

File tree

4 files changed

+29
-24
lines changed

4 files changed

+29
-24
lines changed

docker/patch/v0.5.7/megatron.patch

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ index e21127b87..712793853 100755
379379
),
380380
)
381381
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
382-
index a1230568c..1fd52f65a 100644
382+
index a1230568c..b45e63237 100644
383383
--- a/megatron/core/models/gpt/gpt_model.py
384384
+++ b/megatron/core/models/gpt/gpt_model.py
385385
@@ -446,6 +446,7 @@ class GPTModel(LanguageModule):
@@ -437,7 +437,7 @@ index a1230568c..1fd52f65a 100644
437437
for mtp_layer_number in range(self.config.mtp_num_layers):
438438
# Calc loss for the current Multi-Token Prediction (MTP) layers.
439439
mtp_labels, _ = roll_tensor(
440-
@@ -595,7 +604,7 @@ class GPTModel(LanguageModule):
440+
@@ -595,17 +604,19 @@ class GPTModel(LanguageModule):
441441
sequence_parallel_enabled=self.output_layer.sequence_parallel,
442442
column_parallel_linear=self.output_layer,
443443
col_linear_kwargs={
@@ -446,6 +446,28 @@ index a1230568c..1fd52f65a 100644
446446
'runtime_gather_output': runtime_gather_output,
447447
},
448448
)
449+
450+
mtp_loss = loss_mask * mtp_loss
451+
+ # Guard against division by zero when num_tokens is 0
452+
+ safe_num_tokens = max(num_tokens, 1)
453+
if self.training:
454+
# TODO(shifangx): remove the use of parallel_state here
455+
# after moving loss logging to loss_func in pretrain_gpt.py
456+
MTPLossLoggingHelper.save_loss_to_tracker(
457+
- torch.sum(mtp_loss) / num_tokens,
458+
+ torch.sum(mtp_loss) / safe_num_tokens,
459+
mtp_layer_number,
460+
self.config.mtp_num_layers,
461+
avg_group=parallel_state.get_data_parallel_group(
462+
@@ -619,7 +630,7 @@ class GPTModel(LanguageModule):
463+
)
464+
else:
465+
hidden_states = MTPLossAutoScaler.apply(
466+
- hidden_states, mtp_loss_scale * mtp_loss / num_tokens
467+
+ hidden_states, mtp_loss_scale * mtp_loss / safe_num_tokens
468+
)
469+
sequence_parallel_override = False
470+
449471
diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py
450472
index 6e093f96f..eac21a3ea 100644
451473
--- a/megatron/core/optimizer/distrib_optimizer.py

slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,7 @@ def _convert_mtp_layer(args, name, param, layer_idx):
5353
if "final_layernorm.weight" in name:
5454
return [("mtp.norm.weight", param)]
5555
if "eh_proj.weight" in name:
56-
if param.dim() < 2:
57-
raise ValueError(f"eh_proj weight expects 2D tensor, got {param.shape}")
58-
first_half, second_half = param.chunk(2, dim=1)
59-
new_param = torch.cat([second_half, first_half], dim=1)
60-
return [("mtp.fc.weight", new_param)]
56+
return [("mtp.fc.weight", param)]
6157

6258
# MTP inner transformer layers (keep layer index)
6359
if "transformer_layer" in name:

slime/backends/megatron_utils/model_provider.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,7 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage
194194
if vp_stage is not None:
195195
mtp_kwargs["vp_stage"] = vp_stage
196196

197-
from dataclasses import replace
198-
199-
mtp_config = replace(config, use_gated_attention=True)
200-
object.__setattr__(config, "mtp_config", mtp_config)
201-
mtp_block_spec = get_gpt_mtp_block_spec(mtp_config, transformer_layer_spec, **mtp_kwargs)
197+
mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, **mtp_kwargs)
202198
kwargs["mtp_block_spec"] = mtp_block_spec
203199

204200
with build_model_context(**build_model_context_args):

slime_plugins/mbridge/qwen3_next.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,10 @@ class Qwen3NextBridge(Qwen2MoEBridge):
4747
)
4848

4949
def _get_gptmodel_args(self) -> dict:
50-
"""Override to add MTP block spec with gated attention config."""
51-
from copy import deepcopy
52-
50+
"""Override to add MTP block spec."""
5351
ret = super()._get_gptmodel_args()
5452
if getattr(self.config, "mtp_num_layers", None) is not None:
55-
mtp_config = deepcopy(self.config)
56-
mtp_config.use_gated_attention = True
57-
mtp_block_spec = get_gpt_mtp_block_spec(mtp_config, mtp_config, use_transformer_engine=True)
53+
mtp_block_spec = get_gpt_mtp_block_spec(self.config, self.config, use_transformer_engine=True)
5854
ret["mtp_block_spec"] = mtp_block_spec
5955
return ret
6056

@@ -171,17 +167,11 @@ def _weight_to_mcore_format(
171167
return qgkv
172168

173169
weight = super()._weight_to_mcore_format(mcore_weights_name, hf_weights)
174-
if mcore_weights_name.endswith("eh_proj.weight"):
175-
first_half, second_half = weight.chunk(2, dim=1)
176-
weight = torch.cat([second_half, first_half], dim=1)
177170
return weight
178171

179172
def _weight_to_hf_format(
180173
self, mcore_weights_name: str, mcore_weights: torch.Tensor
181174
) -> tuple[list[str], list[torch.Tensor]]:
182-
if mcore_weights_name.endswith("eh_proj.weight"):
183-
first_half, second_half = mcore_weights.chunk(2, dim=1)
184-
mcore_weights = torch.cat([second_half, first_half], dim=1)
185175
return super()._weight_to_hf_format(mcore_weights_name, mcore_weights)
186176

187177
def _build_config(self):
@@ -211,5 +201,6 @@ def _build_config(self):
211201
# Qwen3 Next specific
212202
attention_output_gate=True,
213203
moe_shared_expert_gate=True,
204+
use_gated_attention=True,
214205
**mtp_args,
215206
)

0 commit comments

Comments
 (0)