Skip to content

Commit 59b9be4

Browse files
authored
Fix paged-attention KV cache dtype + size accounting (issue #119) (#125)
This PR is: - To align the Metal paged-attention KV cache dtype with the model's dtype (fixes batched decode parity for #119). - To compute KV cache byte sizes via `torch.dtype.itemsize` instead of allocating temporary tensors. Notes: - `tests/test_metal_kernel_paged.py::test_batched_decode_matches` now passes. - `tests/test_metal_kernel_paged.py::test_greedy_output_matches` remains xfailed (tracked in #119). This is a remaining single-request greedy parity mismatch between the paged-kernel path and the standard path; fixing it likely requires deeper kernel/offset semantics work, so I'm keeping it out of this PR to keep scope tight. Quick manual smoke test: Terminal 1: ```bash vllm serve Qwen/Qwen3-0.6B --host 127.0.0.1 --port 8000 --max-model-len 2048 ``` Terminal 2 (single request): ```bash curl -fsS http://127.0.0.1:8000/v1/chat/completions \ -H 'Content-Type: application/json' \ -d '{"model":"Qwen/Qwen3-0.6B","messages":[{"role":"user","content":"Write a 2-sentence apple story."}],"max_tokens":512,"temperature":0.8}' \ | jq -r '.choices[0].message.content' ``` Terminal 2 (concurrent 4 requests): ```bash for i in 1 2 3 4; do ( echo "===== req $i =====" curl -fsS http://127.0.0.1:8000/v1/chat/completions \ -H 'Content-Type: application/json' \ -d "{\"model\":\"Qwen/Qwen3-0.6B\",\"messages\":[{\"role\":\"user\",\"content\":\"Write a 2-sentence apple story (${i}).\"}],\"max_tokens\":256,\"temperature\":0.8}" \ | jq -r '.choices[0].message.content' echo ) & done wait ``` Related: #119 --------- Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
1 parent 377eda3 commit 59b9be4

7 files changed

Lines changed: 134 additions & 16 deletions

File tree

tests/test_metal_kernel_paged.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import torch
2222
from mlx_lm import load as mlx_lm_load
2323
from mlx_lm.models.cache import make_prompt_cache
24+
25+
from vllm_metal.kv_cache_dtype import infer_kv_cache_dtype_from_model
2426
except ImportError as exc:
2527
pytest.skip(
2628
f"Metal kernel paged attention tests require mlx/torch/mlx_lm: {exc}",
@@ -68,6 +70,21 @@ def _paged_attention_ops_available() -> None:
6870
# ---------------------------------------------------------------------------
6971

7072

73+
def _test_infer_paged_kv_dtype(model) -> torch.dtype:
74+
"""Test-only helper: choose a float dtype for MPSPagedKVCache.
75+
76+
This is deliberately local to this test module. Production code uses
77+
`vllm_metal.kv_cache_dtype.infer_kv_cache_dtype_from_model()`.
78+
"""
79+
result = infer_kv_cache_dtype_from_model(model)
80+
if result.warning is not None:
81+
raise AssertionError(
82+
"KV cache dtype inference unexpectedly fell back during tests: "
83+
f"{result.warning}"
84+
)
85+
return result.dtype
86+
87+
7188
def _greedy_generate_standard(model, token_ids: list[int], max_new: int) -> list[int]:
7289
"""Generate tokens using the standard mlx_lm KVCache path."""
7390
cache = make_prompt_cache(model)
@@ -109,7 +126,7 @@ def _greedy_generate_metal_kernel(
109126
head_dim=head_dim,
110127
num_blocks=num_blocks,
111128
block_size=BLOCK_SIZE,
112-
dtype=torch.float16,
129+
dtype=_test_infer_paged_kv_dtype(model),
113130
)
114131

115132
n_patched = patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
@@ -190,9 +207,6 @@ def test_greedy_output_matches(self, qwen3_model):
190207
)
191208

192209
@pytest.mark.slow
193-
@pytest.mark.xfail(
194-
reason="Metal paged-attention parity mismatch vs standard path (see #119)"
195-
)
196210
def test_batched_decode_matches(self, qwen3_model):
197211
"""Batched Metal kernel paged decode must match per-request sequential."""
198212
model, tokenizer = qwen3_model
@@ -225,7 +239,7 @@ def test_batched_decode_matches(self, qwen3_model):
225239
head_dim=head_dim,
226240
num_blocks=num_blocks,
227241
block_size=BLOCK_SIZE,
228-
dtype=torch.float16,
242+
dtype=_test_infer_paged_kv_dtype(model),
229243
)
230244
patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
231245

@@ -300,7 +314,7 @@ def test_patch_replaces_self_attn(self, qwen3_model):
300314
head_dim=args.head_dim,
301315
num_blocks=32,
302316
block_size=BLOCK_SIZE,
303-
dtype=torch.float16,
317+
dtype=_test_infer_paged_kv_dtype(model),
304318
)
305319
patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
306320

@@ -323,7 +337,7 @@ def test_fallback_when_no_context(self, qwen3_model):
323337
head_dim=args.head_dim,
324338
num_blocks=32,
325339
block_size=BLOCK_SIZE,
326-
dtype=torch.float16,
340+
dtype=_test_infer_paged_kv_dtype(model),
327341
)
328342
patch_model_attention_metal_kernel(model, mps_cache, BLOCK_SIZE)
329343

tests/test_prefix_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def test_rotating_kvcache_merge_handles_prefill_exceeding_max_size(self) -> None
302302

303303
assert isinstance(merged[0], mr.BatchRotatingKVCache)
304304
assert isinstance(extracted_req0, mr.RotatingKVCache)
305+
assert isinstance(extracted_req1, mr.RotatingKVCache)
305306
assert extracted_req0.offset == cache_req0.offset
306307
assert extracted_req1.offset == cache_req1.offset
307308

vllm_metal/kv_cache_dtype.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""KV cache dtype inference and policy.
3+
4+
The Metal paged-attention backend stores *activation* K/V tensors in an
5+
MPS-backed cache. Those tensors must be floating point. Some models may have
6+
quantized *weights* (e.g. int8), so we must not derive the KV cache dtype from
7+
weights without enforcing a float-only policy.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from dataclasses import dataclass
13+
from typing import Any
14+
15+
import torch
16+
17+
from vllm_metal.paged_attention_common import find_layers_and_attr
18+
from vllm_metal.pytorch_backend.tensor_bridge import MLX_TO_TORCH_DTYPE
19+
20+
DEFAULT_KV_CACHE_DTYPE = torch.float16
21+
ALLOWED_KV_CACHE_DTYPES: frozenset[torch.dtype] = frozenset(
22+
{
23+
torch.float16,
24+
torch.bfloat16,
25+
torch.float32,
26+
}
27+
)
28+
29+
30+
@dataclass(frozen=True)
31+
class KvCacheDtypeInference:
32+
"""Result of inferring the KV cache dtype from a model."""
33+
34+
dtype: torch.dtype
35+
warning: str | None = None
36+
37+
38+
def infer_kv_cache_dtype_from_model(
39+
model: Any, *, default: torch.dtype = DEFAULT_KV_CACHE_DTYPE
40+
) -> KvCacheDtypeInference:
41+
"""Infer a float KV-cache dtype from an MLX(-LM/-VLM) model.
42+
43+
Policy:
44+
- If we can map the model's attention weight dtype to torch and it's a
45+
supported float dtype, use it.
46+
- Otherwise, fall back to *default* and provide a warning string the caller
47+
may log.
48+
"""
49+
try:
50+
layers, attn_attr = find_layers_and_attr(model)
51+
if not layers:
52+
raise ValueError("model has no transformer layers")
53+
54+
attn = getattr(layers[0], attn_attr)
55+
# If the model is already patched, unwrap to the real attention module.
56+
attn = getattr(attn, "_inner", attn)
57+
58+
mlx_dtype = attn.q_proj.weight.dtype
59+
except (AttributeError, IndexError, TypeError, ValueError) as exc:
60+
return KvCacheDtypeInference(
61+
dtype=default,
62+
warning=f"Cannot infer KV cache dtype from model ({exc}); using {default}",
63+
)
64+
65+
torch_dtype = MLX_TO_TORCH_DTYPE.get(mlx_dtype)
66+
if torch_dtype is None:
67+
return KvCacheDtypeInference(
68+
dtype=default,
69+
warning=f"Unsupported MLX dtype for KV cache ({mlx_dtype!r}); using {default}",
70+
)
71+
72+
if torch_dtype not in ALLOWED_KV_CACHE_DTYPES:
73+
return KvCacheDtypeInference(
74+
dtype=default,
75+
warning=(
76+
f"Model weight dtype {mlx_dtype!r} maps to non-float torch dtype "
77+
f"{torch_dtype}; using {default} for KV cache instead"
78+
),
79+
)
80+
81+
return KvCacheDtypeInference(dtype=torch_dtype)

