@@ -379,7 +379,7 @@ index e21127b87..712793853 100755
379379 ),
380380 )
381381diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
382- index a1230568c..b45e63237 100644
382+ index a1230568c..1fd52f65a 100644
383383--- a/megatron/core/models/gpt/gpt_model.py
384384+++ b/megatron/core/models/gpt/gpt_model.py
385385@@ -446,6 +446,7 @@ class GPTModel(LanguageModule):
@@ -437,7 +437,7 @@ index a1230568c..b45e63237 100644
437437 for mtp_layer_number in range(self.config.mtp_num_layers):
438438 # Calc loss for the current Multi-Token Prediction (MTP) layers.
439439 mtp_labels, _ = roll_tensor(
440- @@ -595,17 +604,19 @@ class GPTModel(LanguageModule):
440+ @@ -595,7 +604,7 @@ class GPTModel(LanguageModule):
441441 sequence_parallel_enabled=self.output_layer.sequence_parallel,
442442 column_parallel_linear=self.output_layer,
443443 col_linear_kwargs={
@@ -446,28 +446,6 @@ index a1230568c..b45e63237 100644
446446 'runtime_gather_output': runtime_gather_output,
447447 },
448448 )
449-
450- mtp_loss = loss_mask * mtp_loss
451- + # Guard against division by zero when num_tokens is 0
452- + safe_num_tokens = max(num_tokens, 1)
453- if self.training:
454- # TODO(shifangx): remove the use of parallel_state here
455- # after moving loss logging to loss_func in pretrain_gpt.py
456- MTPLossLoggingHelper.save_loss_to_tracker(
457- - torch.sum(mtp_loss) / num_tokens,
458- + torch.sum(mtp_loss) / safe_num_tokens,
459- mtp_layer_number,
460- self.config.mtp_num_layers,
461- avg_group=parallel_state.get_data_parallel_group(
462- @@ -619,7 +630,7 @@ class GPTModel(LanguageModule):
463- )
464- else:
465- hidden_states = MTPLossAutoScaler.apply(
466- - hidden_states, mtp_loss_scale * mtp_loss / num_tokens
467- + hidden_states, mtp_loss_scale * mtp_loss / safe_num_tokens
468- )
469- sequence_parallel_override = False
470-
471449diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py
472450index 6e093f96f..eac21a3ea 100644
473451--- a/megatron/core/optimizer/distrib_optimizer.py
0 commit comments