Skip to content

Commit f3f94ad

Browse files
committed
revert: Remove MTP loss div-by-zero guard from Megatron patch
1 parent d00d246 commit f3f94ad

File tree

1 file changed

+2
-24
lines changed

1 file changed

+2
-24
lines changed

docker/patch/v0.5.7/megatron.patch

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ index e21127b87..712793853 100755
379379
),
380380
)
381381
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
382-
index a1230568c..b45e63237 100644
382+
index a1230568c..1fd52f65a 100644
383383
--- a/megatron/core/models/gpt/gpt_model.py
384384
+++ b/megatron/core/models/gpt/gpt_model.py
385385
@@ -446,6 +446,7 @@ class GPTModel(LanguageModule):
@@ -437,7 +437,7 @@ index a1230568c..b45e63237 100644
437437
for mtp_layer_number in range(self.config.mtp_num_layers):
438438
# Calc loss for the current Multi-Token Prediction (MTP) layers.
439439
mtp_labels, _ = roll_tensor(
440-
@@ -595,17 +604,19 @@ class GPTModel(LanguageModule):
440+
@@ -595,7 +604,7 @@ class GPTModel(LanguageModule):
441441
sequence_parallel_enabled=self.output_layer.sequence_parallel,
442442
column_parallel_linear=self.output_layer,
443443
col_linear_kwargs={
@@ -446,28 +446,6 @@ index a1230568c..b45e63237 100644
446446
'runtime_gather_output': runtime_gather_output,
447447
},
448448
)
449-
450-
mtp_loss = loss_mask * mtp_loss
451-
+ # Guard against division by zero when num_tokens is 0
452-
+ safe_num_tokens = max(num_tokens, 1)
453-
if self.training:
454-
# TODO(shifangx): remove the use of parallel_state here
455-
# after moving loss logging to loss_func in pretrain_gpt.py
456-
MTPLossLoggingHelper.save_loss_to_tracker(
457-
- torch.sum(mtp_loss) / num_tokens,
458-
+ torch.sum(mtp_loss) / safe_num_tokens,
459-
mtp_layer_number,
460-
self.config.mtp_num_layers,
461-
avg_group=parallel_state.get_data_parallel_group(
462-
@@ -619,7 +630,7 @@ class GPTModel(LanguageModule):
463-
)
464-
else:
465-
hidden_states = MTPLossAutoScaler.apply(
466-
- hidden_states, mtp_loss_scale * mtp_loss / num_tokens
467-
+ hidden_states, mtp_loss_scale * mtp_loss / safe_num_tokens
468-
)
469-
sequence_parallel_override = False
470-
471449
diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py
472450
index 6e093f96f..eac21a3ea 100644
473451
--- a/megatron/core/optimizer/distrib_optimizer.py

0 commit comments

Comments
 (0)