Skip to content

Pad the chunked-prefill tail for sliding-window models#716

Merged
jakmro merged 13 commits into
mainfrom
prefill-tail-padding
Jun 12, 2026
Merged

Pad the chunked-prefill tail for sliding-window models#716
jakmro merged 13 commits into
mainfrom
prefill-tail-padding

Conversation

@ncylich

@ncylich ncylich commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Chunked-prefill tail padding for sliding-window models + chunked prefill for text-only LLMs

Problem

On sliding-window models (Gemma 4), prefill throughput sawtooths in prompt_length mod 128: the prompt remainder runs one token at a time through the step graph, because padding the tail chunk evicts real sliding-window cache rows — the root cause of the earlier chunked-prefill garbled-output bug. A 1023-token prompt pays 127 decode steps. Separately, text-only conversions of several families never emitted chunked-prefill components at all, so their prefill ran token-by-token from the start.

Fix

Run the remainder as one padded chunk through the existing 128 graph, made sliding-window-safe by a cache rollback:

  • The sliding caches compact chronologically on eviction, so a padded append just over-evicts up to pad_count of the oldest window rows. New graph API snapshot_cache_padded_append / rollback_cache_padded_append saves those rows before the padded execute and reinserts them after (also dropping the pad rows), leaving the persistent cache byte-identical to an unpadded prefill — proven by unit tests.
  • The chunk graph only emits the last row's logits (a pad's prediction), so the last real token is rolled back too and re-run through the step graph — the same mechanism the global-attention padding path already uses.
  • Applies to any sliding-window bundle with a chunked prefill graph, keyed on graph window metadata: no transpile or bundle changes — existing bundles speed up as-is. Guards skip too-small windows (sink-aware) and conv/recurrent caches. Kill switch: CACTUS_DISABLE_PREFILL_TAIL_PAD=1.
  • Tails of ≤8 tokens keep the scalar path: a padded 128-chunk costs more than a few decode steps, with measured breakeven at tail ≈9–10 (per the review benchmark sweep). Padding engages from tail 9 up.

Chunked prefill for text-only LLMs

  • Text-only conversions now honor the inferred component plan and emit the chunked-prefill pipeline (lm_encoder_step / lm_encoder_text_chunk / decoder_prefill_chunk / decoder_step) instead of silently falling back to token-by-token prefill; component plans are declared on model profiles, and the engine warns when it loads an LM bundle without decoder_prefill_chunk.
  • LFM2 text-only models get a chunked pipeline reusing the VL backbone-only adapters; the LFM2 text and VL builders share one spec builder (which also carries the decoder_embed_chunk component). LFM2-350M on Mac: prefill 134 → 565+ tps, decode unchanged.
  • Tail padding is gated on conv-cache presence (not family name): pads in a conv ring displace exactly the rows the conv kernel's lookback reads, with no clamp to remove them — measured corrupted output on LFM2-350M before the gate.
  • Gemma 3 chunked prefill: component-pipeline adapters (chunk/step/encoders) mirroring the Qwen text path, plus a gemma3 convert family with the fp16-range weight rescale and a proper <start_of_turn> chat template (gemma3 previously fell through to the gemma4 <|turn> format and produced degenerate chat output; raw continuation was always fine).
  • Prefill chunk size decoupled from the calibration prompt: the chunk components used to slice their trace input from the calibration prompt, so a default convert (short prompt) silently produced a 46-token chunk instead of 128. The trace input is now tiled up to the configured chunk size (gemma3, qwen, LFM2 builders).
  • completion_token_ids in the benchmark JSON, used by the new tests to assert padded-vs-scalar first-token equality and run-to-run determinism. The padded tail fills KV through the chunk graph while the scalar tail uses the step graph — numerically distinct at ~1 ULP — so full-sequence equality is margin-dependent on small models and the test asserts the first token plus telemetry, with the expected chunk size derived from telemetry rather than hardcoded.

Results (M4 Pro, warm, unmodified gemma-4-e2b-it bundle)

prompt scalar tail (kill switch) padded tail TTFT
1023 (worst case) 184 tok/s 352 tok/s 5548 → 2903ms
1024 (aligned) 372 tok/s 370 tok/s (unchanged) unchanged

Generated tokens are identical between the two modes (completion_token_ids byte-equal).

Verification (all local model families, LLM + VLM)

