Skip to content

Commit de80df2

Browse files
authored
Fix MLA KV cache sizing to use latent-only factor (#233)
This PR is: - To apply MLA-specific KV sizing (latent-only, not K+V) in cache sizing paths - To keep `_one_sequence_kv_bytes` consistent with paged KV block sizing - To add a focused MLA sizing test and document the latent dimension context Notes - `get_cache_block_size_bytes()` and `_one_sequence_kv_bytes()` now use `kv_factor = 1` for MLA, `2` otherwise. - Tests cover MLA sizing and document why `head_dim=576` (kv_lora_rank + qk_rope_head_dim). --------- Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
1 parent f518143 commit de80df2

File tree

3 files changed

+36
-13
lines changed

3 files changed

+36
-13
lines changed

tests/test_v1_worker.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from types import SimpleNamespace
77
from unittest.mock import MagicMock
88

9+
import mlx.core as mx
910
import pytest
1011

1112
pytest.importorskip("vllm", reason="vllm not installed")
@@ -117,11 +118,9 @@ class TestOneSequenceKvBytes:
117118
"""_one_sequence_kv_bytes must account for hybrid linear state and block alignment."""
118119

119120
def test_non_hybrid_counts_all_layers(self) -> None:
120-
# Arrange
121-
import mlx.core as mx
122-
123121
model_runner = SimpleNamespace(
124122
is_hybrid=False,
123+
is_mla=False,
125124
num_layers=16,
126125
num_kv_heads=8,
127126
head_dim=64,
@@ -141,12 +140,10 @@ def test_non_hybrid_counts_all_layers(self) -> None:
141140
assert result == 2 * 16 * 2048 * 8 * 64 * 2
142141

143142
def test_hybrid_adds_linear_state(self) -> None:
144-
# Arrange
145-
import mlx.core as mx
146-
147143
linear_bytes = 1_000_000
148144
model_runner = SimpleNamespace(
149145
is_hybrid=True,
146+
is_mla=False,
150147
num_sdpa_layers=8,
151148
num_kv_heads=4,
152149
head_dim=256,
@@ -175,10 +172,9 @@ def test_block_alignment_rounds_up_token_count(self) -> None:
175172
models (e.g. Granite 4.0-H) where the attention block_size is padded
176173
to 400 to match the mamba page size.
177174
"""
178-
import mlx.core as mx
179-
180175
model_runner = SimpleNamespace(
181176
is_hybrid=False,
177+
is_mla=False,
182178
num_layers=4,
183179
num_kv_heads=4,
184180
head_dim=64,
@@ -200,3 +196,28 @@ def test_block_alignment_rounds_up_token_count(self) -> None:
200196
# Verify this is strictly more than the unaligned calculation
201197
unaligned = 2 * 4 * 2048 * 4 * 64 * 2
202198
assert result > unaligned
199+
200+
def test_mla_uses_latent_only(self) -> None:
201+
"""MLA cache stores one latent vector per token, not K+V.
202+
203+
head_dim=576 represents kv_lora_rank + qk_rope_head_dim (e.g. GLM-4).
204+
The 2x K/V factor must NOT be applied — kv_factor=1.
205+
"""
206+
model_runner = SimpleNamespace(
207+
is_hybrid=False,
208+
is_mla=True,
209+
num_layers=4,
210+
num_kv_heads=1,
211+
head_dim=576,
212+
kv_cache_dtype=mx.float16,
213+
)
214+
worker = _make_worker(model_runner, use_paged_attention=False)
215+
worker.model_config = SimpleNamespace(max_model_len=2048)
216+
worker.vllm_config = SimpleNamespace(
217+
cache_config=SimpleNamespace(block_size=16)
218+
)
219+
220+
result = MetalWorker._one_sequence_kv_bytes(worker)
221+
222+
expected = 1 * 4 * 2048 * 1 * 576 * 2
223+
assert result == expected

vllm_metal/v1/model_runner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -990,9 +990,9 @@ def _resolve_model_dims(self) -> None:
990990
self.head_dim: int = int(head_dim)
991991

992992
# MLA (GLM/DeepSeek lineage): cache stores a joint latent vector per
993-
# layer, not per-head K/V. One virtual head sized kv_lora_rank +
994-
# qk_rope_head_dim keeps get_cache_block_size_bytes() conservative (2x)
995-
# without MLA-specific logic in the sizing path.
993+
# layer, not per-head K/V. Use one virtual head sized kv_lora_rank +
994+
# qk_rope_head_dim so shared sizing paths can reuse head_dim/num_kv_heads
995+
# while get_cache_block_size_bytes() applies an MLA-specific factor.
996996
if self.is_mla:
997997
self.num_kv_heads = 1
998998
self.head_dim = int(args["kv_lora_rank"]) + int(
@@ -1155,8 +1155,9 @@ def get_cache_block_size_bytes(self) -> int:
11551155
raise RuntimeError("KV cache dtype not initialized; load_model() first")
11561156
dtype_size = self.kv_cache_dtype.size
11571157
num_kv_layers = self.num_sdpa_layers if self.is_hybrid else self.num_layers
1158+
kv_factor = 1 if self.is_mla else 2
11581159
return (
1159-
2
1160+
kv_factor
11601161
* num_kv_layers
11611162
* block_size
11621163
* self.num_kv_heads

vllm_metal/v1/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,9 @@ def _one_sequence_kv_bytes(self) -> int:
379379
block_size = self.vllm_config.cache_config.block_size
380380
max_tokens = -(-self.model_config.max_model_len // block_size) * block_size
381381

382+
kv_factor = 1 if runner.is_mla else 2
382383
sdpa_kv_bytes = (
383-
2
384+
kv_factor
384385
* num_kv_layers
385386
* max_tokens
386387
* runner.num_kv_heads

0 commit comments

Comments
 (0)