Long-context fixes: RoPE precision + Gemma chunked-prefill sliding-window cache#688
Closed
ncylich wants to merge 2 commits into
Closed
Long-context fixes: RoPE precision + Gemma chunked-prefill sliding-window cache#688ncylich wants to merge 2 commits into
ncylich wants to merge 2 commits into
Conversation
The transpiled graph computed the RoPE angle (position * inv_freq) and ran it through fp16, which cannot represent absolute positions > 2048 exactly. Past ~6k context the angle error reaches several radians and randomises cos/sin, corrupting generation (verified against the MLX reference, which stays correct through 16k). Add an optimize_graph pass that precomputes cos/sin offline in fp64, materialises them as fp16 constant tables, and gathers them by position id, keeping the precision-critical angle off the runtime fp16 path with no new kernels. Gated to the text decoder so it does not touch vision/audio encoder rope. Handles models with multiple rope thetas (e.g. Gemma sliding vs global layers) by resolving each cos/sin node's own inv_freq. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
run_chunked_prefill pads the final partial chunk up to chunk_size to fit the fixed-size chunk graph. Global-attention layers trim the padding tokens from the KV cache, but sliding-window layers do not: the padded chunk evicts real recent tokens and shifts the entire sliding window by the padding count, corrupting deep-context attention (e.g. Gemma needle retrieval past ~2k). Disable tail padding when the prefill graph has a sliding-window KV cache and process the tail token-by-token via run_step, which already produces an exact cache. Detect via a new get_node_window_size accessor reading window_size off the KV_CACHE_STATE nodes. Models without sliding-window caches (e.g. Qwen) are unaffected. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Collaborator
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two independent long-context correctness fixes, each validated against the MLX reference (which stayed correct where the Cactus engine diverged).
1. RoPE precision (transpiler)
The transpiled graph computed the RoPE angle (
position * inv_freq) in fp16, which cannot represent absolute positions > 2048 exactly. Past ~6k context the angle error reaches several radians and randomisescos/sin, corrupting generation. A newoptimize_graphpass precomputescos/sinoffline in fp64, stores them as fp16 constant tables, and gathers them by position id — no new kernels, no fp32 runtime path. Gated to the text decoder (vision/audio encoder rope untouched); resolves eachcos/sinnode's owninv_freqso multi-theta models (Gemma sliding vs global) work.Before: Qwen3/Gemma generation garbles past ~6k. After: matches MLX through 16k.
2. Gemma chunked-prefill sliding-window cache (engine)
run_chunked_prefillpads the final partial chunk tochunk_size; global layers trim the padding from the KV cache but sliding-window layers do not — the padding evicts real recent tokens and shifts the sliding window, corrupting deep-context attention. Fix: skip tail padding when the prefill graph has a sliding-window KV cache and process the tail token-by-token (already exact). Qwen (no sliding window) unaffected.Before: Gemma 8k needle retrieval 3/9 (deep needles digit-garbled). After: 7/9 (≈ MLX's 8/9).
Verification
test_precompute_rope_tables.py: dual-theta fp64 table correctness, cat-duplication layout, position binding, Qwen regression, and vision/lm-encoder inertness (skips when transpiled bundles are absent).BLUE-42-FALCONat 2k/8k/16k on Qwen3-1.7B and Gemma-4-e2b.Note
The RoPE table is sized to
max_cache_seq_len. Absolute positions beyond it (rolling/streaming cache past the cache window) raise a fail-loud out-of-bounds error rather than silently corrupting — that regime produced garbage before this fix.