Skip to content

Commit fac064f

Browse files
authored
[Paged KV] Packed prefill with cu_seq_lens for multiple requests (#151)
## Summary <img width="1493" height="527" alt="截圖 2026-03-10 上午9 15 51" src="https://github.com/user-attachments/assets/60498140-d25a-448d-a160-0f9b34400307" /> - Pack multiple complete prefill requests into a single forward pass using cumulative sequence lengths (`cu_seq_lens`) as separators - Build block-diagonal causal mask so packed requests don't cross-attend - Apply per-request RoPE position reset within packed sequences - Fall back to single-request path when only 1 prefill is scheduled This is the first step of Stage 3 (Chunked Prefilling & Continuous Batching) in the roadmap (#148). ## Changes - `paged_attention_common.py`: Add `cu_seq_lens` field to `PagedAttentionContext`; add `prepare_prefill_packed()` - `paged_attention.py`: Add `_build_packed_causal_mask()`; update `_metal_kernel_prefill_attention()` for packed mode - `model_runner.py`: Add `_prefill_packed_paged()`; restructure Phase 1 to collect complete prefills and batch them - `test_paged_attention.py`: Add tests for packed slot_mapping, cu_seq_lens, and block-diagonal causal mask isolation ## Test - [x] Unit tests: 9/9 pass (`pytest tests/test_paged_attention.py`) - [x] E2E: Qwen3-0.6B with `VLLM_METAL_USE_PAGED_ATTENTION=1`, 3 concurrent requests all return correct responses - [x] Verified packed path triggered via debug log: scheduler batches 2 complete requests into packed prefill --------- Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
1 parent 9081de3 commit fac064f

5 files changed

Lines changed: 505 additions & 62 deletions

File tree

tests/test_paged_attention.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_context,
1414
prepare_decode,
1515
prepare_prefill,
16+
prepare_prefill_packed,
1617
)
1718

1819

@@ -47,6 +48,30 @@ def test_prepare_prefill_slot_mapping(self):
4748
assert ctx.is_prefill
4849
assert ctx.slot_mapping == [40, 41, 42, 43, 44]
4950

