Skip to content

Qwen3 support please?? #1902

@BBC-Esq

Description

@BBC-Esq

Unfortunately I don't have the time nor technical expertise to do it, but here's a plethora of research that should allow someone with minimal training get a good start at it:

Compared to Qwen2

First, within the transformers library, configuration_qwen2.py and [modeling_qwen2.py] from the Qwen2 architecture and configuration_qwen3.py and modeling_qwen3.py from the Qwen3 architecture, it appears that the only differences are:

  • Qwen-3 adds self.q_norm and self.k_norm and applies them to the outputs of q_proj and k_proj ; Qwen-2 has no such step.
  • All four projection layers now take bias=config.attention_bias instead of hard-coded True/False (compare to Qwen-2’s fixed biases ).
  • head_dim (default 128) simply exposes the per-head dimension that Qwen2 already infers as hidden_size // num_attention_heads, so the default tensor shapes are unchanged .
  • attention_bias (bool, default False) drives the bias=config.attention_bias switch that I pointed out in the updated attention block; Qwen2 keeps those biases hard-wired .
  • Everything else is copy-paste identical. Vocabulary size, layer counts, RoPE settings, sliding-window options, etc. match across the two configuration files .

In summary, head_dim defaults to the same value as before and attention_bias is the mechanism already discussed, there are no further architectural changes to add.

Qwen2 Loader Can't be Re-used

  • Qwen-3 introduces two additional RMSNorm layers applied to the Q and K projections, q_norm and k_norm, inside each attention block . Qwen2Loader maps only the four projection matrices and a single pre-attention layer-norm; there is nowhere in the current TransformerDecoderModelSpec to store the extra γ vectors, and they cannot be fused into the linear weights because RMSNorm is nonlinear.
  • Optional biases are harmless. The new attention_bias flag just toggles whether each nn.Linear carries a bias. set_linear already copies a bias when it exists, so no change is needed there.
  • Head-dim field likewise. head_dim is explicitly set in Qwen3Config but defaults to the same value hidden_size // num_attention_heads. Omitting it in the spec is benign.

The additional RMSNorm operations are functionally significant (they scale Q and K before the dot-product). In short, Qwen3 adds query and key normalization layers (q_norm and k_norm) within the attention mechanism that don't exist in Qwen2.

Ctranslate2 modifications needed?

After examining attention_spec.py, transformer_spec.py, attention.h, and attention_layer.h, the following modifications are needed.

Looking at the current Qwen2Loader.set_decoder() method:

split_layers = [common_spec.LinearSpec() for _ in range(3)]
self.set_linear(split_layers[0], layer.self_attn.q_proj)  # Direct access
self.set_linear(split_layers[1], layer.self_attn.k_proj)  # Direct access
self.set_linear(split_layers[2], layer.self_attn.v_proj)  # Direct access

This code expects to directly access the linear projection layers. However, in Qwen3, these projections are wrapped with normalization:

# Qwen3 forward pass
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape))
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape))

Based on the attention_spec.py and transformer_spec.py files, the current ctranslate2 system does NOT support query/key normalization. Looking at MultiHeadAttentionSpec in attention_spec.py:

def __init__(self, ...):
    self.queries_scale = model_spec.OPTIONAL
    self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)  # Only ONE layer norm
    self.linear = [
        common_spec.LinearSpec() for _ in range(2 if self_attention else 3)
    ]

The attention spec only has:

  • One layer_norm (the standard pre-attention layer norm)
  • Standard linear layers (Q, K, V projections)
  • No query/key normalization support

Qwen3's architecture requires:

# Qwen3 forward pass
query_states = self.q_norm(self.q_proj(hidden_states))  # ← q_norm not supported
key_states = self.k_norm(self.k_proj(hidden_states))    # ← k_norm not supported

But ctranslate2's attention mechanism expects:

# Current ctranslate2 expectation
linear[0] = fused [Q, K, V] projections  # Direct projections, no normalization

Furthermore, examining the AttentionLayer class in attention_layer.h

protected:
  const std::vector<Dense> _linear;  // Q, K, V projections
  const std::unique_ptr<const LayerNorm> _layer_norm;  // Only ONE layer norm

Thus, ctranslate2 currently has:

  1. One _layer_norm: The standard pre-attention layer normalization
  2. _linear vector: The Q, K, V projection layers
  3. No query/key normalization support

For Qwen3's architecture, you would need to add:

// Additional members needed for Qwen3
const std::unique_ptr<const LayerNorm> _query_norm;
const std::unique_ptr<const LayerNorm> _key_norm;

And modify the attention computation to:

  1. Apply Q, K, V projections
  2. Apply query normalization to Q
  3. Apply key normalization to K
  4. Continue with standard attention computation

Summarized differently...

  • C++ Backend Changes: Modify the attention implementation in include/ctranslate2/layers/attention.h to support query/key normalization
  • Add q_norm and k_norm fields to MultiHeadAttentionSpec
  • Create a Qwen3Loader that handles the normalization layers

Can't Qwen2 be used because query/key normalization doesn't significantly affect inference, just training?

No.

  1. The original QK Norm paper shows a difference during inference, not just training.
  2. This and this source agree.
  3. Research regarding Meta's Chameleon models confirms this as well.

Further details

According to modeling_qwen3.py, this "norm" is "equivalent to T5LayerNorm:

@use_kernel_forward_from_hub("RMSNorm")
class Qwen3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen3RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

That's as far as I got...let me know.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions