Skip to content

Commit 476f717

Browse files
janhilgardclaude
andcommitted
feat: Qwen3.5 VLM loading, streaming detokenizer, tool markup stripping
- Add _needs_strict_false() to detect VLM models and skip wasteful strict=True load - Add per-request streaming detokenizer pool for UTF-8 safe incremental decode - Strip leaked <tool_call> markup tags from content output Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d8e6e42 commit 476f717

File tree

5 files changed

+153
-21
lines changed

5 files changed

+153
-21
lines changed

vllm_mlx/api/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
r"<\|end\|>|<\|eot_id\|>|<\|start_header_id\|>|<\|end_header_id\|>|"
1919
r"<\|channel\|>|<\|message\|>|<\|start\|>|<\|return\|>|<\|call\|>|<\|constrain\|>|"
2020
r"</s>|<s>|<pad>|\[PAD\]|\[SEP\]|\[CLS\]|"
21-
r"\[e~\[|\]~b\][a-z]*|\]~!b\["
21+
r"\[e~\[|\]~b\][a-z]*|\]~!b\[|"
22+
r"</?tool_call>|</?tool_call_reasoning>"
2223
)
2324

2425

@@ -133,6 +134,9 @@ def clean_output_text(text: str) -> str:
133134
"InternVL", # InternVL
134135
"deepseek-vl",
135136
"DeepSeek-VL", # DeepSeek-VL
137+
# NOTE: Qwen3.5 is natively multimodal but MoE produces ArraysCache
138+
# which is incompatible with MLLM continuous batching (requires KVCache).
139+
# Runs as text-only via strict=False fallback until upstream fixes this.
136140
]
137141

138142

vllm_mlx/mllm_scheduler.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from dataclasses import dataclass, field
2929
from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple
3030

31+
from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer
3132