51+
def test_prepare_prefill_packed_slot_mapping(self):
52+
# Two requests: 3 tokens in block 10, 2 tokens in block 20
53+
requests = [([10], 3), ([20], 2)]
54+
prepare_prefill_packed(requests, block_size=4)
55+
ctx = get_context()
56+
57+
assert ctx is not None
58+
assert ctx.is_prefill
59+
# Request 0: block 10, slots 40,41,42
60+
# Request 1: block 20, slots 80,81
61+
assert ctx.slot_mapping == [40, 41, 42, 80, 81]
62+
assert ctx.cu_seqlens == [0, 3, 5]
63+
64+
def test_prepare_prefill_packed_single_request(self):
65+
# Single request should still produce valid cu_seqlens
66+
requests = [([5, 6], 5)]
67+
prepare_prefill_packed(requests, block_size=4)
68+
ctx = get_context()
69+
70+
assert ctx is not None
71+
assert ctx.cu_seqlens == [0, 5]
72+
# block 5: slots 20,21,22,23; block 6: slot 24
73+
assert ctx.slot_mapping == [20, 21, 22, 23, 24]
74+
5075
def test_prepare_decode(self):
5176
# Arrange
5277
requests = [([5, 6], 7)]
@@ -61,3 +86,137 @@ def test_prepare_decode(self):
6186
assert ctx.slot_mapping == [27]
6287
assert ctx.context_lens == [8]
6388
assert ctx.offsets == [7]
89+
90+
91+
class TestPackedCausalMask:
92+
"""Tests for the block-diagonal causal mask used in packed prefill."""
93+
94+
def test_single_sequence(self):
95+
from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
96+
build_packed_causal_mask,
97+
)
98+
99+
mask = build_packed_causal_mask([0, 3], total_len=3)
100+
# Standard causal: lower-triangular (0) with upper-triangular (-inf)
101+
assert mask.shape == (1, 1, 3, 3)
102+
m = mask[0, 0]
103+
# Diagonal and below should be 0
104+
assert m[0, 0].item() == 0.0
105+
assert m[1, 0].item() == 0.0
106+
assert m[1, 1].item() == 0.0
107+
# Above diagonal should be -inf
108+
assert m[0, 1].item() == float("-inf")
109+
assert m[0, 2].item() == float("-inf")
110+
111+
def test_two_sequences_isolation(self):
112+
from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
113+
build_packed_causal_mask,
114+
)
115+
116+
# Two sequences: [0,2) and [2,5)
117+
mask = build_packed_causal_mask([0, 2, 5], total_len=5)
118+
m = mask[0, 0]
119+
# Seq 0 tokens should not attend to seq 1 tokens
120+
assert m[0, 2].item() == float("-inf")
121+
assert m[0, 3].item() == float("-inf")
122+
assert m[1, 2].item() == float("-inf")
123+
# Seq 1 tokens should not attend to seq 0 tokens
124+
assert m[2, 0].item() == float("-inf")
125+
assert m[2, 1].item() == float("-inf")
126+
assert m[3, 0].item() == float("-inf")
127+
# Within seq 1: causal
128+
assert m[2, 2].item() == 0.0
129+
assert m[3, 2].item() == 0.0
130+
assert m[3, 3].item() == 0.0
131+
assert m[2, 3].item() == float("-inf")
132+
133+
def test_mask_dtype_matches_request(self):
134+
import mlx.core as mx
135+
136+
from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
137+
build_packed_causal_mask,
138+
)
139+
140+
mask = build_packed_causal_mask([0, 3], total_len=3, dtype=mx.bfloat16)
141+
assert mask.dtype == mx.bfloat16
142+
143+
144+
class TestPackedRoPE:
145+
"""Tests for per-request RoPE position reset in packed prefill."""
146+
147+
def test_positions_reset_per_request(self):
148+
"""Each packed request's RoPE should start from position 0."""
149+
import mlx.core as mx
150+
151+
from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
152+
apply_packed_rope,
153+
)
154+
155+
# Minimal RoPE stub: returns input + offset so we can verify offsets
156+
class FakeRoPE:
157+
def rope(self, x, offset=0):
158+
return x + offset
159+
160+
module = FakeRoPE()
161+
# Two requests packed: 3 tokens + 2 tokens
162+
# Shape: (1, heads=1, total_len=5, head_dim=2)
163+
q = mx.zeros((1, 1, 5, 2))
164+
k = mx.zeros((1, 1, 5, 2))
165+
cu_seqlens = [0, 3, 5]
166+
167+
q_out, k_out = apply_packed_rope(module, q, k, cu_seqlens)
168+
169+
# All values should be 0 (offset=0 for every request)
170+
assert q_out.shape == (1, 1, 5, 2)
171+
assert mx.allclose(q_out, mx.zeros_like(q_out)).item()
172+
assert mx.allclose(k_out, mx.zeros_like(k_out)).item()
173+
174+
175+
class TestBatchSplitting:
176+
"""Tests for the packed-prefill batch splitting logic."""
177+
178+
@staticmethod
179+
def _split_batches(
180+
entries: list[tuple[int, int]],
181+
max_tokens: int,
182+
) -> list[list[tuple[int, int]]]:
183+
"""Reproduce the batch splitting algorithm from _run_packed_prefill.
184+
185+
entries: list of (index, num_tokens) for simplicity.
186+
"""
187+
batches: list[list[tuple[int, int]]] = [[]]
188+
batch_tokens = 0
189+
for entry in entries:
190+
entry_tokens = entry[1]
191+
if batch_tokens + entry_tokens > max_tokens and batches[-1]:
192+
batches.append([])
193+
batch_tokens = 0
194+
batches[-1].append(entry)
195+
batch_tokens += entry_tokens
196+
return batches
197+
198+
def test_all_fit_single_batch(self):
199+
entries = [(0, 100), (1, 200), (2, 300)]
200+
batches = self._split_batches(entries, max_tokens=4096)
201+
assert len(batches) == 1
202+
assert batches[0] == entries
203+
204+
def test_split_into_two_batches(self):
205+
entries = [(0, 3000), (1, 2000)]
206+
batches = self._split_batches(entries, max_tokens=4096)
207+
assert len(batches) == 2
208+
assert batches[0] == [(0, 3000)]
209+
assert batches[1] == [(1, 2000)]
210+
211+
def test_single_large_request_not_dropped(self):
212+
# A request exceeding the cap should still go into its own batch
213+
entries = [(0, 5000)]
214+
batches = self._split_batches(entries, max_tokens=4096)
215+
assert len(batches) == 1
216+
assert batches[0] == [(0, 5000)]
217+
218+
def test_preserves_all_entries(self):
219+
entries = [(i, 1000) for i in range(10)]
220+
batches = self._split_batches(entries, max_tokens=4096)
221+
flat = [e for batch in batches for e in batch]
222+
assert flat == entries
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SCAFFOLDING: remove when varlen kernel is ready.
3+
#
4+
# Dense causal mask and per-request RoPE helpers for packed prefill.
5+
# These are temporary — the varlen kernel will handle masking and
6+
# position encoding natively, making this module unnecessary.
7+
8+
from __future__ import annotations
9+
10+
import mlx.core as mx
11+
12+
13+
def build_packed_causal_mask(
14+
cu_seqlens: list[int],
15+
total_len: int,
16+
dtype: mx.Dtype = mx.float32,
17+
) -> mx.array:
18+
"""Build a block-diagonal causal mask for packed prefill.
19+
20+
Each request only attends to its own tokens (causally). Returns an
21+
additive mask of shape ``(1, 1, total_len, total_len)`` with 0 for
22+
allowed positions and ``-inf`` for blocked positions, suitable for
23+
``mx.fast.scaled_dot_product_attention``.
24+
25+
Args:
26+
dtype: Construct the mask directly in this dtype to avoid a
27+
transient float32 allocation followed by a cast.
28+
29+
SCAFFOLDING: remove when varlen kernel is ready.
30+
"""
31+
neg_inf = mx.array(-mx.inf, dtype=dtype)
32+
# Start with all-blocked, then open causal windows per request
33+
mask = mx.full((total_len, total_len), neg_inf)
34+
for i in range(len(cu_seqlens) - 1):
35+
start = cu_seqlens[i]
36+
end = cu_seqlens[i + 1]
37+
seq_len = end - start
38+
# Causal mask for this request's tokens
39+
causal = mx.triu(mx.full((seq_len, seq_len), neg_inf), k=1)
40+
mask[start:end, start:end] = causal
41+
return mask.reshape(1, 1, total_len, total_len)
42+
43+
44+
def apply_packed_rope(
45+
attn_module: object,
46+
queries: mx.array,
47+
keys: mx.array,
48+
cu_seqlens: list[int],
49+
) -> tuple[mx.array, mx.array]:
50+
"""Apply per-request RoPE with position reset for packed prefill.
51+
52+
SCAFFOLDING: remove when varlen kernel is ready.
53+
"""
54+
q_parts = []
55+
k_parts = []
56+
for i in range(len(cu_seqlens) - 1):
57+
start = cu_seqlens[i]
58+
end = cu_seqlens[i + 1]
59+
q_parts.append(attn_module.rope(queries[:, :, start:end, :], offset=0))
60+
k_parts.append(attn_module.rope(keys[:, :, start:end, :], offset=0))
61+
return mx.concatenate(q_parts, axis=2), mx.concatenate(k_parts, axis=2)

vllm_metal/metal_kernel_backend/paged_attention.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@
6868

6969
from vllm_metal.metal import get_ops
7070
from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache
71+
from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
72+
apply_packed_rope,
73+
build_packed_causal_mask,
74+
)
7175
from vllm_metal.paged_attention_common import (
7276
PagedAttentionContext,
7377
find_layers_and_attr,
@@ -89,25 +93,36 @@ def _metal_kernel_prefill_attention(
8993
ctx: PagedAttentionContext,
9094
offset_cache: Any,
9195
) -> mx.array:
92-
"""Prefill: B=1, L=prompt_len.
96+
"""Prefill: B=1, L=prompt_len (single) or L=total_tokens (packed).
9397
9498
Inline causal SDPA in MLX, then write K/V to paged cache via
95-
``reshape_and_cache``.
99+
``reshape_and_cache``. When ``ctx.cu_seqlens`` is set, builds a
100+
block-diagonal causal mask so packed requests don't cross-attend.
96101
"""
97102
B, _, L, _ = queries.shape # noqa: N806
98103

99-
# RoPE
104+
# RoPE — per-request position reset for packed prefill
100105
if not hasattr(attn_module, "rope"):
101106
raise NotImplementedError(
102107
f"Attention module {type(attn_module).__name__} does not have a 'rope' "
103108
"attribute. Only RoPE-based models are supported by paged attention."
104109
)
105-
offset = offset_cache.offset if offset_cache is not None else 0
106-
queries = attn_module.rope(queries, offset=offset)
107-
keys = attn_module.rope(keys, offset=offset)
108110

109-
# Causal SDPA (inline — K/V already in hand)
110-
attn_mask = "causal" if L > 1 else None
111+
# SCAFFOLDING: packed RoPE + mask — remove when varlen kernel is ready.
112+
if ctx.cu_seqlens is not None:
113+
queries, keys = apply_packed_rope(attn_module, queries, keys, ctx.cu_seqlens)
114+
else:
115+
offset = offset_cache.offset if offset_cache is not None else 0
116+
queries = attn_module.rope(queries, offset=offset)
117+
keys = attn_module.rope(keys, offset=offset)
118+
119+
# Causal SDPA
120+
# SCAFFOLDING: dense mask — remove when varlen kernel is ready.
121+
if ctx.cu_seqlens is not None and len(ctx.cu_seqlens) > 2:
122+
attn_mask = build_packed_causal_mask(ctx.cu_seqlens, L, dtype=queries.dtype)
123+
else:
124+
attn_mask = "causal" if L > 1 else None
125+
111126
output = mx.fast.scaled_dot_product_attention(
112127
queries, keys, values, scale=attn_module.scale, mask=attn_mask
113128
)

vllm_metal/paged_attention_common.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class PagedAttentionContext:
4242
block_tables: list[list[int]] = field(default_factory=list)
4343
context_lens: list[int] = field(default_factory=list)
4444
offsets: list[int] = field(default_factory=list)
45+
# packed prefill fields — set when multiple requests are packed into
46+
# a single forward pass. cu_seqlens is a cumulative sequence length
47+
# array: [0, len0, len0+len1, ...] (length = num_requests + 1).
48+
cu_seqlens: list[int] | None = None
4549

4650

4751
def set_context(ctx: PagedAttentionContext) -> None:
@@ -164,6 +168,39 @@ def prepare_prefill(
164168
)
165169

166170

171+
def prepare_prefill_packed(
172+
requests: list[tuple[list[int], int]],
173+
block_size: int,
174+
) -> None:
175+
"""Compute slot_mapping and cu_seqlens for packed prefill.
176+
177+
Packs multiple prefill requests into a single forward pass. The
178+
attention wrapper uses ``cu_seqlens`` to build a block-diagonal
179+
causal mask so that each request only attends to its own tokens.
180+
181+
Args:
182+
requests: list of (block_ids, num_tokens) per request.
183+
block_size: tokens per block.
184+
"""
185+
slot_mapping: list[int] = []
186+
cu_seqlens: list[int] = [0]
187+
188+
for block_ids, num_tokens in requests:
189+
for pos in range(num_tokens):
190+
block_idx = block_ids[pos // block_size]
191+
slot = block_idx * block_size + (pos % block_size)
192+
slot_mapping.append(slot)
193+
cu_seqlens.append(cu_seqlens[-1] + num_tokens)
194+
195+
set_context(
196+
PagedAttentionContext(
197+
is_prefill=True,
198+
slot_mapping=slot_mapping,
199+
cu_seqlens=cu_seqlens,
200+
)
201+
)
202+
203+
167204
def prepare_decode(
168205
requests: list[tuple[list[int], int]],
169206
block_size: int,

0 commit comments

Comments
 (0)