Skip to content

Commit 9392ffd

Browse files
ch-wanlch1475369claude
committed
feat(mem_cache): page-major (layer-major within a page) KV/state layout
Opt-in physical layout (--enable-page-major-kv-layout) that makes the page the outermost axis: each page is laid out layer-major in one contiguous byte buffer for the Mamba state, full-KV, and SWA-KV caches instead of per-layer tensors. At page_size=1 this is a token-granularity envelope; independent of any shared/virtual-slot allocator. - mem_cache/layout/page_major.py: strided-view builders + byte geometry. - PageMajorMHATokenToKVPool subclass (kv_cache_layout=page_major_layer_major) via _store_kv_layer / _move_kv_impl template hooks on MHATokenToKVPool; layout-incompatible methods raise instead of silently mis-indexing. MambaPool envelope branch for the conv/temporal state. - Triton decode/extend + store_cache_4d kernels: page-aware strides, a byte-identical no-op at page_size=1 (PAGE_SIZE constexpr). - GDN prefill gather/scatter in gdn_backend.forward_extend so the strided envelope state is persisted correctly to the pool (the prefill conv / chunk_gated_delta_rule kernels assume a contiguous slot layout). - server_args flag + Triton-backend validator; model_runner_kv_cache_mixin routes the layout into the plain-MHA, SWA-hybrid, and Mamba-hybrid pools. - Removed the dead enable_kvcache_transpose param. - Tests: store_cache_4d / decode+extend parity, CPU view/move, and e2e page-major accuracy (gpt-oss, qwen) in the label-gated extra suite. Docs. Co-Authored-By: lch1475369 <lch1475369@gmail.com> Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent da802dd commit 9392ffd

25 files changed

Lines changed: 2136 additions & 124 deletions

docs_new/docs/advanced_features/server_arguments.mdx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,12 @@ Please consult the documentation below and [server_args.py](https://github.com/s
509509
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>The number of tokens in a page.</td>
510510
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}>`1`</td>
511511
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Type: int</td>
512+
</tr>
513+
<tr>
514+
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}>`--enable-page-major-kv-layout`</td>
515+
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Enable the page-major KV layout: lay out the Mamba state and full/SWA KV caches in a page-granularity envelope (page is the outermost axis, layer-major within a page) instead of the default per-layer (layer-major) layout. Requires the Triton attention / linear-attn / Mamba backends (`--attention-backend triton`, and for hybrid models `--linear-attn-backend triton --mamba-backend triton`).</td>
516+
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}>`False`</td>
517+
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>bool flag (set to enable)</td>
512518
</tr>
513519
<tr>
514520
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}>`--swa-full-tokens-ratio`</td>

