Skip to content

Commit 556aba9

Browse files
authored
[Qwen3.5] Hybrid cache allocation for SDPA + linear attention (Stage B) (#210)
## Summary Allocate per-layer-type cache buffers for hybrid models (Qwen3.5) where SDPA and GDN linear attention layers coexist. This is Stage B of the Qwen3.5 roadmap (#194), builds on the dispatch refactor (Stage A, #201). - Unwrap `text_config` in `_extract_model_args` so Qwen3.5 dimensions are accessible - Add `is_hybrid` detection and GDN dimensions to `_resolve_model_dims` - Emit `FullAttentionSpec` for SDPA layers and `MambaSpec` for GDN layers in `get_kv_cache_spec` - Fix `get_cache_block_size_bytes` to count only SDPA layers - Add `LinearAttentionCache` with layout `[num_blocks, Hv, Dv, Dk]` per linear layer - Add `HybridPagedAttentionBackend` that allocates both `MetalPagedKVCache` (SDPA) and `LinearAttentionCache` (GDN) - Fail fast with `RuntimeError` when hybrid model enables paged attention (gated until Stage C) - Only SDPA layers patched; linear layers keep original mlx_lm forward Ref: #194 (Stage B: Hybrid cache allocation) ## Cache layout | Layer type | Cache class | Shape per layer | |---|---|---| | SDPA | `MetalPagedKVCache` | `[num_blocks, block_size, num_kv_heads, head_dim]` | | Linear (GDN) | `LinearAttentionCache` | `[num_blocks, Hv, Dv, Dk]` | Both caches use the same `num_blocks` from the scheduler's memory budget. `get_kv_cache_spec` emits `MambaSpec` for GDN layers so the scheduler groups them separately. This PR delivers allocation infrastructure to unblock Stage C kernel work. --------- Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
1 parent 9ba8afa commit 556aba9

File tree

7 files changed

+280
-101
lines changed

7 files changed

+280
-101
lines changed

tests/test_attention_dispatch.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,21 +111,15 @@ def test_find_layers_on_qwen3_model():
111111

112112

113113
@pytest.mark.slow
114-
@pytest.mark.xfail(
115-
raises=NotImplementedError,
116-
reason="Linear attention (GatedDeltaNet) Metal kernel not yet implemented",
117-
strict=True,
118-
)
119-
def test_qwen35_paged_attention_raises_on_linear_layers():
120-
"""Loading Qwen/Qwen3.5-0.8B with paged attention raises
121-
NotImplementedError on the linear attention layers."""
122-
from vllm import LLM, SamplingParams
114+
def test_qwen35_paged_attention_raises_on_hybrid():
115+
"""Loading Qwen/Qwen3.5-0.8B with paged attention raises RuntimeError
116+
at setup — hybrid models are not yet supported on the paged path."""
117+
from vllm import LLM
123118

124119
with pytest.MonkeyPatch.context() as mp:
125120
mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
126121
mp.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1")
127122
mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2")
128123

129-
llm = LLM(model="Qwen/Qwen3.5-0.8B", max_model_len=512, max_num_seqs=1)
130-
sp = SamplingParams(temperature=0, max_tokens=5)
131-
llm.generate(["Hello"], sp)
124+
with pytest.raises(RuntimeError, match="not yet supported for hybrid"):
125+
LLM(model="Qwen/Qwen3.5-0.8B", max_model_len=512, max_num_seqs=1)

tests/test_v1_worker.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,50 @@ def test_get_supported_tasks_delegates_to_runner_capability(self) -> None:
110110

111111
assert tasks == ("transcription",)
112112
model_runner.supported_worker_tasks.assert_called_once_with()
113+
114+
115+
class TestOneSequenceKvBytes:
116+
"""_one_sequence_kv_bytes must account for hybrid linear state."""
117+
118+
def test_non_hybrid_counts_all_layers(self) -> None:
119+
# Arrange
120+
import mlx.core as mx
121+
122+
model_runner = SimpleNamespace(
123+
is_hybrid=False,
124+
num_layers=16,
125+
num_kv_heads=8,
126+
head_dim=64,
127+
kv_cache_dtype=mx.float16,
128+
)
129+
worker = _make_worker(model_runner, use_paged_attention=False)
130+
worker.model_config = SimpleNamespace(max_model_len=2048)
131+
132+
# Act
133+
result = MetalWorker._one_sequence_kv_bytes(worker)
134+
135+
# Assert — 2 * 16 * 2048 * 8 * 64 * 2
136+
assert result == 2 * 16 * 2048 * 8 * 64 * 2
137+
138+
def test_hybrid_adds_linear_state(self) -> None:
139+
# Arrange
140+
import mlx.core as mx
141+
142+
linear_bytes = 1_000_000
143+
model_runner = SimpleNamespace(
144+
is_hybrid=True,
145+
num_sdpa_layers=8,
146+
num_kv_heads=4,
147+
head_dim=256,
148+
kv_cache_dtype=mx.float16,
149+
linear_cache_bytes_per_slot=MagicMock(return_value=linear_bytes),
150+
)
151+
worker = _make_worker(model_runner, use_paged_attention=False)
152+
worker.model_config = SimpleNamespace(max_model_len=2048)
153+
154+
# Act
155+
result = MetalWorker._one_sequence_kv_bytes(worker)
156+
157+
# Assert — SDPA bytes + linear state
158+
sdpa_bytes = 2 * 8 * 2048 * 4 * 256 * 2
159+
assert result == sdpa_bytes + linear_bytes

vllm_metal/metal_kernel_backend/paged_attention.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,19 @@ def __init__(
5555
layer_idx: int,
5656
kv_cache: MetalPagedKVCache,
5757
block_size: int,
58+
*,
59+
cache_idx: int | None = None,
5860
) -> None:
5961
super().__init__()
6062
object.__setattr__(self, "_inner", inner)
6163
object.__setattr__(self, "_mk_layer_idx", layer_idx)
6264
object.__setattr__(self, "_mk_kv_cache", kv_cache)
6365
object.__setattr__(self, "_mk_block_size", block_size)
66+
# For compact caches (hybrid models), cache_idx maps to the
67+
# per-type cache array. Defaults to layer_idx for non-hybrid.
68+
object.__setattr__(
69+
self, "_mk_cache_idx", cache_idx if cache_idx is not None else layer_idx
70+
)
6471

6572
def __call__(self, x: mx.array, mask: Any = None, cache: Any = None) -> mx.array:
6673
ctx = get_context()
@@ -71,12 +78,11 @@ def __call__(self, x: mx.array, mask: Any = None, cache: Any = None) -> mx.array
7178
inner = self._inner
7279

7380
# Dispatch to the right attention backend
81+
cache_idx = self._mk_cache_idx
7482
if is_sdpa(inner):
75-
return sdpa_forward(inner, x, ctx, self._mk_kv_cache, self._mk_layer_idx)
83+
return sdpa_forward(inner, x, ctx, self._mk_kv_cache, cache_idx)
7684
elif is_linear_attention(inner):
77-
return linear_attention_forward(
78-
inner, x, ctx, self._mk_kv_cache, self._mk_layer_idx
79-
)
85+
return linear_attention_forward(inner, x, ctx, self._mk_kv_cache, cache_idx)
8086
else:
8187
raise NotImplementedError(
8288
f"No Metal attention backend for {type(inner).__name__}. "
@@ -94,19 +100,35 @@ def patch_model_attention_metal_kernel(
94100
model: Any,
95101
kv_cache: MetalPagedKVCache,
96102
block_size: int,
103+
*,
104+
cache_idx_map: dict[int, int] | None = None,
105+
only_layers: list[int] | None = None,
97106
) -> int:
98107
"""Walk model layers and replace each attention module with a
99108
``MetalKernelPagedAttentionWrapper``.
100109
101110
Supports hybrid models (e.g. Qwen3.5) where different layers use
102111
different attribute names (``self_attn``, ``linear_attn``, etc.).
103112
113+
Args:
114+
cache_idx_map: Optional mapping from model layer_idx to compact
115+
cache index. Used for hybrid models so that a compact
116+
``MetalPagedKVCache`` (SDPA layers only) is indexed correctly.
117+
When ``None``, ``layer_idx`` is used directly.
118+
only_layers: If provided, only patch these layer indices and skip
119+
the rest. Used by hybrid backend to avoid wrapping linear
120+
attention layers that have no kernel implementation yet.
121+
104122
Returns the number of patched layers.
105123
"""
106124
layer_list = find_layers(model)
125+
only_set = set(only_layers) if only_layers is not None else None
107126
patched = 0
108127

109128
for layer_idx, layer in enumerate(layer_list):
129+
if only_set is not None and layer_idx not in only_set:
130+
continue
131+
110132
attn_attr = find_attn_attr(layer)
111133
if attn_attr is None:
112134
continue
@@ -119,8 +141,13 @@ def patch_model_attention_metal_kernel(
119141
patched += 1
120142
continue
121143

144+
cache_idx = (
145+
cache_idx_map[layer_idx]
146+
if cache_idx_map is not None and layer_idx in cache_idx_map
147+
else layer_idx
148+
)
122149
wrapper = MetalKernelPagedAttentionWrapper(
123-
attn, layer_idx, kv_cache, block_size
150+
attn, layer_idx, kv_cache, block_size, cache_idx=cache_idx
124151
)
125152
setattr(layer, attn_attr, wrapper)
126153
patched += 1
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Hybrid model helpers for paged attention backend.
3+
4+
Provides spec construction for GDN linear attention layers in hybrid
5+
models (Qwen3.5). The full hybrid backend will be added in Stage C
6+
when the linear attention kernel is implemented.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import torch
12+
from vllm.v1.kv_cache_interface import MambaSpec
13+
14+
15+
def _build_linear_layer_spec(
16+
*,
17+
conv_kernel_dim: int,
18+
conv_dim: int,
19+
num_v_heads: int,
20+
value_head_dim: int,
21+
key_head_dim: int,
22+
torch_dtype: torch.dtype,
23+
) -> MambaSpec:
24+
"""Build a MambaSpec for one GDN linear attention layer."""
25+
return MambaSpec(
26+
shapes=(
27+
(conv_kernel_dim - 1, conv_dim),
28+
(num_v_heads, value_head_dim, key_head_dim),
29+
),
30+
dtypes=(torch_dtype, torch_dtype),
31+
block_size=1,
32+
)

vllm_metal/paged_attention_backend/mha.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import mlx.core as mx
88
from vllm.logger import init_logger
99

10+
from vllm_metal.metal import get_ops
11+
1012
if TYPE_CHECKING:
1113
from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache
1214

@@ -18,6 +20,39 @@
1820
_METAL_LANGUAGE_VERSION_ERROR = "language version"
1921

2022

23+
def warm_up_paged_cache(cache: MetalPagedKVCache) -> None:
24+
"""Trigger Metal shader compilation with a dummy reshape_and_cache call.
25+
26+
Shared by MHA and Hybrid backends to avoid duplicating warm-up logic.
27+
"""
28+
macos_version = platform.mac_ver()[0]
29+
logger.info("Warming up paged attention Metal kernel...")
30+
31+
try:
32+
ops = get_ops()
33+
except Exception as e:
34+
raise RuntimeError(
35+
f"Failed to load Metal kernel: {e}. macOS {macos_version}"
36+
) from e
37+
38+
try:
39+
dummy_k = mx.zeros((1, cache.num_kv_heads, cache.head_dim), dtype=cache.dtype)
40+
dummy_v = mx.zeros((1, cache.num_kv_heads, cache.head_dim), dtype=cache.dtype)
41+
dummy_slot = mx.zeros((1,), dtype=mx.int64)
42+
mx.eval(dummy_k, dummy_v, dummy_slot)
43+
ops.reshape_and_cache(
44+
dummy_k, dummy_v, cache.key_caches[0], cache.value_caches[0], dummy_slot
45+
)
46+
mx.eval(cache.key_caches[0])
47+
logger.info("Paged attention Metal kernel warm-up complete")
48+
except RuntimeError as e:
49+
if _METAL_LANGUAGE_VERSION_ERROR in str(e):
50+
raise RuntimeError(
51+
f"Metal kernel incompatible with macOS {macos_version}: {e}"
52+
) from e
53+
raise
54+
55+
2156
class MHAPagedAttentionBackend:
2257
"""Paged attention backend for standard MHA models.
2358
@@ -69,40 +104,7 @@ def patch_model(self, model: Any) -> int:
69104
return patch_model_attention_metal_kernel(model, cache, self._block_size)
70105

71106
def warm_up(self) -> None:
72-
cache = self._require_initialized("warm_up")
73-
74-
from vllm_metal.metal import get_ops
75-
76-
macos_version = platform.mac_ver()[0]
77-
logger.info("Warming up paged attention Metal kernel...")
78-
79-
try:
80-
ops = get_ops()
81-
except Exception as e:
82-
raise RuntimeError(
83-
f"Failed to load Metal kernel: {e}. macOS {macos_version}"
84-
) from e
85-
86-
try:
87-
dummy_k = mx.zeros(
88-
(1, cache.num_kv_heads, cache.head_dim), dtype=cache.dtype
89-
)
90-
dummy_v = mx.zeros(
91-
(1, cache.num_kv_heads, cache.head_dim), dtype=cache.dtype
92-
)
93-
dummy_slot = mx.zeros((1,), dtype=mx.int64)
94-
mx.eval(dummy_k, dummy_v, dummy_slot)
95-
ops.reshape_and_cache(
96-
dummy_k, dummy_v, cache.key_caches[0], cache.value_caches[0], dummy_slot
97-
)
98-
mx.eval(cache.key_caches[0])
99-
logger.info("Paged attention Metal kernel warm-up complete")
100-
except RuntimeError as e:
101-
if _METAL_LANGUAGE_VERSION_ERROR in str(e):
102-
raise RuntimeError(
103-
f"Metal kernel incompatible with macOS {macos_version}: {e}"
104-
) from e
105-
raise
107+
warm_up_paged_cache(self._require_initialized("warm_up"))
106108

107109
def num_blocks(self) -> int:
108110
return self._require_initialized("num_blocks").num_blocks

0 commit comments

Comments
 (0)