-
Notifications
You must be signed in to change notification settings - Fork 424
Description
Unfortunately I don't have the time nor technical expertise to do it, but here's a plethora of research that should allow someone with minimal training get a good start at it:
Compared to Qwen2
First, within the transformers library, configuration_qwen2.py and [modeling_qwen2.py] from the Qwen2 architecture and configuration_qwen3.py and modeling_qwen3.py from the Qwen3 architecture, it appears that the only differences are:
- Qwen-3 adds self.q_norm and self.k_norm and applies them to the outputs of q_proj and k_proj ; Qwen-2 has no such step.
- All four projection layers now take bias=config.attention_bias instead of hard-coded True/False (compare to Qwen-2’s fixed biases ).
- head_dim (default 128) simply exposes the per-head dimension that Qwen2 already infers as hidden_size // num_attention_heads, so the default tensor shapes are unchanged .
- attention_bias (bool, default False) drives the bias=config.attention_bias switch that I pointed out in the updated attention block; Qwen2 keeps those biases hard-wired .
- Everything else is copy-paste identical. Vocabulary size, layer counts, RoPE settings, sliding-window options, etc. match across the two configuration files .
In summary, head_dim defaults to the same value as before and attention_bias is the mechanism already discussed, there are no further architectural changes to add.
Qwen2 Loader Can't be Re-used
- Qwen-3 introduces two additional RMSNorm layers applied to the Q and K projections, q_norm and k_norm, inside each attention block . Qwen2Loader maps only the four projection matrices and a single pre-attention layer-norm; there is nowhere in the current TransformerDecoderModelSpec to store the extra γ vectors, and they cannot be fused into the linear weights because RMSNorm is nonlinear.
- Optional biases are harmless. The new attention_bias flag just toggles whether each nn.Linear carries a bias. set_linear already copies a bias when it exists, so no change is needed there.
- Head-dim field likewise. head_dim is explicitly set in Qwen3Config but defaults to the same value hidden_size // num_attention_heads. Omitting it in the spec is benign.
The additional RMSNorm operations are functionally significant (they scale Q and K before the dot-product). In short, Qwen3 adds query and key normalization layers (q_norm and k_norm) within the attention mechanism that don't exist in Qwen2.
Ctranslate2 modifications needed?
After examining attention_spec.py, transformer_spec.py, attention.h, and attention_layer.h, the following modifications are needed.
Looking at the current Qwen2Loader.set_decoder() method:
split_layers = [common_spec.LinearSpec() for _ in range(3)]
self.set_linear(split_layers[0], layer.self_attn.q_proj) # Direct access
self.set_linear(split_layers[1], layer.self_attn.k_proj) # Direct access
self.set_linear(split_layers[2], layer.self_attn.v_proj) # Direct accessThis code expects to directly access the linear projection layers. However, in Qwen3, these projections are wrapped with normalization:
# Qwen3 forward pass
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape))
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape))Based on the attention_spec.py and transformer_spec.py files, the current ctranslate2 system does NOT support query/key normalization. Looking at MultiHeadAttentionSpec in attention_spec.py:
def __init__(self, ...):
self.queries_scale = model_spec.OPTIONAL
self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) # Only ONE layer norm
self.linear = [
common_spec.LinearSpec() for _ in range(2 if self_attention else 3)
]The attention spec only has:
- One layer_norm (the standard pre-attention layer norm)
- Standard linear layers (Q, K, V projections)
- No query/key normalization support
Qwen3's architecture requires:
# Qwen3 forward pass
query_states = self.q_norm(self.q_proj(hidden_states)) # ← q_norm not supported
key_states = self.k_norm(self.k_proj(hidden_states)) # ← k_norm not supportedBut ctranslate2's attention mechanism expects:
# Current ctranslate2 expectation
linear[0] = fused [Q, K, V] projections # Direct projections, no normalizationFurthermore, examining the AttentionLayer class in attention_layer.h
protected:
const std::vector<Dense> _linear; // Q, K, V projections
const std::unique_ptr<const LayerNorm> _layer_norm; // Only ONE layer normThus, ctranslate2 currently has:
- One _layer_norm: The standard pre-attention layer normalization
- _linear vector: The Q, K, V projection layers
- No query/key normalization support
For Qwen3's architecture, you would need to add:
// Additional members needed for Qwen3
const std::unique_ptr<const LayerNorm> _query_norm;
const std::unique_ptr<const LayerNorm> _key_norm;And modify the attention computation to:
- Apply Q, K, V projections
- Apply query normalization to Q
- Apply key normalization to K
- Continue with standard attention computation
Summarized differently...
- C++ Backend Changes: Modify the attention implementation in include/ctranslate2/layers/attention.h to support query/key normalization
- Add q_norm and k_norm fields to MultiHeadAttentionSpec
- Create a Qwen3Loader that handles the normalization layers
Can't Qwen2 be used because query/key normalization doesn't significantly affect inference, just training?
No.
- The original QK Norm paper shows a difference during inference, not just training.
- This and this source agree.
- Research regarding Meta's Chameleon models confirms this as well.
Further details
According to modeling_qwen3.py, this "norm" is "equivalent to T5LayerNorm:
@use_kernel_forward_from_hub("RMSNorm")
class Qwen3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"