Skip to content

Commit e3ce224

Browse files
committed
fix
1 parent 011892e commit e3ce224

1 file changed

Lines changed: 10 additions & 14 deletions

File tree

src/mcore_bridge/model/gpt_model.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -430,35 +430,31 @@ def _postprocess(
430430
cp_group=self.cp_group,
431431
packed_seq_params=packed_seq_params,
432432
)
433-
loss_mask, num_tokens = roll_tensor(
433+
loss_mask, _ = roll_tensor(
434434
loss_mask,
435435
shifts=-1,
436436
dims=-1,
437437
cp_group=self.cp_group,
438438
packed_seq_params=packed_seq_params,
439439
)
440440
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
441-
mtp_loss = loss_mask * mtp_loss
441+
loss_mask_ = (loss_mask & (mtp_labels != -100))
442+
num_tokens = loss_mask_.sum()
443+
mtp_loss = loss_mask_ * mtp_loss
442444
if self.training:
443-
# TODO(shifangx): remove the use of parallel_state here
444-
# after moving loss logging to loss_func in pretrain_gpt.py
445+
mtp_loss_for_log = (
446+
torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0))
445447
MTPLossLoggingHelper.save_loss_to_tracker(
446-
torch.sum(mtp_loss) / num_tokens,
448+
mtp_loss_for_log,
447449
mtp_layer_number,
448450
self.config.mtp_num_layers,
449-
avg_group=parallel_state.get_data_parallel_group(
450-
with_context_parallel=True
451-
),
451+
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
452452
)
453453
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
454454
if self.config.calculate_per_token_loss:
455-
hidden_states = MTPLossAutoScaler.apply(
456-
hidden_states, mtp_loss_scale * mtp_loss
457-
)
455+
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
458456
else:
459-
hidden_states = MTPLossAutoScaler.apply(
460-
hidden_states, mtp_loss_scale * mtp_loss / num_tokens
461-
)
457+
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
462458
sequence_parallel_override = False
463459
if in_inference_mode and inference_context.materialize_only_last_token_logits:
464460
if inference_context.is_static_batching():

0 commit comments

Comments
 (0)