3233
from .mllm_batch_generator import (
3334
MLLMBatchGenerator,
@@ -198,6 +199,9 @@ def __init__(
198199
self.request_id_to_uid: Dict[str, int] = {}
199200
self.uid_to_request_id: Dict[int, str] = {}
200201

202+
# Per-request streaming detokenizers for UTF-8-safe incremental decode
203+
self._detokenizer_pool: Dict[str, Any] = {}
204+
201205
# Output queues for async streaming
202206
self.output_queues: Dict[str, asyncio.Queue] = {}
203207

@@ -446,8 +450,17 @@ def _process_batch_responses(
446450
request.output_tokens.append(response.token)
447451
request.num_output_tokens = len(request.output_tokens)
448452

449-
# Decode the new token
450-
new_text = tokenizer.decode([response.token])
453+
# Decode the new token using streaming detokenizer (UTF-8 safe)
454+
if request_id not in self._detokenizer_pool:
455+
if hasattr(tokenizer, "detokenizer"):
456+
detok = tokenizer.detokenizer
457+
else:
458+
detok = NaiveStreamingDetokenizer(tokenizer)
459+
detok.reset()
460+
self._detokenizer_pool[request_id] = detok
461+
detok = self._detokenizer_pool[request_id]
462+
detok.add_token(response.token)
463+
new_text = detok.last_segment
451464

452465
# Create output
453466
output = RequestOutput(
@@ -470,10 +483,16 @@ def _process_batch_responses(
470483
output.finish_reason = response.finish_reason
471484
finished_ids.add(request_id)
472485

473-
# Decode full output
474-
output.output_text = tokenizer.decode(request.output_tokens)
486+
# Finalize streaming detokenizer and get full output
487+
detok = self._detokenizer_pool.get(request_id)
488+
if detok is not None:
489+
detok.finalize()
490+
output.output_text = detok.text
491+
else:
492+
output.output_text = tokenizer.decode(request.output_tokens)
475493
request.output_text = output.output_text
476494
request.finish_reason = response.finish_reason
495+
self._detokenizer_pool.pop(request_id, None)
477496

478497
self.total_completion_tokens += request.num_output_tokens
479498
self.num_requests_processed += 1

vllm_mlx/scheduler.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import mlx.core as mx
2121
from mlx_lm.generate import BatchGenerator
2222
from mlx_lm.sample_utils import make_logits_processors, make_sampler
23+
from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer
2324

2425
from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig
2526
from .paged_cache import PagedCacheManager
@@ -403,7 +404,9 @@ def _chunked_next(self=batch_gen): # noqa: C901
403404

404405
if not is_cached:
405406
padded = _left_pad_prompts(inputs_raw, max_length=max_length)
406-
prompt_cache = _make_cache(self.model, padding, self.max_kv_size)
407+
prompt_cache = _make_cache(
408+
self.model, padding, self.max_kv_size
409+
)
407410
else:
408411
last_inputs = mx.array([p[-1:] for p in inputs_raw])
409412
padded = _right_pad_prompts(inputs_raw, max_length=max_length)
@@ -980,6 +983,9 @@ def __init__(
980983
# Detect if tokenizer is a processor (MLLM) and get the actual tokenizer
981984
self._actual_tokenizer = self._get_actual_tokenizer(tokenizer)
982985

986+
# Per-request streaming detokenizers for UTF-8-safe incremental decode
987+
self._detokenizer_pool: Dict[str, Any] = {}
988+
983989
# Request management - following vLLM's design
984990
self.waiting: deque[Request] = deque() # Waiting queue (FCFS)
985991
self.running: Dict[str, Request] = {} # Running requests by ID
@@ -1080,6 +1086,21 @@ def _decode_tokens(self, token_ids: List[int]) -> str:
10801086
"""
10811087
return self._actual_tokenizer.decode(token_ids)
10821088

1089+
def _get_detokenizer(self, request_id: str) -> Any:
1090+
"""Get or create a streaming detokenizer for a request."""
1091+
if request_id not in self._detokenizer_pool:
1092+
if hasattr(self.tokenizer, "detokenizer"):
1093+
detok = self.tokenizer.detokenizer
1094+
else:
1095+
detok = NaiveStreamingDetokenizer(self._actual_tokenizer)
1096+
detok.reset()
1097+
self._detokenizer_pool[request_id] = detok
1098+
return self._detokenizer_pool[request_id]
1099+
1100+
def _cleanup_detokenizer(self, request_id: str) -> None:
1101+
"""Remove the streaming detokenizer for a finished request."""
1102+
self._detokenizer_pool.pop(request_id, None)
1103+
10831104
def _get_stop_tokens(self) -> Set[int]:
10841105
"""Get stop token IDs from tokenizer or processor."""
10851106
stop_tokens = set()
@@ -1872,11 +1893,13 @@ def _process_batch_responses(
18721893

18731894
request.first_token_time = _time.time()
18741895

1875-
# Decode the new token (skip stop tokens — they are not content)
1896+
# Decode the new token using streaming detokenizer (UTF-8 safe)
18761897
if response.finish_reason == "stop":
18771898
new_text = ""
18781899
else:
1879-
new_text = self._decode_tokens([response.token])
1900+
detok = self._get_detokenizer(request_id)
1901+
detok.add_token(response.token)
1902+
new_text = detok.last_segment
18801903

18811904
# Create output
18821905
output = RequestOutput(
@@ -1899,9 +1922,15 @@ def _process_batch_responses(
18991922
output.finish_reason = response.finish_reason
19001923
finished_ids.add(request_id)
19011924

1902-
# Decode full output
1903-
output.output_text = self._decode_tokens(request.output_token_ids)
1925+
# Finalize streaming detokenizer and get full output
1926+
detok = self._detokenizer_pool.get(request_id)
1927+
if detok is not None:
1928+
detok.finalize()
1929+
output.output_text = detok.text
1930+
else:
1931+
output.output_text = self._decode_tokens(request.output_token_ids)
19041932
request.output_text = output.output_text
1933+
self._cleanup_detokenizer(request_id)
19051934

19061935
# Extract cache for future reuse (critical for agentic multi-turn)
19071936
if hasattr(response, "prompt_cache"):

vllm_mlx/server.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import json
4343
import logging
4444
import os
45+
import re
4546
import secrets
4647
import tempfile
4748
import threading
@@ -158,6 +159,11 @@ def _resolve_top_p(request_value: float | None) -> float:
158159
_tool_call_parser: str | None = None # Parser name: auto, mistral, qwen, llama, hermes
159160
_tool_parser_instance = None # Instantiated parser
160161

162+
# Pattern to strip leaked tool call markup from content output.
163+
# Safety net: the tool parser should consume these, but if it doesn't
164+
# (e.g. malformed JSON, stray closing tags), strip them before emitting.
165+
_TOOL_MARKUP_PATTERN = re.compile(r"</?tool_call>|</?tool_call_reasoning>")
166+
161167

162168
def _load_prefix_cache_from_disk() -> None:
163169
"""Load prefix cache from disk during startup."""
@@ -2097,6 +2103,9 @@ async def stream_chat_completion(
20972103

20982104
# Normal content from tool parser
20992105
content = tool_result.get("content", "")
2106+
# Strip any leaked tool markup tags
2107+
if content:
2108+
content = _TOOL_MARKUP_PATTERN.sub("", content)
21002109

21012110
chunk = ChatCompletionChunk(
21022111
id=response_id,
@@ -2187,6 +2196,9 @@ async def stream_chat_completion(
21872196

21882197
# Normal content from tool parser
21892198
content = tool_result.get("content", "")
2199+
# Strip any leaked tool markup tags
2200+
if content:
2201+
content = _TOOL_MARKUP_PATTERN.sub("", content)
21902202

21912203
chunk = ChatCompletionChunk(
21922204
id=response_id,

vllm_mlx/utils/tokenizer.py

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,27 @@ def _needs_tokenizer_fallback(model_name: str) -> bool:
2828
return any(pattern.lower() in model_lower for pattern in FALLBACK_MODELS)
2929

3030

31+
def _needs_strict_false(model_name: str) -> bool:
32+
"""Check if model needs strict=False loading (VLM models with extra weights).
33+
34+
VLM models (e.g., Qwen3.5) have vision_tower weights that don't match
35+
the text-only model class. Loading with strict=True fails and wastes
36+
memory by loading all weights (~100 GB) before raising ValueError.
37+
Detect these models up-front to avoid the double-load penalty.
38+
"""
39+
from mlx_lm.utils import _download, load_config
40+
41+
try:
42+
model_path = _download(model_name)
43+
config = load_config(model_path)
44+
except Exception:
45+
return False
46+
# VLM models have vision_config or text_config with a separate model_type
47+
if "vision_config" in config and "text_config" in config:
48+
return True
49+
return False
50+
51+
3152
def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
3253
"""
3354
Load model and tokenizer with fallback for non-standard tokenizers.
@@ -50,32 +71,36 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
5071
)
5172
return _load_with_tokenizer_fallback(model_name)
5273

74+
# VLM models (e.g., Qwen3.5) have extra vision weights that cause
75+
# strict=True to fail. Skip the first load attempt to avoid loading
76+
# ~100 GB of weights twice (which can cause OOM on 256 GB systems).
77+
if _needs_strict_false(model_name):
78+
logger.info(
79+
f"Model {model_name} detected as VLM, loading directly with strict=False"
80+
)
81+
return _load_strict_false(model_name, tokenizer_config)
82+
5383
try:
5484
model, tokenizer = load(model_name, tokenizer_config=tokenizer_config)
5585
except ValueError as e:
5686
# Fallback for models with non-standard tokenizers
5787
if "TokenizersBackend" in str(e) or "Tokenizer class" in str(e):
5888
logger.warning(f"Standard tokenizer loading failed, using fallback: {e}")
5989
return _load_with_tokenizer_fallback(model_name)
60-
# Fallback for models with extra weights (e.g., MTP layers)
90+
# Fallback for models with extra weights (e.g., MTP layers, vision tower)
6191
elif "parameters not in model" in str(e):
62-
logger.warning(f"Extra parameters found (e.g., MTP weights), retrying with strict=False")
92+
logger.warning(
93+
f"Extra parameters found (e.g., MTP/vision weights), retrying with strict=False"
94+
)
6395
# Clear traceback references to free memory from the failed first load.
6496
# Without this, large models (200GB+) cause OOM during retry because
6597
# the traceback holds references to the first load's weight tensors.
6698
e.__traceback__ = None
6799
del e
68100
import gc
101+
69102
gc.collect()
70-
from mlx_lm.utils import _download, load_model, load_tokenizer
71-
model_path = _download(model_name)
72-
model, config = load_model(model_path, strict=False)
73-
tokenizer = load_tokenizer(
74-
model_path, tokenizer_config,
75-
eos_token_ids=config.get("eos_token_id", None),
76-
)
77-
_try_inject_mtp(model, model_path, config)
78-
return model, tokenizer
103+
return _load_strict_false(model_name, tokenizer_config)
79104
else:
80105
raise
81106

@@ -84,10 +109,53 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
84109
return model, tokenizer
85110

86111

112+
def _load_strict_false(model_name: str, tokenizer_config: dict = None):
113+
"""Load model with strict=False to discard extra weights.
114+
115+
Handles models with extra parameters that the text-only model class
116+
doesn't define (e.g., vision tower weights in VLM models like Qwen3.5,
117+
or MTP layers). The model's own sanitize() handles key remapping
118+
(e.g., language_model.* prefix), and strict=False silently drops
119+
unmatched keys.
120+
"""
121+
import mlx.core as mx
122+
from mlx_lm.utils import _download, load_model, load_tokenizer
123+
124+
model_path = _download(model_name)
125+
model, config = load_model(model_path, strict=False)
126+
127+
# Verify weights loaded correctly
128+
from mlx.utils import tree_flatten
129+
130+
params = tree_flatten(model.parameters())
131+
total_params = len(params)
132+
zero_params = sum(1 for _, v in params if mx.all(v == 0).item())
133+
logger.info(
134+
f"[strict=False] Loaded {total_params} parameters, "
135+
f"{zero_params} all-zero tensors"
136+
)
137+
# Spot-check embedding weights
138+
if hasattr(model, "language_model"):
139+
emb = model.language_model.model.embed_tokens.weight
140+
logger.info(
141+
f"[strict=False] embed_tokens: shape={emb.shape}, "
142+
f"dtype={emb.dtype}, mean={mx.mean(emb.astype(mx.float32)).item():.4f}"
143+
)
144+
145+
tokenizer = load_tokenizer(
146+
model_path,
147+
tokenizer_config or {},
148+
eos_token_ids=config.get("eos_token_id", None),
149+
)
150+
_try_inject_mtp(model, model_path, config)
151+
return model, tokenizer
152+
153+
87154
def _try_inject_mtp(model, model_path, config):
88155
"""Inject MTP support if model has MTP config + weights."""
89156
if config.get("num_nextn_predict_layers", 0) > 0:
90157
from ..patches.qwen3_next_mtp import inject_mtp_support
158+
91159
inject_mtp_support(model, model_path, config)
92160

93161

0 commit comments

Comments
 (0)