All bundles below are either pre-existing legacy bundles or freshly converted+transpiled with this branch's own pipeline, and every family was checked for real, coherent completions — not just suite pass/fail. (Earlier runs against packed-panel bundles from a kernel-format side branch turned out to load garbage on a mainline engine while still "passing" the content-free suite assertions; those bundles are excluded.)

  • Gemma 4 (gemma-4-e2b-it legacy bundle, untouched): LLM suite 14/14 incl. the padding test; VLM suite 4/4; 3-turn chat recalls earlier-turn facts; padded-tail benchmarks above on the unmodified bundle.
  • Gemma 3: google/gemma-3-270m-it converted+transpiled end-to-end with the default (short) calibration prompt — bundle still gets chunk=128 (decoupling fix), suite 11/14 with chunked_prefill_padding passing on the real model (the 3 failures are tool-call formatting, a 270M capability ceiling), "What is the capital of France?" → "The capital of France is Paris.", multi-turn recalls the user's name across turns. Hermetic tiny-model pipeline also re-run end-to-end.
  • Qwen3 (global attention, control family; legacy bundle): 13/14 with coherent outputs — all tool-call tests pass; the one failure (prefill warm-reuse) is a pre-existing main chat-template regression with a one-line fix that belongs in a separate PR.
  • LFM2 text (LFM2-350M converted+transpiled end-to-end): chunked pipeline emitted, conv-safe gating verified against the corrupting padded path; 13/14 with coherent outputs — the one failure is the 350M model answering a two-action prompt with a single tool call, identical upstream.
  • LFM2-VL: the unified text/VL spec builder is exercised by the LFM2 text path and emits the full VL component set including decoder_embed_chunk; end-to-end VL verification is blocked upstream — transformers ≥ 5.6 silently breaks LFM2-VL vision-weight loading on main (fix exists, belongs in a separate PR).
  • New unit tests: 4 graph-level rollback-invariant tests (padded append + rollback ≡ exact append byte-for-byte incl. eviction/overshoot/sink cases) and chunked_prefill_padding in test_llm (first-token equality vs scalar, telemetry-derived chunk size, determinism, kill switch). Cache suite 31/31; python convert tests 54/54.

Follow-ups (intentionally deferred): a small transpiled tail-chunk variant (e.g. 32 tokens) could shave ~85ms on remainders ≤32; the qwen prefill prefix-render fix; the LFM2-VL transformers ≥ 5.6 restore.

@ncylich

ncylich commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator Author

Ran a two-reviewer pass (adversarial re-derivation of the rollback math from the append code: confirmed correct; verdict approve). Hardening applied:

  • test_llm chunked_prefill_padding now asserts padded output == scalar-tail output (was determinism-only) — passes repeatedly at both test lengths.
  • Sink-aware padding guard: new CactusGraph::get_node_sink_size, and padded_tail_window_ok now also rejects window <= chunk + sink, so a large-sink bundle degrades to the scalar tail instead of reaching the snapshot's defensive throw.
  • Deduped the gemma3 fp16-range scale constant (reuses GEMMA4_WEIGHT_SCALE from naming.py; also removes the duplicate local in the gemma4 block) and noted the tied-lm_head assumption.
  • Removed the empty Gemma3LMEncoderTextChunkAdapter subclass; trimmed three oversized docstrings.

Declined with rationale: restoring cache bytes beyond current_seq_len after rollback (dead region by the same contract plain appends/evictions already rely on; no consumer reads past the length), and in-process fp16-KV test coverage (CACTUS_KV_CACHE_FP16 is latched at first use — would need a separate suite invocation).

Re-verified after the changes: cache suite 31/31, LLM suite 14/14 ×2 on the unmodified gemma-4 bundle, fresh tiny-gemma3 convert→transpile→engine pipeline, and real gemma-3-270m (padded vs scalar tokens match exactly; chat output correct).


Update (superseded in part by the review below): the full-sequence padded≡scalar assertion was later relaxed to first-token equality with the expected chunk size derived from telemetry — full-sequence equality is margin-dependent across the chunk/step graphs (see the thread with @kar-m). The padded_tail_window_ok helper was since inlined into the cache-node scan, and the padding policy gained a tail ≤8 scalar cutoff.

@ncylich ncylich force-pushed the prefill-tail-padding branch from 28619cb to b98b682 Compare June 11, 2026 06:59
@ncylich ncylich force-pushed the prefill-tail-padding branch 4 times, most recently from 04c4fe1 to 63ef7d1 Compare June 11, 2026 08:45
@cactus-compute cactus-compute deleted a comment from kar-m Jun 11, 2026
@ncylich ncylich changed the base branch from il-fused-gemv to main June 11, 2026 20:53
@ncylich ncylich force-pushed the prefill-tail-padding branch 2 times, most recently from e17562c to b971ac8 Compare June 11, 2026 21:32
@kar-m

