Skip to content

Long-context fixes: RoPE precision + Gemma chunked-prefill sliding-window cache#688

Closed
ncylich wants to merge 2 commits into
long-ctx-attnfrom
longctx-rope-gemma-fix
Closed

Long-context fixes: RoPE precision + Gemma chunked-prefill sliding-window cache#688
ncylich wants to merge 2 commits into
long-ctx-attnfrom
longctx-rope-gemma-fix

Conversation

@ncylich

@ncylich ncylich commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator

Summary

Two independent long-context correctness fixes, each validated against the MLX reference (which stayed correct where the Cactus engine diverged).

1. RoPE precision (transpiler)

The transpiled graph computed the RoPE angle (position * inv_freq) in fp16, which cannot represent absolute positions > 2048 exactly. Past ~6k context the angle error reaches several radians and randomises cos/sin, corrupting generation. A new optimize_graph pass precomputes cos/sin offline in fp64, stores them as fp16 constant tables, and gathers them by position id — no new kernels, no fp32 runtime path. Gated to the text decoder (vision/audio encoder rope untouched); resolves each cos/sin node's own inv_freq so multi-theta models (Gemma sliding vs global) work.

Before: Qwen3/Gemma generation garbles past ~6k. After: matches MLX through 16k.

2. Gemma chunked-prefill sliding-window cache (engine)

run_chunked_prefill pads the final partial chunk to chunk_size; global layers trim the padding from the KV cache but sliding-window layers do not — the padding evicts real recent tokens and shifts the sliding window, corrupting deep-context attention. Fix: skip tail padding when the prefill graph has a sliding-window KV cache and process the tail token-by-token (already exact). Qwen (no sliding window) unaffected.

Before: Gemma 8k needle retrieval 3/9 (deep needles digit-garbled). After: 7/9 (≈ MLX's 8/9).

Verification

  • test_precompute_rope_tables.py: dual-theta fp64 table correctness, cat-duplication layout, position binding, Qwen regression, and vision/lm-encoder inertness (skips when transpiled bundles are absent).
  • Coherence probe matches MLX BLUE-42-FALCON at 2k/8k/16k on Qwen3-1.7B and Gemma-4-e2b.
  • Gemma NIAH 8k: depth-50 0/2 -> 2/2 exact; depths 10/50/90 = 7/9. Chunked-prefill speed retained (~200 prefill tps).

Note

The RoPE table is sized to max_cache_seq_len. Absolute positions beyond it (rolling/streaming cache past the cache window) raise a fail-loud out-of-bounds error rather than silently corrupting — that regime produced garbage before this fix.

ncylich added 2 commits June 4, 2026 11:00
The transpiled graph computed the RoPE angle (position * inv_freq) and ran it
through fp16, which cannot represent absolute positions > 2048 exactly. Past
~6k context the angle error reaches several radians and randomises cos/sin,
corrupting generation (verified against the MLX reference, which stays correct
through 16k).

Add an optimize_graph pass that precomputes cos/sin offline in fp64,
materialises them as fp16 constant tables, and gathers them by position id,
keeping the precision-critical angle off the runtime fp16 path with no new
kernels. Gated to the text decoder so it does not touch vision/audio encoder
rope. Handles models with multiple rope thetas (e.g. Gemma sliding vs global
layers) by resolving each cos/sin node's own inv_freq.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
run_chunked_prefill pads the final partial chunk up to chunk_size to fit the
fixed-size chunk graph. Global-attention layers trim the padding tokens from
the KV cache, but sliding-window layers do not: the padded chunk evicts real
recent tokens and shifts the entire sliding window by the padding count,
corrupting deep-context attention (e.g. Gemma needle retrieval past ~2k).

Disable tail padding when the prefill graph has a sliding-window KV cache and
process the tail token-by-token via run_step, which already produces an exact
cache. Detect via a new get_node_window_size accessor reading window_size off
the KV_CACHE_STATE nodes. Models without sliding-window caches (e.g. Qwen) are
unaffected.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
@ncylich

ncylich commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator Author

Superseded by #686, which now carries these changes (RoPE precompute-table fix + Gemma chunked-prefill sliding-window fix) alongside the KeyDiff / rolling-KV-compaction work. Consolidating into #686 as the umbrella PR.

@ncylich ncylich closed this Jun 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant