-
Notifications
You must be signed in to change notification settings - Fork 598
Support qwen3-next MTP Training #1575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,15 +37,20 @@ class Qwen3NextBridge(Qwen2MoEBridge): | |
| "model.layers.{layer_number}.self_attn.k_proj.weight", | ||
| "model.layers.{layer_number}.self_attn.v_proj.weight", | ||
| ], | ||
| "self_attention.linear_qgkv.layer_norm_weight": ["model.layers.{layer_number}.input_layernorm.weight"], | ||
| "self_attention.linear_qgkv.weight": [ | ||
| "model.layers.{layer_number}.self_attn.q_proj.weight", | ||
| "model.layers.{layer_number}.self_attn.k_proj.weight", | ||
| "model.layers.{layer_number}.self_attn.v_proj.weight", | ||
| ], | ||
| } | ||
| ) | ||
|
|
||
| def _get_gptmodel_args(self) -> dict: | ||
| """Override to add MTP block spec if needed.""" | ||
| """Override to add MTP block spec.""" | ||
| ret = super()._get_gptmodel_args() | ||
| if getattr(self.config, "mtp_num_layers", None) is not None: | ||
| transformer_layer_spec = self.config | ||
| mtp_block_spec = get_gpt_mtp_block_spec(self.config, transformer_layer_spec, use_transformer_engine=True) | ||
| mtp_block_spec = get_gpt_mtp_block_spec(self.config, self.config, use_transformer_engine=True) | ||
| ret["mtp_block_spec"] = mtp_block_spec | ||
| return ret | ||
|
|
||
|
|
@@ -96,6 +101,39 @@ def _convert_mtp_param(self, name: str) -> list[str]: | |
| def _weight_to_mcore_format( | ||
| self, mcore_weights_name: str, hf_weights: list[torch.Tensor] | ||
| ) -> tuple[list[str], list[torch.Tensor]]: | ||
| if "self_attention.linear_qgkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name: | ||
| # Gated attention: merge Q+G, K, V into per-group QGKV layout | ||
| # HF q_proj contains Q+G interleaved per head: [Q0, G0, Q1, G1, ...] | ||
| # Megatron expects: [Q_g0, G_g0, K_g0, V_g0, Q_g1, G_g1, K_g1, V_g1, ...] | ||
| assert len(hf_weights) == 3 | ||
| qg, k, v = hf_weights | ||
|
|
||
| num_heads = self.hf_config.num_attention_heads | ||
| num_kv_heads = self.hf_config.num_key_value_heads | ||
| head_dim = self.hf_config.head_dim | ||
| hidden_size = self.hf_config.hidden_size | ||
| heads_per_group = num_heads // num_kv_heads | ||
|
|
||
| # Split Q and G from interleaved q_proj | ||
| qg = qg.view(num_heads, 2 * head_dim, hidden_size) | ||
| q = qg[:, :head_dim, :] # [num_heads, head_dim, hidden] | ||
| g = qg[:, head_dim:, :] # [num_heads, head_dim, hidden] | ||
|
|
||
| k = k.view(num_kv_heads, head_dim, hidden_size) | ||
| v = v.view(num_kv_heads, head_dim, hidden_size) | ||
|
|
||
| # Organize per query group: [Q_g, G_g, K_g, V_g] | ||
| q = q.view(num_kv_heads, heads_per_group, head_dim, hidden_size) | ||
| g = g.view(num_kv_heads, heads_per_group, head_dim, hidden_size) | ||
|
|
||
| groups = [] | ||
| for i in range(num_kv_heads): | ||
| q_g = q[i].reshape(heads_per_group * head_dim, hidden_size) | ||
| g_g = g[i].reshape(heads_per_group * head_dim, hidden_size) | ||
| groups.append(torch.cat([q_g, g_g, k[i], v[i]], dim=0)) | ||
|
|
||
| return torch.cat(groups, dim=0).contiguous() | ||
|
|
||
| if "self_attention.linear_qkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name: | ||
| # merge qkv | ||
| assert len(hf_weights) == 3 | ||
|
|
@@ -129,17 +167,11 @@ def _weight_to_mcore_format( | |
| return qgkv | ||
|
|
||
| weight = super()._weight_to_mcore_format(mcore_weights_name, hf_weights) | ||
| if mcore_weights_name.endswith("eh_proj.weight"): | ||
| first_half, second_half = weight.chunk(2, dim=1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good fix. |
||
| weight = torch.cat([second_half, first_half], dim=1) | ||
| return weight | ||
|
|
||
| def _weight_to_hf_format( | ||
| self, mcore_weights_name: str, mcore_weights: torch.Tensor | ||
| ) -> tuple[list[str], list[torch.Tensor]]: | ||
| if mcore_weights_name.endswith("eh_proj.weight"): | ||
| first_half, second_half = mcore_weights.chunk(2, dim=1) | ||
| mcore_weights = torch.cat([second_half, first_half], dim=1) | ||
| return super()._weight_to_hf_format(mcore_weights_name, mcore_weights) | ||
|
|
||
| def _build_config(self): | ||
|
|
@@ -169,5 +201,6 @@ def _build_config(self): | |
| # Qwen3 Next specific | ||
| attention_output_gate=True, | ||
| moe_shared_expert_gate=True, | ||
| use_gated_attention=True, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe no need to modify this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don't have this, megatron will not instantiate a gated qgkv attention but the normal qkv attention |
||
| **mtp_args, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ def __init__( | |
| layer_number: int, | ||
| cp_comm_type: str = "p2p", | ||
| pg_collection=None, | ||
| **kwargs, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe no need to modify this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is to support qgkv attention |
||
| ): | ||
| super().__init__(config=config) | ||
| self.args = args | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -170,13 +170,15 @@ def __init__( | |
| layer_number: int, | ||
| cp_comm_type: str = "p2p", | ||
| pg_collection=None, | ||
| **kwargs, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe no need to modify this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is to support qgkv attention |
||
| ): | ||
| super().__init__( | ||
| args, | ||
| config, | ||
| layer_number, | ||
| cp_comm_type, | ||
| pg_collection, | ||
| **kwargs, | ||
| ) | ||
| if Qwen3NextAttention is None: | ||
| raise ImportError("Please install transformers>=4.35.0 to use Qwen3NextAttention.") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job!