Commit fba263c
authored
Add paged KV cache with HF Metal kernel for kv cache read/write by-reference decode (#92)
Usage
Page KV Cache On
```
VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 vllm serve Qwen/Qwen3-0.6B --max-model-len 2048
```
```
vllm bench serve --backend vllm --model Qwen/Qwen3-0.6B \
--endpoint /v1/completions \
--dataset-name sharegpt \
--dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
--num-prompts 100 \
--request-rate 10 \
--max-concurrency 32
```
baseline: default, Page KV Cache Off, mlx_lm
```
vllm serve Qwen/Qwen3-0.6B --max-model-len 2048
```
### Benchmark
Apple M2 Max 36GB, ShareGPT, num-prompts 100, request-rate 10,
max-concurrency 32
<img width="2225" height="766" alt="bench_comparison"
src="https://github.com/user-attachments/assets/c66c4847-f522-4cec-947c-ef5321c36a0b"
/>
* TTFT: better
* output troughput: better
* Mean ITL: worse
output equivelence (paged kv cache vs mlx_lm), both:
* Total input tokens:23260
* Total generated tokens:22061
Memory allocation
* mlx_lm use `auto` to only use just enough memory
* paged kv cache use VLLM_METAL_MEMORY_FRACTION, allocate as much memory
as possible.
Paged KV cache trades higher memory usage for better concurrency, making
it systemwide faster than mlx_lm. Whether it's also faster at the kernel
level is unclear, but advanced features like continuous batching and
chunked prefilling are infeasible to support with mlx_lm alone.
### PR Summary
<details>
<summary> Patch the mlx models with paged attention kernel. </summary>
- mlx_lm requires contiguous kv cache, and this PR use paged kv cache
(not contiguous).
- Paged kv cache is a prerequisite of future continuous batching and
real chunked prefilling.
- Integrates the
https://huggingface.co/kernels-community/paged-attention Metal shader
for paged KV cache on Apple Silicon (This can be replaced by mlx native
page attention or other better kernels in the future)
- Patches existing mlx_lm model attention layers at runtime with a
wrapper that routes to the external Metal kernel for cache read/write,
while keeping MLX for projections and other layers.
- Prefill: standard MLX causal SDPA, then writes K/V to MPS paged cache
via reshape_and_cache
- Decode: zero-copy attention via paged_attention_v1 — reads K/V
directly from block tables on GPU, eliminating the O(seq_len)
gather/copy per layer per step
- Falls back to original mlx_lm attention when the env var is not set
</details>
<details>
<summary> Implement the model runner <--> vllm scheduler contract, so
they are aligned. </summary>
- for chunked prefilling 0:n-1: sample-then-drop the last token
- for chunked prefilling n: sample-and-keep the last token
- for decoding: generate 1 new token
</details>
### Known Limitation & Planned Future PRs:
* **[High Priority]** when setting too small
`VLLM_METAL_MEMORY_FRACTION=0.1`. hit `RuntimeError: Not enough free
blocks: need 21, have 0` . This is because: All kv blocks have been
consumed, while the prefilling/decoding have not been finished, then
deadlock. Need to implement vllm-metal's paged kv cache preemption to
align with vllm scheduler contract.
* **[High Priority]** torch_to_mlx in the tensor bridge may not be true
zero copy.
* [Bug] not working with HuggingFaceTB/SmolVLM-Instruct,
#114 might be the fix
* [Feature] re-enable prefix caching under paged kv cache. Prior
version: #80
* [Medium Priority] real chunked prefilling. This PR's implementation is
wasteful. Expected: chunk n prefill read the 0:n-1 kv cache, and only
prefill n. Actual: prefill all 0:n kv cache each time. The time
complexity is quadratic in terms of the number of chunks. Why? just to
satisfy vllm scheduler.
* [Medium Priority] real continuous batching to align with upstream
vllm. This requires var len prefilling & decoding operating on
`[total_num_token, *]` instead of the current `[batch, seq, *]`.
* [Refactor] #97
* [Refactor] five separate forward paths (_prefill_single,
_prefill_single_request_paged .etc) that share the same pre/post
processing. Maybe we can merged duplicate codes, but that's for the sake
of aesthetics.
* [Doc] Readme architecture figure is no longer accurate.
* ~~[Testing] Need to test on macos 14/15, metal 3.2. It is expected to
work.~~ It works.
### FAQ
<details>
<summary>Why hack the paged KV cache as a global variable?</summary>
The model's `__call__` signature is `(input_ids, cache=...)` — and
`mlx_lm`'s call requires contiguous tensors with no additional
parameters. There's no way to pass `slot_mapping`, `block_tables`, or
any other per-forward metadata down to the attention layers. This design
is inspired by [nano-vllm](https://github.com/GeeeekExplorer/nano-vllm).
</details>
<details>
<summary>Why use this attention kernel?</summary>
This kernel supports variable-length prefilling and decoding, so
attention can be computed over `[total_tokens, *]` instead of `[batch,
seq, *]`. This is essential for supporting real continuous batching in
the future.
</details>
<details>
<summary>Each call to the attention kernel triggers an MLX ↔ Torch round
trip?</summary>
Yes, but it's okay as long as the MLX-to-Torch round trip is implemented
in zero-copy mode. Besides, this kernel can be replaced by better ones
in the future if they become available. The most important thing is to
get the whole system working end-to-end (chunked prefill, continuous
batching) first; then we can swap in better modules later.
</details>
<details>
<summary>What would a future paged attention kernel look like?</summary>
It would need to support variable sequence lengths, `slot_map`, etc. —
similar to `flash_attn_varlen_func` and `flash_attn_with_kvcache` from
FlashAttention. The difficulties are:
1. HuggingFace kernel libraries only expose PyTorch bindings, which
require type conversion from our MLX tensors.
2. As far as I understand, a proper FlashAttention-style implementation
would need to be written directly in Metal, not in MLX.
</details>
### Acknowledgement:
Early prototype #71
---------
Signed-off-by: ran <hzz5361@psu.edu>1 parent 3fbaea6 commit fba263c
17 files changed
Lines changed: 2018 additions & 112 deletions
File tree
- scripts
- src
- tests
- vllm_metal
- metal_kernel_backend
- mlx_backend
- v1
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
12 | | - | |
| 12 | + | |
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| |||
78 | 78 | | |
79 | 79 | | |
80 | 80 | | |
| 81 | + | |
81 | 82 | | |
82 | 83 | | |
83 | 84 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
44 | 48 | | |
45 | 49 | | |
46 | 50 | | |
| |||
54 | 58 | | |
55 | 59 | | |
56 | 60 | | |
57 | | - | |
| 61 | + | |
58 | 62 | | |
59 | 63 | | |
60 | 64 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
49 | 49 | | |
50 | 50 | | |
51 | 51 | | |
52 | | - | |
| 52 | + | |
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
24 | | - | |
| 24 | + | |
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | | - | |
| 56 | + | |
57 | 57 | | |
58 | 58 | | |
59 | 59 | | |
| |||
83 | 83 | | |
84 | 84 | | |
85 | 85 | | |
86 | | - | |
| 86 | + | |
87 | 87 | | |
88 | 88 | | |
89 | 89 | | |
| |||
99 | 99 | | |
100 | 100 | | |
101 | 101 | | |
102 | | - | |
| 102 | + | |
103 | 103 | | |
104 | 104 | | |
105 | 105 | | |
| |||
119 | 119 | | |
120 | 120 | | |
121 | 121 | | |
122 | | - | |
| 122 | + | |
123 | 123 | | |
124 | 124 | | |
125 | 125 | | |
| |||
130 | 130 | | |
131 | 131 | | |
132 | 132 | | |
133 | | - | |
| 133 | + | |
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
| |||
280 | 280 | | |
281 | 281 | | |
282 | 282 | | |
283 | | - | |
284 | | - | |
| 283 | + | |
| 284 | + | |
285 | 285 | | |
286 | | - | |
| 286 | + | |
287 | 287 | | |
288 | 288 | | |
289 | 289 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
0 commit comments