@@ -430,35 +430,31 @@ def _postprocess(
430430 cp_group = self .cp_group ,
431431 packed_seq_params = packed_seq_params ,
432432 )
433- loss_mask , num_tokens = roll_tensor (
433+ loss_mask , _ = roll_tensor (
434434 loss_mask ,
435435 shifts = - 1 ,
436436 dims = - 1 ,
437437 cp_group = self .cp_group ,
438438 packed_seq_params = packed_seq_params ,
439439 )
440440 mtp_loss = self .compute_language_model_loss (mtp_labels , mtp_logits )
441- mtp_loss = loss_mask * mtp_loss
441+ loss_mask_ = (loss_mask & (mtp_labels != - 100 ))
442+ num_tokens = loss_mask_ .sum ()
443+ mtp_loss = loss_mask_ * mtp_loss
442444 if self .training :
443- # TODO(shifangx): remove the use of parallel_state here
444- # after moving loss logging to loss_func in pretrain_gpt.py
445+ mtp_loss_for_log = (
446+ torch . sum ( mtp_loss ) / num_tokens if num_tokens > 0 else mtp_loss . new_tensor ( 0.0 ))
445447 MTPLossLoggingHelper .save_loss_to_tracker (
446- torch . sum ( mtp_loss ) / num_tokens ,
448+ mtp_loss_for_log ,
447449 mtp_layer_number ,
448450 self .config .mtp_num_layers ,
449- avg_group = parallel_state .get_data_parallel_group (
450- with_context_parallel = True
451- ),
451+ avg_group = parallel_state .get_data_parallel_group (with_context_parallel = True ),
452452 )
453453 mtp_loss_scale = self .config .mtp_loss_scaling_factor / self .config .mtp_num_layers
454454 if self .config .calculate_per_token_loss :
455- hidden_states = MTPLossAutoScaler .apply (
456- hidden_states , mtp_loss_scale * mtp_loss
457- )
455+ hidden_states = MTPLossAutoScaler .apply (hidden_states , mtp_loss_scale * mtp_loss )
458456 else :
459- hidden_states = MTPLossAutoScaler .apply (
460- hidden_states , mtp_loss_scale * mtp_loss / num_tokens
461- )
457+ hidden_states = MTPLossAutoScaler .apply (hidden_states , mtp_loss_scale * mtp_loss / num_tokens )
462458 sequence_parallel_override = False
463459 if in_inference_mode and inference_context .materialize_only_last_token_logits :
464460 if inference_context .is_static_batching ():
0 commit comments