Skip to content

feat(mem_cache): page-major (layer-major within a page) KV/state layout#29533

Open
ch-wan wants to merge 1 commit into
mainfrom
feature/page-local-kv-layout
Open

feat(mem_cache): page-major (layer-major within a page) KV/state layout#29533
ch-wan wants to merge 1 commit into
mainfrom
feature/page-local-kv-layout

Conversation

@ch-wan

@ch-wan ch-wan commented Jun 28, 2026

Copy link
Copy Markdown
Collaborator

Motivation

SGLang stores the KV cache (and Mamba conv/SSM state) as per-layer tensors — layer 0's slots, then layer 1's slots, etc. (a layer-major layout). Each token's K/V for a given layer is contiguous, but a single token/page is scattered across num_layers separate allocations.

This PR adds an opt-in physical layout, --enable-page-major-kv-layout, that flips the outermost axis to the page: each page's whole depth — all layers' K/V (and all Mamba conv/temporal state) — lives in one contiguous byte buffer, laid out layer-major within the page. At page_size=1 this is a per-token envelope. Co-locating a page's whole depth is a building block for page-granular KV operations (movement, transfer, offload, allocation) and improves locality for those paths.

The layout is off by default and behavior-preserving when off — the hot paths are byte-identical (the page-aware Triton kernels constexpr-fold to the legacy addressing at page_size=1).

Modifications

  • mem_cache/layout/page_major.py — standalone strided-view builders (build_page_major_mha_views, build_page_major_mamba_views) + byte geometry; hold no allocator state.
  • PageMajorMHATokenToKVPool — a subclass of MHATokenToKVPool (not in-class branching) selected via new _store_kv_layer / _move_kv_cache_impl template hooks. Layout-incompatible inherited methods (contiguous-buf-infos, CPU offload, prefix-commit) raise NotImplementedError instead of silently mis-indexing the 4-D strided views. MambaPool gains an envelope branch for conv/temporal state.
  • Triton decode/extend + store_cache_4d kernels — page-aware strides behind a PAGE_SIZE constexpr; at page_size=1 the page math is dead-code-eliminated, so the SASS is identical to today.
  • GDN prefill gather/scatter (gdn_backend.forward_extend) — the prefill conv (causal_conv1d_fwd) and chunk_gated_delta_rule kernels write state back assuming a contiguous slot layout; under the strided envelope they silently dropped the write. The hybrid prefill now runs on contiguous per-sequence copies and scatters the updated state back. (TODO(ch-wan) left to make those kernels stride-aware and drop the copies.)
  • server_args--enable-page-major-kv-layout flag + a validator requiring the Triton attention / linear-attn / Mamba backends; model_runner_kv_cache_mixin routes the layout into the plain-MHA, SWA-hybrid, and Mamba-hybrid pools.
  • Removed the dead enable_kvcache_transpose parameter (was always False).
  • Tests + docs — kernel parity (store_cache_4d, decode/extend), CPU view/move tests, two e2e accuracy tests (gpt-oss, qwen) in the label-gated extra suite, and the server-arg doc.

Accuracy Tests

GSM8K (5-shot/300, Triton backend), page-major vs baseline:

Model Path Baseline Page-major
Llama-2-7b-chat plain MHA 0.243 0.243
Qwen3.5-4B hybrid GDN (Mamba) 0.863 0.863
gpt-oss-20b hybrid-SWA MoE 0.56 0.52 (within noise; reasoning model underperforms few-shot completion)

A GDN-prefill state-persistence bug (page-major dropped Qwen3.5 to ~0.61) was found and fixed; the table reflects the fix. A dedicated review confirmed the disabled path is behavior-preserving with no measurable overhead.

Speed Tests and Profiling

Not yet benchmarked — this PR is a layout/correctness foundation, off by default. The page_size=1 path is a verified no-op (constexpr-folded), so no regression is expected when disabled. Throughput/locality benchmarking of the enabled path is a follow-up.

Notes / follow-ups

  • --enable-page-major-kv-layout is not yet supported with fp4 KV cache (asserted) or the speculative-decode target-verify path.
  • The GDN prefill still pays a .contiguous() gather/scatter under the envelope (TODO(ch-wan)); making the conv / chunk_gated_delta_rule kernels stride-aware would remove it.

Checklist

  • Format code with pre-commit
  • Add unit tests (kernel parity + CPU view/move + e2e accuracy), registered to CI
  • Update documentation (server_arguments.mdx)
  • Accuracy results provided (above); speed benchmarking is a follow-up
  • Follow SGLang code style

Co-Authored-By: lch1475369 lch1475369@gmail.com


CI States

Latest PR Test (Base): ✅ Run #28309127030
Latest PR Test (Extra): ✅ Run #28309127002

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@mintlify

mintlify Bot commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Preview deployment for your docs. Learn more about Mintlify Previews.

Project Status Preview Updated (UTC)
lmsysorg 🟢 Ready View Preview Jun 28, 2026, 12:41 AM

💡 Tip: Enable Workflows to automatically generate PRs for you.

@ch-wan

ch-wan commented Jun 28, 2026

Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@ch-wan ch-wan force-pushed the feature/page-local-kv-layout branch from 192a875 to 9392ffd Compare June 28, 2026 02:21
@ch-wan ch-wan changed the title feat(mem_cache): page-local layer-major (page-granularity envelope) KV/state layout feat(mem_cache): page-major (layer-major within a page) KV/state layout Jun 28, 2026
Opt-in physical layout (--enable-page-major-kv-layout) that makes the page the
outermost axis: each page is laid out layer-major in one contiguous byte buffer
for the Mamba state, full-KV, and SWA-KV caches instead of per-layer tensors. At
page_size=1 this is a token-granularity envelope; independent of any
shared/virtual-slot allocator.

- mem_cache/layout/page_major.py: strided-view builders + byte geometry.
- PageMajorMHATokenToKVPool subclass (kv_cache_layout=page_major_layer_major)
  via _store_kv_layer / _move_kv_impl template hooks on MHATokenToKVPool;
  layout-incompatible methods raise instead of silently mis-indexing.
  MambaPool envelope branch for the conv/temporal state.
- Triton decode/extend + store_cache_4d kernels: page-aware strides, a
  byte-identical no-op at page_size=1 (PAGE_SIZE constexpr).
- GDN prefill gather/scatter in gdn_backend.forward_extend so the strided
  envelope state is persisted correctly to the pool (the prefill conv /
  chunk_gated_delta_rule kernels assume a contiguous slot layout).
- server_args flag + Triton-backend validator; model_runner_kv_cache_mixin
  routes the layout into the plain-MHA, SWA-hybrid, and Mamba-hybrid pools.
- Removed the dead enable_kvcache_transpose param.
- Tests: store_cache_4d / decode+extend parity, CPU view/move, and e2e
  page-major accuracy (gpt-oss, qwen) in the label-gated extra suite. Docs.

Co-Authored-By: lch1475369 <lch1475369@gmail.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bypass-fastfail documentation Improvements or additions to documentation run-ci run-ci-extra

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant