Skip to content

Commit 08118ec

Browse files
authored
[model] use self attn in megatron for gated attn (#624)
1 parent 687d18e commit 08118ec

File tree

6 files changed

+312
-90
lines changed

6 files changed

+312
-90
lines changed

docker/patch/latest/megatron.patch

Lines changed: 189 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,6 @@ index 860ee64a9..80944b702 100755
9393
sharded_state_dict_keys_map={
9494
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
9595
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
96-
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
97-
index 6aec66e6d..3ac631935 100644
98-
--- a/megatron/core/models/gpt/gpt_model.py
99-
+++ b/megatron/core/models/gpt/gpt_model.py
100-
@@ -446,7 +446,7 @@ class GPTModel(LanguageModule):
101-
if self.share_embeddings_and_output_weights:
102-
output_weight = self.shared_embedding_or_output_weight()
103-
104-
- if mtp_in_postprocess:
105-
+ if mtp_in_postprocess and labels is not None:
106-
hidden_states = self.mtp(
107-
input_ids=input_ids,
108-
position_ids=position_ids,
10996
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
11097
index a40c85a88..86688c331 100644
11198
--- a/megatron/core/parallel_state.py
@@ -149,6 +136,148 @@ index 63ee9d1f5..b90b744c1 100644
149136
)
150137
ops.append(recv_next_op)
151138
if len(ops) > 0:
139+
diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py
140+
index c749bac43..dde8d50e7 100644
141+
--- a/megatron/core/transformer/attention.py
142+
+++ b/megatron/core/transformer/attention.py
143+
@@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC):
144+
# Get the query, key and value tensors based on the type of attention -
145+
# self or cross attn.
146+
nvtx_range_push(suffix="qkv")
147+
- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
148+
+ if self.config.use_gated_attention:
149+
+ query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states)
150+
+ else:
151+
+ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
152+
nvtx_range_pop(suffix="qkv")
153+
154+
# ===================================================
155+
@@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC):
156+
# Output. [sq, b, h]
157+
# =================
158+
159+
+ if self.config.use_gated_attention:
160+
+ nvtx_range_push(suffix="sigmoid_gate")
161+
+ core_attn_out = core_attn_out * torch.sigmoid(gate)
162+
+ nvtx_range_pop(suffix="sigmoid_gate")
163+
+
164+
nvtx_range_push(suffix="linear_proj")
165+
output, bias = self.linear_proj(core_attn_out)
166+
nvtx_range_pop(suffix="linear_proj")
167+
@@ -879,19 +887,34 @@ class SelfAttention(Attention):
168+
model_comm_pgs=model_comm_pgs,
169+
)
170+
171+
- self.linear_qkv = build_module(
172+
- submodules.linear_qkv,
173+
- self.config.hidden_size,
174+
- self.query_projection_size + 2 * self.kv_projection_size,
175+
- config=self.config,
176+
- init_method=self.config.init_method,
177+
- gather_output=False,
178+
- bias=self.config.add_bias_linear or self.config.add_qkv_bias,
179+
- skip_bias_add=False,
180+
- is_expert=False,
181+
- tp_comm_buffer_name='qkv',
182+
- tp_group=self.model_comm_pgs.tp,
183+
- )
184+
+ if self.config.use_gated_attention:
185+
+ self.linear_qgkv = build_module(
186+
+ submodules.linear_qkv,
187+
+ self.config.hidden_size,
188+
+ 2 * (self.query_projection_size + self.kv_projection_size),
189+
+ config=self.config,
190+
+ init_method=self.config.init_method,
191+
+ gather_output=False,
192+
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
193+
+ skip_bias_add=False,
194+
+ is_expert=False,
195+
+ tp_comm_buffer_name='qkv',
196+
+ tp_group=self.model_comm_pgs.tp,
197+
+ )
198+
+ else:
199+
+ self.linear_qkv = build_module(
200+
+ submodules.linear_qkv,
201+
+ self.config.hidden_size,
202+
+ self.query_projection_size + 2 * self.kv_projection_size,
203+
+ config=self.config,
204+
+ init_method=self.config.init_method,
205+
+ gather_output=False,
206+
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
207+
+ skip_bias_add=False,
208+
+ is_expert=False,
209+
+ tp_comm_buffer_name='qkv',
210+
+ tp_group=self.model_comm_pgs.tp,
211+
+ )
212+
213+
if submodules.q_layernorm is not None:
214+
self.q_layernorm = build_module(
215+
@@ -1036,6 +1059,65 @@ class SelfAttention(Attention):
216+
217+
return query, key, value
218+
219+
+ # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192
220+
+ def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None):
221+
+ """
222+
+ Derives `query`, `key` and `value` tensors from `hidden_states`.
223+
+ """
224+
+ # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)]
225+
+ mixed_qgkv, _ = self.linear_qgkv(hidden_states)
226+
+
227+
+ # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn]
228+
+ new_tensor_shape = mixed_qgkv.size()[:-1] + (
229+
+ self.num_query_groups_per_partition,
230+
+ (
231+
+ 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1)
232+
+ * self.hidden_size_per_attention_head
233+
+ ),
234+
+ )
235+
+ mixed_qgkv = mixed_qgkv.view(*new_tensor_shape)
236+
+
237+
+ split_arg_list = [
238+
+ (
239+
+ self.num_attention_heads_per_partition
240+
+ // self.num_query_groups_per_partition
241+
+ * self.hidden_size_per_attention_head
242+
+ ),
243+
+ (
244+
+ self.num_attention_heads_per_partition
245+
+ // self.num_query_groups_per_partition
246+
+ * self.hidden_size_per_attention_head
247+
+ ),
248+
+ self.hidden_size_per_attention_head,
249+
+ self.hidden_size_per_attention_head,
250+
+ ]
251+
+
252+
+ if SplitAlongDim is not None:
253+
+
254+
+ # [sq, b, ng, (np/ng + 2) * hn]
255+
+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
256+
+ (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list)
257+
+ else:
258+
+
259+
+ # [sq, b, ng, (np/ng + 2) * hn]
260+
+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
261+
+ (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3)
262+
+
263+
+ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
264+
+ query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
265+
+ gate = gate.reshape(query.size(0), query.size(1), -1)
266+
+
267+
+ if self.q_layernorm is not None:
268+
+ query = self.q_layernorm(query)
269+
+
270+
+ if self.k_layernorm is not None:
271+
+ key = self.k_layernorm(key)
272+
+
273+
+ if self.config.test_mode:
274+
+ self.run_realtime_tests()
275+
+
276+
+ return query, gate, key, value
277+
+
278+
def backward_dw(self) -> NoReturn:
279+
"""Execute weight update operations"""
280+
self._backward_qkv_proj()
152281
diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py
153282
index 235b6f6af..fbcffe278 100644
154283
--- a/megatron/core/transformer/moe/moe_utils.py
@@ -177,16 +306,42 @@ index 6b20b8622..459e65921 100644
177306
def _maintain_float32_expert_bias(self):
178307
"""
179308
Maintain the expert bias in float32.
309+
diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py
310+
index b7884e18e..7ea47da8a 100755
311+
--- a/megatron/core/transformer/multi_token_prediction.py
312+
+++ b/megatron/core/transformer/multi_token_prediction.py
313+
@@ -681,9 +681,6 @@ class MultiTokenPredictionLayer(MegatronModule):
314+
[s, b, h], and optionally the updated context tensor if cross-attention is used.
315+
"""
316+
assert context is None, f"multi token prediction + cross attention is not yet supported."
317+
- assert (
318+
- packed_seq_params is None
319+
- ), f"multi token prediction + sequence packing is not yet supported."
320+
321+
input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings(
322+
input_ids=input_ids,
323+
@@ -910,9 +907,7 @@ class MultiTokenPredictionBlock(MegatronModule):
324+
# to the hidden_states_list
325+
hidden_states_list.append(hidden_states)
326+
327+
- # concat the hidden states of all mtp layers
328+
- hidden_states = torch.cat(hidden_states_list, dim=0)
329+
- return hidden_states
330+
+ return hidden_states_list
331+
332+
def sharded_state_dict(
333+
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
180334
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
181-
index d55bebe7e..1e1d9c781 100644
335+
index d55bebe7e..1eecbbd38 100644
182336
--- a/megatron/core/transformer/transformer_config.py
183337
+++ b/megatron/core/transformer/transformer_config.py
184-
@@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig):
338+
@@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig):
185339
qk_layernorm: bool = False
186340
"""Whether to apply `normalization` type of normalization to the query and key embeddings."""
187341

