Skip to content

Commit bf78650

Browse files
janhilgardclaude
andcommitted
feat: patch Qwen3.5 attention for BatchKVCache compatibility
mlx-vlm's Qwen3_5Attention uses cache.offset for kv_seq_len and mask slicing, but BatchKVCache stores offset as mx.array (per-batch-item), causing "Slice indices must be integers or None" during generation. Add monkey-patch that converts cache.offset to int before arithmetic, while leaving actual cache.offset untouched for update_and_fetch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8eb783f commit bf78650

File tree

3 files changed

+126
-1
lines changed

3 files changed

+126
-1
lines changed

vllm_mlx/mllm_batch_generator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ def __init__(
327327
"MLLMBatchGenerator: Model does not have language_model, using model directly"
328328
)
329329

330+
# Patch Qwen3.5 attention for BatchKVCache compatibility
331+
from .patches.qwen3_5_mllm import patch_qwen35_attention_for_batching
332+
333+
patch_qwen35_attention_for_batching()
334+
330335
self.max_tokens = max_tokens
331336
self.stop_tokens = stop_tokens or set()
332337
self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1))

vllm_mlx/mllm_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ async def _process_loop(self) -> None:
623623
except asyncio.CancelledError:
624624
break
625625
except Exception as e:
626-
logger.error(f"Error in MLLM process loop: {e}")
626+
logger.error(f"Error in MLLM process loop: {e}", exc_info=True)
627627
await asyncio.sleep(0.1)
628628

629629
async def add_request_async(

vllm_mlx/patches/qwen3_5_mllm.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Runtime patch for mlx-vlm's Qwen3.5 attention to support BatchKVCache.
4+
5+
mlx-vlm's Qwen3_5Attention uses cache.offset directly for kv_seq_len
6+
computation and mask slicing. BatchKVCache stores offset as mx.array
7+
(per-batch-item), not int, causing:
8+
9+
mask = mask[..., :kv_seq_len]
10+
ValueError: Slice indices must be integers or None.
11+
12+
This patch replaces Qwen3_5Attention.__call__ with a version that
13+
converts cache.offset to int before using it for arithmetic/slicing,
14+
while leaving the actual cache.offset untouched so update_and_fetch
15+
still works correctly with per-batch offsets.
16+
"""
17+
18+
import logging
19+
from typing import Optional
20+
21+
import mlx.core as mx
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def _cache_offset_to_int(cache) -> int:
27+
"""Extract cache offset as int, handling BatchKVCache mx.array offset."""
28+
if cache is None:
29+
return 0
30+
off = cache.offset
31+
if isinstance(off, int):
32+
return off
33+
if isinstance(off, mx.array):
34+
return int(off.max().item()) if off.ndim > 0 else int(off.item())
35+
return int(off)
36+
37+
38+
def patch_qwen35_attention_for_batching() -> bool:
39+
"""Monkey-patch Qwen3_5Attention.__call__ to handle BatchKVCache.
40+
41+
Returns True if patch was applied, False if mlx-vlm is not installed
42+
or Qwen3.5 module not available.
43+
"""
44+
try:
45+
from mlx_vlm.models.qwen3_5.language import (
46+
Qwen3_5Attention,
47+
apply_multimodal_rotary_pos_emb,
48+
)
49+
from mlx_lm.models.base import scaled_dot_product_attention
50+
except ImportError:
51+
logger.debug("[Qwen3.5 patch] mlx-vlm Qwen3.5 module not available")
52+
return False
53+
54+
if getattr(Qwen3_5Attention, "_batch_patched", False):
55+
logger.debug("[Qwen3.5 patch] Already patched")
56+
return True
57+
58+
def _patched_call(
59+
self,
60+
x: mx.array,
61+
mask: Optional[mx.array] = None,
62+
cache=None,
63+
position_ids: Optional[mx.array] = None,
64+
) -> mx.array:
65+
B, L, D = x.shape
66+
67+
q_proj_output = self.q_proj(x)
68+
queries, gate = mx.split(
69+
q_proj_output.reshape(B, L, self.num_attention_heads, -1),
70+
2,
71+
axis=-1,
72+
)
73+
gate = gate.reshape(B, L, -1)
74+
75+
keys, values = self.k_proj(x), self.v_proj(x)
76+
77+
queries = self.q_norm(queries).transpose(0, 2, 1, 3)
78+
keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose(
79+
0, 2, 1, 3
80+
)
81+
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
82+
0, 2, 1, 3
83+
)
84+
85+
kv_seq_len = keys.shape[-2]
86+
87+
# Convert cache.offset to int for slice compatibility.
88+
# BatchKVCache stores offset as mx.array (per-batch-item),
89+
# but kv_seq_len must be int for mask[..., :kv_seq_len].
90+
_offset = _cache_offset_to_int(cache)
91+
92+
if position_ids is None:
93+
kv_seq_len += _offset + 1
94+
position_ids = mx.arange(_offset, _offset + L)
95+
position_ids = mx.expand_dims(position_ids, axis=0)
96+
position_ids = mx.tile(position_ids, (3, 1, 1))
97+
else:
98+
kv_seq_len += _offset + 1 if cache is not None else 0
99+
100+
cos, sin = self.rotary_emb(values, position_ids)
101+
102+
if mask is not None and isinstance(mask, mx.array):
103+
mask = mask[..., :kv_seq_len]
104+
105+
queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin)
106+
107+
if cache is not None:
108+
keys, values = cache.update_and_fetch(keys, values)
109+
110+
output = scaled_dot_product_attention(
111+
queries, keys, values, cache=cache, scale=self.scale, mask=mask
112+
)
113+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
114+
115+
return self.o_proj(output * mx.sigmoid(gate))
116+
117+
Qwen3_5Attention.__call__ = _patched_call
118+
Qwen3_5Attention._batch_patched = True
119+
logger.info("[Qwen3.5 patch] Attention patched for BatchKVCache support")
120+
return True

0 commit comments

Comments
 (0)