@@ -391,6 +391,8 @@ def _postprocess(
391391 output_weight = None
392392 if self .share_embeddings_and_output_weights :
393393 output_weight = self .shared_embedding_or_output_weight ()
394+ if self .config .is_multimodal and self .config .context_parallel_size > 1 :
395+ input_ids = split_cp_inputs (input_ids , getattr (packed_seq_params , 'cu_seqlens_q' , None ), 1 )
394396
395397 if self .mtp_process :
396398 hidden_states = self .mtp (
@@ -407,55 +409,56 @@ def _postprocess(
407409 embedding = self .embedding ,
408410 ** (extra_block_kwargs or {}),
409411 )
412+ mtp_labels = labels .clone ()
410413 hidden_states_list = torch .chunk (hidden_states , 1 + self .config .mtp_num_layers , dim = 0 )
411414 hidden_states = hidden_states_list [0 ]
412-
413- if labels is not None :
414- mtp_labels = labels .clone ()
415- if loss_mask is None :
416- # if loss_mask is not provided, use all ones as loss_mask
417- if packed_seq_params is None :
418- loss_mask = torch .ones_like (mtp_labels )
419- else :
420- loss_mask = mtp_labels .new_ones ((1 , packed_seq_params .cu_seqlens_q [- 1 ]))
421- cu_seqlens = packed_seq_params .cu_seqlens_q if packed_seq_params is not None else None
422- for mtp_layer_number in range (self .config .mtp_num_layers ):
423- # output
424- mtp_logits , _ = self .output_layer (
425- hidden_states_list [mtp_layer_number + 1 ],
426- weight = output_weight ,
427- runtime_gather_output = runtime_gather_output ,
415+ if loss_mask is None :
416+ # if loss_mask is not provided, use all ones as loss_mask
417+ loss_mask = torch .ones_like (mtp_labels )
418+ for mtp_layer_number in range (self .config .mtp_num_layers ):
419+ # output
420+ mtp_logits , _ = self .output_layer (
421+ hidden_states_list [mtp_layer_number + 1 ],
422+ weight = output_weight ,
423+ runtime_gather_output = runtime_gather_output ,
424+ )
425+ # Calc loss for the current Multi-Token Prediction (MTP) layers.
426+ mtp_labels , _ = roll_tensor (
427+ mtp_labels ,
428+ shifts = - 1 ,
429+ dims = - 1 ,
430+ cp_group = self .cp_group ,
431+ packed_seq_params = packed_seq_params ,
432+ )
433+ loss_mask , num_tokens = roll_tensor (
434+ loss_mask ,
435+ shifts = - 1 ,
436+ dims = - 1 ,
437+ cp_group = self .cp_group ,
438+ packed_seq_params = packed_seq_params ,
439+ )
440+ mtp_loss = self .compute_language_model_loss (mtp_labels , mtp_logits )
441+ mtp_loss = loss_mask * mtp_loss
442+ if self .training :
443+ # TODO(shifangx): remove the use of parallel_state here
444+ # after moving loss logging to loss_func in pretrain_gpt.py
445+ MTPLossLoggingHelper .save_loss_to_tracker (
446+ torch .sum (mtp_loss ) / num_tokens ,
447+ mtp_layer_number ,
448+ self .config .mtp_num_layers ,
449+ avg_group = parallel_state .get_data_parallel_group (
450+ with_context_parallel = True
451+ ),
452+ )
453+ mtp_loss_scale = self .config .mtp_loss_scaling_factor / self .config .mtp_num_layers
454+ if self .config .calculate_per_token_loss :
455+ hidden_states = MTPLossAutoScaler .apply (
456+ hidden_states , mtp_loss_scale * mtp_loss
457+ )
458+ else :
459+ hidden_states = MTPLossAutoScaler .apply (
460+ hidden_states , mtp_loss_scale * mtp_loss / num_tokens
428461 )
429- # Calc loss for the current Multi-Token Prediction (MTP) layers.
430- mtp_labels , _ = roll_tensor (mtp_labels , shifts = - 1 , dims = - 1 , cp_group = self .cp_group )
431- if cu_seqlens is None :
432- loss_mask , _ = roll_tensor (loss_mask , shifts = - 1 , dims = - 1 , cp_group = self .cp_group )
433- loss_mask_ = loss_mask
434- else :
435- loss_mask [:, cu_seqlens [:- 1 ]] = 0
436- loss_mask , _ = roll_tensor (loss_mask , shifts = - 1 , dims = - 1 )
437- if self .config .context_parallel_size > 1 :
438- loss_mask_ = split_cp_inputs (loss_mask , cu_seqlens , dim = 1 )
439- else :
440- loss_mask_ = loss_mask .clone ()
441- mtp_loss = self .compute_language_model_loss (mtp_labels , mtp_logits )
442- loss_mask_ = loss_mask_ & (mtp_labels != - 100 )
443- mtp_loss = loss_mask_ * mtp_loss
444- num_tokens = loss_mask_ .sum ()
445- if self .training :
446- mtp_loss_for_log = (
447- torch .sum (mtp_loss ) / num_tokens if num_tokens > 0 else mtp_loss .new_tensor (0.0 ))
448- MTPLossLoggingHelper .save_loss_to_tracker (
449- mtp_loss_for_log ,
450- mtp_layer_number ,
451- self .config .mtp_num_layers ,
452- avg_group = parallel_state .get_data_parallel_group (with_context_parallel = True ),
453- )
454- mtp_loss_scale = self .config .mtp_loss_scaling_factor / self .config .mtp_num_layers
455- if self .config .calculate_per_token_loss :
456- hidden_states = MTPLossAutoScaler .apply (hidden_states , mtp_loss_scale * mtp_loss )
457- else :
458- hidden_states = MTPLossAutoScaler .apply (hidden_states , mtp_loss_scale * mtp_loss / num_tokens )
459462 sequence_parallel_override = False
460463 if in_inference_mode and inference_context .materialize_only_last_token_logits :
461464 if inference_context .is_static_batching ():
0 commit comments