Skip to content

Commit fba263c

Browse files
authored
Add paged KV cache with HF Metal kernel for kv cache read/write by-reference decode (#92)
Usage Page KV Cache On ``` VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 vllm serve Qwen/Qwen3-0.6B --max-model-len 2048 ``` ``` vllm bench serve --backend vllm --model Qwen/Qwen3-0.6B \ --endpoint /v1/completions \ --dataset-name sharegpt \ --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ --num-prompts 100 \ --request-rate 10 \ --max-concurrency 32 ``` baseline: default, Page KV Cache Off, mlx_lm ``` vllm serve Qwen/Qwen3-0.6B --max-model-len 2048 ``` ### Benchmark Apple M2 Max 36GB, ShareGPT, num-prompts 100, request-rate 10, max-concurrency 32 <img width="2225" height="766" alt="bench_comparison" src="https://github.com/user-attachments/assets/c66c4847-f522-4cec-947c-ef5321c36a0b" /> * TTFT: better * output troughput: better * Mean ITL: worse output equivelence (paged kv cache vs mlx_lm), both: * Total input tokens:23260 * Total generated tokens:22061 Memory allocation * mlx_lm use `auto` to only use just enough memory * paged kv cache use VLLM_METAL_MEMORY_FRACTION, allocate as much memory as possible. Paged KV cache trades higher memory usage for better concurrency, making it systemwide faster than mlx_lm. Whether it's also faster at the kernel level is unclear, but advanced features like continuous batching and chunked prefilling are infeasible to support with mlx_lm alone. ### PR Summary <details> <summary> Patch the mlx models with paged attention kernel. </summary> - mlx_lm requires contiguous kv cache, and this PR use paged kv cache (not contiguous). - Paged kv cache is a prerequisite of future continuous batching and real chunked prefilling. - Integrates the https://huggingface.co/kernels-community/paged-attention Metal shader for paged KV cache on Apple Silicon (This can be replaced by mlx native page attention or other better kernels in the future) - Patches existing mlx_lm model attention layers at runtime with a wrapper that routes to the external Metal kernel for cache read/write, while keeping MLX for projections and other layers. - Prefill: standard MLX causal SDPA, then writes K/V to MPS paged cache via reshape_and_cache - Decode: zero-copy attention via paged_attention_v1 — reads K/V directly from block tables on GPU, eliminating the O(seq_len) gather/copy per layer per step - Falls back to original mlx_lm attention when the env var is not set </details> <details> <summary> Implement the model runner <--> vllm scheduler contract, so they are aligned. </summary> - for chunked prefilling 0:n-1: sample-then-drop the last token - for chunked prefilling n: sample-and-keep the last token - for decoding: generate 1 new token </details> ### Known Limitation & Planned Future PRs: * **[High Priority]** when setting too small `VLLM_METAL_MEMORY_FRACTION=0.1`. hit `RuntimeError: Not enough free blocks: need 21, have 0` . This is because: All kv blocks have been consumed, while the prefilling/decoding have not been finished, then deadlock. Need to implement vllm-metal's paged kv cache preemption to align with vllm scheduler contract. * **[High Priority]** torch_to_mlx in the tensor bridge may not be true zero copy. * [Bug] not working with HuggingFaceTB/SmolVLM-Instruct, #114 might be the fix * [Feature] re-enable prefix caching under paged kv cache. Prior version: #80 * [Medium Priority] real chunked prefilling. This PR's implementation is wasteful. Expected: chunk n prefill read the 0:n-1 kv cache, and only prefill n. Actual: prefill all 0:n kv cache each time. The time complexity is quadratic in terms of the number of chunks. Why? just to satisfy vllm scheduler. * [Medium Priority] real continuous batching to align with upstream vllm. This requires var len prefilling & decoding operating on `[total_num_token, *]` instead of the current `[batch, seq, *]`. * [Refactor] #97 * [Refactor] five separate forward paths (_prefill_single, _prefill_single_request_paged .etc) that share the same pre/post processing. Maybe we can merged duplicate codes, but that's for the sake of aesthetics. * [Doc] Readme architecture figure is no longer accurate. * ~~[Testing] Need to test on macos 14/15, metal 3.2. It is expected to work.~~ It works. ### FAQ <details> <summary>Why hack the paged KV cache as a global variable?</summary> The model's `__call__` signature is `(input_ids, cache=...)` — and `mlx_lm`'s call requires contiguous tensors with no additional parameters. There's no way to pass `slot_mapping`, `block_tables`, or any other per-forward metadata down to the attention layers. This design is inspired by [nano-vllm](https://github.com/GeeeekExplorer/nano-vllm). </details> <details> <summary>Why use this attention kernel?</summary> This kernel supports variable-length prefilling and decoding, so attention can be computed over `[total_tokens, *]` instead of `[batch, seq, *]`. This is essential for supporting real continuous batching in the future. </details> <details> <summary>Each call to the attention kernel triggers an MLX ↔ Torch round trip?</summary> Yes, but it's okay as long as the MLX-to-Torch round trip is implemented in zero-copy mode. Besides, this kernel can be replaced by better ones in the future if they become available. The most important thing is to get the whole system working end-to-end (chunked prefill, continuous batching) first; then we can swap in better modules later. </details> <details> <summary>What would a future paged attention kernel look like?</summary> It would need to support variable sequence lengths, `slot_map`, etc. — similar to `flash_attn_varlen_func` and `flash_attn_with_kvcache` from FlashAttention. The difficulties are: 1. HuggingFace kernel libraries only expose PyTorch bindings, which require type conversion from our MLX tensors. 2. As far as I understand, a proper FlashAttention-style implementation would need to be written directly in Metal, not in MLX. </details> ### Acknowledgement: Early prototype #71 --------- Signed-off-by: ran <hzz5361@psu.edu>
1 parent 3fbaea6 commit fba263c

17 files changed

Lines changed: 2018 additions & 112 deletions

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ vLLM Metal is a plugin that enables vLLM to run on Apple Silicon Macs using MLX
99
- **MLX-accelerated inference**: faster than PyTorch MPS on Apple Silicon
1010
- **Unified memory**: True zero-copy operations leveraging Apple Silicon's unified memory architecture
1111
- **vLLM compatibility**: Full integration with vLLM's engine, scheduler, and OpenAI-compatible API
12-
- **Paged attention**: Efficient KV cache management for long sequences
12+
- **Paged attention** *(experimental)*: Efficient KV cache management for long sequences — opt-in via `VLLM_METAL_USE_PAGED_ATTENTION=1` (requires `pip install 'vllm-metal[paged]'`); default path uses MLX-managed KV cache
1313
- **GQA support**: Grouped-Query Attention for efficient inference
1414

1515
## Requirements
@@ -78,6 +78,7 @@ Environment variables for customization:
7878
| `VLLM_METAL_USE_MLX` | `1` | Use MLX for compute (1=yes, 0=no) |
7979
| `VLLM_MLX_DEVICE` | `gpu` | MLX device (`gpu` or `cpu`) |
8080
| `VLLM_METAL_BLOCK_SIZE` | `16` | KV cache block size |
81+
| `VLLM_METAL_USE_PAGED_ATTENTION` | `0` | Enable experimental paged KV cache (requires `pip install 'vllm-metal[paged]'`) |
8182
| `VLLM_METAL_DEBUG` | `0` | Enable debug logging |
8283
| `VLLM_USE_MODELSCOPE` | `False` | Set True to change model registry to <https://www.modelscope.cn/> |
8384
| `VLLM_METAL_MODELSCOPE_CACHE` | None | Specify the absolute path of the local model |

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ dependencies = [
4141
]
4242

4343
[project.optional-dependencies]
44+
paged = [
45+
# Paged attention Metal kernel (opt-in, experimental)
46+
"kernels>=0.4.5; platform_system == 'Darwin' and platform_machine == 'arm64'",
47+
]
4448
vllm = ["vllm>=0.14.0"]
4549
stt = [
4650
# Speech-to-text audio processing (Whisper models)
@@ -54,7 +58,7 @@ dev = [
5458
"mypy>=1.19.1",
5559
]
5660
all = [
57-
"vllm-metal[vllm,stt,dev]",
61+
"vllm-metal[vllm,paged,stt,dev]",
5862
]
5963

6064
[project.urls]

scripts/lib.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ ensure_venv() {
4949
# Install dev dependencies
5050
install_dev_deps() {
5151
section "Installing dependencies"
52-
uv pip install -e ".[dev]"
52+
uv pip install -e ".[dev,paged]"
5353
}
5454

5555
# Full development environment setup

src/lib.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub struct BlockAllocator {
2121
/// Free blocks stored in a deque for O(1) operations
2222
free_blocks: VecDeque<usize>,
2323
/// Mapping from sequence ID to allocated blocks
24-
sequence_blocks: HashMap<i64, Vec<usize>>,
24+
sequence_blocks: HashMap<String, Vec<usize>>,
2525
/// Total number of blocks
2626
num_blocks: usize,
2727
}
@@ -53,7 +53,7 @@ impl BlockAllocator {
5353
///
5454
/// # Raises
5555
/// RuntimeError if not enough free blocks
56-
pub fn allocate_blocks(&mut self, seq_id: i64, num_blocks: usize) -> PyResult<Vec<usize>> {
56+
pub fn allocate_blocks(&mut self, seq_id: String, num_blocks: usize) -> PyResult<Vec<usize>> {
5757
if self.free_blocks.len() < num_blocks {
5858
return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
5959
"Not enough free blocks: need {}, have {}",
@@ -83,7 +83,7 @@ impl BlockAllocator {
8383
///
8484
/// # Arguments
8585
/// * `seq_id` - Sequence identifier
86-
pub fn free_sequence(&mut self, seq_id: i64) {
86+
pub fn free_sequence(&mut self, seq_id: String) {
8787
if let Some(blocks) = self.sequence_blocks.remove(&seq_id) {
8888
// Return blocks to the free pool
8989
for block_idx in blocks {
@@ -99,7 +99,7 @@ impl BlockAllocator {
9999
///
100100
/// # Returns
101101
/// List of block indices for the sequence
102-
pub fn get_sequence_blocks(&self, seq_id: i64) -> Vec<usize> {
102+
pub fn get_sequence_blocks(&self, seq_id: String) -> Vec<usize> {
103103
self.sequence_blocks
104104
.get(&seq_id)
105105
.cloned()
@@ -119,7 +119,7 @@ impl BlockAllocator {
119119
}
120120

121121
/// Check if sequence has blocks allocated.
122-
pub fn has_sequence(&self, seq_id: i64) -> bool {
122+
pub fn has_sequence(&self, seq_id: String) -> bool {
123123
self.sequence_blocks.contains_key(&seq_id)
124124
}
125125

@@ -130,7 +130,7 @@ impl BlockAllocator {
130130
}
131131

132132
/// Get all sequence blocks as a dictionary.
133-
pub fn get_all_sequence_blocks(&self) -> HashMap<i64, Vec<usize>> {
133+
pub fn get_all_sequence_blocks(&self) -> HashMap<String, Vec<usize>> {
134134
self.sequence_blocks.clone()
135135
}
136136
}
@@ -280,10 +280,10 @@ pub fn compute_kv_block_indices(
280280
/// Batch compute block indices for multiple sequences.
281281
#[pyfunction]
282282
pub fn batch_compute_kv_indices(
283-
sequence_blocks: HashMap<i64, Vec<usize>>,
284-
seq_lens: HashMap<i64, usize>,
283+
sequence_blocks: HashMap<String, Vec<usize>>,
284+
seq_lens: HashMap<String, usize>,
285285
block_size: usize,
286-
) -> HashMap<i64, Vec<(usize, usize, usize)>> {
286+
) -> HashMap<String, Vec<(usize, usize, usize)>> {
287287
let mut result = HashMap::with_capacity(sequence_blocks.len());
288288

289289
for (seq_id, blocks) in sequence_blocks {

tests/test_kernel_loader.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for kernel_loader: OS-aware revision pinning for Metal compatibility.
3+
4+
Verifies that:
5+
- macOS 16+ uses the latest HF kernel (default revision)
6+
- macOS 15 and earlier pins to the Nov 2025 compat revision (Metal 3.2)
7+
- Both revisions actually load and expose the expected ops
8+
9+
Run with:
10+
python -m pytest tests/test_kernel_loader.py -v -s
11+
"""
12+
13+
from __future__ import annotations
14+
15+
from unittest import mock
16+
17+
import pytest
18+
19+
pytest.importorskip("kernels")
20+
21+
# ---------------------------------------------------------------------------
22+
# Unit tests (no network, no GPU)
23+
# ---------------------------------------------------------------------------
24+
25+
26+
class TestNeedsCompatRevision:
27+
"""Test _needs_compat_revision() with mocked macOS versions."""
28+
29+
@pytest.mark.parametrize(
30+
"ver, expected",
31+
[
32+
("15.7.4", True), # macOS 15 — needs compat
33+
("14.5", True), # macOS 14 — needs compat
34+
("26.3", False), # macOS 26 — modern
35+
("", False), # empty — safe default
36+
],
37+
)
38+
def test_version_check(self, ver, expected):
39+
from vllm_metal.metal_kernel_backend.kernel_loader import _needs_compat_revision
40+
41+
with mock.patch("platform.mac_ver", return_value=(ver, ("", "", ""), "")):
42+
assert _needs_compat_revision() is expected
43+
44+
45+
class TestGetKernelRevisionSelection:
46+
"""Test that get_paged_attention_ops passes the right revision to get_kernel."""
47+
48+
def _reset_kernel_cache(self):
49+
import vllm_metal.metal_kernel_backend.kernel_loader as kl
50+
51+
kl._kernel = None
52+
53+
def test_macos_15_uses_compat_revision(self):
54+
self._reset_kernel_cache()
55+
with (
56+
mock.patch("platform.mac_ver", return_value=("15.7.4", ("", "", ""), "")),
57+
mock.patch("kernels.get_kernel", return_value=mock.MagicMock()) as mk,
58+
):
59+
from vllm_metal.metal_kernel_backend.kernel_loader import (
60+
_MACOS15_COMPAT_REVISION,
61+
get_paged_attention_ops,
62+
)
63+
64+
get_paged_attention_ops()
65+
mk.assert_called_once_with(
66+
"kernels-community/paged-attention",
67+
revision=_MACOS15_COMPAT_REVISION,
68+
)
69+
self._reset_kernel_cache()
70+
71+
def test_macos_26_uses_latest(self):
72+
self._reset_kernel_cache()
73+
with (
74+
mock.patch("platform.mac_ver", return_value=("26.3", ("", "", ""), "")),
75+
mock.patch("kernels.get_kernel", return_value=mock.MagicMock()) as mk,
76+
):
77+
from vllm_metal.metal_kernel_backend.kernel_loader import (
78+
get_paged_attention_ops,
79+
)
80+
81+
get_paged_attention_ops()
82+
mk.assert_called_once_with(
83+
"kernels-community/paged-attention",
84+
revision=None,
85+
)
86+
self._reset_kernel_cache()
87+
88+
89+
# ---------------------------------------------------------------------------
90+
# Integration tests (require network + MPS)
91+
# ---------------------------------------------------------------------------
92+
93+
94+
def _mps_available() -> bool:
95+
try:
96+
import torch
97+
98+
return torch.backends.mps.is_available()
99+
except Exception:
100+
return False
101+
102+
103+
@pytest.mark.skipif(not _mps_available(), reason="MPS not available")
104+
class TestKernelLoadsForReal:
105+
"""Actually load the kernel from HuggingFace and verify ops exist."""
106+
107+
_EXPECTED_OPS = {"reshape_and_cache", "paged_attention_v1"}
108+
109+
def test_latest_revision_loads(self):
110+
from kernels import get_kernel
111+
112+
kernel = get_kernel("kernels-community/paged-attention")
113+
ops = set(dir(kernel))
114+
assert self._EXPECTED_OPS <= ops, f"Missing ops: {self._EXPECTED_OPS - ops}"
115+
116+
def test_compat_revision_loads(self):
117+
from kernels import get_kernel
118+
119+
from vllm_metal.metal_kernel_backend.kernel_loader import (
120+
_MACOS15_COMPAT_REVISION,
121+
)
122+
123+
kernel = get_kernel(
124+
"kernels-community/paged-attention",
125+
revision=_MACOS15_COMPAT_REVISION,
126+
)
127+
ops = set(dir(kernel))
128+
assert self._EXPECTED_OPS <= ops, f"Missing ops: {self._EXPECTED_OPS - ops}"

0 commit comments

Comments
 (0)