vllm_metal/metal_kernel_backend/paged_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
from vllm_metal.metal_kernel_backend.kernel_loader import get_paged_attention_ops
7373
from vllm_metal.paged_attention_common import (
7474
PagedAttentionContext,
75-
_find_layers_and_attr,
75+
find_layers_and_attr,
7676
get_context,
7777
)
7878
from vllm_metal.pytorch_backend.tensor_bridge import mlx_to_torch, torch_to_mlx
@@ -327,7 +327,7 @@ def patch_model_attention_metal_kernel(
327327
328328
Returns the number of patched layers.
329329
"""
330-
layer_list, attn_attr = _find_layers_and_attr(model)
330+
layer_list, attn_attr = find_layers_and_attr(model)
331331
patched = 0
332332

333333
for layer_idx, layer in enumerate(layer_list):

vllm_metal/paged_attention_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def make_mask(
103103
# ---------------------------------------------------------------------------
104104

105105

106-
def _find_layers_and_attr(model: Any) -> tuple[list[Any], str]:
106+
def find_layers_and_attr(model: Any) -> tuple[list[Any], str]:
107107
"""Find transformer layers and the attention attribute name.
108108
109109
Returns (layer_list, attn_attr_name) where each layer has

vllm_metal/v1/model_runner.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from vllm.v1.sample.sampler import Sampler
4848

4949
from vllm_metal.config import get_config
50+
from vllm_metal.kv_cache_dtype import infer_kv_cache_dtype_from_model
5051
from vllm_metal.paged_attention_common import (
5152
OffsetCache,
5253
clear_context,
@@ -617,6 +618,7 @@ def __init__(
617618
self._paged_kv_cache: Any = None # MPSPagedKVCache, set by worker
618619
self._paged_block_size: int = 0
619620
self._paged_request_seq_lens: dict[str, int] = {} # req_id → seq_len
621+
self.kv_cache_dtype: torch.dtype | None = None
620622

621623
def _is_vlm_model(self) -> bool:
622624
"""Check if the model is a vision-language model (VLM).
@@ -650,6 +652,7 @@ def load_model(self) -> None:
650652
)
651653
self._extract_model_args()
652654
self._resolve_model_dims()
655+
self._initialize_kv_cache_dtype()
653656
return
654657

655658
# Load model using appropriate backend
@@ -673,9 +676,20 @@ def load_model(self) -> None:
673676

674677
self._extract_model_args()
675678
self._resolve_model_dims()
679+
self._initialize_kv_cache_dtype()
676680
load_time = time.time() - start_time
677681
logger.info(f"Model loaded in {load_time:.2f}s: {model_name}")
678682

683+
def _initialize_kv_cache_dtype(self) -> None:
684+
"""Infer and store the KV cache dtype for this runner."""
685+
if self.model is None:
686+
raise RuntimeError("Model not loaded")
687+
688+
paged_kv_dtype = infer_kv_cache_dtype_from_model(self.model)
689+
if paged_kv_dtype.warning:
690+
logger.warning("%s", paged_kv_dtype.warning)
691+
self.kv_cache_dtype = paged_kv_dtype.dtype
692+
679693
def _extract_model_args(self) -> None:
680694
"""Extract model configuration from loaded model.
681695
@@ -782,6 +796,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
782796
Dictionary mapping attention layer names to KV cache specs
783797
"""
784798
block_size = self.metal_config.block_size
799+
if self.kv_cache_dtype is None:
800+
raise RuntimeError("KV cache dtype not initialized; load_model() first")
785801

786802
# Create a spec for each layer
787803
specs: dict[str, KVCacheSpec] = {}
@@ -791,7 +807,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
791807
block_size=block_size,
792808
num_kv_heads=self.num_kv_heads,
793809
head_size=self.head_dim,
794-
dtype=torch.float16,
810+
dtype=self.kv_cache_dtype,
795811
)
796812

797813
return specs
@@ -817,7 +833,9 @@ def get_cache_block_size_bytes(self) -> int:
817833

818834
# Each block stores key and value for all layers
819835
# Block memory = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype_size
820-
dtype_size = 2 # float16
836+
if self.kv_cache_dtype is None:
837+
raise RuntimeError("KV cache dtype not initialized; load_model() first")
838+
dtype_size = self.kv_cache_dtype.itemsize
821839
return (
822840
2
823841
* self.num_layers

vllm_metal/v1/worker.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def _setup_paged_attention(self) -> None:
146146
max_model_len.
147147
"""
148148
import psutil
149-
import torch
150149

151150
from vllm_metal.metal_kernel_backend.cache import MPSPagedKVCache
152151
from vllm_metal.metal_kernel_backend.paged_attention import (
@@ -242,13 +241,15 @@ def _setup_paged_attention(self) -> None:
242241
)
243242

244243
# --- Create cache and patch model ---
244+
if runner.kv_cache_dtype is None:
245+
raise RuntimeError("KV cache dtype not initialized; runner.load_model()")
245246
mps_kv_cache = MPSPagedKVCache(
246247
num_layers=runner.num_layers,
247248
num_kv_heads=runner.num_kv_heads,
248249
head_dim=runner.head_dim,
249250
num_blocks=num_blocks,
250251
block_size=block_size,
251-
dtype=torch.float16,
252+
dtype=runner.kv_cache_dtype,
252253
)
253254

254255
n_patched = patch_model_attention_metal_kernel(
@@ -295,15 +296,18 @@ def _get_model_memory_usage(self) -> int:
295296
return 0
296297

297298
def _one_sequence_kv_bytes(self) -> int:
298-
"""Bytes for one max-length sequence of KV cache (K + V, float16)."""
299+
"""Bytes for one max-length sequence of KV cache (K + V)."""
299300
runner = self.model_runner
301+
dtype_size = (
302+
runner.kv_cache_dtype.itemsize if runner.kv_cache_dtype is not None else 2
303+
)
300304
return (
301305
2 # K and V
302306
* runner.num_layers
303307
* self.model_config.max_model_len
304308
* runner.num_kv_heads
305309
* runner.head_dim
306-
* 2 # float16
310+
* dtype_size
307311
)
308312

309313
def determine_available_memory(self) -> int:

0 commit comments

Comments
 (0)