Skip to content

Commit ed2cefa

Browse files
authored
[varlen Kernel] Extend paged attention v2 to varlen [4/n] (#166)
## Summary - Add `find_seq_idx` binary search to the v2 Metal kernel so each threadgroup discovers its sequence from a flat `cu_seqlens_q` array, enabling variable-length queries (prefill + decode in one launch) - **This PR does not take actual effect in production**. The current production still use `mx.sdpa` for prefilling, and use this PR v2 for decoding. But the kernels_v2 is identical to previous v1, by freezing some parameters. - Pass all triangle tests. safe to move forward to stage 3 continuous batching. Notes: - vendored feature from upstream vllm: adding sliding window support, and soft capping to the v2 kernel - Update production decode path to match the new function signature (default params, no behavior change) - **These features are NOT TESTED IN END-to-END production usage, they are expected to be binded with specific models such as early version of mistral models.** ## Triangle Test Status ``` ref (pure-MLX naive) / \ edge 1 edge 3 / \ v1 ── edge 2 ── v2 ``` - **Edge 1** (v1 == ref): 6 pass (unchanged) - **Edge 2** (v2 == v1): 6 pass (unchanged) - **Edge 3** (v2 == ref): **24 pass** (was 3 pass + 21 xfail) - varlen (q_len > 1): now passing - sliding window (128): now passing - soft capping (50.0): now passing **Before:** 15 passed + 21 xfail → **After:** 36 passed + 0 xfail ## What's NOT in this PR The kernel now supports unified prefill+decode, but production still uses the split path (MLX SDPA for prefill, v2 kernel for decode). Wiring `metal_unified_attention()` into `model_runner.py` is a follow-up. ## Numeric Stability ``` python -m pytest tests/test_paged_deterministic.py -v -s ``` * Before this PR: 5/6 match mlx_lm path * After this PR: 6/6 match mlx_lm path However, I don't want to change the test for now. The test result will flip on and off later by the following PRs. ## Benchmark run same benchmark script as #136 <details> <summary>This PR: </summary> ``` ============ Serving Benchmark Result ============ Successful requests: 100 Failed requests: 0 Maximum request concurrency: 32 Request rate configured (RPS): 10.00 Benchmark duration (s): 107.24 Total input tokens: 23260 Total generated tokens: 22061 Request throughput (req/s): 0.93 Output token throughput (tok/s): 205.71 Peak output token throughput (tok/s): 319.00 Peak concurrent requests: 35.00 Total token throughput (tok/s): 422.60 ---------------Time to First Token---------------- Mean TTFT (ms): 593.33 Median TTFT (ms): 386.44 P99 TTFT (ms): 2147.57 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 134.30 Median TPOT (ms): 127.79 P99 TPOT (ms): 477.35 ---------------Inter-token Latency---------------- Mean ITL (ms): 117.58 Median ITL (ms): 104.05 P99 ITL (ms): 473.33 ================================================== ``` </details> <details> <summary>before this PR:</summary> ``` ============ Serving Benchmark Result ============ Successful requests: 100 Failed requests: 0 Maximum request concurrency: 32 Request rate configured (RPS): 10.00 Benchmark duration (s): 106.42 Total input tokens: 23260 Total generated tokens: 22061 Request throughput (req/s): 0.94 Output token throughput (tok/s): 207.30 Peak output token throughput (tok/s): 320.00 Peak concurrent requests: 35.00 Total token throughput (tok/s): 425.87 ---------------Time to First Token---------------- Mean TTFT (ms): 982.74 Median TTFT (ms): 452.35 P99 TTFT (ms): 3030.71 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 132.63 Median TPOT (ms): 124.33 P99 TPOT (ms): 442.78 ---------------Inter-token Latency---------------- Mean ITL (ms): 115.19 Median ITL (ms): 101.69 P99 ITL (ms): 440.37 ================================================== ``` </details> This PR has no effects on the performance. It paves the way for continuous batching. ## Possible Limitation * binary search is translated from the triton kernel. But it may not be neccecary. Triton uses it to avoid CPU-GPU data copy, but we are on a unifed memory. Maybe we can prebuild the reverse map. But from the data range, O(log(n)) are the same with O(1) but takes less space. * didn't check the partition on or off. --------- Signed-off-by: ran <hzz5361@psu.edu>
1 parent 95ad433 commit ed2cefa

5 files changed

Lines changed: 136 additions & 51 deletions

File tree

tests/test_metal_unified_attention.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,6 @@ def test_metal_unified_attn(
357357
query_lens = [x[0] for x in seq_lens]
358358
kv_lens = [x[1] for x in seq_lens]
359359

360-
# xfail cases that need features not yet in the v2 kernel:
361-
# varlen (q_len > 1), sliding window, or soft capping.
362-
# Decode-only cases with no extras already work and should pass.
363-
max_query_len_val = max(query_lens)
364-
if max_query_len_val > 1 or sliding_window is not None or soft_cap is not None:
365-
pytest.xfail("v2 varlen/sliding-window/soft-cap not yet implemented")
366360
num_query_heads = num_heads[0]
367361
num_kv_heads = num_heads[1]
368362
assert num_query_heads % num_kv_heads == 0

vllm_metal/metal/__init__.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,35 +87,33 @@ def metal_unified_attention(
8787
) -> None:
8888
"""Unified varlen paged attention for Metal.
8989
90-
Currently supports decode-only (max_seqlen_q=1). Sliding window and
91-
soft capping are not yet supported. These will be enabled when the v2
92-
kernel is extended to handle variable-length queries (prefill + decode).
90+
Supports variable-length queries (prefill + decode) with online softmax,
91+
paged KV cache, causal masking, sliding window, and soft capping.
92+
93+
Grid: one threadgroup per (head, query_token). Each threadgroup uses
94+
binary search on cu_seqlens_q to find its sequence and computes causal
95+
attention against the paged KV cache.
9396
"""
97+
assert causal, "Only causal attention is supported"
9498
import mlx.core as mx
9599

96-
if max_seqlen_q != 1:
97-
raise NotImplementedError(
98-
f"metal_unified_attention only supports decode (max_seqlen_q=1), "
99-
f"got {max_seqlen_q}"
100-
)
101-
if window_size != (-1, -1):
102-
raise NotImplementedError(
103-
f"Sliding window not yet supported, got window_size={window_size}"
104-
)
105-
if softcap != 0:
106-
raise NotImplementedError(
107-
f"Soft capping not yet supported, got softcap={softcap}"
108-
)
109-
110100
# Extract dimensions from cache shape
111101
# k shape: [num_blocks, block_size, num_kv_heads, head_size]
112102
num_kv_heads = k.shape[2]
113103
block_size = k.shape[1]
114104

105+
# Convert window_size tuple to a single sliding_window int.
106+
# window_size = (left, right) where left = sw-1, right = 0 for causal.
107+
# sliding_window = left + 1 = total window size. -1 = disabled.
108+
if window_size == (-1, -1):
109+
sliding_window = -1
110+
else:
111+
sliding_window = window_size[0] + 1
112+
115113
ops = get_ops()
116114

117115
# Ensure all inputs are evaluated before raw Metal dispatch
118-
mx.eval(out, q, k, v, block_table, seqused_k)
116+
mx.eval(out, q, k, v, block_table, seqused_k, cu_seqlens_q)
119117

120118
ops.paged_attention_v2_online(
121119
out,
@@ -124,10 +122,13 @@ def metal_unified_attention(
124122
v,
125123
num_kv_heads,
126124
softmax_scale,
125+
softcap,
127126
block_table,
128127
seqused_k,
128+
cu_seqlens_q,
129129
block_size,
130130
max_seqlen_k,
131+
sliding_window,
131132
)
132133
mx.synchronize()
133134

vllm_metal/metal/kernels_v2/pagedattention.metal

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,34 @@ inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid,
756756
#define MIN(a, b) ((a) < (b) ? (a) : (b))
757757
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
758758

759+
// Binary search to find which sequence a global query token belongs to.
760+
//
761+
// In varlen (ragged-batch) attention, queries from multiple sequences are
762+
// packed contiguously into a flat array:
763+
// q[0..q_len_0-1] → seq 0, q[q_len_0..q_len_0+q_len_1-1] → seq 1, ...
764+
// The kernel launches one threadgroup per (head, query_token) in a flat grid.
765+
// Each threadgroup needs to discover which sequence it belongs to so it can
766+
// look up the correct block_table row, kv_len, and causal mask boundary.
767+
//
768+
// This is the same approach used by the upstream vLLM unified Triton kernel
769+
// (triton_unified_attention.py:find_seq_idx) and FlashAttention's varlen API.
770+
//
771+
// cu_seqlens_q is sorted ascending: [0, q_len_0, q_len_0+q_len_1, ...].
772+
// Returns seq_idx such that cu_seqlens_q[seq_idx] <= q_token_idx < cu_seqlens_q[seq_idx+1].
773+
inline int find_seq_idx(const device int32_t *cu_seqlens_q,
774+
int q_token_idx, int num_seqs) {
775+
int lo = 0, hi = num_seqs;
776+
while (lo < hi) {
777+
int mid = (lo + hi + 1) / 2;
778+
if (cu_seqlens_q[mid] <= q_token_idx) {
779+
lo = mid;
780+
} else {
781+
hi = mid - 1;
782+
}
783+
}
784+
return lo;
785+
}
786+
759787
constant bool use_partitioning [[function_constant(10)]];
760788
constant bool use_alibi [[function_constant(20)]];
761789
constant bool use_fp8_scales [[function_constant(30)]];
@@ -795,24 +823,41 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
795823
const constant int &kv_head_stride [[buffer(17)]],
796824
const device float *sinks
797825
[[buffer(18), function_constant(use_sinks)]], // [num_heads]
826+
device const int32_t *cu_seqlens_q [[buffer(19)]], // [num_seqs + 1]
827+
const constant int &num_seqs [[buffer(20)]],
828+
const constant int &sliding_window [[buffer(21)]], // -1 = disabled
798829
threadgroup char *shared_mem [[threadgroup(0)]],
799830
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
800831
uint3 threadgroups_per_grid [[threadgroups_per_grid]],
801832
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
802833
uint simd_tid [[simdgroup_index_in_threadgroup]],
803834
uint simd_lid [[thread_index_in_simdgroup]]) {
804-
const int seq_idx = threadgroup_position_in_grid.y;
835+
// Varlen: each threadgroup handles one query token.
836+
// Use binary search on cu_seqlens_q to find which sequence it belongs to.
837+
const int q_token_idx = threadgroup_position_in_grid.y;
838+
const int seq_idx = find_seq_idx(cu_seqlens_q, q_token_idx, num_seqs);
839+
const int q_seq_start = cu_seqlens_q[seq_idx];
840+
const int q_len = cu_seqlens_q[seq_idx + 1] - q_seq_start;
841+
const int q_pos_in_seq = q_token_idx - q_seq_start;
805842
const int partition_idx = threadgroup_position_in_grid.z;
806843
const int max_num_partitions = threadgroups_per_grid.z;
807844
const int thread_idx = thread_position_in_threadgroup.x;
808845
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
809-
const uint32_t context_len = context_lens[seq_idx];
810-
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
846+
const uint32_t context_len = context_lens[seq_idx]; // total KV length for this seq
847+
848+
// Causal: this query token can attend to KV positions [0, effective_context_len).
849+
const int effective_context_len = (int)context_len - q_len + q_pos_in_seq + 1;
850+
if (effective_context_len <= 0) {
851+
// No KV tokens to attend to. Caller guarantees out is zero-initialized.
852+
return;
853+
}
854+
855+
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= effective_context_len) {
811856
// No work to do. Terminate the thread block.
812857
return;
813858
}
814859

815-
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
860+
const int num_context_blocks = DIVIDE_ROUND_UP(effective_context_len, BLOCK_SIZE);
816861
const int num_blocks_per_partition =
817862
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
818863

@@ -867,7 +912,7 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
867912
// For example, if the thread group size is 4, then the first thread in the
868913
// group has 0, 4, 8, ... th vectors of the query, and the second thread has
869914
// 1, 5, 9, ... th vectors of the query, and so on.
870-
const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
915+
const device T *q_ptr = q + q_token_idx * q_stride + head_idx * HEAD_SIZE;
871916
threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
872917
#pragma unroll
873918
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
@@ -955,15 +1000,20 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
9551000
float qk = scale * Qk_dot<T, THREAD_GROUP_SIZE>::dot(
9561001
q_vecs[thread_group_offset], k_vecs);
9571002

958-
if (softcapping != 1.0) {
1003+
if (softcapping > 0.0f) {
9591004
qk = tanh(qk / softcapping) * softcapping;
9601005
}
9611006

9621007
qk +=
963-
(alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
1008+
(alibi_slope != 0) ? alibi_slope * (token_idx - effective_context_len + 1) : 0;
9641009

9651010
if (thread_group_offset == 0) {
966-
const bool mask = token_idx >= context_len;
1011+
// Causal mask: only attend to KV positions < effective_context_len.
1012+
bool mask = token_idx >= effective_context_len;
1013+
// Sliding window mask: skip positions too far in the past.
1014+
if (sliding_window >= 0) {
1015+
mask = mask || (token_idx < effective_context_len - sliding_window);
1016+
}
9671017
warp_scores[physical_block_offset] = mask ? -FLT_MAX : qk;
9681018
}
9691019
}
@@ -981,7 +1031,7 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
9811031
// Valid tokens in this block:
9821032
const int block_start_token = block_idx * BLOCK_SIZE;
9831033
const int block_valid_tokens =
984-
MIN(BLOCK_SIZE, (int)context_len - block_start_token);
1034+
MIN(BLOCK_SIZE, effective_context_len - block_start_token);
9851035

9861036
// Find max score in this block (all lanes participate for speed).
9871037
float block_max = -FLT_MAX;
@@ -1058,13 +1108,14 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
10581108
// reached by all threads in the threadgroup).
10591109

10601110
// If partitioning is enabled, store the partial result for the reduce kernel.
1111+
// Indexed by q_token_idx (not seq_idx) for varlen compatibility.
10611112
if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) {
10621113
device float *max_logits_ptr =
1063-
max_logits + seq_idx * num_heads * max_num_partitions +
1114+
max_logits + q_token_idx * num_heads * max_num_partitions +
10641115
head_idx * max_num_partitions + partition_idx;
10651116
*max_logits_ptr = warp_m;
10661117
device float *exp_sums_ptr = exp_sums +
1067-
seq_idx * num_heads * max_num_partitions +
1118+
q_token_idx * num_heads * max_num_partitions +
10681119
head_idx * max_num_partitions + partition_idx;
10691120
*exp_sums_ptr = warp_l;
10701121
}
@@ -1143,7 +1194,7 @@ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE,
11431194
const float inv_l = 1.f / (warp_l + 1e-6f);
11441195

11451196
device T *out_ptr =
1146-
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
1197+
out + q_token_idx * num_heads * max_num_partitions * HEAD_SIZE +
11471198
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
11481199
#pragma unroll
11491200
for (int j = 0; j < V_ELEMS_PER_THREAD; j++) {
@@ -1165,6 +1216,8 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
11651216
const constant int &max_num_partitions [[buffer(5)]],
11661217
const device float *sinks
11671218
[[buffer(6), function_constant(use_sinks)]], // [num_heads]
1219+
device const int32_t *cu_seqlens_q [[buffer(7)]], // [num_seqs + 1]
1220+
const constant int &num_seqs [[buffer(8)]],
11681221
threadgroup char *shared_mem [[threadgroup(0)]],
11691222
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
11701223
uint3 threadgroups_per_grid [[threadgroups_per_grid]],
@@ -1174,15 +1227,21 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
11741227
uint simd_lid [[thread_index_in_simdgroup]]) {
11751228
const int num_heads = threadgroups_per_grid.x;
11761229
const int head_idx = threadgroup_position_in_grid.x;
1177-
const int seq_idx = threadgroup_position_in_grid.y;
1230+
// Varlen: grid.y is q_token_idx (one per query token), not seq_idx.
1231+
const int q_token_idx = threadgroup_position_in_grid.y;
1232+
const int seq_idx = find_seq_idx(cu_seqlens_q, q_token_idx, num_seqs);
1233+
const int q_seq_start = cu_seqlens_q[seq_idx];
1234+
const int q_len = cu_seqlens_q[seq_idx + 1] - q_seq_start;
1235+
const int q_pos_in_seq = q_token_idx - q_seq_start;
11781236
const uint32_t context_len = context_lens[seq_idx];
1179-
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
1237+
const int effective_context_len = (int)context_len - q_len + q_pos_in_seq + 1;
1238+
const int num_partitions = DIVIDE_ROUND_UP(effective_context_len, PARTITION_SIZE);
11801239
if (num_partitions == 1 && !use_sinks) {
11811240
// No need to reduce. Only copy tmp_out to out.
11821241
device T *out_ptr =
1183-
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
1242+
out + q_token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
11841243
const device T *tmp_out_ptr =
1185-
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
1244+
tmp_out + q_token_idx * num_heads * max_num_partitions * HEAD_SIZE +
11861245
head_idx * max_num_partitions * HEAD_SIZE;
11871246
for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
11881247
i += threads_per_threadgroup.x) {
@@ -1203,7 +1262,7 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
12031262
threadgroup float *shared_max_logits =
12041263
reinterpret_cast<threadgroup float *>(shared_mem);
12051264
const device float *max_logits_ptr =
1206-
max_logits + seq_idx * num_heads * max_num_partitions +
1265+
max_logits + q_token_idx * num_heads * max_num_partitions +
12071266
head_idx * max_num_partitions;
12081267
float max_logit = -FLT_MAX;
12091268
for (int i = thread_position_in_threadgroup.x; i < num_partitions;
@@ -1242,7 +1301,7 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
12421301
threadgroup float *shared_exp_sums = reinterpret_cast<threadgroup float *>(
12431302
shared_mem + sizeof(float) * num_partitions);
12441303
const device float *exp_sums_ptr = exp_sums +
1245-
seq_idx * num_heads * max_num_partitions +
1304+
q_token_idx * num_heads * max_num_partitions +
12461305
head_idx * max_num_partitions;
12471306
float global_exp_sum = 0.0f;
12481307
for (int i = thread_position_in_threadgroup.x; i < num_partitions;
@@ -1265,10 +1324,10 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
12651324

12661325
// Aggregate tmp_out to out.
12671326
const device T *tmp_out_ptr =
1268-
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
1327+
tmp_out + q_token_idx * num_heads * max_num_partitions * HEAD_SIZE +
12691328
head_idx * max_num_partitions * HEAD_SIZE;
12701329
device T *out_ptr =
1271-
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
1330+
out + q_token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
12721331
#pragma unroll
12731332
for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
12741333
i += NUM_THREADS) {
@@ -1313,6 +1372,9 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
13131372
const constant int &kv_block_stride [[buffer(16)]], \
13141373
const constant int &kv_head_stride [[buffer(17)]], \
13151374
const device float *sinks [[buffer(18), function_constant(use_sinks)]], \
1375+
device const int32_t *cu_seqlens_q [[buffer(19)]], \
1376+
const constant int &num_seqs [[buffer(20)]], \
1377+
const constant int &sliding_window [[buffer(21)]], \
13161378
threadgroup char *shared_mem [[threadgroup(0)]], \
13171379
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
13181380
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
@@ -1334,6 +1396,8 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
13341396
device uint32_t *context_lens [[buffer(4)]], \
13351397
const constant int &max_num_partitions [[buffer(5)]], \
13361398
const device float *sinks [[buffer(6), function_constant(use_sinks)]], \
1399+
device const int32_t *cu_seqlens_q [[buffer(7)]], \
1400+
const constant int &num_seqs [[buffer(8)]], \
13371401
threadgroup char *shared_mem [[threadgroup(0)]], \
13381402
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
13391403
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \

0 commit comments

Comments
 (0)