kar-m commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Looks mostly good, but always padding to next 128 is not necessarily the best policy

M4 Pro, gemma-4-e2b-it window=512/chunk=128, median of 8, alternating reps

N tail padded tps scalar tps winner padded TTFT (ms) scalar TTFT (ms)
128 0 440.3 440.4 tie 291 291
130 2 264.9 371.7 scalar +40% 491 350
132 4 266.3 332.9 scalar +25% 496 398
134 6 268.7 302.9 scalar +13% 499 443
136 8 279.3 286.3 scalar +2.5% 488 475
138 10 274.2 259.6 padded +5.6% 504 532
142 14 279.1 227.9 padded +22% 509 623
144 16 284.2 217.0 padded +31% 507 664
152 24 292.0 179.3 padded +63% 521 848
160 32 301.6 155.9 padded +93% 531 1029
168 40 310.4 139.1 padded +123% 542 1209

Another thing is that as currently written the 16x activation scaling is applied to gemma3. Not necessarily bad or incorrect, but gemma3 doesnt have the same overflow issues as gemma4 and i don't know if it has any underflow issues that we may potentially hit.

@ncylich ncylich force-pushed the prefill-tail-padding branch from b971ac8 to b0a1c7b Compare June 12, 2026 02:26
@ncylich

ncylich commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator Author

Great catches — all addressed, plus your sweep surfaced a deeper issue. What changed:

Padding policy: adopted your breakeven. Tails ≤8 now keep the scalar path; padding engages from 9 up (tail_tokens > 8). Verified a 132-token prompt runs scalar_tail=4, pads=0.

Chunk-size pinning: fixed by decoupling, as a dedicated commit. The chunk components no longer slice their trace input from the calibration prompt — the input is tiled up to the configured chunk size, so a default convert (46-token prompt) now still produces chunk=128. Applied to the gemma3, qwen (both builders), and the shared LFM2 helper. Verified by converting+transpiling gemma-3-270m-it with the default prompt: telemetry shows chunk=128.

Test fragility: chunked_prefill_padding now derives the expected chunk size from telemetry instead of hardcoding % 128, and the cross-mode assertion is first-token equality + per-mode determinism instead of full-sequence equality — your token-4 divergence is exactly the chunk-graph-vs-step-graph ~1-ULP tie-flip, so full-sequence equality was a too-strong invariant. With these changes the test passes on the real 270m.

