Skip to content

Commit 5bba500

Browse files
authored
Fix eager import cascade causing platform fallback in platform.py + add GPT-OSS mlx_lm ground truth (#221)
1 parent 0bc1044 commit 5bba500

File tree

4 files changed

+175
-20
lines changed

4 files changed

+175
-20
lines changed

tests/test_platform.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
"""Tests for Metal platform."""
33

44
import platform
5-
from types import SimpleNamespace
5+
import sys
6+
from types import ModuleType, SimpleNamespace
67

78
import pytest
89
import torch
@@ -24,10 +25,12 @@ def _patch_stt_resolution(
2425
is_stt: bool,
2526
) -> None:
2627
monkeypatch.setattr(
27-
"vllm_metal.platform.get_model_download_path",
28+
"vllm_metal.utils.get_model_download_path",
2829
lambda model: model,
2930
)
30-
monkeypatch.setattr("vllm_metal.platform.is_stt_model", lambda _model: is_stt)
31+
monkeypatch.setattr(
32+
"vllm_metal.stt.detection.is_stt_model", lambda _model: is_stt
33+
)
3134

3235
def test_device_name(self) -> None:
3336
"""Test device name retrieval."""
@@ -123,6 +126,29 @@ def test_is_available_does_not_mutate_default_device(self) -> None:
123126

124127
assert before == after
125128

129+
def test_is_available_propagates_unexpected_mlx_errors(
130+
self, monkeypatch: pytest.MonkeyPatch
131+
) -> None:
132+
"""Unexpected MLX errors should surface instead of looking unavailable."""
133+
monkeypatch.setattr("vllm_metal.platform.py_platform.machine", lambda: "arm64")
134+
monkeypatch.setattr("vllm_metal.platform.py_platform.system", lambda: "Darwin")
135+
136+
mlx_module = ModuleType("mlx")
137+
mlx_core = ModuleType("mlx.core")
138+
139+
class _BrokenMetal:
140+
@staticmethod
141+
def is_available() -> bool:
142+
raise ValueError("unexpected mlx regression")
143+
144+
mlx_core.metal = _BrokenMetal()
145+
mlx_module.core = mlx_core
146+
monkeypatch.setitem(sys.modules, "mlx", mlx_module)
147+
monkeypatch.setitem(sys.modules, "mlx.core", mlx_core)
148+
149+
with pytest.raises(ValueError, match="unexpected mlx regression"):
150+
MetalPlatform.is_available()
151+
126152
def test_torch_device(self) -> None:
127153
"""Test PyTorch device retrieval."""
128154

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
#!/usr/bin/env python3
22
# SPDX-License-Identifier: Apache-2.0
3-
"""Generate golden token IDs for the deterministic smoke test.
3+
"""Generate golden token IDs for deterministic smoke tests.
44
55
Runs vLLM offline inference (greedy, max_num_seqs=1) and prints golden
6-
token-ID dicts to paste into test_paged_deterministic.py.
6+
token-ID dicts to paste into test files or smoke scripts.
77
88
Usage:
9-
# MLX inline cache (default):
9+
# Qwen3 (default, MLX inline cache):
1010
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py
1111
12+
# GPT-OSS (requires chat template):
13+
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py \
14+
--model openai/gpt-oss-20b --max-tokens 100 --chat-template
15+
1216
# Paged KV cache:
1317
VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \
1418
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py
@@ -17,14 +21,12 @@
1721
Numeric fractions are only valid for the paged attention path.
1822
"""
1923

24+
import argparse
2025
import os
2126

2227
os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
2328

24-
from vllm import LLM, SamplingParams
25-
26-
MODEL = "Qwen/Qwen3-0.6B"
27-
MAX_TOKENS = 10
29+
from vllm import LLM, SamplingParams # noqa: E402
2830

2931
PROMPTS = [
3032
"The capital of France is",
@@ -35,21 +37,54 @@
3537
"Machine learning is",
3638
]
3739

40+
41+
def _apply_chat_template(model_name, prompts):
42+
"""Apply chat template and return (formatted_prompts, reverse_map)."""
43+
from transformers import AutoTokenizer
44+
45+
tokenizer = AutoTokenizer.from_pretrained(model_name)
46+
formatted = []
47+
reverse_map = {}
48+
for prompt in prompts:
49+
messages = [{"role": "user", "content": prompt}]
50+
fmt = tokenizer.apply_chat_template(
51+
messages, add_generation_prompt=True, tokenize=False
52+
)
53+
formatted.append(fmt)
54+
reverse_map[fmt] = prompt
55+
return formatted, reverse_map
56+
57+
3858
if __name__ == "__main__":
59+
parser = argparse.ArgumentParser(description=__doc__)
60+
parser.add_argument("--model", default="Qwen/Qwen3-0.6B")
61+
parser.add_argument("--max-tokens", type=int, default=10)
62+
parser.add_argument(
63+
"--chat-template",
64+
action="store_true",
65+
help="Apply chat template before inference (required for GPT-OSS)",
66+
)
67+
args = parser.parse_args()
68+
3969
paged = os.environ.get("VLLM_METAL_USE_PAGED_ATTENTION", "0") == "1"
4070
label = "PAGED" if paged else "MLX"
41-
print(f"\n--- Generating golden values for {label} path ---\n")
71+
print(f"\n--- Generating golden values for {label} path ({args.model}) ---\n")
72+
73+
prompts = PROMPTS
74+
reverse_map = None
75+
if args.chat_template:
76+
prompts, reverse_map = _apply_chat_template(args.model, PROMPTS)
4277

43-
llm = LLM(model=MODEL, max_model_len=512, max_num_seqs=1)
44-
sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
45-
outputs = llm.generate(PROMPTS, sp)
78+
llm = LLM(model=args.model, max_model_len=512, max_num_seqs=1)
79+
sp = SamplingParams(temperature=0, max_tokens=args.max_tokens)
80+
outputs = llm.generate(prompts, sp)
4681

4782
print(f"\nGOLDEN_{label} = {{")
4883
for o in outputs:
49-
prompt = o.prompt
84+
display = reverse_map[o.prompt] if reverse_map else o.prompt
5085
ids = list(o.outputs[0].token_ids)
5186
text = o.outputs[0].text
52-
pad = 45 - len(prompt)
53-
print(f" {prompt!r}:{' ' * pad}{ids},")
87+
pad = 50 - len(display)
88+
print(f" {display!r}:{' ' * max(pad, 1)}{ids},")
5489
print(f" # → {text!r}")
5590
print("}")

tools/test_gpt_oss_smoke.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""GPT-OSS 20B smoke test: mlx_lm ground truth for sink attention work (#148).
4+
5+
Loads openai/gpt-oss-20b, generates with greedy decoding, and compares
6+
output against golden token IDs. Not in CI since it requires ~21.5 GB model.
7+
8+
Run:
9+
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/test_gpt_oss_smoke.py
10+
"""
11+
12+
import os
13+
import sys
14+
15+
os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
16+
17+
from transformers import AutoTokenizer # noqa: E402
18+
from vllm import LLM, SamplingParams # noqa: E402
19+
20+
MODEL_NAME = "openai/gpt-oss-20b"
21+
MAX_TOKENS = 100
22+
23+
PROMPTS = [
24+
"The capital of France is",
25+
"The weather today is not",
26+
"One plus one equals",
27+
"The largest planet in our solar system is",
28+
"Water boils at a temperature of",
29+
]
30+
31+
# fmt: off
32+
# Golden token IDs from MLX inline cache, greedy decoding (openai/gpt-oss-20b).
33+
# Generated via:
34+
# VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py \
35+
# --model openai/gpt-oss-20b --max-tokens 100 --chat-template
36+
#
37+
# Note: FP non-determinism at longer sequences may cause 2-3 prompts to diverge
38+
# after ~25 tokens across runs. Regenerate with the command above if needed.
39+
GOLDEN_MLX = {
40+
"The capital of France is": [200005, 35644, 200008, 976, 1825, 5003, 25, 392, 976, 9029, 328, 10128, 382, 4050, 3164, 6960, 1682, 290, 6052, 25, 392, 72782, 4050, 2632, 9570, 483, 392, 72782, 4050, 63659, 1327, 6052, 13, 200007, 200006, 173781, 200005, 17196, 200008, 72782, 200002],
41+
"The weather today is not": [200005, 35644, 200008, 976, 1825, 5003, 25, 392, 976, 11122, 4044, 382, 625, 4050, 4569, 7890, 60592, 13, 3164, 3572, 413, 8601, 261, 21872, 25, 392, 976, 11122, 4044, 382, 625, 723, 64493, 49706, 889, 1023, 9289, 9115, 13, 3164, 3572, 413, 16054, 395, 3543, 30, 2604, 10112, 1023, 1682, 316, 1761, 290, 11122, 30, 623, 1825, 5003, 392, 976, 11122, 4044, 382, 625, 4050, 4569, 382, 60592, 13, 1416, 1309, 316, 9570, 54286, 13, 1416, 2023, 3810, 395, 108041, 25, 392, 4827, 1481, 481, 1299, 316, 1761, 1078, 290, 11122, 16842, 2604, 581, 2023, 18135, 484, 1023, 1682, 316],
42+
"One plus one equals": [200005, 35644, 200008, 976, 1825, 5003, 25, 392, 5045, 2932, 1001, 29702, 4050, 3164, 6960, 1682, 290, 6052, 25, 220, 17, 13, 3072, 10112, 1023, 1682, 261, 945, 65742, 6052, 30, 623, 1825, 3572, 413, 11493, 13, 623, 63122, 6052, 25, 220, 17, 13, 3072, 10112, 1023, 1682, 261, 15681, 30, 623, 21179, 25, 392, 3575, 553, 17554, 162016, 11, 261, 4410, 6439, 2359, 22203, 656, 7788, 17527, 3692, 32711, 860, 3582, 21179, 13, 2632, 6052, 25, 220, 17, 13, 200007, 200006, 173781, 200005, 17196, 200008, 5045, 2932, 1001, 29702, 6240, 17, 410, 13, 200002],
43+
"The largest planet in our solar system is": [200005, 35644, 200008, 976, 1825, 31064, 25, 392, 976, 10574, 17921, 306, 1039, 17624, 2420, 382, 4050, 3164, 6960, 1682, 290, 6052, 25, 79575, 13, 3164, 3572, 1682, 261, 18128, 13, 2632, 6052, 25, 79575, 13, 138743, 8633, 4275, 290, 10574, 13, 2632, 9570, 25, 79575, 13, 200007, 200006, 173781, 200005, 17196, 200008, 976, 10574, 17921, 306, 1039, 17624, 2420, 382, 6240, 41, 26451, 410, 13, 200002],
44+
"Water boils at a temperature of": [200005, 35644, 200008, 976, 1825, 5003, 25, 392, 27874, 165683, 540, 261, 12088, 328, 4050, 3164, 6960, 1682, 290, 79667, 2438, 328, 3411, 13, 3072, 290, 4928, 382, 60592, 25, 392, 27874, 165683, 540, 261, 12088, 328, 4050, 3164, 3572, 1682, 290, 6052, 25, 220, 1353, 26557, 540, 220, 16, 83327, 11, 503, 220, 19584, 68854, 13, 3072, 10112, 1023, 1682, 290, 12088, 306, 181775, 25, 220, 33797, 13, 1055, 658, 13, 623, 1825, 3572, 413, 35885, 261, 52077, 6052, 13, 623, 4928, 382, 60592, 889, 6960, 1023, 1682, 290, 79667, 2438, 13, 2632, 6052, 25, 220, 1353, 26557, 350],
45+
}
46+
# fmt: on
47+
48+
49+
def _apply_chat_template(model_name, prompts):
50+
"""Apply chat template and return (formatted_prompts, reverse_map)."""
51+
tokenizer = AutoTokenizer.from_pretrained(model_name)
52+
formatted = []
53+
reverse_map = {}
54+
for prompt in prompts:
55+
messages = [{"role": "user", "content": prompt}]
56+
fmt = tokenizer.apply_chat_template(
57+
messages, add_generation_prompt=True, tokenize=False
58+
)
59+
formatted.append(fmt)
60+
reverse_map[fmt] = prompt
61+
return formatted, reverse_map
62+
63+
64+
if __name__ == "__main__":
65+
formatted_prompts, reverse_map = _apply_chat_template(MODEL_NAME, PROMPTS)
66+
67+
llm = LLM(model=MODEL_NAME, max_model_len=512, max_num_seqs=1)
68+
sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
69+
outputs = llm.generate(formatted_prompts, sp)
70+
71+
passed = 0
72+
failed = 0
73+
for o in outputs:
74+
prompt = reverse_map[o.prompt]
75+
token_ids = list(o.outputs[0].token_ids)
76+
text = o.outputs[0].text
77+
expected = GOLDEN_MLX[prompt]
78+
matched = token_ids == expected
79+
80+
status = "PASS" if matched else "FAIL"
81+
print(f" [{status}] {prompt!r}")
82+
print(f" output: {text!r}")
83+
if not matched:
84+
print(f" got: {token_ids}")
85+
print(f" expected: {expected}")
86+
failed += 1
87+
else:
88+
passed += 1
89+
90+
print(f"\n{passed} passed, {failed} failed")
91+
sys.exit(1 if failed else 0)

vllm_metal/platform.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
from vllm.v1.attention.backends.registry import AttentionBackendEnum
1212

1313
from vllm_metal.config import get_config
14-
from vllm_metal.stt.detection import is_stt_model
15-
from vllm_metal.stt.policy import apply_stt_scheduler_policy
16-
from vllm_metal.utils import get_model_download_path
1714

1815
if TYPE_CHECKING:
1916
from vllm.config import VllmConfig
@@ -273,6 +270,12 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
273270
model_config.disable_cascade_attn = True
274271

275272
# STT model detection — set tokenizer fallback if not already configured.
273+
# Lazy imports to avoid circular import: platform.py is loaded during
274+
# vllm.config init, and stt.detection imports from vllm.config.
275+
from vllm_metal.stt.detection import is_stt_model
276+
from vllm_metal.stt.policy import apply_stt_scheduler_policy
277+
from vllm_metal.utils import get_model_download_path
278+
276279
resolved_model = (
277280
get_model_download_path(model_config.model)
278281
if model_config is not None

0 commit comments

Comments
 (0)