Skip to content

Commit 16c19de

Browse files
committed
Add Metal platform plugin with MLX backend and tests
Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent 78276c7 commit 16c19de

22 files changed

Lines changed: 3288 additions & 27 deletions

README.md

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,39 +27,39 @@ vLLM Metal is a plugin that enables vLLM to run on Apple Silicon Macs using MLX
2727
## Architecture
2828

2929
```
30-
┌────────────────────────────────────────────────────────────
31-
│ vLLM Core (Unchanged)
32-
│ Engine, Scheduler, API Server, Tokenizers
33-
└────────────────────────────────────────────────────────────
30+
┌──────────────────────────────────────────────────────────┐
31+
│ vLLM Core (Unchanged) │
32+
│ Engine, Scheduler, API Server, Tokenizers │
33+
└──────────────────────────────────────────────────────────┘
3434
3535
36-
┌────────────────────────────────────────────────────────────
37-
│ vllm_metal Plugin Layer
38-
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
39-
│ │MetalPlatform│ │ MetalWorker │ │ MetalModelRunner │ │
40-
│ │ (Platform) │ │ (Worker) │ │ (ModelRunner) │ │
41-
│ └─────────────┘ └─────────────┘ └─────────────────────┘ │
42-
└────────────────────────────────────────────────────────────
36+
┌──────────────────────────────────────────────────────────┐
37+
│ vllm_metal Plugin Layer │
38+
│ ┌─────────────┐ ┌─────────────┐ ┌───────────────────┐ │
39+
│ │MetalPlatform│ │ MetalWorker │ │ MetalModelRunner │ │
40+
│ │ (Platform) │ │ (Worker) │ │ (ModelRunner) │ │
41+
│ └─────────────┘ └─────────────┘ └───────────────────┘ │
42+
└──────────────────────────────────────────────────────────┘
4343
4444
45-
┌────────────────────────────────────────────────────────────
46-
│ Unified Compute Backend
47-
│ ┌──────────────────────┐ ┌─────────────────────────────┐ │
48-
│ │ MLX Backend │ │ PyTorch Backend │ │
49-
│ │ (Primary) │ │ (Model Loading/Interop) │ │
50-
│ │ │ │ │ │
51-
│ │ • SDPA Attention │ │ • HuggingFace Loading │ │
52-
│ │ • RMSNorm │ │ • Weight Conversion │ │
53-
│ │ • RoPE │ │ • Tensor Bridge │ │
54-
│ │ • Cache Ops │ │ │ │
55-
│ └──────────────────────┘ └─────────────────────────────┘ │
56-
└────────────────────────────────────────────────────────────
45+
┌──────────────────────────────────────────────────────────┐
46+
│ Unified Compute Backend │
47+
│ ┌──────────────────────┐ ┌───────────────────────────┐ │
48+
│ │ MLX Backend │ │ PyTorch Backend │ │
49+
│ │ (Primary) │ │ (Model Loading/Interop) │ │
50+
│ │ │ │ │ │
51+
│ │ • SDPA Attention │ │ • HuggingFace Loading │ │
52+
│ │ • RMSNorm │ │ • Weight Conversion │ │
53+
│ │ • RoPE │ │ • Tensor Bridge │ │
54+
│ │ • Cache Ops │ │ │ │
55+
│ └──────────────────────┘ └───────────────────────────┘ │
56+
└──────────────────────────────────────────────────────────┘
5757
5858
59-
┌────────────────────────────────────────────────────────────
60-
│ Metal GPU Layer
61-
│ Apple Silicon Unified Memory Architecture
62-
└────────────────────────────────────────────────────────────
59+
┌──────────────────────────────────────────────────────────┐
60+
│ Metal GPU Layer │
61+
│ Apple Silicon Unified Memory Architecture │
62+
└──────────────────────────────────────────────────────────┘
6363
```
6464

6565
## Configuration

