Skip to content

Commit c010de5

Browse files
raullenchaiYour Nameclaude
authored
fix: auto-install Python + Metal shader warmup (#43)
* fix: auto-install Python + Metal shader warmup on startup P0 — install.sh: if no Python 3.10+ and no Homebrew, automatically downloads standalone Python from python-build-standalone (no sudo needed). Eliminates the #1 install blocker for users without Homebrew. P0 — first request hang: adds a warmup step after model load that runs one forward pass to trigger Metal shader compilation. Prints "Warming up (compiling Metal shaders)..." so users know what's happening. Prevents the first real request from hanging 5+ minutes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: strip think tags from Anthropic endpoint + disk space check P2: Think tags leaked through Anthropic /v1/messages endpoint because it bypassed the reasoning parser entirely. Both streaming and non-streaming paths now use the reasoning parser to separate reasoning from content, emitting only content to Anthropic clients. P1: Add disk space check before model download — queries HuggingFace for model repo size and warns if available disk is insufficient. Skips silently for local/cached models. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: standalone Python URL + move warmup to lifespan hook P0: The hardcoded python-build-standalone URL pointed at the old indygreg repo which now 404s. Updated to astral-sh/python-build-standalone with cpython 3.12.13 (release 20260320), verified accessible. P2: Metal shader warmup ran in CLI before batched/hybrid engines were started (they start in the FastAPI lifespan hook). Moved warmup into the lifespan hook so it runs after engine.start() for all engine types. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: add generate_warmup() to BatchedEngine and HybridEngine Both engines inherited the no-op base generate_warmup(), so Metal shader warmup in the lifespan hook was silently skipped for --continuous-batching and hybrid modes. Now both engines override it with a real forward pass, matching SimpleEngine's implementation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Your Name <you@example.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 669f6e0 commit c010de5

File tree

7 files changed

+206
-9
lines changed

7 files changed

+206
-9
lines changed

install.sh

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,27 @@ done
4646

4747
if [ -z "$PYTHON" ]; then
4848
echo ""
49-
echo " Python 3.10+ not found."
49+
echo " Python 3.10+ not found. Installing automatically..."
5050
if command -v brew &>/dev/null; then
5151
echo " Installing Python 3.12 via Homebrew..."
5252
brew install python@3.12
5353
PYTHON="python3.12"
5454
else
55-
echo " Please install Python 3.10+ from https://www.python.org/downloads/"
56-
echo " Or install Homebrew first: https://brew.sh"
57-
exit 1
55+
# Download standalone Python — no Homebrew or sudo needed
56+
STANDALONE_DIR="${HOME}/.rapid-mlx-python"
57+
PY_VERSION="3.12.13"
58+
PY_BUILD="20260320"
59+
PY_URL="https://github.com/astral-sh/python-build-standalone/releases/download/${PY_BUILD}/cpython-${PY_VERSION}+${PY_BUILD}-aarch64-apple-darwin-install_only.tar.gz"
60+
echo " Downloading Python ${PY_VERSION} (standalone, no sudo needed)..."
61+
mkdir -p "$STANDALONE_DIR"
62+
curl -fsSL "$PY_URL" | tar xz -C "$STANDALONE_DIR" --strip-components=1
63+
PYTHON="${STANDALONE_DIR}/bin/python3"
64+
if ! "$PYTHON" --version &>/dev/null; then
65+
echo " Error: Failed to install standalone Python."
66+
echo " Please install Python 3.10+ from https://www.python.org/downloads/"
67+
exit 1
68+
fi
69+
echo " Installed Python $("$PYTHON" --version 2>&1) to $STANDALONE_DIR"
5870
fi
5971
fi
6072

vllm_mlx/cli.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,78 @@
1616
import sys
1717

1818

19+
def _check_disk_space(model_name: str) -> None:
20+
"""Check if there's enough disk space to download the model.
21+
22+
Queries HuggingFace for model repo size and compares with available space.
23+
Warns (but does not block) if disk space is insufficient.
24+
Skips silently if the model is already local or if the check fails.
25+
"""
26+
import os
27+
from pathlib import Path
28+
29+
# Skip if model is a local path that already exists
30+
if os.path.exists(model_name):
31+
return
32+
33+
# Check if model is already cached by huggingface_hub
34+
try:
35+
from huggingface_hub import try_to_load_from_cache
36+
37+
# Quick check: see if config.json is cached (implies model is downloaded)
38+
cached = try_to_load_from_cache(model_name, "config.json")
39+
if isinstance(cached, str) and os.path.exists(cached):
40+
return
41+
except Exception:
42+
pass
43+
44+
# Query HuggingFace API for model size
45+
try:
46+
from huggingface_hub import model_info
47+
48+
info = model_info(model_name, files_metadata=True)
49+
# safetensors_total or siblings file sizes
50+
model_size_bytes = 0
51+
if hasattr(info, "safetensors") and info.safetensors:
52+
# Total size from safetensors metadata
53+
params = info.safetensors
54+
if hasattr(params, "total"):
55+
# This is parameter count, not file size — use siblings instead
56+
pass
57+
# Sum file sizes from siblings
58+
if hasattr(info, "siblings") and info.siblings:
59+
for sibling in info.siblings:
60+
if hasattr(sibling, "size") and sibling.size:
61+
model_size_bytes += sibling.size
62+
63+
if model_size_bytes == 0:
64+
return # Can't determine size, skip check
65+
66+
# Get available disk space
67+
cache_dir = Path.home() / ".cache" / "huggingface"
68+
stat = os.statvfs(str(cache_dir) if cache_dir.exists() else str(Path.home()))
69+
available_bytes = stat.f_bavail * stat.f_frsize
70+
71+
model_size_gb = model_size_bytes / (1024**3)
72+
available_gb = available_bytes / (1024**3)
73+
74+
# Need ~10% extra for temp files during download
75+
required_bytes = int(model_size_bytes * 1.1)
76+
77+
if available_bytes < required_bytes:
78+
print()
79+
print(
80+
f" Warning: Model requires ~{model_size_gb:.1f} GB "
81+
f"but only {available_gb:.1f} GB available on disk."
82+
)
83+
print(
84+
" The download may fail. Free up disk space or choose a smaller model."
85+
)
86+
print()
87+
except Exception:
88+
pass # Non-critical — don't block startup on check failure
89+
90+
1991
def serve_command(args):
2092
"""Start the OpenAI-compatible server."""
2193
import logging
@@ -203,6 +275,9 @@ def serve_command(args):
203275
else:
204276
print("Mode: Simple (maximum throughput)")
205277

278+
# Check disk space before downloading model
279+
_check_disk_space(args.model)
280+
206281
# Load model with unified server
207282
load_model(
208283
args.model,
@@ -224,6 +299,8 @@ def serve_command(args):
224299
)
225300

226301
# Start server
302+
# Note: Metal shader warmup runs in the FastAPI lifespan hook (server.py)
303+
# so it works for all engine types including batched/hybrid which start later.
227304
print()
228305
host_display = "localhost" if args.host == "0.0.0.0" else args.host
229306
print(f" Ready: http://{host_display}:{args.port}/v1")

vllm_mlx/engine/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ def preserve_native_tool_format(self) -> bool:
6969
def preserve_native_tool_format(self, value: bool) -> None:
7070
self._preserve_native_tool_format = value
7171

72+
def generate_warmup(self) -> None:
73+
"""Run a minimal generation to compile Metal shaders.
74+
75+
This prevents the first real request from hanging for minutes
76+
while shaders compile on-demand.
77+
"""
78+
pass # Subclasses may override
79+
7280
@abstractmethod
7381
async def start(self) -> None:
7482
"""Start the engine (load model if not loaded)."""

vllm_mlx/engine/batched.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,20 @@ def tokenizer(self) -> Any:
186186
return getattr(self._processor, "tokenizer", self._processor)
187187
return self._tokenizer
188188

189+
def generate_warmup(self) -> None:
190+
"""Run a minimal forward pass to compile Metal shaders."""
191+
if not self._loaded or self._model is None or self._is_mllm:
192+
return
193+
try:
194+
import mlx.core as mx
195+
196+
tokens = self._tokenizer.encode("Hi")
197+
input_ids = mx.array([tokens])
198+
self._model(input_ids)
199+
mx.eval(mx.zeros(1))
200+
except Exception:
201+
pass # Non-fatal
202+
189203
async def start(self) -> None:
190204
"""Start the engine (load model if not loaded)."""
191205
if self._loaded:

vllm_mlx/engine/hybrid.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ def tokenizer(self) -> Any:
127127
"""Get the tokenizer."""
128128
return self._shared_tokenizer
129129

130+
def generate_warmup(self) -> None:
131+
"""Run a minimal forward pass to compile Metal shaders."""
132+
# Delegate to the simple engine if available (it has the model loaded)
133+
if self._simple is not None:
134+
self._simple.generate_warmup()
135+
elif self._batched is not None:
136+
self._batched.generate_warmup()
137+
130138
async def start(self) -> None:
131139
"""Start the engine (load shared model and initialize sub-engines)."""
132140
if self._loaded:

vllm_mlx/engine/simple.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,24 @@ def tokenizer(self) -> Any:
101101
return getattr(self._model, "processor", None)
102102
return self._model.tokenizer
103103

104+
def generate_warmup(self) -> None:
105+
"""Run a minimal generation to compile Metal shaders."""
106+
if not self._loaded or self._model is None or self._is_mllm:
107+
return
108+
try:
109+
import mlx.core as mx
110+
111+
model = self._model
112+
tokenizer = model.tokenizer
113+
# Encode a short prompt and generate 1 token
114+
tokens = tokenizer.encode("Hi")
115+
input_ids = mx.array([tokens])
116+
# Run one forward pass to trigger shader compilation
117+
model.model(input_ids)
118+
mx.eval(mx.zeros(1))
119+
except Exception:
120+
pass # Non-fatal
121+
104122
async def start(self) -> None:
105123
"""Start the engine (load model if not loaded)."""
106124
if self._loaded:

vllm_mlx/server.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,23 @@ async def lifespan(app: FastAPI):
327327
if _engine is not None and hasattr(_engine, "_loaded") and not _engine._loaded:
328328
await _engine.start()
329329

330+
# Warmup: generate one token to trigger Metal shader compilation.
331+
# Runs here (not in CLI) so all engine types are fully started first.
332+
if _engine is not None:
333+
import time as _time
334+
335+
logger.info("Warming up (compiling Metal shaders)...")
336+
_warmup_start = _time.monotonic()
337+
try:
338+
import mlx.core as mx
339+
340+
_engine.generate_warmup()
341+
mx.eval(mx.zeros(1)) # Force sync
342+
except Exception as e:
343+
logger.debug(f"Warmup failed (non-fatal): {e}")
344+
_warmup_secs = _time.monotonic() - _warmup_start
345+
logger.info(f"Warmup complete ({_warmup_secs:.1f}s)")
346+
330347
# Load persisted cache from disk (AFTER engine start — AsyncEngineCore must exist)
331348
if _engine is not None and hasattr(_engine, "load_cache_from_disk"):
332349
_load_prefix_cache_from_disk()
@@ -2204,10 +2221,10 @@ async def create_anthropic_message(
22042221
output.text, openai_request
22052222
)
22062223

2207-
# Clean output text
2224+
# Clean output text — strip think tags so Anthropic clients get pure content
22082225
final_content = None
22092226
if cleaned_text:
2210-
final_content = clean_output_text(cleaned_text)
2227+
final_content = strip_thinking_tags(clean_output_text(cleaned_text))
22112228

22122229
# Determine finish reason
22132230
finish_reason = "tool_calls" if tool_calls else output.finish_reason
@@ -2370,10 +2387,15 @@ async def _stream_anthropic_messages(
23702387
}
23712388
yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n"
23722389

2373-
# Stream content deltas
2390+
# Stream content deltas — use reasoning parser to strip think tags
23742391
accumulated_text = ""
2392+
accumulated_raw = ""
23752393
completion_tokens = 0
23762394

2395+
# Reset reasoning parser state for this stream
2396+
if _reasoning_parser:
2397+
_reasoning_parser.reset_state()
2398+
23772399
async for output in engine.stream_chat(messages=messages, **chat_kwargs):
23782400
delta_text = output.new_text
23792401

@@ -2382,8 +2404,25 @@ async def _stream_anthropic_messages(
23822404
completion_tokens = output.completion_tokens
23832405

23842406
if delta_text:
2385-
# Filter special tokens
2386-
content = strip_special_tokens(delta_text)
2407+
content = None
2408+
2409+
# Use reasoning parser to separate reasoning from content
2410+
if _reasoning_parser:
2411+
previous_raw = accumulated_raw
2412+
accumulated_raw += delta_text
2413+
delta_msg = _reasoning_parser.extract_reasoning_streaming(
2414+
previous_raw, accumulated_raw, delta_text
2415+
)
2416+
if delta_msg is not None:
2417+
# Only emit content, discard reasoning for Anthropic clients
2418+
content = delta_msg.content
2419+
else:
2420+
# No reasoning parser — pass through with special token filter
2421+
content = strip_special_tokens(delta_text)
2422+
2423+
if content:
2424+
# Filter special tokens from parser output too
2425+
content = strip_special_tokens(content)
23872426

23882427
if content:
23892428
accumulated_text += content
@@ -2394,6 +2433,27 @@ async def _stream_anthropic_messages(
23942433
}
23952434
yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n"
23962435

2436+
# Handle reasoning parser finalization (e.g. no-tag correction)
2437+
if _reasoning_parser and accumulated_raw:
2438+
final_msg = (
2439+
_reasoning_parser.finalize_streaming(accumulated_raw)
2440+
if hasattr(_reasoning_parser, "finalize_streaming")
2441+
else None
2442+
)
2443+
if final_msg and final_msg.content:
2444+
# Emit corrected content (model didn't use think tags at all)
2445+
content = strip_special_tokens(final_msg.content)
2446+
if content:
2447+
accumulated_text = content # Replace accumulated
2448+
delta_event = {
2449+
"type": "content_block_delta",
2450+
"index": 0,
2451+
"delta": {"type": "text_delta", "text": content},
2452+
}
2453+
yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n"
2454+
# Reset parser state for next request
2455+
_reasoning_parser.reset_state()
2456+
23972457
# Check for tool calls in accumulated text
23982458
_, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request)
23992459

0 commit comments

Comments
 (0)