The actual gemma3 chat bug: while re-verifying real completions I found gemma3 bundles fell through to the gemma4 <|turn> chat template (the tokenizer's model-type sniff matched any "gemma" to GEMMA4), producing degenerate <|turn>model loops — your coherent "Brainrot is internet slang…" line is verbatim the assistant turn already in test_prefill's message history, i.e. an echo. Raw continuation was always fine ("The capital of France is" → " Paris. However, the city of Paris"). Fixed with a proper <start_of_turn>/<end_of_turn> formatter for the classic-gemma model type; real chat now works ("What is the capital of France?" → "The capital of France is Paris.") and multi-turn recalls earlier-turn facts.

On the 16× rescale: kept. Earlier experiments hit fp16-range overflow on the 270m residual stream without it, and I tensor-level verified the converted bundle matches the intended scheme exactly (norms in offset form bit-exact, gate/up ×16, tied embeddings ÷16, everything else 1×) with coherent output end-to-end through the legacy quantizer. No underflow symptoms in any of the verified outputs — if one shows up on another gemma3 size, the scale constant is a single shared knob.

@ncylich ncylich force-pushed the prefill-tail-padding branch from b0a1c7b to 5691574 Compare June 12, 2026 02:50
@ncylich

ncylich commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator Author

Closing the loop on the review — every raised item is addressed, re-verified end-to-end on the final branch state, and CI is fully green. Status:

Review item Resolution Verified by
Padding-to-128 not always optimal (benchmark sweep) Tails ≤8 stay scalar; padding engages from 9 (tail_tokens > 8), matching the measured breakeven 132-token prompt → scalar_tail=4, pads=0; 95/600 still take the padded path
Chunk size pinned to calibration-prompt length Decoupled in a dedicated commit: trace input is tiled up to the configured chunk size (gemma3, qwen ×2, shared LFM2 helper) Default-prompt convert+transpile of gemma-3-270m-it now yields chunk=128 (telemetry: 94+33+1)
chunked_prefill_padding fragile (hardcoded %128, too-strong token equality) Expected scalar count derived from telemetry; cross-mode assertion is first-token equality + per-mode determinism (chunk vs step graphs differ at ~1 ULP, so full-sequence equality is margin-dependent) Passes on real gemma-3-270m and gemma-4-e2b
16× rescale applied to gemma3 Kept — needed for fp16-range on the 270m; tensor-level audit confirms the converted bundle matches the scheme exactly (norms bit-exact in offset form, gate/up ×16, tied embeddings ÷16, rest 1×) Coherent end-to-end output through the legacy quantizer; no underflow symptoms
Degenerate gemma3 chat output Root-caused beyond the suite: gemma3 fell through to the gemma4 `< turn>template (the earlier "coherent" suite line was an echo oftest_prefill's own assistant turn). Added a <start_of_turn>/<end_of_turn>` formatter + stop for the classic-gemma model type

Final verification (all on freshly built bundles or untouched legacy bundles, with real-output checks): gemma4 LLM 14/14 + VLM 4/4 + 3-turn recall, gemma3-270m 11/14 with padding test passing, qwen 13/14 (sole failure is a pre-existing main template regression, fix queued as a separate PR), LFM2-350M 13/14, cache suite 31/31, convert tests 54/54. The diff itself was given a final pass: gemma3 adapter logic deduplicated behind shared backbone-forward/metadata helpers (mirroring the gemma4 pattern), zero added comments, 13 signed commits, +917/−198.

@ncylich ncylich force-pushed the prefill-tail-padding branch from 5691574 to 7eeeb44 Compare June 12, 2026 17:10
ncylich added 13 commits June 12, 2026 10:11
Lets benchmark runs compare generated tokens across engine/bundle variants
without a separate text-capture path.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The prompt remainder (L mod 128) ran one token at a time at decode speed
(~17.7ms/token on M4): a 1023-token prompt paid 127 decode steps, 3.3x
slower prefill than 1024 (gemma-prefill-bug.md). Global-attention models
already pad the tail chunk, but a pad through a sliding-window cache evicts
real window rows, which is why sliding bundles were excluded — the original
chunked-prefill garbled-output bug.

Run the remainder as one padded chunk instead. The sliding caches compact
chronologically on eviction, so a padded append just over-evicts up to
pad_count of the oldest window rows: snapshot them before the padded execute
and reinsert them after (rollback also drops the pad rows), leaving the
persistent cache state identical to an unpadded prefill. The chunk graph
only emits the last row's logits — a pad's prediction — so the last real
token is rolled back too and re-run through the step graph, which is the
same mechanism the global-attention padding path uses. Worst case tail cost
drops from 127 decode steps (~2.2s) to one padded chunk and one decode step
(~150ms), with no bundle changes: existing sliding-window bundles speed up
as-is.

The path applies to any sliding-window bundle with a chunked prefill graph;
it skips graphs whose window is too small for the pad bound and bundles with
recurrent or conv caches (no rollback support), and those plus the
global-attention and no-chunk paths behave exactly as before. Telemetry:
prefill_tail_chunk_tokens / prefill_tail_padding_tokens added;
prefill_scalar_tail_tokens is now at most 1 on the padded path. Kill switch
for A/B: CACTUS_DISABLE_PREFILL_TAIL_PAD=1. Note: cloud-handoff probe
hiddens are only captured in run_step, so probe rollouts see at most one
sample on such bundles.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Graph level: append(real + pads) followed by rollback must leave the cache
byte-identical to append(real) — covered for no-eviction, eviction with
overshoot, pads-only eviction, and empty-cache cases. Engine level: the
padded tail engages with correct telemetry, is deterministic, and the
CACTUS_DISABLE_PREFILL_TAIL_PAD kill switch restores the scalar tail.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Gemma3 previously only had a single full-graph causal-LM capture, so the
engine had no CACHED_STEP route and prefilled token-by-token. Mirror the
qwen text component set (lm_encoder_step, lm_encoder_text_chunk,
decoder_prefill_chunk, decoder_step) with internal KV caches and
per-layer-type rotary/masks. A family-neutral optimize pass generalizes
the gemma4-only hint assignment so sliding layers get
window_size = sliding_window on their KV cache states.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Gemma3 text checkpoints fell through to the generic family, so bundles
carried model_type=generic, which the engine treats as gemma4 and
rejects for missing gemma4-only config fields. Worse, gemma3's residual
stream peaks far above fp16 max (~103K on gemma-3-270m), so unscaled
weights produce garbage on the fp16-activation engine. Add a gemma3
family that keeps generic naming/policy but mirrors the gemma4 16x
residual rescale: embeddings (and the tied lm head) shrink 16x,
gate/up restore MLP scale, and the norm files fold the rescale through
gemma3's RMSNorm +1 offset, which the transpiled graph re-applies.
The engine maps model_type=gemma3 onto its non-strict GEMMA type.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
If an LM bundle provides no decoder_prefill_chunk component, prompts
prefill one token at a time and prefill throughput collapses to roughly
decode speed - silently. Emit a one-time warning at model load naming
the fallback and the fix (re-transpile with the current converter);
audio/whisper routes are exempt. Raise the benchmark fixture's log
level to WARN so the warning is visible during benchmarking, where an
unnoticed fallback invalidates prefill numbers (observed with a qwen3
bundle whose manifest omitted the chunk components: 21 tps prefill vs
62 with them).

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
build_component_module_specs only ever received explicit --components;
the component plan that infer_component_plan_from_config computes was
consulted for task inference and then dropped, so default text-qwen3
transpiles emitted the monolithic decoder/decoder_step pair instead of
the chunked-prefill pipeline - prefill collapsed to decode speed (~21
vs ~62 tps at 512 tokens on Pixel 10a).

Pass the plan's components through when --components is absent. Spec
builders reject inapplicable plans with the new UnsupportedComponentsError
and the call site falls back to builder defaults with a warning;
explicit --components and unrelated failures still raise.

Verified: fresh default transpile of Qwen/Qwen3-0.6B emits
lm_encoder_text_chunk/decoder_prefill_chunk/lm_encoder_step/
decoder_media_step/decoder_step and benches at 229 prefill / 61 decode
tps on an M4 Pro (previously 67/51 token-by-token); transpile test
suite passes.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The remedy only applies to families whose converter emits chunked
components (currently qwen3 text); for others (e.g. base LFM2) a
re-transpile changes nothing and the advice misleads.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The lfm2 chunk adapters (lm_encoder_step, lm_encoder_text_chunk,
decoder_prefill_chunk and the embeds-based decoder_step) already power
the LFM2-VL pipeline and resolve the language backbone generically, so
text-only Lfm2ForCausalLM can reuse them directly. Route chunk-component
requests in the lfm2 text spec builder to a new chunked builder and make
the component plan select the chunked set for the Lfm2ForCausalLM
architecture, matching the qwen3 text default.

Verified on LFM2-350M: transpile with reference compare passes, the
bundle prefills 565 tps vs 134 token-by-token on an M4 Pro (4.2x) with
decode unchanged (141 vs 143), no prefill-fallback warning, and greedy
generation stays coherent and factually correct across spot checks.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The four chunked component specs (lm_encoder_step, lm_encoder_text_chunk,
decoder_prefill_chunk, decoder_step) were emitted twice with identical
structure: once by the text-only builder and once inline in the LFM2-VL
multimodal builder. Generalize the text builder into
_lfm2_chunked_pipeline_specs, parameterized by graph/spec metadata and an
optional decoder-inputs provider, and emit both paths through it. The VL
path keeps its lazy vision-encoder + lm_encoder decoder-input computation
via a closure; the text path keeps deriving decoder inputs from
lm_encoder_text_chunk. Net -54 lines, no behavior change: transpiled
artifacts for LFM2-350M and LFM2-VL-450M are byte-identical before and
after.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Move the qwen3/lfm2 text-only chunked component plans out of
_plan_from_profile's hardcoded architecture conditionals into a
text_component_plans field on ModelProfile, matching how multimodal
pipelines are declared via default_components. The plan layer now
consults only the matched profile's own declarations, so one family's
architecture markers can no longer select another family's components;
real checkpoints are unaffected since model_type and architectures
always agree on family.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
Text-only LFM2 bundles now ship chunked prefill graphs, so they reach the
tail-padding decision for the first time — and the family gate only covered
lfm2_vl. Pads appended to a conv ring displace exactly the rows the conv
kernel's lookback reads and no clamp removes them: padded prefill produced
corrupted output on LFM2-350M (tokens diverge from position 1 vs the scalar
tail). Gating on conv-cache presence fixes that and subsumes the lfm2_vl
special case.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
The chunk components sliced their example input from the calibration
prompt, so a short prompt silently shrank the bundle's prefill chunk
(default convert: 46 tokens instead of 128). Tile the trace input up to
the configured chunk size instead.

Signed-off-by: Noah Cylich <noahcylich@gmail.com>
@ncylich ncylich force-pushed the prefill-tail-padding branch from 7eeeb44 to e8ef8cd Compare June 12, 2026 17:15
@jakmro jakmro merged commit d51dbcf into main Jun 12, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants