Skip to content

Commit 2824472

Browse files
cjx0709jimoosciuc
andauthored
feat(swa): support SWA mempool with paged allocation, eviction, and dual-pool scheduling (#921)
* feat: enhance SWA mempool with paged allocation, eviction, and dual-pool scheduling - Add SWA dual-pool architecture with full/SWA sub-allocators - Support paged allocation (alloc_extend/alloc_decode) for SWA hybrid KV cache - Add SWA eviction logic (maybe_evict_swa, _evict_swa) with overlap safety - Fix scheduler cache init order for hybrid models (#202) - Restore SWA-aware status reporting in available_and_evictable_str (#225) - Fix swa_protected_size_ leak via count_swa_mapped + adjust_swa_protected_size (#226) - Fix decode eviction interval: use min(sliding_window_size, page_size) (#227) - Fix mapping OOB: mapping_size = size + page_size (#231) - Fix evict_interval==1 edge case: x%1 is always 0, handle separately - Add adjust_layer_num for hybrid model memory estimation - Add comprehensive test coverage: paged allocator, eviction, overlap safety, scheduler cache init regression tests Co-authored-by: jimoosciuc <33337387+jimoosciuc@users.noreply.github.com>
1 parent da29e8b commit 2824472

13 files changed

Lines changed: 1136 additions & 85 deletions

File tree

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# SWA Eviction Strategy
2+
3+
## 1. Overview
4+
5+
Hybrid models (e.g., MiMo-V2-Flash with 9 full-attention + 39 SWA layers) mix full-attention and sliding-window-attention (SWA) layers. Full-attention layers retain all historical KV data, while SWA layers only need the most recent `W` tokens. This document describes the dual-pool KV cache architecture and eviction strategy that exploits this difference to reduce memory usage.
6+
7+
### Support Status
8+
9+
| Cache Mode | SWA Support | Description |
10+
|-----------|-------------|-------------|
11+
| **ChunkCache** (`--disable-radix-cache`) | **Supported** | Per-request proactive eviction. Each request owns its KV slots; SWA slots outside the window are freed during extend/decode. |
12+
| **RadixCache** (default) | **Not supported** | RadixCache with SWA-aware eviction (tombstone strategies, dual LRU lists) is not implemented. Hybrid models must use `--disable-radix-cache`. |
13+
14+
## 2. Dual-Pool Architecture
15+
16+
Two separate KV cache pools are maintained:
17+
18+
| Pool | Serves | Lifecycle | Eviction |
19+
|------|--------|-----------|----------|
20+
| **Full Pool** | Full-attention layers | Retains all historical KV data for the request lifetime | Freed on request completion |
21+
| **SWA Pool** | SWA layers | Only the most recent `W` tokens are needed | Proactively freed as tokens fall outside the sliding window |
22+
23+
These pools are linked via `full_to_swa_index_mapping`, a numpy array that maps a full-pool index to its corresponding SWA-pool index. A mapping value of 0 means the SWA slot has been freed.
24+
25+
### Key Term
26+
27+
| Term | Definition |
28+
|------|-----------|
29+
| `swa_evicted_seqlen` | Per-request watermark. SWA slots in `[0, swa_evicted_seqlen)` have already been freed. Monotonically increasing. |
30+
31+
### Memory Layout Example
32+
33+
For MiMo-V2-Flash on TPU v6e-16 (TP=16, 1 KV head, head_dim=192+128=320, bf16):
34+
35+
```
36+
Full pool: 104,704 tokens x 9 FA layers x 640 bytes/token = ~603 MB
37+
SWA pool: 83,840 tokens x 39 SWA layers x 640 bytes/token = ~2.1 GB
38+
SWA held per request: ~256 tokens (sliding_window=128 + page_size=128 alignment)
39+
```
40+
41+
## 3. Allocator: `SWATokenToKVPoolAllocator`
42+
43+
The allocator maintains two independent sub-allocators (one for each pool) and the mapping array.
44+
45+
### Allocation
46+
47+
```
48+
alloc(need_size):
49+
1. Check both pools have capacity >= need_size
50+
2. Allocate full_indices from full pool
51+
3. Allocate swa_indices from SWA pool
52+
4. Update mapping: full_to_swa_index_mapping[full_indices] = swa_indices
53+
5. Return full_indices (SWA indices are transparent to callers)
54+
```
55+
56+
For paged mode (`page_size > 1`), `alloc_extend` and `alloc_decode` follow the same pattern but use page-level allocation with **atomic rollback**: if SWA allocation fails after full allocation succeeds, the full pages are rolled back to prevent partial allocation.
57+
58+
### Freeing
59+
60+
| Method | What it frees | When called |
61+
|--------|--------------|-------------|
62+
| `free(indices)` | Both full and SWA pools | Request completion |
63+
| `free_swa(indices)` | SWA pool only (looks up mapping, frees non-zero entries, zeroes mapping) | Per-request SWA eviction |
64+
| `count_swa_mapped(indices)` | Nothing (read-only) — counts indices with active SWA mapping | Bookkeeping before `free_swa` |
65+
66+
## 4. Per-Request SWA Eviction (`_evict_swa`)
67+
68+
This function frees SWA slots that fall outside the sliding window from a request's `req_to_token` buffer.
69+
70+
### Algorithm
71+
72+
```
73+
_evict_swa(req, pre_len, sliding_window_size, page_size):
74+
1. new_evicted = max(req.swa_evicted_seqlen, pre_len - sliding_window_size)
75+
2. If page_size > 1: align new_evicted down to page boundary
76+
3. If new_evicted <= req.swa_evicted_seqlen: return (nothing to evict)
77+
4. Read full-pool indices from req_to_token[swa_evicted_seqlen : new_evicted]
78+
5. Count actual SWA slots to free (count_swa_mapped)
79+
6. free_swa(those indices)
80+
7. Update req.swa_evicted_seqlen = new_evicted
81+
```
82+
83+
### Example
84+
85+
With `sliding_window=128`, `page_size=256`, and `seqlen=2049`:
86+
87+
```
88+
new_evicted = max(0, 2049 - 128) = 1921
89+
page-aligned: (1921 // 256) * 256 = 1792
90+
Free SWA slots in [0, 1792), retain [1792, 2049) within the window.
91+
```
92+
93+
## 5. Extend Phase Behavior
94+
95+
With overlap scheduling enabled, `maybe_evict_swa` is gated by `extend_batch_idx` to prevent freeing SWA pages that a previous extend batch may still be reading on device:
96+
97+
| Condition | Action | Reason |
98+
|-----------|--------|--------|
99+
| `extend_batch_idx < 2` | Eviction **skipped** | Previous extend batch may still be executing |
100+
| `extend_batch_idx >= 2` | Eviction proceeds with `pre_len -= chunked_prefill_size` | Safe: previous batch has completed |
101+
102+
This creates a one-chunk safety delay: chunk N+1 evicts chunk N-1's outdated SWA cache.
103+
104+
**Example** (8K tokens, chunk_size=2048, sliding_window=128, page_size=256, overlap enabled):
105+
106+
| Chunk | `extend_batch_idx` | Action | SWA slots freed |
107+
|-------|-------------------|--------|-----------------|
108+
| 1 | 0 | Skipped ||
109+
| 2 | 1 | Skipped ||
110+
| 3 | 2 | `pre_len=4096-2048=2048`, evicts `[0, 1792)` | 1792 |
111+
| 4 | 3 | `pre_len=6144-2048=4096`, evicts `[1792, 3840)` | 2048 |
112+
113+
Without overlap scheduling, there is no `extend_batch_idx` gate and no `pre_len` adjustment:
114+
115+
| Chunk | Action | SWA slots freed |
116+
|-------|--------|-----------------|
117+
| 1 | `pre_len=0`, nothing to evict ||
118+
| 2 | `pre_len=2048`, evicts `[0, 1792)` | 1792 |
119+
| 3 | `pre_len=4096`, evicts `[1792, 3840)` | 2048 |
120+
121+
## 6. Decode Phase Behavior
122+
123+
### Eviction Interval
124+
125+
```python
126+
evict_interval = max(min(sliding_window_size, page_size), 1)
127+
```
128+
129+
| Scenario | Interval | Rationale |
130+
|----------|----------|-----------|
131+
| `page_size >= sliding_window_size` | Every step | Window advances past a full page each step |
132+
| `page_size < sliding_window_size` | Every `page_size` steps | Avoid partial-page eviction |
133+
| `evict_interval == 1` | Every step | `max(..., 1)` guard prevents `x % 1 == 1` (always false) from disabling eviction |
134+
135+
### Overlap Safety
136+
137+
| Condition | Action | Reason |
138+
|-----------|--------|--------|
139+
| `decode_batch_idx == 0` | Eviction **skipped** | Previous decode batch may still be reading SWA pages on device |
140+
| `decode_batch_idx > 0` | Eviction triggers on `decode_batch_idx % evict_interval == 1` | Safe: previous batch has completed |
141+
142+
## 7. Summary
143+
144+
| Component | Description |
145+
|-----------|-------------|
146+
| **Dual-pool architecture** | Full pool (all layers, all history) + SWA pool (SWA layers, window only) |
147+
| **Index mapping** | `full_to_swa_index_mapping` translates full-pool indices to SWA-pool indices |
148+
| **Allocation** | Atomic dual-pool alloc with rollback on SWA exhaustion |
149+
| **Extend eviction** | Proactive per-chunk; skips first 2 chunks for overlap safety |
150+
| **Decode eviction** | Periodic per-step; skips batch 0 for overlap safety |
151+
| **Eviction algorithm** | `_evict_swa`: advance watermark, page-align, free SWA slots in `[old, new)` |
152+
| **RadixCache SWA** | Not supported — hybrid models must use `--disable-radix-cache` |

python/sgl_jax/srt/layers/attention/flashattention_backend.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class FlashAttentionMetadata:
4141
seq_lens: jax.Array = None
4242
distribution: jax.Array = None
4343
custom_mask: jax.Array = None
44+
swa_page_indices: jax.Array = None
4445

4546
def tree_flatten(self):
4647
children = (
@@ -51,6 +52,7 @@ def tree_flatten(self):
5152
self.seq_lens,
5253
self.distribution,
5354
self.custom_mask,
55+
self.swa_page_indices,
5456
)
5557

5658
aux_data = {}
@@ -67,6 +69,7 @@ def tree_unflatten(cls, aux_data, children):
6769
obj.seq_lens = children[4]
6870
obj.distribution = children[5]
6971
obj.custom_mask = children[6]
72+
obj.swa_page_indices = children[7]
7073

7174
return obj
7275

@@ -96,6 +99,7 @@ def __init__(
9699
self.kv_partition_axis = kv_partition_axis
97100
self.forward_metadata = nnx.data(FlashAttentionMetadata())
98101
self.mesh = mesh
102+
self.swa_index_mapping = None
99103

100104
def get_forward_metadata(
101105
self,
@@ -151,17 +155,47 @@ def get_forward_metadata(
151155
else:
152156
raise ValueError(f"Invalid forward mode: {batch.forward_mode}")
153157

154-
(
155-
metadata.num_seqs,
156-
metadata.cu_q_lens,
157-
metadata.cu_kv_lens,
158-
metadata.page_indices,
159-
metadata.seq_lens,
160-
metadata.distribution,
161-
) = device_array(
162-
(num_seqs, cu_q_lens, cu_kv_lens, page_indices, seq_lens, distribution),
163-
sharding=(NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None),
164-
)
158+
# Compute swa_page_indices if SWA index mapping is available
159+
swa_page_indices = None
160+
if self.swa_index_mapping is not None:
161+
swa_cache_loc = self.swa_index_mapping[batch.cache_loc]
162+
swa_indices = np.arange(0, len(swa_cache_loc), self.page_size)
163+
swa_selected = swa_cache_loc[swa_indices]
164+
swa_page_indices = (swa_selected // self.page_size).astype(np.int32)
165+
166+
if swa_page_indices is not None:
167+
(
168+
metadata.num_seqs,
169+
metadata.cu_q_lens,
170+
metadata.cu_kv_lens,
171+
metadata.page_indices,
172+
metadata.seq_lens,
173+
metadata.distribution,
174+
metadata.swa_page_indices,
175+
) = device_array(
176+
(
177+
num_seqs,
178+
cu_q_lens,
179+
cu_kv_lens,
180+
page_indices,
181+
seq_lens,
182+
distribution,
183+
swa_page_indices,
184+
),
185+
sharding=(NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None),
186+
)
187+
else:
188+
(
189+
metadata.num_seqs,
190+
metadata.cu_q_lens,
191+
metadata.cu_kv_lens,
192+
metadata.page_indices,
193+
metadata.seq_lens,
194+
metadata.distribution,
195+
) = device_array(
196+
(num_seqs, cu_q_lens, cu_kv_lens, page_indices, seq_lens, distribution),
197+
sharding=(NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None),
198+
)
165199
return metadata
166200

167201
def get_eagle_forward_metadata(self, batch: ModelWorkerBatch):
@@ -454,7 +488,13 @@ def __call__(
454488
causal = 0
455489
# Select page indices and remap to SWA pool if KV cache supports it
456490
page_indices_arg = self.forward_metadata.page_indices
457-
if hasattr(token_to_kv_pool, "remap_cache_loc") and self.page_size == 1:
491+
if self.forward_metadata.swa_page_indices is not None and hasattr(
492+
token_to_kv_pool, "layers_mapping"
493+
):
494+
_, is_swa = token_to_kv_pool.layers_mapping[layer.layer_id]
495+
if is_swa:
496+
page_indices_arg = self.forward_metadata.swa_page_indices
497+
elif hasattr(token_to_kv_pool, "remap_cache_loc") and self.page_size == 1:
458498
page_indices_arg = token_to_kv_pool.remap_cache_loc(page_indices_arg, layer.layer_id)
459499

460500
in_specs = (

python/sgl_jax/srt/managers/schedule_batch.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"speculative_accept_threshold_single",
6666
"speculative_accept_threshold_acc",
6767
"enable_deterministic_sampling",
68+
"chunked_prefill_size",
6869
]
6970

7071
PADDING_BUCKETS = [1 << i for i in range(6, 21)]
@@ -306,6 +307,11 @@ def __init__(
306307
) = None
307308
self.hidden_states: list[list[float]] = []
308309

310+
# SWA eviction tracking
311+
self.swa_evicted_seqlen: int = 0
312+
self.extend_batch_idx: int = 0
313+
self.decode_batch_idx: int = 0
314+
309315
# The number of cached tokens that were already cached in the KV cache
310316
self.cached_tokens = 0
311317
self.already_computed = 0
@@ -741,6 +747,63 @@ def mix_with_running(self, running_batch: ScheduleBatch):
741747
self.extend_num_tokens += running_bs
742748
self.extend_logprob_start_lens.extend([0] * running_bs)
743749

750+
def _evict_swa(self, req: Req, pre_len: int, sliding_window_size: int, page_size: int):
751+
"""Evict SWA pool tokens outside the sliding window for a single request."""
752+
new_evicted = max(req.swa_evicted_seqlen, pre_len - sliding_window_size)
753+
if page_size > 1:
754+
new_evicted = (new_evicted // page_size) * page_size
755+
if new_evicted <= req.swa_evicted_seqlen:
756+
return
757+
free_slots = self.req_to_token_pool.req_to_token[
758+
req.req_pool_idx, req.swa_evicted_seqlen : new_evicted
759+
]
760+
# Count actual SWA slots that will be freed (those with active mapping)
761+
num_swa_freed = self.token_to_kv_pool_allocator.count_swa_mapped(free_slots)
762+
self.token_to_kv_pool_allocator.free_swa(free_slots)
763+
# Notify cache layer: these slots were protected (node is locked),
764+
# so adjust swa_protected_size_ to prevent bookkeeping leak.
765+
if num_swa_freed > 0 and isinstance(self.tree_cache, SWARadixCache):
766+
self.tree_cache.adjust_swa_protected_size(-num_swa_freed)
767+
req.swa_evicted_seqlen = new_evicted
768+
769+
def maybe_evict_swa(self, sliding_window_size=None):
770+
"""Evict SWA pool tokens for all requests if hybrid model."""
771+
if not self.is_hybrid:
772+
return
773+
if sliding_window_size is None:
774+
sliding_window_size = getattr(self.model_config, "sliding_window", None)
775+
if sliding_window_size is None or sliding_window_size <= 0:
776+
return
777+
page_size = getattr(
778+
self.token_to_kv_pool_allocator,
779+
"_page_size",
780+
getattr(self.token_to_kv_pool_allocator, "page_size", 1),
781+
)
782+
783+
if self.forward_mode is not None and self.forward_mode.is_decode():
784+
# Evict at the smaller of sliding_window_size and page_size to avoid
785+
# stale SWA slot accumulation.
786+
evict_interval = max(min(sliding_window_size, page_size), 1)
787+
for req in self.reqs:
788+
if req.decode_batch_idx > 0 and (
789+
evict_interval <= 1 or req.decode_batch_idx % evict_interval == 1
790+
):
791+
self._evict_swa(req, req.seqlen - 1, sliding_window_size, page_size)
792+
return
793+
794+
if self.forward_mode is None or not self.forward_mode.is_extend():
795+
return
796+
797+
for i, req in enumerate(self.reqs):
798+
pre_len = self.prefix_lens[i] if self.prefix_lens is not None else 0
799+
if self.enable_overlap and req.is_chunked > 0:
800+
if req.extend_batch_idx < 2:
801+
continue
802+
chunked_prefill_size = global_server_args_dict.get("chunked_prefill_size")
803+
if chunked_prefill_size is not None and chunked_prefill_size > 0:
804+
pre_len -= chunked_prefill_size
805+
self._evict_swa(req, pre_len, sliding_window_size, page_size)
806+
744807
def prepare_for_extend(self):
745808
self.forward_mode = ForwardMode.EXTEND
746809

@@ -776,6 +839,7 @@ def prepare_for_extend(self):
776839
req.cached_tokens += pre_len - req.already_computed
777840
req.already_computed = seq_len
778841
req.is_retracted = False
842+
req.extend_batch_idx += 1
779843

780844
# Compute the relative logprob_start_len in an extend batch
781845
if req.logprob_start_len >= pre_len:
@@ -878,6 +942,10 @@ def prepare_for_extend(self):
878942
)
879943
pt += extend_lens[i]
880944

945+
# Evict SWA tokens outside sliding window
946+
if self.is_hybrid:
947+
self.maybe_evict_swa()
948+
881949
# Build sampling info
882950
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
883951
self,
@@ -1019,6 +1087,11 @@ def prepare_for_idle(self):
10191087
def prepare_for_decode(self):
10201088
self.forward_mode = ForwardMode.DECODE
10211089
bs = len(self.reqs)
1090+
1091+
# Evict SWA tokens outside sliding window
1092+
if self.is_hybrid:
1093+
self.maybe_evict_swa()
1094+
10221095
if self.spec_algorithm is not None and self.spec_algorithm.is_eagle():
10231096
# if spec decoding is used, the decode batch is prepared inside
10241097
# `forward_batch_speculative_generation` after running draft models.
@@ -1068,6 +1141,9 @@ def prepare_for_decode(self):
10681141
(self.req_pool_indices, locs), self.out_cache_loc.astype(np.int32)
10691142
)
10701143

1144+
for req in self.reqs:
1145+
req.decode_batch_idx += 1
1146+
10711147
def filter_batch(
10721148
self,
10731149
chunked_req_to_exclude: Req | list[Req] | None = None,

0 commit comments

Comments
 (0)