@@ -93,19 +93,6 @@ index 860ee64a9..80944b702 100755
9393 sharded_state_dict_keys_map={
9494 "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
9595 "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
96- diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
97- index 6aec66e6d..3ac631935 100644
98- --- a/megatron/core/models/gpt/gpt_model.py
99- +++ b/megatron/core/models/gpt/gpt_model.py
100- @@ -446,7 +446,7 @@ class GPTModel(LanguageModule):
101- if self.share_embeddings_and_output_weights:
102- output_weight = self.shared_embedding_or_output_weight()
103-
104- - if mtp_in_postprocess:
105- + if mtp_in_postprocess and labels is not None:
106- hidden_states = self.mtp(
107- input_ids=input_ids,
108- position_ids=position_ids,
10996diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
11097index a40c85a88..86688c331 100644
11198--- a/megatron/core/parallel_state.py
@@ -149,6 +136,148 @@ index 63ee9d1f5..b90b744c1 100644
149136 )
150137 ops.append(recv_next_op)
151138 if len(ops) > 0:
139+ diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py
140+ index c749bac43..dde8d50e7 100644
141+ --- a/megatron/core/transformer/attention.py
142+ +++ b/megatron/core/transformer/attention.py
143+ @@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC):
144+ # Get the query, key and value tensors based on the type of attention -
145+ # self or cross attn.
146+ nvtx_range_push(suffix="qkv")
147+ - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
148+ + if self.config.use_gated_attention:
149+ + query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states)
150+ + else:
151+ + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
152+ nvtx_range_pop(suffix="qkv")
153+
154+ # ===================================================
155+ @@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC):
156+ # Output. [sq, b, h]
157+ # =================
158+
159+ + if self.config.use_gated_attention:
160+ + nvtx_range_push(suffix="sigmoid_gate")
161+ + core_attn_out = core_attn_out * torch.sigmoid(gate)
162+ + nvtx_range_pop(suffix="sigmoid_gate")
163+ +
164+ nvtx_range_push(suffix="linear_proj")
165+ output, bias = self.linear_proj(core_attn_out)
166+ nvtx_range_pop(suffix="linear_proj")
167+ @@ -879,19 +887,34 @@ class SelfAttention(Attention):
168+ model_comm_pgs=model_comm_pgs,
169+ )
170+
171+ - self.linear_qkv = build_module(
172+ - submodules.linear_qkv,
173+ - self.config.hidden_size,
174+ - self.query_projection_size + 2 * self.kv_projection_size,
175+ - config=self.config,
176+ - init_method=self.config.init_method,
177+ - gather_output=False,
178+ - bias=self.config.add_bias_linear or self.config.add_qkv_bias,
179+ - skip_bias_add=False,
180+ - is_expert=False,
181+ - tp_comm_buffer_name='qkv',
182+ - tp_group=self.model_comm_pgs.tp,
183+ - )
184+ + if self.config.use_gated_attention:
185+ + self.linear_qgkv = build_module(
186+ + submodules.linear_qkv,
187+ + self.config.hidden_size,
188+ + 2 * (self.query_projection_size + self.kv_projection_size),
189+ + config=self.config,
190+ + init_method=self.config.init_method,
191+ + gather_output=False,
192+ + bias=self.config.add_bias_linear or self.config.add_qkv_bias,
193+ + skip_bias_add=False,
194+ + is_expert=False,
195+ + tp_comm_buffer_name='qkv',
196+ + tp_group=self.model_comm_pgs.tp,
197+ + )
198+ + else:
199+ + self.linear_qkv = build_module(
200+ + submodules.linear_qkv,
201+ + self.config.hidden_size,
202+ + self.query_projection_size + 2 * self.kv_projection_size,
203+ + config=self.config,
204+ + init_method=self.config.init_method,
205+ + gather_output=False,
206+ + bias=self.config.add_bias_linear or self.config.add_qkv_bias,
207+ + skip_bias_add=False,
208+ + is_expert=False,
209+ + tp_comm_buffer_name='qkv',
210+ + tp_group=self.model_comm_pgs.tp,
211+ + )
212+
213+ if submodules.q_layernorm is not None:
214+ self.q_layernorm = build_module(
215+ @@ -1036,6 +1059,65 @@ class SelfAttention(Attention):
216+
217+ return query, key, value
218+
219+ + # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192
220+ + def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None):
221+ + """
222+ + Derives `query`, `key` and `value` tensors from `hidden_states`.
223+ + """
224+ + # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)]
225+ + mixed_qgkv, _ = self.linear_qgkv(hidden_states)
226+ +
227+ + # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn]
228+ + new_tensor_shape = mixed_qgkv.size()[:-1] + (
229+ + self.num_query_groups_per_partition,
230+ + (
231+ + 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1)
232+ + * self.hidden_size_per_attention_head
233+ + ),
234+ + )
235+ + mixed_qgkv = mixed_qgkv.view(*new_tensor_shape)
236+ +
237+ + split_arg_list = [
238+ + (
239+ + self.num_attention_heads_per_partition
240+ + // self.num_query_groups_per_partition
241+ + * self.hidden_size_per_attention_head
242+ + ),
243+ + (
244+ + self.num_attention_heads_per_partition
245+ + // self.num_query_groups_per_partition
246+ + * self.hidden_size_per_attention_head
247+ + ),
248+ + self.hidden_size_per_attention_head,
249+ + self.hidden_size_per_attention_head,
250+ + ]
251+ +
252+ + if SplitAlongDim is not None:
253+ +
254+ + # [sq, b, ng, (np/ng + 2) * hn]
255+ + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
256+ + (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list)
257+ + else:
258+ +
259+ + # [sq, b, ng, (np/ng + 2) * hn]
260+ + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
261+ + (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3)
262+ +
263+ + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
264+ + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
265+ + gate = gate.reshape(query.size(0), query.size(1), -1)
266+ +
267+ + if self.q_layernorm is not None:
268+ + query = self.q_layernorm(query)
269+ +
270+ + if self.k_layernorm is not None:
271+ + key = self.k_layernorm(key)
272+ +
273+ + if self.config.test_mode:
274+ + self.run_realtime_tests()
275+ +
276+ + return query, gate, key, value
277+ +
278+ def backward_dw(self) -> NoReturn:
279+ """Execute weight update operations"""
280+ self._backward_qkv_proj()
152281diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py
153282index 235b6f6af..fbcffe278 100644
154283--- a/megatron/core/transformer/moe/moe_utils.py
@@ -177,16 +306,42 @@ index 6b20b8622..459e65921 100644
177306 def _maintain_float32_expert_bias(self):
178307 """
179308 Maintain the expert bias in float32.
309+ diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py
310+ index b7884e18e..7ea47da8a 100755
311+ --- a/megatron/core/transformer/multi_token_prediction.py
312+ +++ b/megatron/core/transformer/multi_token_prediction.py
313+ @@ -681,9 +681,6 @@ class MultiTokenPredictionLayer(MegatronModule):
314+ [s, b, h], and optionally the updated context tensor if cross-attention is used.
315+ """
316+ assert context is None, f"multi token prediction + cross attention is not yet supported."
317+ - assert (
318+ - packed_seq_params is None
319+ - ), f"multi token prediction + sequence packing is not yet supported."
320+
321+ input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings(
322+ input_ids=input_ids,
323+ @@ -910,9 +907,7 @@ class MultiTokenPredictionBlock(MegatronModule):
324+ # to the hidden_states_list
325+ hidden_states_list.append(hidden_states)
326+
327+ - # concat the hidden states of all mtp layers
328+ - hidden_states = torch.cat(hidden_states_list, dim=0)
329+ - return hidden_states
330+ + return hidden_states_list
331+
332+ def sharded_state_dict(
333+ self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
180334diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
181- index d55bebe7e..1e1d9c781 100644
335+ index d55bebe7e..1eecbbd38 100644
182336--- a/megatron/core/transformer/transformer_config.py
183337+++ b/megatron/core/transformer/transformer_config.py
184- @@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig):
338+ @@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig):
185339 qk_layernorm: bool = False
186340 """Whether to apply `normalization` type of normalization to the query and key embeddings."""
187341
188342+ post_self_attn_layernorm: bool = False
189343+ post_mlp_layernorm: bool = False
344+ + use_gated_attention: bool = False
190345+
191346 test_mode: bool = False
192347 """Whether to run real-time tests."""
@@ -262,7 +417,7 @@ index 84f22bdea..f0f3f8e86 100644
262417 # discard the output of the pre-mlp layernorm and register the recompute
263418 # as a gradient hook of mlp_output_with_bias[0]
264419diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
265- index e3459c5ee..2a2fefac3 100644
420+ index e3459c5ee..7346bf35b 100644
266421--- a/megatron/training/arguments.py
267422+++ b/megatron/training/arguments.py
268423@@ -937,8 +937,6 @@ def validate_args(args, defaults={}):
@@ -274,24 +429,40 @@ index e3459c5ee..2a2fefac3 100644
274429 if args.num_experts is not None and args.moe_ffn_hidden_size is None:
275430 args.moe_ffn_hidden_size = args.ffn_hidden_size
276431 print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.")
277- @@ -1198,6 +1196,9 @@ def core_transformer_config_from_args(args, config_class=None):
432+ @@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None):
278433 if args.is_hybrid_model:
279434 kw_args['is_hybrid_model'] = args.is_hybrid_model
280435
281436+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm
282437+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm
438+ + kw_args['use_gated_attention'] = args.use_gated_attention
283439+
284440 # handle quantization config
285441 # NOTE: Kitchen arguments are only added to the namespace when
286442 # Kitchen library is available.
287- @@ -1488,6 +1489,10 @@ def _add_network_size_args(parser):
443+ @@ -1488,6 +1490,12 @@ def _add_network_size_args(parser):
288444 action='store_true',
289445 help='If set, use original BERT residula connection '
290446 'ordering.')
291447+ group.add_argument('--post-self-attn-layernorm', action='store_true',
292448+ help='If set, use post self attention layernorm.')
293449+ group.add_argument('--post-mlp-layernorm', action='store_true',
294450+ help='If set, use post MLP layernorm.')
451+ + group.add_argument('--use-gated-attention', action='store_true',
452+ + help='If set, use gated attention as in Qwen3Next')
295453 group.add_argument('--openai-gelu', action='store_true',
296454 help='Use OpenAIs GeLU implementation. This option'
297455 'should not be used unless for backward compatibility'
456+ diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py
457+ index 5cf222ccc..d1554ca4c 100644
458+ --- a/megatron/training/tokenizer/tokenizer.py
459+ +++ b/megatron/training/tokenizer/tokenizer.py
460+ @@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer):
461+ f"The transformers library must be installed to use huggingface_tokenizer_provider"
462+ )
463+
464+ + if "trust_remote_code" not in kwargs:
465+ + kwargs["trust_remote_code"] = True
466+ # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there
467+ self._tokenizer = transformers.AutoTokenizer.from_pretrained(
468+ pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
0 commit comments