188342
+ post_self_attn_layernorm: bool = False
189343
+ post_mlp_layernorm: bool = False
344+
+ use_gated_attention: bool = False
190345
+
191346
test_mode: bool = False
192347
"""Whether to run real-time tests."""
@@ -262,7 +417,7 @@ index 84f22bdea..f0f3f8e86 100644
262417
# discard the output of the pre-mlp layernorm and register the recompute
263418
# as a gradient hook of mlp_output_with_bias[0]
264419
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
265-
index e3459c5ee..2a2fefac3 100644
420+
index e3459c5ee..7346bf35b 100644
266421
--- a/megatron/training/arguments.py
267422
+++ b/megatron/training/arguments.py
268423
@@ -937,8 +937,6 @@ def validate_args(args, defaults={}):
@@ -274,24 +429,40 @@ index e3459c5ee..2a2fefac3 100644
274429
if args.num_experts is not None and args.moe_ffn_hidden_size is None:
275430
args.moe_ffn_hidden_size = args.ffn_hidden_size
276431
print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.")
277-
@@ -1198,6 +1196,9 @@ def core_transformer_config_from_args(args, config_class=None):
432+
@@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None):
278433
if args.is_hybrid_model:
279434
kw_args['is_hybrid_model'] = args.is_hybrid_model
280435

281436
+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm
282437
+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm
438+
+ kw_args['use_gated_attention'] = args.use_gated_attention
283439
+
284440
# handle quantization config
285441
# NOTE: Kitchen arguments are only added to the namespace when
286442
# Kitchen library is available.
287-
@@ -1488,6 +1489,10 @@ def _add_network_size_args(parser):
443+
@@ -1488,6 +1490,12 @@ def _add_network_size_args(parser):
288444
action='store_true',
289445
help='If set, use original BERT residula connection '
290446
'ordering.')
291447
+ group.add_argument('--post-self-attn-layernorm', action='store_true',
292448
+ help='If set, use post self attention layernorm.')
293449
+ group.add_argument('--post-mlp-layernorm', action='store_true',
294450
+ help='If set, use post MLP layernorm.')
451+
+ group.add_argument('--use-gated-attention', action='store_true',
452+
+ help='If set, use gated attention as in Qwen3Next')
295453
group.add_argument('--openai-gelu', action='store_true',
296454
help='Use OpenAIs GeLU implementation. This option'
297455
'should not be used unless for backward compatibility'
456+
diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py
457+
index 5cf222ccc..d1554ca4c 100644
458+
--- a/megatron/training/tokenizer/tokenizer.py
459+
+++ b/megatron/training/tokenizer/tokenizer.py
460+
@@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer):
461+
f"The transformers library must be installed to use huggingface_tokenizer_provider"
462+
)
463+
464+
+ if "trust_remote_code" not in kwargs:
465+
+ kwargs["trust_remote_code"] = True
466+
# TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there
467+
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
468+
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs

scripts/models/qwen3-next-80B-A3B.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ MODEL_ARGS=(
2525
--num-layers 48
2626
--hidden-size 2048
2727
--ffn-hidden-size 5120
28+
--use-gated-attention
2829

2930
--normalization RMSNorm
3031
--apply-layernorm-1p

slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,25 @@ def convert_qwen3_next_to_hf(args, name, param):
7575

7676
if rest == "self_attention.linear_proj.weight":
7777
return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)]
78-
elif rest == "self_attention.linear_qkv.weight":
78+
elif rest == "self_attention.linear_qgkv.weight":
7979

8080
param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size)
81-
q_param, k_param, v_param = torch.split(param, split_size_or_sections=[value_num_per_group, 1, 1], dim=1)
82-
q_param = q_param.reshape(-1, args.hidden_size)
81+
q_param, k_param, v_param = torch.split(
82+
param, split_size_or_sections=[2 * value_num_per_group, 1, 1], dim=1
83+
)
84+
q_param = (
85+
q_param.reshape(args.num_query_groups, 2, value_num_per_group, head_dim, args.hidden_size)
86+
.transpose(1, 2)
87+
.reshape(-1, args.hidden_size)
88+
)
8389
k_param = k_param.reshape(-1, args.hidden_size)
8490
v_param = v_param.reshape(-1, args.hidden_size)
8591
return [
8692
(f"model.layers.{layer_idx}.self_attn.q_proj.weight", q_param),
8793
(f"model.layers.{layer_idx}.self_attn.k_proj.weight", k_param),
8894
(f"model.layers.{layer_idx}.self_attn.v_proj.weight", v_param),
8995
]
90-
elif rest == "self_attention.linear_qkv.bias":
96+
elif rest == "self_attention.linear_qgkv.bias":
9197
param = param.view(args.num_query_groups, -1)
9298
q_bias, k_bias, v_bias = torch.split(
9399
param,
@@ -110,7 +116,7 @@ def convert_qwen3_next_to_hf(args, name, param):
110116
]
111117
elif rest == "mlp.linear_fc2.weight":
112118
return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)]
113-
elif rest == "self_attention.linear_qkv.layer_norm_weight":
119+
elif rest == "self_attention.linear_qgkv.layer_norm_weight":
114120
return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)]
115121
elif rest == "mlp.linear_fc1.layer_norm_weight":
116122
return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)]

0 commit comments

Comments
 (0)