@@ -94,7 +94,7 @@ index 860ee64a9..80944b702 100755
9494 "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
9595 "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
9696diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
97- index 6aec66e6d..7aa4b2f7d 100644
97+ index 6aec66e6d..b660a2002 100644
9898--- a/megatron/core/models/gpt/gpt_model.py
9999+++ b/megatron/core/models/gpt/gpt_model.py
100100@@ -355,6 +355,7 @@ class GPTModel(LanguageModule):
@@ -143,8 +143,24 @@ index 6aec66e6d..7aa4b2f7d 100644
143143 hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
144144 hidden_states = hidden_states_list[0]
145145 if loss_mask is None:
146- @@ -480,9 +485,9 @@ class GPTModel(LanguageModule):
147- runtime_gather_output=runtime_gather_output,
146+ @@ -474,15 +479,21 @@ class GPTModel(LanguageModule):
147+ loss_mask = torch.ones_like(mtp_labels)
148+ for mtp_layer_number in range(self.config.mtp_num_layers):
149+ # output
150+ - mtp_logits, _ = self.output_layer(
151+ - hidden_states_list[mtp_layer_number + 1],
152+ - weight=output_weight,
153+ - runtime_gather_output=runtime_gather_output,
154+ + output_layer_params = {k: v.detach() for k, v in self.output_layer.named_parameters()}
155+ + output_layer_buffers = dict(self.output_layer.named_buffers())
156+ + mtp_logits, _ = torch.func.functional_call(
157+ + self.output_layer,
158+ + {**output_layer_params, **output_layer_buffers},
159+ + (hidden_states_list[mtp_layer_number + 1],),
160+ + {
161+ + "weight": output_weight.detach() if output_weight else None,
162+ + "runtime_gather_output": runtime_gather_output,
163+ + },
148164 )
149165 # Calc loss for the current Multi-Token Prediction (MTP) layers.
150166- mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
0 commit comments