Skip to content

Commit 4bc89dd

Browse files
authored
Fix RoPE convention for NEOX-style models in quantized_llama.rs (huggingface#3411)
quantized_llama.rs always used rope_i (interleaved RoPE), which pairs dimensions (2i, 2i+1). This is correct for standard Llama (rope_type=0) but wrong for NEOX-style architectures like Qwen2, Falcon, Phi, etc. (rope_type=2), which pair (i, i+d/2). The wrong dimension pairing corrupts attention patterns in every layer. Over 48 layers this compounds to +11.7 logit inflation on special tokens vs llama.cpp reference output, causing repetition loops and degenerate text. The fix reads general.architecture from GGUF metadata and dispatches to rope (non-interleaved) for NEOX-style models and rope_i (interleaved) for NORM-style models, matching llama.cpp's llama_model_rope_type(). After the fix, logits match llama.cpp to <0.01 precision across all 152K vocab tokens after 48 layers, with identical top-20 rankings. Fixes huggingface#3410
1 parent aff7c10 commit 4bc89dd

1 file changed

Lines changed: 35 additions & 3 deletions

File tree

candle-transformers/src/models/quantized_llama.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ struct LayerWeights {
154154
n_head: usize,
155155
n_kv_head: usize,
156156
head_dim: usize,
157+
/// RoPE convention: true = NEOX (non-interleaved, pairs i with i+d/2),
158+
/// false = NORM (interleaved, pairs 2i with 2i+1).
159+
/// Must match the model architecture — using the wrong convention corrupts
160+
/// attention patterns and causes severe output degradation.
161+
rope_is_neox: bool,
157162
cos: Tensor,
158163
sin: Tensor,
159164
neg_inf: Tensor,
@@ -175,9 +180,12 @@ impl LayerWeights {
175180
let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
176181
let cos = self.cos.narrow(0, index_pos, seq_len)?;
177182
let sin = self.sin.narrow(0, index_pos, seq_len)?;
178-
// The call to contiguous below is only necessary when processing the prompt.
179-
// When the seq_len is 1 in the inference loop, this is a no-op.
180-
candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin)
183+
let x = x.contiguous()?;
184+
if self.rope_is_neox {
185+
candle_nn::rotary_emb::rope(&x, &cos, &sin)
186+
} else {
187+
candle_nn::rotary_emb::rope_i(&x, &cos, &sin)
188+
}
181189
}
182190

183191
fn forward_attn(
@@ -333,6 +341,7 @@ impl ModelWeights {
333341
n_head: ct.hparams.n_head as usize,
334342
n_kv_head: ct.hparams.n_head as usize / gqa,
335343
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
344+
rope_is_neox: false, // GGML format = standard Llama = interleaved
336345
cos: cos.clone(),
337346
sin: sin.clone(),
338347
neg_inf: neg_inf.clone(),
@@ -383,6 +392,28 @@ impl ModelWeights {
383392
let rope_freq_base = md_get("llama.rope.freq_base")
384393
.and_then(|m| m.to_f32())
385394
.unwrap_or(10000f32);
395+
396+
// Determine RoPE convention from model architecture (matching llama.cpp).
397+
// NEOX (non-interleaved): pairs (i, i+d/2) — Qwen, Qwen2, Falcon, Phi, etc.
398+
// NORM (interleaved): pairs (2i, 2i+1) — Llama, Mistral, DeepSeek, etc.
399+
// See llama_model_rope_type() in llama.cpp for the authoritative mapping.
400+
let arch = ct
401+
.metadata
402+
.get("general.architecture")
403+
.and_then(|v| v.to_string().ok())
404+
.cloned()
405+
.unwrap_or_default();
406+
let rope_is_neox = matches!(
407+
arch.as_str(),
408+
"qwen" | "qwen2" | "qwen2moe" | "qwen3" | "qwen3moe"
409+
| "falcon" | "grok" | "dbrx"
410+
| "phi2" | "phi3" | "phimoe"
411+
| "stablelm" | "starcoder2"
412+
| "bert" | "nomic-bert" | "jina-bert-v2"
413+
| "olmo2" | "olmoe"
414+
| "codeshell" | "plamo"
415+
);
416+
386417
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
387418
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
388419

@@ -456,6 +487,7 @@ impl ModelWeights {
456487
n_head: head_count,
457488
n_kv_head: head_count_kv,
458489
head_dim: embedding_length / head_count,
490+
rope_is_neox,
459491
cos: cos.clone(),
460492
sin: sin.clone(),
461493
neg_inf: neg_inf.clone(),

0 commit comments

Comments
 (0)