Skip to content

Commit 011892e

Browse files
committed
clear mtp code
1 parent 641dc17 commit 011892e

1 file changed

Lines changed: 49 additions & 46 deletions

File tree

src/mcore_bridge/model/gpt_model.py

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ def _postprocess(
391391
output_weight = None
392392
if self.share_embeddings_and_output_weights:
393393
output_weight = self.shared_embedding_or_output_weight()
394+
if self.config.is_multimodal and self.config.context_parallel_size > 1:
395+
input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)
394396

395397
if self.mtp_process:
396398
hidden_states = self.mtp(
@@ -407,55 +409,56 @@ def _postprocess(
407409
embedding=self.embedding,
408410
**(extra_block_kwargs or {}),
409411
)
412+
mtp_labels = labels.clone()
410413
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
411414
hidden_states = hidden_states_list[0]
412-
413-
if labels is not None:
414-
mtp_labels = labels.clone()
415-
if loss_mask is None:
416-
# if loss_mask is not provided, use all ones as loss_mask
417-
if packed_seq_params is None:
418-
loss_mask = torch.ones_like(mtp_labels)
419-
else:
420-
loss_mask = mtp_labels.new_ones((1, packed_seq_params.cu_seqlens_q[-1]))
421-
cu_seqlens = packed_seq_params.cu_seqlens_q if packed_seq_params is not None else None
422-
for mtp_layer_number in range(self.config.mtp_num_layers):
423-
# output
424-
mtp_logits, _ = self.output_layer(
425-
hidden_states_list[mtp_layer_number + 1],
426-
weight=output_weight,
427-
runtime_gather_output=runtime_gather_output,
415+
if loss_mask is None:
416+
# if loss_mask is not provided, use all ones as loss_mask
417+
loss_mask = torch.ones_like(mtp_labels)
418+
for mtp_layer_number in range(self.config.mtp_num_layers):
419+
# output
420+
mtp_logits, _ = self.output_layer(
421+
hidden_states_list[mtp_layer_number + 1],
422+
weight=output_weight,
423+
runtime_gather_output=runtime_gather_output,
424+
)
425+
# Calc loss for the current Multi-Token Prediction (MTP) layers.
426+
mtp_labels, _ = roll_tensor(
427+
mtp_labels,
428+
shifts=-1,
429+
dims=-1,
430+
cp_group=self.cp_group,
431+
packed_seq_params=packed_seq_params,
432+
)
433+
loss_mask, num_tokens = roll_tensor(
434+
loss_mask,
435+
shifts=-1,
436+
dims=-1,
437+
cp_group=self.cp_group,
438+
packed_seq_params=packed_seq_params,
439+
)
440+
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
441+
mtp_loss = loss_mask * mtp_loss
442+
if self.training:
443+
# TODO(shifangx): remove the use of parallel_state here
444+
# after moving loss logging to loss_func in pretrain_gpt.py
445+
MTPLossLoggingHelper.save_loss_to_tracker(
446+
torch.sum(mtp_loss) / num_tokens,
447+
mtp_layer_number,
448+
self.config.mtp_num_layers,
449+
avg_group=parallel_state.get_data_parallel_group(
450+
with_context_parallel=True
451+
),
452+
)
453+
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
454+
if self.config.calculate_per_token_loss:
455+
hidden_states = MTPLossAutoScaler.apply(
456+
hidden_states, mtp_loss_scale * mtp_loss
457+
)
458+
else:
459+
hidden_states = MTPLossAutoScaler.apply(
460+
hidden_states, mtp_loss_scale * mtp_loss / num_tokens
428461
)
429-
# Calc loss for the current Multi-Token Prediction (MTP) layers.
430-
mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
431-
if cu_seqlens is None:
432-
loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group)
433-
loss_mask_ = loss_mask
434-
else:
435-
loss_mask[:, cu_seqlens[:-1]] = 0
436-
loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1)
437-
if self.config.context_parallel_size > 1:
438-
loss_mask_ = split_cp_inputs(loss_mask, cu_seqlens, dim=1)
439-
else:
440-
loss_mask_ = loss_mask.clone()
441-
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
442-
loss_mask_ = loss_mask_ & (mtp_labels != -100)
443-
mtp_loss = loss_mask_ * mtp_loss
444-
num_tokens = loss_mask_.sum()
445-
if self.training:
446-
mtp_loss_for_log = (
447-
torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0))
448-
MTPLossLoggingHelper.save_loss_to_tracker(
449-
mtp_loss_for_log,
450-
mtp_layer_number,
451-
self.config.mtp_num_layers,
452-
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
453-
)
454-
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
455-
if self.config.calculate_per_token_loss:
456-
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
457-
else:
458-
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
459462
sequence_parallel_override = False
460463
if in_inference_mode and inference_context.materialize_only_last_token_logits:
461464
if inference_context.is_static_batching():

0 commit comments

Comments
 (0)