python/sglang/srt/layers/attention/linear/gdn_backend.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,30 @@ def forward_extend(
425425
else:
426426
has_initial_states = forward_batch.extend_prefix_lens > 0
427427

428+
# Page-major envelope: the prefill kernels (CUDA causal_conv1d_fwd,
429+
# chunk_gated_delta_rule) write state back in place assuming a contiguous
430+
# slot layout, so they silently drop the write to the strided envelope
431+
# pool. Run them on contiguous per-sequence copies (identity-indexed) and
432+
# scatter the result back. No-op for the default contiguous pool.
433+
# TODO(ch-wan): drop these .contiguous() copies by making the prefill conv
434+
# and chunk_gated_delta_rule kernels honor the pool's real slot stride +
435+
# int64 indexing, like packed_decode / causal_conv1d_update already do.
436+
gather_mamba_state = (not is_target_verify) and (
437+
not conv_states.is_contiguous() or not ssm_states.is_contiguous()
438+
)
439+
if gather_mamba_state:
440+
conv_states_run = conv_states[cache_indices].contiguous()
441+
ssm_states_run = ssm_states[cache_indices].contiguous()
442+
state_cache_indices = torch.arange(
443+
cache_indices.shape[0],
444+
device=cache_indices.device,
445+
dtype=cache_indices.dtype,
446+
)
447+
else:
448+
conv_states_run = conv_states
449+
ssm_states_run = ssm_states
450+
state_cache_indices = cache_indices
451+
428452
if is_target_verify:
429453
batch_size = seq_len // forward_batch.spec_info.draft_token_num
430454
draft_token_num = forward_batch.spec_info.draft_token_num
@@ -460,9 +484,9 @@ def forward_extend(
460484
layer.conv_weights,
461485
layer.bias,
462486
activation=layer.activation,
463-
conv_states=conv_states,
487+
conv_states=conv_states_run,
464488
has_initial_state=has_initial_states,
465-
cache_indices=cache_indices,
489+
cache_indices=state_cache_indices,
466490
query_start_loc=query_start_loc,
467491
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
468492
).transpose(0, 1)[:seq_len]
@@ -514,8 +538,8 @@ def forward_extend(
514538
v=value,
515539
g=g,
516540
beta=beta,
517-
ssm_states=ssm_states,
518-
cache_indices=cache_indices,
541+
ssm_states=ssm_states_run,
542+
cache_indices=state_cache_indices,
519543
query_start_loc=query_start_loc,
520544
)
521545

@@ -525,6 +549,12 @@ def forward_extend(
525549
)
526550
ssm_states[cache_indices] = last_recurrent_state
527551

552+
if gather_mamba_state:
553+
# Scatter the in-place-updated contiguous copies back to the
554+
# strided envelope pool (advanced indexing handles the strides).
555+
conv_states[cache_indices] = conv_states_run
556+
ssm_states[cache_indices] = ssm_states_run
557+
528558
if h is not None:
529559
self._track_mamba_state_extend(
530560
forward_batch, h, ssm_states, forward_metadata

python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ def track_mamba_state_if_needed_kernel(
4343
if not track_mask:
4444
return
4545

46-
# Load source and destination indices
47-
src_idx = tl.load(cache_indices_ptr + batch_idx)
48-
dst_idx = tl.load(mamba_track_indices_ptr + batch_idx)
46+
# Cast indices to int64 before they multiply the row stride. The
47+
# page-granularity envelope layout makes the conv/ssm row stride large
48+
# (stride_0 = entry_bytes / itemsize), so an int32 `idx * stride_0` can
49+
# overflow for moderately large idx and wrap to an illegal address. int64 is
50+
# harmless for the small-stride (per-layer) case.
51+
src_idx = tl.load(cache_indices_ptr + batch_idx).to(tl.int64)
52+
dst_idx = tl.load(mamba_track_indices_ptr + batch_idx).to(tl.int64)
4953

5054
# Copy conv_states
5155
# Each thread handles BLOCK_SIZE elements

python/sglang/srt/layers/attention/triton_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ def __init__(
147147
self.req_to_token = model_runner.req_to_token_pool.req_to_token
148148
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
149149
self.use_sliding_window_kv_pool = isinstance(self.token_to_kv_pool, SWAKVPool)
150+
# Pass-through to the Triton attention wrappers so they can extract the
151+
# KV view strides and specialize on the PAGE_SIZE constexpr. At
152+
# page_size=1 the kernel path matches the slot-based envelope addresses.
153+
# `model_runner.page_size` defaults to 1 when `server_args.page_size` is
154+
# None, avoiding the Optional case here.
155+
self.page_size = getattr(model_runner, "page_size", 1) or 1
150156
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
151157
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
152158
self.topk = model_runner.server_args.speculative_eagle_topk or 0
@@ -1306,6 +1312,7 @@ def forward_extend(
13061312
sinks=sinks,
13071313
window_kv_offsets=window_kv_offsets,
13081314
xai_temperature_len=layer.xai_temperature_len,
1315+
page_size=self.page_size,
13091316
)
13101317
return o
13111318

@@ -1575,6 +1582,7 @@ def _forward_extend_unified(
15751582
sinks=sinks,
15761583
window_start_pos=window_start_pos,
15771584
xai_temperature_len=layer.xai_temperature_len,
1585+
page_size=self.page_size,
15781586
)
15791587

15801588
return o
@@ -1710,6 +1718,7 @@ def forward_decode(
17101718
xai_temperature_len=layer.xai_temperature_len,
17111719
has_mla=self.use_mla,
17121720
use_pdl=self.use_pdl,
1721+
page_size=self.page_size,
17131722
)
17141723
return o
17151724

0 commit comments

Comments
 (0)