Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,42 @@
import torch


def _convert_qgkv_weight_to_hf(args, param, head_dim, prefix):
"""Convert gated attention QGKV from Megatron per-group layout to HF format.

Megatron stores: [Q_g0, G_g0, K_g0, V_g0, Q_g1, G_g1, K_g1, V_g1, ...]
HF stores: q_proj=[Q+G interleaved per head], k_proj=[K flat], v_proj=[V flat]
"""
num_heads = args.num_attention_heads
num_kv_heads = args.num_query_groups
heads_per_group = num_heads // num_kv_heads
hidden_size = args.hidden_size

q_size = heads_per_group * head_dim
g_size = heads_per_group * head_dim
group_size = q_size + g_size + head_dim + head_dim # Q + G + K + V per group

groups = param.view(num_kv_heads, group_size, hidden_size)
g_off, k_off, v_off = q_size, q_size + g_size, q_size + g_size + head_dim

all_q = groups[:, :g_off, :].reshape(num_kv_heads, heads_per_group, head_dim, hidden_size)
all_g = groups[:, g_off:k_off, :].reshape(num_kv_heads, heads_per_group, head_dim, hidden_size)
all_k = groups[:, k_off:v_off, :]
all_v = groups[:, v_off:, :]

q = all_q.reshape(num_heads, head_dim, hidden_size)
g = all_g.reshape(num_heads, head_dim, hidden_size)
qg = torch.cat([q, g], dim=1).reshape(num_heads * 2 * head_dim, hidden_size)
k = all_k.reshape(num_kv_heads * head_dim, hidden_size).contiguous()
v = all_v.reshape(num_kv_heads * head_dim, hidden_size).contiguous()

return [
(f"{prefix}.self_attn.q_proj.weight", qg),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job!

(f"{prefix}.self_attn.k_proj.weight", k),
(f"{prefix}.self_attn.v_proj.weight", v),
]


def _convert_mtp_layer(args, name, param, layer_idx):
"""Convert MTP layer parameters from Megatron to HuggingFace format.

Expand All @@ -17,11 +53,7 @@ def _convert_mtp_layer(args, name, param, layer_idx):
if "final_layernorm.weight" in name:
return [("mtp.norm.weight", param)]
if "eh_proj.weight" in name:
if param.dim() < 2:
raise ValueError(f"eh_proj weight expects 2D tensor, got {param.shape}")
first_half, second_half = param.chunk(2, dim=1)
new_param = torch.cat([second_half, first_half], dim=1)
return [("mtp.fc.weight", new_param)]
return [("mtp.fc.weight", param)]

# MTP inner transformer layers (keep layer index)
if "transformer_layer" in name:
Expand Down Expand Up @@ -146,6 +178,10 @@ def convert_qwen3_next_to_hf(args, name, param):
(f"model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias),
(f"model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias),
]
elif rest == "self_attention.linear_qgkv.weight":
return _convert_qgkv_weight_to_hf(args, param, head_dim, f"model.layers.{layer_idx}")
elif rest == "self_attention.linear_qgkv.layer_norm_weight":
return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)]
elif rest == "mlp.linear_fc1.weight":
gate_weight, up_weight = param.chunk(2, dim=0)
return [
Expand Down
51 changes: 42 additions & 9 deletions slime_plugins/mbridge/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,20 @@ class Qwen3NextBridge(Qwen2MoEBridge):
"model.layers.{layer_number}.self_attn.k_proj.weight",
"model.layers.{layer_number}.self_attn.v_proj.weight",
],
"self_attention.linear_qgkv.layer_norm_weight": ["model.layers.{layer_number}.input_layernorm.weight"],
"self_attention.linear_qgkv.weight": [
"model.layers.{layer_number}.self_attn.q_proj.weight",
"model.layers.{layer_number}.self_attn.k_proj.weight",
"model.layers.{layer_number}.self_attn.v_proj.weight",
],
}
)

def _get_gptmodel_args(self) -> dict:
"""Override to add MTP block spec if needed."""
"""Override to add MTP block spec."""
ret = super()._get_gptmodel_args()
if getattr(self.config, "mtp_num_layers", None) is not None:
transformer_layer_spec = self.config
mtp_block_spec = get_gpt_mtp_block_spec(self.config, transformer_layer_spec, use_transformer_engine=True)
mtp_block_spec = get_gpt_mtp_block_spec(self.config, self.config, use_transformer_engine=True)
ret["mtp_block_spec"] = mtp_block_spec
return ret

Expand Down Expand Up @@ -96,6 +101,39 @@ def _convert_mtp_param(self, name: str) -> list[str]:
def _weight_to_mcore_format(
self, mcore_weights_name: str, hf_weights: list[torch.Tensor]
) -> tuple[list[str], list[torch.Tensor]]:
if "self_attention.linear_qgkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name:
# Gated attention: merge Q+G, K, V into per-group QGKV layout
# HF q_proj contains Q+G interleaved per head: [Q0, G0, Q1, G1, ...]
# Megatron expects: [Q_g0, G_g0, K_g0, V_g0, Q_g1, G_g1, K_g1, V_g1, ...]
assert len(hf_weights) == 3
qg, k, v = hf_weights

num_heads = self.hf_config.num_attention_heads
num_kv_heads = self.hf_config.num_key_value_heads
head_dim = self.hf_config.head_dim
hidden_size = self.hf_config.hidden_size
heads_per_group = num_heads // num_kv_heads

# Split Q and G from interleaved q_proj
qg = qg.view(num_heads, 2 * head_dim, hidden_size)
q = qg[:, :head_dim, :] # [num_heads, head_dim, hidden]
g = qg[:, head_dim:, :] # [num_heads, head_dim, hidden]

k = k.view(num_kv_heads, head_dim, hidden_size)
v = v.view(num_kv_heads, head_dim, hidden_size)

# Organize per query group: [Q_g, G_g, K_g, V_g]
q = q.view(num_kv_heads, heads_per_group, head_dim, hidden_size)
g = g.view(num_kv_heads, heads_per_group, head_dim, hidden_size)

groups = []
for i in range(num_kv_heads):
q_g = q[i].reshape(heads_per_group * head_dim, hidden_size)
g_g = g[i].reshape(heads_per_group * head_dim, hidden_size)
groups.append(torch.cat([q_g, g_g, k[i], v[i]], dim=0))

return torch.cat(groups, dim=0).contiguous()

if "self_attention.linear_qkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name:
# merge qkv
assert len(hf_weights) == 3
Expand Down Expand Up @@ -129,17 +167,11 @@ def _weight_to_mcore_format(
return qgkv

weight = super()._weight_to_mcore_format(mcore_weights_name, hf_weights)
if mcore_weights_name.endswith("eh_proj.weight"):
first_half, second_half = weight.chunk(2, dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good fix.

weight = torch.cat([second_half, first_half], dim=1)
return weight

def _weight_to_hf_format(
self, mcore_weights_name: str, mcore_weights: torch.Tensor
) -> tuple[list[str], list[torch.Tensor]]:
if mcore_weights_name.endswith("eh_proj.weight"):
first_half, second_half = mcore_weights.chunk(2, dim=1)
mcore_weights = torch.cat([second_half, first_half], dim=1)
return super()._weight_to_hf_format(mcore_weights_name, mcore_weights)

def _build_config(self):
Expand Down Expand Up @@ -169,5 +201,6 @@ def _build_config(self):
# Qwen3 Next specific
attention_output_gate=True,
moe_shared_expert_gate=True,
use_gated_attention=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe no need to modify this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't have this, megatron will not instantiate a gated qgkv attention but the normal qkv attention

**mtp_args,
)
1 change: 1 addition & 0 deletions slime_plugins/models/hf_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
layer_number: int,
cp_comm_type: str = "p2p",
pg_collection=None,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe no need to modify this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to support qgkv attention

):
super().__init__(config=config)
self.args = args
Expand Down
2 changes: 2 additions & 0 deletions slime_plugins/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,15 @@ def __init__(
layer_number: int,
cp_comm_type: str = "p2p",
pg_collection=None,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe no need to modify this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to support qgkv attention

):
super().__init__(
args,
config,
layer_number,
cp_comm_type,
pg_collection,
**kwargs,
)
if Qwen3NextAttention is None:
raise ImportError("Please install transformers>=4.35.0 to use Qwen3NextAttention.")
Expand Down