tests/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for vLLM Metal plugin."""

tests/test_cache.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for KV cache implementations."""
3+
4+
import mlx.core as mx
5+
import pytest
6+
7+
from vllm_metal.mlx_backend.cache import KVCache, PagedKVCache
8+
9+
10+
class TestKVCache:
11+
"""Tests for simple KV cache."""
12+
13+
def test_cache_initialization(self) -> None:
14+
"""Test cache initialization."""
15+
cache = KVCache(
16+
num_layers=4,
17+
num_kv_heads=8,
18+
head_dim=64,
19+
max_seq_len=512,
20+
)
21+
22+
assert cache.num_layers == 4
23+
assert cache.num_kv_heads == 8
24+
assert cache.head_dim == 64
25+
assert cache.max_seq_len == 512
26+
assert cache.seq_len == 0
27+
28+
def test_cache_update(self) -> None:
29+
"""Test cache update and retrieval."""
30+
cache = KVCache(
31+
num_layers=2,
32+
num_kv_heads=4,
33+
head_dim=32,
34+
max_seq_len=128,
35+
)
36+
37+
# Create test tensors
38+
batch = 1
39+
seq_len = 4
40+
key = mx.random.normal((batch, seq_len, 4, 32))
41+
value = mx.random.normal((batch, seq_len, 4, 32))
42+
positions = mx.arange(seq_len)[None, :]
43+
44+
# Update layer 0
45+
cached_k, cached_v = cache.update(0, key, value, positions)
46+
mx.eval(cached_k, cached_v)
47+
48+
assert cached_k.shape == (1, seq_len, 4, 32)
49+
assert cached_v.shape == (1, seq_len, 4, 32)
50+
assert cache.seq_len == seq_len
51+
52+
def test_cache_incremental_update(self) -> None:
53+
"""Test incremental cache updates."""
54+
cache = KVCache(
55+
num_layers=2,
56+
num_kv_heads=4,
57+
head_dim=32,
58+
max_seq_len=128,
59+
)
60+
61+
# First update
62+
key1 = mx.random.normal((1, 4, 4, 32))
63+
value1 = mx.random.normal((1, 4, 4, 32))
64+
positions1 = mx.arange(4)[None, :]
65+
66+
cache.update(0, key1, value1, positions1)
67+
assert cache.seq_len == 4
68+
69+
# Second update (incremental)
70+
key2 = mx.random.normal((1, 1, 4, 32))
71+
value2 = mx.random.normal((1, 1, 4, 32))
72+
positions2 = mx.array([[4]])
73+
74+
cached_k, cached_v = cache.update(0, key2, value2, positions2)
75+
mx.eval(cached_k, cached_v)
76+
77+
assert cached_k.shape == (1, 5, 4, 32)
78+
assert cached_v.shape == (1, 5, 4, 32)
79+
assert cache.seq_len == 5
80+
81+
def test_cache_reset(self) -> None:
82+
"""Test cache reset."""
83+
cache = KVCache(
84+
num_layers=2,
85+
num_kv_heads=4,
86+
head_dim=32,
87+
max_seq_len=128,
88+
)
89+
90+
# Add some data
91+
key = mx.random.normal((1, 4, 4, 32))
92+
value = mx.random.normal((1, 4, 4, 32))
93+
positions = mx.arange(4)[None, :]
94+
cache.update(0, key, value, positions)
95+
96+
assert cache.seq_len == 4
97+
98+
# Reset
99+
cache.reset()
100+
101+
assert cache.seq_len == 0
102+
103+
104+
class TestPagedKVCache:
105+
"""Tests for paged KV cache."""
106+
107+
def test_paged_cache_initialization(self) -> None:
108+
"""Test paged cache initialization."""
109+
cache = PagedKVCache(
110+
num_layers=4,
111+
num_kv_heads=8,
112+
head_dim=64,
113+
num_blocks=100,
114+
block_size=16,
115+
)
116+
117+
assert cache.num_layers == 4
118+
assert cache.num_kv_heads == 8
119+
assert cache.head_dim == 64
120+
assert cache.num_blocks == 100
121+
assert cache.block_size == 16
122+
assert cache.num_free_blocks == 100
123+
124+
def test_block_allocation(self) -> None:
125+
"""Test block allocation."""
126+
cache = PagedKVCache(
127+
num_layers=2,
128+
num_kv_heads=4,
129+
head_dim=32,
130+
num_blocks=10,
131+
block_size=16,
132+
)
133+
134+
# Allocate blocks for sequence 0
135+
blocks = cache.allocate_blocks(seq_id=0, num_blocks=3)
136+
137+
assert len(blocks) == 3
138+
assert cache.num_free_blocks == 7
139+
assert 0 in cache.sequence_blocks
140+
141+
def test_block_allocation_insufficient(self) -> None:
142+
"""Test block allocation with insufficient blocks."""
143+
cache = PagedKVCache(
144+
num_layers=2,
145+
num_kv_heads=4,
146+
head_dim=32,
147+
num_blocks=5,
148+
block_size=16,
149+
)
150+
151+
# Try to allocate more blocks than available
152+
with pytest.raises(RuntimeError, match="Not enough free blocks"):
153+
cache.allocate_blocks(seq_id=0, num_blocks=10)
154+
155+
def test_sequence_free(self) -> None:
156+
"""Test freeing sequence blocks."""
157+
cache = PagedKVCache(
158+
num_layers=2,
159+
num_kv_heads=4,
160+
head_dim=32,
161+
num_blocks=10,
162+
block_size=16,
163+
)
164+
165+
# Allocate blocks
166+
cache.allocate_blocks(seq_id=0, num_blocks=3)
167+
cache.allocate_blocks(seq_id=1, num_blocks=2)
168+
169+
assert cache.num_free_blocks == 5
170+
171+
# Free sequence 0
172+
cache.free_sequence(seq_id=0)
173+
174+
assert cache.num_free_blocks == 8
175+
assert 0 not in cache.sequence_blocks
176+
assert 1 in cache.sequence_blocks
177+
178+
def test_block_update(self) -> None:
179+
"""Test updating block contents."""
180+
cache = PagedKVCache(
181+
num_layers=2,
182+
num_kv_heads=4,
183+
head_dim=32,
184+
num_blocks=10,
185+
block_size=16,
186+
)
187+
188+
blocks = cache.allocate_blocks(seq_id=0, num_blocks=1)
189+
block_idx = blocks[0]
190+
191+
# Update block
192+
key = mx.random.normal((8, 4, 32))
193+
value = mx.random.normal((8, 4, 32))
194+
195+
cache.update_block(
196+
block_idx=block_idx,
197+
layer_idx=0,
198+
key=key,
199+
value=value,
200+
slot_offset=0,
201+
)
202+
203+
# Verify update
204+
cached_k, cached_v = cache.get_sequence_kv(seq_id=0, layer_idx=0, seq_len=8)
205+
mx.eval(cached_k, cached_v)
206+
207+
assert cached_k.shape == (8, 4, 32)
208+
assert cached_v.shape == (8, 4, 32)

tests/test_config.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for vLLM Metal configuration."""
3+
4+
import os
5+
6+
from vllm_metal.config import MetalConfig, get_config, reset_config
7+
8+
9+
class TestMetalConfig:
10+
"""Tests for MetalConfig class."""
11+
12+
def setup_method(self) -> None:
13+
"""Reset config before each test."""
14+
reset_config()
15+
# Clear environment variables
16+
for var in [
17+
"VLLM_METAL_MEMORY_FRACTION",
18+
"VLLM_METAL_USE_MLX",
19+
"VLLM_MLX_DEVICE",
20+
"VLLM_METAL_BLOCK_SIZE",
21+
"VLLM_METAL_DEBUG",
22+
]:
23+
os.environ.pop(var, None)
24+
25+
def teardown_method(self) -> None:
26+
"""Reset config after each test."""
27+
reset_config()
28+
for var in [
29+
"VLLM_METAL_MEMORY_FRACTION",
30+
"VLLM_METAL_USE_MLX",
31+
"VLLM_MLX_DEVICE",
32+
"VLLM_METAL_BLOCK_SIZE",
33+
"VLLM_METAL_DEBUG",
34+
]:
35+
os.environ.pop(var, None)
36+
37+
def test_default_config(self) -> None:
38+
"""Test default configuration values."""
39+
config = MetalConfig.from_env()
40+
41+
assert config.memory_fraction == 0.9
42+
assert config.use_mlx is True
43+
assert config.mlx_device == "gpu"
44+
assert config.block_size == 16
45+
assert config.debug is False
46+
47+
def test_custom_config_from_env(self) -> None:
48+
"""Test configuration from environment variables."""
49+
os.environ["VLLM_METAL_MEMORY_FRACTION"] = "0.75"
50+
os.environ["VLLM_METAL_USE_MLX"] = "0"
51+
os.environ["VLLM_MLX_DEVICE"] = "cpu"
52+
os.environ["VLLM_METAL_BLOCK_SIZE"] = "32"
53+
os.environ["VLLM_METAL_DEBUG"] = "1"
54+
55+
config = MetalConfig.from_env()
56+
57+
assert config.memory_fraction == 0.75
58+
assert config.use_mlx is False
59+
assert config.mlx_device == "cpu"
60+
assert config.block_size == 32
61+
assert config.debug is True
62+
63+
def test_get_config_singleton(self) -> None:
64+
"""Test that get_config returns a singleton."""
65+
config1 = get_config()
66+
config2 = get_config()
67+
68+
assert config1 is config2
69+
70+
def test_reset_config(self) -> None:
71+
"""Test that reset_config clears the singleton."""
72+
config1 = get_config()
73+
reset_config()
74+
config2 = get_config()
75+
76+
# After reset, we get a new config instance
77+
# (but with same values since env vars haven't changed)
78+
assert config1 is not config2

0 commit comments

Comments
 (0)