Long-context: rolling KeyDiff KV-cache compaction + chunked prefill#686
Merged
Conversation
Plain text causal LLM families (e.g. qwen3) previously transpiled to only a monolithic decoder + decoder_step, so the engine took the DIRECT_DECODER_STEP route and prefilled token-by-token (~27 tok/s). Emit the chunked-prefill component pipeline (lm_encoder_step, lm_encoder_text_chunk, decoder_media_step, decoder_prefill_chunk) for the qwen3 family via text-only embed/chunk adapters that mirror the multimodal qwen3.5 ones (minus vision), so the existing chunked-prefill engine route engages. ~5x faster prefill (27 -> 130-300 tok/s) with identical generation output. The default transpile (decoder + decoder_step) and multimodal families are unchanged. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Collaborator
Author
|
Rolled additional long-context work into this PR (3 commits on top of KeyDiff):
|
b3fb568 to
1279d87
Compare
Compute-once-at-prefill KV compression: per (layer, kv-head), keep a budget of tokens = attention sink + recent window + the most distinctive middle tokens by the KeyDiff key-geometry score (s_i = -cos(k_i, mean key)), then physically compact the survivors and renumber their RoPE positions to a contiguous window. Two modes: one-shot (compact once at end of prefill) and rolling bounded cache (compact to target_len every time the cache reaches trigger_len). Default OFF (exact no-op); enabled via config or the CACTUS_KV_COMPRESS / CACTUS_KV_COMPRESS_ROLL env overrides. Supported on all-global-layer models (Qwen3); subset compression on mixed/KV-shared architectures is rejected with a warning (it would need per-layer positions). FP16 and INT8 caches handled. 22 unit/integration tests cross-check the keep-set math bit-for-bit against the reference and verify compaction/renumber/dense-check/rolling. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The transpiled graph computed the RoPE angle (position * inv_freq) and ran it through fp16, which cannot represent absolute positions > 2048 exactly. Past ~6k context the angle error reaches several radians and randomises cos/sin, corrupting generation (verified against the MLX reference, which stays correct through 16k). Add an optimize_graph pass that precomputes cos/sin offline in fp64, materialises them as fp16 constant tables, and gathers them by position id, keeping the precision-critical angle off the runtime fp16 path with no new kernels. Gated to the text decoder so it does not touch vision/audio encoder rope. Handles models with multiple rope thetas (e.g. Gemma sliding vs global layers) by resolving each cos/sin node's own inv_freq. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
run_chunked_prefill pads the final partial chunk up to chunk_size to fit the fixed-size chunk graph. Global-attention layers trim the padding tokens from the KV cache, but sliding-window layers do not: the padded chunk evicts real recent tokens and shifts the entire sliding window by the padding count, corrupting deep-context attention (e.g. Gemma needle retrieval past ~2k). Disable tail padding when the prefill graph has a sliding-window KV cache and process the tail token-by-token via run_step, which already produces an exact cache. Detect via a new get_node_window_size accessor reading window_size off the KV_CACHE_STATE nodes. Models without sliding-window caches (e.g. Qwen) are unaffected. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
KeyDiff compaction now defaults on (trigger 4096 -> target 2048) for causal-LM generation, bounding the KV cache without an env var. CACTUS_KV_COMPRESS_AT (trigger) and CACTUS_KV_COMPRESS_TO (target) override the defaults independently; CACTUS_KV_COMPRESS_AT=0 disables. The one-shot budget-fraction mode (never shipped) is removed. STT models are unaffected -- they transcribe via a separate API path that never reaches the compaction hook; sliding-window models cleanly no-op via the existing all-layers guard. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
1279d87 to
aa16cf3
Compare
…only)
Make rolling KV compaction work for Gemma-style hybrid models (interleaved
sliding-window + global layers) entirely in the engine, on a single position
frame, with no graph/transpiler change (works on existing bundles).
At each compaction, compress_kv_cache_keydiff now runs two passes over the
decoder cache states:
- Pass 1 (global/full-attention layers): KeyDiff compact + renumber to
0..B-1, as before.
- Pass 2 (sliding-window layers): rotate the recent K rows [sink_size,
current_seq_len) in place by -Δ (Δ = old frontier - B) using the layer's
LOCAL rope theta, so the window tracks the renumbered global frontier.
Sink rows stay fixed; V is never rotated. current_seq_len is untouched
(graph eviction already bounds it).
RoPE is relative, so a uniform shift preserves all query·key offsets; a single
shared position_ids = B then serves global (renumbered) and sliding (shifted)
caches alike. The old all-global guard is replaced by global-subset compaction.
New pure free functions in kv_compress: rerope_recent_fp16/int8 and the int8
helper rotate_int8_row; compact_int8's dequant->rotate->requant is factored
into a shared requant_row (no behavior change). Tests T1-T6 cover the fp16/int8
re-rope, the int8 refactor (bitwise), the combined sliding+global rolling
invariant over 3 cycles, the local-theta requirement, and the no-op guards.
Verified: 17/17 free-function tests; model_loading/llm/vlm green; live Gemma
gemma-4-e2b-it stays coherent across ~8 in-decode compaction cycles (no
re-transpile); Qwen all-global path unchanged.
Signed-off-by: Noah Cylich <noahcylich@gmail.com>
995f37e to
165ed7c
Compare
Brings the text-only chunked-prefill adapters (Qwen3LMEncoderStep/TextChunk, Qwen3EmbedsCausalLMPrefillChunk) into this branch so text qwen3 bundles can transpile the chunked-prefill component pipeline (~5x faster prefill, identical output), combined with this branch's RoPE-precompute + KV-compaction work. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Builds on the merged #687 text chunked-prefill adapters: make the chunked component pipeline (lm_encoder_step, lm_encoder_text_chunk, decoder_prefill_chunk, decoder_media_step, decoder_step) the DEFAULT for text-only Qwen3ForCausalLM conversions, instead of the monolithic decoder/decoder_step. Gated on the Qwen3ForCausalLM architecture (the chunk adapters are qwen3-only); other families and multimodal paths are unchanged. ~5x faster prefill, identical output, and avoids the stale-bundle FP16-legalization crash on re-convert. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
compress_kv_cache_keydiff iterated all of decoder_->cache_states and classified layers as compressible attention purely by the "sliding" substring in layer_types. Hybrid models (LFM2: conv + gated-deltanet) bind their conv and recurrent cache states into the SAME list, so those were KeyDiff-compacted as if they were attention KV -> state corruption and out-of-bounds heap r/w (the kv_heads/head_dim==0 guard is bypassed because conv metadata lands nonzero in those header slots). Compaction is default-on, so this corrupted any hybrid model past the trigger. Fix: in both compaction passes, skip any cache state whose key node op type is not KV_CACHE_STATE (the same op-type discrimination the engine already uses for recurrent/conv caches in run_chunked_prefill). Conv and gated-deltanet caches are now left untouched; pure-attention and sliding+global models are unaffected. Verified: LFM2 (lfm2.5-vl-1.6b) now generates coherently with compaction firing (was <|pad|> garbage at the trigger); Qwen all-global and Gemma sliding+global unchanged; kv unit tests 17/17. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Bring the long-context / KV-compaction / chunked-prefill work up to date with the latest main (v2 merge, jax transpiler path, kernel refactors, build fixes). Signed-off-by: Noah Cylich <noahcylich@gmail.com>
f888366 to
7427f59
Compare
Compaction pass 2 re-roped every non-compressible layer's recent keys with the local (sliding) theta, assuming the complement of the compressible set is all sliding-window layers. That is false for KV-sharing hybrids (Gemma): a global layer that is a KV-shared source is excluded from compaction yet is not sliding, so it was re-roped with the wrong theta -- desyncing the cache that the shared global layers attend to. Pick the re-rope theta from each layer's own type via is_sliding_layer (shared with physical_compressible_layers' is_full check) instead of from compressibility: sliding -> local theta, full-attention -> global theta. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
validate_kv_compress zeroed trigger/target on an invalid config but left kv_compress=true, inconsistent with the explicit-disable path in parse_kv_compress_override. Set kv_compress=false here too so both disable paths leave the same state. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The MLX-fixture consumer test (test_kv_compress.cpp) and its generated fixtures were research artifacts. The self-contained free-functions test covers the KV-compress math without them. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Cut comments that restate self-evident code; keep only non-obvious WHY (banker's rounding edge case, per-layer re-rope theta source-layer subtlety, fp16 rope-angle precompute rationale). Signed-off-by: Noah Cylich <noahcylich@gmail.com>
…test The standalone suite depended on gitignored transpiled weight bundles (skipped on CI). Replace it with a self-contained synthetic-graph test in test_optimize_gemma4_attention.py asserting precompute_rope_tables rewrites the runtime cos angle to an fp16 embedding-table lookup keyed by the position input, with positions past 2048 representable. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The s_i formula is already on the header declaration; keep only the double-accumulation rationale. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The old test_kv_compress.cpp (the MLX-fixture consumer) was removed, freeing the canonical name for the self-contained unit tests. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Trim multi-line comments to two focused lines and drop low-value restatements. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
cmd_run forwards --confidence-threshold/--cloud-timeout-ms/--no-cloud-handoff (features on this branch), but the full-context test's args Namespace omitted them, raising AttributeError after the two were merged. Add the attributes. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The LFM2 decoder_step path was the last component still deriving max_cache_seq_len from the capture prompt (max(1024, prompt+512)) rather than the model context; its chunked sibling already used the resolver. Route it through _max_cache_seq_len so LFM2 reserves the full context like the other families. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Full-context rope-table precompute raises when a component has neither max_cache_seq_len nor max_position_embeddings in its meta. Non-cached components (e.g. LFM2's full-context decoder) set neither, so re-transpiling multimodal bundles failed at precompute. Seed common graph meta with the model's max_position_embeddings as the rope-table fallback. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
steal_cache_buffer asserted the destination and source buffers share a precision, but the destination is the step node's not-yet-executed baked buffer (INT8) while the source is the runtime prefill buffer, which is FP16 under CACTUS_KV_CACHE_FP16. The move replaces the destination buffer wholesale, so its precision follows the source; only op_type is invariant. This aborted every fp16-cache run at the handoff. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Keep only the non-obvious rationale (the fp16 precision invariant in the cache move, the padded-prefill re-run, the rope-table meta fallback). Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Drop the declaration comments that restate the signature and tighten the rest to one line, keeping only what the code can't express: the KeyDiff score, the POST-RoPE rotation convention, the K/V renumber semantics, and the per-family compressible-layer selection. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
optimize_graph baked the precomputed rope tables into the IR it returns, so the component pipeline's saved optimized_ir.json carried them too -- which the rope-table tests can't read back (the values live in graph.cactus, not the JSON). Gate precompute behind a flag and run it on the deepcopy the component pipeline lowers, leaving the saved optimized IR un-baked. graph.cactus is still baked; runtime is unchanged. Other transpile paths keep baking via the default. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The prefill->decode cache move iterated key then value node ids; conv and recurrent caches serialize one node as both, so the second steal_cache_buffer moved the already-emptied source over the destination and blanked the buffer. Dedup the move so a shared node is transferred once. Also preflight-skip KeyDiff compaction when any compressible layer's V head dim differs from its K head dim (MLA-style): compaction strides both by one head_dim and would otherwise corrupt the value rows. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Special-token protect was built from the head-0-anchored cache_token_ids_ and applied as global positions, so after the first rolling compaction (which keeps per-head-divergent rows) non-head-0 heads protected the wrong rows and could drop a mid-context special across cycles. Track special rows per (compressible layer, head) in a SpecialRowTracker, rebuilt lazily from the still-accurate appended region of cache_token_ids_ and remapped through each head's keep set after compaction. Feed a per-head protect into keepsets so every head reserves its own specials; KeyDiff scoring stays per-head. Add a specials-fit preflight that skips the pass rather than letting a special be truncated, and disable tracking once an untracked compaction (e.g. media) diverges the heads. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
keepsets received special_rows_.protect(li) unconditionally, so a compaction with per-head tracking off (media, preserve off, or post-invalidate) would still override Params::protect with stale tracker rows. Pass empty protect unless per_head_protect, clear the rows on invalidate(), and test the Params::protect fallback. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Compaction renumbers the cache, which the Gemma4 thinking-token strip can't follow, so previous-turn thoughts leak once a compaction has run. Until the strip is compaction-aware, suppress compaction for thinking requests (set in do_prefill, checked in maybe_roll_compact); non-thinking and other models are unaffected. Gemma4 is the only model that strips prior thoughts. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The fast prefill_and_sample_first_token path (text-only, empty cache) bypasses do_prefill, where compaction_suppressed_ was set, so Gemma4 thinking requests on that path still compacted. Set the flag before the fast path as well. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The growable cache only ever grew to peak occupancy and was never reclaimed, so a long prefill (moved into decode) left the buffer over-allocated at prompt length even after compaction dropped occupancy to target. After compacting each layer, shrink its buffer to the smallest power of two >= trigger, which holds the decode occupancy oscillation [target, trigger] without re-growing. Factor a shared resize_cache_buffer used by grow and the new shrink_cache_buffer. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Stop stripping generated thinking for Gemma-4 so it persists across turns as ordinary context (Qwen already does this). - Remove the in-cache strip path: strip_thinking_from_cache, Model::remove_thinking_tokens, find_channel_token_ranges, and the now-unread cache_renumbered_ member/writes. - Remove the compaction-suppression stopgap (compaction_suppressed_, set_compaction_suppressed, set sites, and the maybe_roll_compact guard) so compaction runs normally for thinking context. - format_gemma4_style keeps assistant content verbatim (drop strip_channel) so prior-turn thinking stays in the rendered prompt. - Rename strip_thinking_block -> partition_thinking_response: it now only splits generated text into (thinking, content) for the API output. - Emit a non-user-facing context_response (raw assistant text before partitioning) in the result JSON for conversation-history persistence. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
- chat.cpp and server.py store context_response (raw assistant text with thinking) as conversation history while still displaying the clean response, so thinking survives the next turn's re-render. - Rewrite test_gemma4_thinking for the new contract: the formatter retains assistant channel/thinking content, the API response stays clean, and re-rendered multi-turn history prefix-matches the cache. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
server.py returned the raw context_response (channel tags) as public message content; return the clean response (context_response is only for stateful chat history). Delete the superseded v1 tests/test_gemma4_thinking.cpp, which still asserted the removed strip contract; the active test lives in cactus-engine/tests. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Signed-off-by: Noah Cylich <noahcylich@gmail.com>
find_channel_token_ranges was the only reader of channel_open/close_token_id; drop the now-unused config fields and their parse. Revert the one-line server.py refactor to the original inline form. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
e992f75 to
dfbefff
Compare
…urns Qwen's chat template injects a thinking opener (<think>\n, or <think>\n\n</think>\n\n when thinking is disabled) only in the generation-prompt branch. The model continues from that opener, so the opener tokens are stored in the KV cache as part of the assistant turn. Re-rendering that turn as history did not reproduce the opener, so the next turn's prompt diverged from the cached tokens at the assistant boundary, prompt_context_matches failed, and every turn did a full re-prefill. Emit the same opener per assistant history turn (selected from whether the content already closed </think>) so the re-templated history byte-matches the cache. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Signed-off-by: Noah Cylich <noahcylich@gmail.com> # Conflicts: # python/cactus/cli/__init__.py # python/cactus/cli/convert.py # python/cactus/cli/model.py # python/tests/test_cli_run.py # python/tests/test_cli_transpile_defaults.py # tests/ios/CactusTest/CactusTest/AppDelegate.mm
Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Standalone manual benchmark harness, not part of the test suite or referenced anywhere; out of scope for the PR. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Move the four gemma4 thinking tests into test_llm.cpp (next to the prefill cache-reuse tests) and delete the standalone suite. The three model-gated tests run only when the chosen model is Gemma4 and warn-skip otherwise; partition_thinking_response is a pure unit test that always runs. Share a load_gemma4_or_skip helper instead of repeating the gate. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Umbrella PR for the long-context work: bounded rolling KV-cache compaction that works across architectures, plus the chunked-prefill and RoPE-precision fixes it depends on. Rolling compaction is default-on for causal LMs — when the live KV cache reaches the trigger length it KeyDiff-compacts back to the target length (sink + recent + distinctive-middle) and renumbers, so the cache stays bounded through long decodes.
What's included
KeyDiff KV-cache compression (
cactus-engine/src/kv_compress.{h,cpp})s_i = -cos(k_i, mean(k))) so the keep-set is position-frame-independent.kv_compress_trigger_len/kv_compress_target_len(andkv_compress_recent_frac/kv_compress_sink), overridable at runtime withCACTUS_KV_COMPRESS_AT/CACTUS_KV_COMPRESS_TO(AT=0disables). Invalid rolling configs disable the flag instead of corrupting the cache.Sliding-window re-rope (pure engine, works on existing bundles)
0..B-1; sliding-window layers rotate their recent keys in place by-Δusing the local RoPE theta. Since RoPE is relative, a single compressed position frame serves both — no graph/transpiler/manifest change. Enables Gemma-style sliding+global hybrids to compact correctly. KV-shared global source layers are re-roped with the global theta.Hybrid-cache safety (LFM2 etc.)
KV_CACHE_STATE, so conv/recurrent caches in hybrid models are left untouched instead of being misclassified as attention KV and corrupted.Special-token / thinking-token handling
On-demand KV cache + prefill handoff
Chunked prefill
cactus-graphwindow-size plumbing + the engine's chunked-prefill path).Qwen3ForCausalLMarchitecture so other families stay on the monolithic decoder.RoPE precision (long-context correctness)
optimize_graph.precompute_rope_tables) so positions past the fp16-angle range no longer garble;max_position_embeddingsis carried in common graph meta to size the table.Testing
test_kv_compress.cpp: keep-set selection, fp16/int8 compaction, int8 refactor bit-equivalence, sliding+global rolling over multiple cycles, local-theta load-bearing, no-op guards.python/tests/test_optimize_gemma4_attention.py,test_cli_transpile_defaults.py,test_cli_run.py,test_bindings.pyfor the transpiler/CLI wiring.transpile/tools/benchmark_gemma4_context_scaling.pyfor Gemma-4 context-scaling measurement.