Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 189 additions & 18 deletions docker/patch/latest/megatron.patch
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,6 @@ index 860ee64a9..80944b702 100755
sharded_state_dict_keys_map={
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
index 6aec66e6d..3ac631935 100644
--- a/megatron/core/models/gpt/gpt_model.py
+++ b/megatron/core/models/gpt/gpt_model.py
@@ -446,7 +446,7 @@ class GPTModel(LanguageModule):
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()

- if mtp_in_postprocess:
+ if mtp_in_postprocess and labels is not None:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
index a40c85a88..86688c331 100644
--- a/megatron/core/parallel_state.py
Expand Down Expand Up @@ -149,6 +136,148 @@ index 63ee9d1f5..b90b744c1 100644
)
ops.append(recv_next_op)
if len(ops) > 0:
diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py
index c749bac43..dde8d50e7 100644
--- a/megatron/core/transformer/attention.py
+++ b/megatron/core/transformer/attention.py
@@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC):
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
nvtx_range_push(suffix="qkv")
- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
+ if self.config.use_gated_attention:
+ query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states)
+ else:
+ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
nvtx_range_pop(suffix="qkv")

# ===================================================
@@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC):
# Output. [sq, b, h]
# =================

+ if self.config.use_gated_attention:
+ nvtx_range_push(suffix="sigmoid_gate")
+ core_attn_out = core_attn_out * torch.sigmoid(gate)
+ nvtx_range_pop(suffix="sigmoid_gate")
+
nvtx_range_push(suffix="linear_proj")
output, bias = self.linear_proj(core_attn_out)
nvtx_range_pop(suffix="linear_proj")
@@ -879,19 +887,34 @@ class SelfAttention(Attention):
model_comm_pgs=model_comm_pgs,
)

- self.linear_qkv = build_module(
- submodules.linear_qkv,
- self.config.hidden_size,
- self.query_projection_size + 2 * self.kv_projection_size,
- config=self.config,
- init_method=self.config.init_method,
- gather_output=False,
- bias=self.config.add_bias_linear or self.config.add_qkv_bias,
- skip_bias_add=False,
- is_expert=False,
- tp_comm_buffer_name='qkv',
- tp_group=self.model_comm_pgs.tp,
- )
+ if self.config.use_gated_attention:
+ self.linear_qgkv = build_module(
+ submodules.linear_qkv,
+ self.config.hidden_size,
+ 2 * (self.query_projection_size + self.kv_projection_size),
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='qkv',
+ tp_group=self.model_comm_pgs.tp,
+ )
+ else:
+ self.linear_qkv = build_module(
+ submodules.linear_qkv,
+ self.config.hidden_size,
+ self.query_projection_size + 2 * self.kv_projection_size,
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='qkv',
+ tp_group=self.model_comm_pgs.tp,
+ )

if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
@@ -1036,6 +1059,65 @@ class SelfAttention(Attention):

return query, key, value

+ # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192
+ def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None):
+ """
+ Derives `query`, `key` and `value` tensors from `hidden_states`.
+ """
+ # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)]
+ mixed_qgkv, _ = self.linear_qgkv(hidden_states)
+
+ # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn]
+ new_tensor_shape = mixed_qgkv.size()[:-1] + (
+ self.num_query_groups_per_partition,
+ (
+ 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1)
+ * self.hidden_size_per_attention_head
+ ),
+ )
+ mixed_qgkv = mixed_qgkv.view(*new_tensor_shape)
+
+ split_arg_list = [
+ (
+ self.num_attention_heads_per_partition
+ // self.num_query_groups_per_partition
+ * self.hidden_size_per_attention_head
+ ),
+ (
+ self.num_attention_heads_per_partition
+ // self.num_query_groups_per_partition
+ * self.hidden_size_per_attention_head
+ ),
+ self.hidden_size_per_attention_head,
+ self.hidden_size_per_attention_head,
+ ]
+
+ if SplitAlongDim is not None:
+
+ # [sq, b, ng, (np/ng + 2) * hn]
+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
+ (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list)
+ else:
+
+ # [sq, b, ng, (np/ng + 2) * hn]
+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
+ (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3)
+
+ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
+ query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
+ gate = gate.reshape(query.size(0), query.size(1), -1)
+
+ if self.q_layernorm is not None:
+ query = self.q_layernorm(query)
+
+ if self.k_layernorm is not None:
+ key = self.k_layernorm(key)
+
+ if self.config.test_mode:
+ self.run_realtime_tests()
+
+ return query, gate, key, value
+
def backward_dw(self) -> NoReturn:
"""Execute weight update operations"""
self._backward_qkv_proj()
diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py
index 235b6f6af..fbcffe278 100644
--- a/megatron/core/transformer/moe/moe_utils.py
Expand Down Expand Up @@ -177,16 +306,42 @@ index 6b20b8622..459e65921 100644
def _maintain_float32_expert_bias(self):
"""
Maintain the expert bias in float32.
diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py
index b7884e18e..7ea47da8a 100755
--- a/megatron/core/transformer/multi_token_prediction.py
+++ b/megatron/core/transformer/multi_token_prediction.py
@@ -681,9 +681,6 @@ class MultiTokenPredictionLayer(MegatronModule):
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
assert context is None, f"multi token prediction + cross attention is not yet supported."
- assert (
- packed_seq_params is None
- ), f"multi token prediction + sequence packing is not yet supported."

input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings(
input_ids=input_ids,
@@ -910,9 +907,7 @@ class MultiTokenPredictionBlock(MegatronModule):
# to the hidden_states_list
hidden_states_list.append(hidden_states)

- # concat the hidden states of all mtp layers
- hidden_states = torch.cat(hidden_states_list, dim=0)
- return hidden_states
+ return hidden_states_list

def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
index d55bebe7e..1e1d9c781 100644
index d55bebe7e..1eecbbd38 100644
--- a/megatron/core/transformer/transformer_config.py
+++ b/megatron/core/transformer/transformer_config.py
@@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig):
@@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig):
qk_layernorm: bool = False
"""Whether to apply `normalization` type of normalization to the query and key embeddings."""

+ post_self_attn_layernorm: bool = False
+ post_mlp_layernorm: bool = False
+ use_gated_attention: bool = False
+
test_mode: bool = False
"""Whether to run real-time tests."""
Expand Down Expand Up @@ -262,7 +417,7 @@ index 84f22bdea..f0f3f8e86 100644
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
index e3459c5ee..2a2fefac3 100644
index e3459c5ee..7346bf35b 100644
--- a/megatron/training/arguments.py
+++ b/megatron/training/arguments.py
@@ -937,8 +937,6 @@ def validate_args(args, defaults={}):
Expand All @@ -274,24 +429,40 @@ index e3459c5ee..2a2fefac3 100644
if args.num_experts is not None and args.moe_ffn_hidden_size is None:
args.moe_ffn_hidden_size = args.ffn_hidden_size
print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.")
@@ -1198,6 +1196,9 @@ def core_transformer_config_from_args(args, config_class=None):
@@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None):
if args.is_hybrid_model:
kw_args['is_hybrid_model'] = args.is_hybrid_model

+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm
+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm
+ kw_args['use_gated_attention'] = args.use_gated_attention
+
# handle quantization config
# NOTE: Kitchen arguments are only added to the namespace when
# Kitchen library is available.
@@ -1488,6 +1489,10 @@ def _add_network_size_args(parser):
@@ -1488,6 +1490,12 @@ def _add_network_size_args(parser):
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
+ group.add_argument('--post-self-attn-layernorm', action='store_true',
+ help='If set, use post self attention layernorm.')
+ group.add_argument('--post-mlp-layernorm', action='store_true',
+ help='If set, use post MLP layernorm.')
+ group.add_argument('--use-gated-attention', action='store_true',
+ help='If set, use gated attention as in Qwen3Next')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py
index 5cf222ccc..d1554ca4c 100644
--- a/megatron/training/tokenizer/tokenizer.py
+++ b/megatron/training/tokenizer/tokenizer.py
@@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer):
f"The transformers library must be installed to use huggingface_tokenizer_provider"
)

+ if "trust_remote_code" not in kwargs:
+ kwargs["trust_remote_code"] = True
# TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
1 change: 1 addition & 0 deletions scripts/models/qwen3-next-80B-A3B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ MODEL_ARGS=(
--num-layers 48
--hidden-size 2048
--ffn-hidden-size 5120
--use-gated-attention

--normalization RMSNorm
--apply-layernorm-1p
Expand Down
16 changes: 11 additions & 5 deletions slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,25 @@ def convert_qwen3_next_to_hf(args, name, param):

if rest == "self_attention.linear_proj.weight":
return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)]
elif rest == "self_attention.linear_qkv.weight":
elif rest == "self_attention.linear_qgkv.weight":

param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size)
q_param, k_param, v_param = torch.split(param, split_size_or_sections=[value_num_per_group, 1, 1], dim=1)
q_param = q_param.reshape(-1, args.hidden_size)
q_param, k_param, v_param = torch.split(
param, split_size_or_sections=[2 * value_num_per_group, 1, 1], dim=1
)
q_param = (
q_param.reshape(args.num_query_groups, 2, value_num_per_group, head_dim, args.hidden_size)
.transpose(1, 2)
.reshape(-1, args.hidden_size)
)
k_param = k_param.reshape(-1, args.hidden_size)
v_param = v_param.reshape(-1, args.hidden_size)
return [
(f"model.layers.{layer_idx}.self_attn.q_proj.weight", q_param),
(f"model.layers.{layer_idx}.self_attn.k_proj.weight", k_param),
(f"model.layers.{layer_idx}.self_attn.v_proj.weight", v_param),
]
elif rest == "self_attention.linear_qkv.bias":
elif rest == "self_attention.linear_qgkv.bias":
param = param.view(args.num_query_groups, -1)
q_bias, k_bias, v_bias = torch.split(
param,
Expand All @@ -110,7 +116,7 @@ def convert_qwen3_next_to_hf(args, name, param):
]
elif rest == "mlp.linear_fc2.weight":
return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)]
elif rest == "self_attention.linear_qkv.layer_norm_weight":
elif rest == "self_attention.linear_qgkv.layer_norm_weight":
return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)]
elif rest == "mlp.linear_fc1.layer_norm_weight":
return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
Expand Down
Loading
Loading