Skip to content

Commit 8401ca0

Browse files
authored
[Feature] Add Whisper model definition and weight sanitization (#133)
## Summary - Add `vllm_metal/stt/whisper.py` — Whisper encoder/decoder model (MLX) - WhisperConfig supports both HuggingFace and MLX config formats - Weight sanitization: HF key renaming, Conv1d transpose, dtype casting - Add 5 unit tests (TestWeightSanitize) Related #91 ## How to verify ### Run tests ```bash source ~/.venv-vllm-metal/bin/activate pytest tests/test_whisper.py -v ``` All 9 tests should pass, including the 4 new decoder cache-path tests. ### Minimal repro (before fix) The mask bug triggers when `q_len > 1` during cached decode: ```python import mlx.core as mx from vllm_metal.stt.whisper import WhisperConfig, WhisperModel config = WhisperConfig( n_mels=80, n_audio_ctx=10, n_audio_state=64, n_audio_head=2, n_audio_layer=1, n_vocab=100, n_text_ctx=32, n_text_state=64, n_text_head=2, n_text_layer=1, ) model = WhisperModel(config, dtype=mx.float32) mel = mx.random.normal((1, 20, 80)) audio_features = model.encode(mel) # Prefill 3 tokens _, kv_cache = model.decode(mx.array([[1, 2, 3]]), audio_features) # Cached decode 2 tokens — before fix this raised: # ValueError: [broadcast_shapes] Shapes (1,2,2,5) and (2,2) cannot be broadcast logits, _ = model.decode(mx.array([[4, 5]]), audio_features, kv_cache) mx.eval(logits) print(f"OK — logits shape: {logits.shape}") # (1, 2, 100) ``` --------- Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
1 parent 1a036e8 commit 8401ca0

2 files changed

Lines changed: 612 additions & 0 deletions

File tree

tests/test_whisper.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for Whisper model: weight sanitization and decoder cache paths."""
3+
4+
from __future__ import annotations
5+
6+
import mlx.core as mx
7+
import pytest
8+
9+
from vllm_metal.stt.whisper import WhisperConfig, WhisperModel
10+
11+
12+
@pytest.fixture()
13+
def model():
14+
"""Create a minimal WhisperModel for testing."""
15+
return WhisperModel(WhisperConfig(), dtype=mx.float16)
16+
17+
18+
# ===========================================================================
19+
# Weight sanitization
20+
# ===========================================================================
21+
22+
23+
class TestWeightSanitize:
24+
"""Tests for WhisperModel.sanitize() weight mapping."""
25+
26+
def test_sanitize_hf_key_rename(self, model) -> None:
27+
"""HuggingFace keys should be renamed to MLX format."""
28+
weights = {
29+
"model.encoder.layers.0.self_attn.q_proj.weight": mx.zeros((512, 512)),
30+
}
31+
sanitized = model.sanitize(weights)
32+
assert "encoder.blocks.0.attn.query.weight" in sanitized
33+
assert "model.encoder.layers.0.self_attn.q_proj.weight" not in sanitized
34+
35+
def test_sanitize_skips_encoder_positions(self, model) -> None:
36+
"""encoder.embed_positions should be skipped (None mapping)."""
37+
weights = {
38+
"model.encoder.embed_positions.weight": mx.zeros((1500, 512)),
39+
"model.decoder.embed_tokens.weight": mx.zeros((51865, 512)),
40+
}
41+
sanitized = model.sanitize(weights)
42+
assert "encoder.embed_positions.weight" not in sanitized
43+
assert "decoder.token_embedding.weight" in sanitized
44+
45+
def test_sanitize_transposes_conv_weights(self, model) -> None:
46+
"""Conv1d weights should be transposed from HF format."""
47+
hf_conv = mx.zeros((512, 80, 3))
48+
weights = {"model.encoder.conv1.weight": hf_conv}
49+
sanitized = model.sanitize(weights)
50+
assert sanitized["encoder.conv1.weight"].shape == (512, 3, 80)
51+
52+
def test_sanitize_preserves_mlx_format(self, model) -> None:
53+
"""Already-MLX-format weights pass through unchanged."""
54+
weights = {
55+
"encoder.blocks.0.attn.query.weight": mx.zeros((512, 512)),
56+
}
57+
sanitized = model.sanitize(weights)
58+
assert "encoder.blocks.0.attn.query.weight" in sanitized
59+
60+
def test_sanitize_casts_dtype(self, model) -> None:
61+
"""Weights should be cast to model dtype."""
62+
weights = {"encoder.ln_post.weight": mx.ones((512,), dtype=mx.float32)}
63+
sanitized = model.sanitize(weights)
64+
assert sanitized["encoder.ln_post.weight"].dtype == mx.float16
65+
66+
67+
# ===========================================================================
68+
# Decoder cache paths
69+
# ===========================================================================
70+
71+
72+
class TestDecoderCachePaths:
73+
"""Tests for decoder self-attention with and without KV cache."""
74+
75+
@pytest.fixture()
76+
def tiny_model(self):
77+
"""Create a tiny model for fast decode tests.
78+
79+
n_audio_ctx must equal input_frames // 2 because conv2 has stride=2.
80+
We use input frames = 20 -> conv2 output = 10 -> n_audio_ctx = 10.
81+
"""
82+
config = WhisperConfig(
83+
n_mels=80,
84+
n_audio_ctx=10,
85+
n_audio_state=64,
86+
n_audio_head=2,
87+
n_audio_layer=1,
88+
n_vocab=100,
89+
n_text_ctx=32,
90+
n_text_state=64,
91+
n_text_head=2,
92+
n_text_layer=1,
93+
)
94+
return WhisperModel(config, dtype=mx.float32)
95+
96+
def test_prefill_without_cache(self, tiny_model) -> None:
97+
"""Prefill (no cache) should produce logits without error."""
98+
mel = mx.random.normal((1, 20, 80))
99+
tokens = mx.array([[1, 2, 3]])
100+
101+
audio_features = tiny_model.encode(mel)
102+
logits, kv_cache = tiny_model.decode(tokens, audio_features)
103+
104+
assert logits.shape == (1, 3, 100)
105+
assert kv_cache is not None
106+
assert len(kv_cache) == 1 # 1 layer
107+
108+
def test_cached_decode_single_token(self, tiny_model) -> None:
109+
"""Decode a single token with cache should work."""
110+
mel = mx.random.normal((1, 20, 80))
111+
tokens_prefill = mx.array([[1, 2, 3]])
112+
113+
audio_features = tiny_model.encode(mel)
114+
_, kv_cache = tiny_model.decode(tokens_prefill, audio_features)
115+
116+
# Decode 1 new token
117+
next_token = mx.array([[4]])
118+
logits, kv_cache2 = tiny_model.decode(next_token, audio_features, kv_cache)
119+
120+
assert logits.shape == (1, 1, 100)
121+
# Self-attn cache k should now have 4 tokens (3 prefill + 1 new)
122+
assert kv_cache2[0][0][0].shape[1] == 4
123+
124+
def test_cached_decode_multiple_tokens(self, tiny_model) -> None:
125+
"""Decode q_len > 1 with cache — the mask bug repro case."""
126+
mel = mx.random.normal((1, 20, 80))
127+
tokens_prefill = mx.array([[1, 2, 3]])
128+
129+
audio_features = tiny_model.encode(mel)
130+
_, kv_cache = tiny_model.decode(tokens_prefill, audio_features)
131+
132+
# Decode 2 tokens at once with cache (q_len=2, k_len=5)
133+
next_tokens = mx.array([[4, 5]])
134+
logits, kv_cache2 = tiny_model.decode(next_tokens, audio_features, kv_cache)
135+
136+
assert logits.shape == (1, 2, 100)
137+
assert kv_cache2[0][0][0].shape[1] == 5
138+
139+
def test_cached_vs_full_decode_match(self, tiny_model) -> None:
140+
"""Cached decode should produce same logits as full non-cached decode."""
141+
mx.random.seed(42)
142+
mel = mx.random.normal((1, 20, 80))
143+
all_tokens = mx.array([[1, 2, 3, 4, 5]])
144+
145+
audio_features = tiny_model.encode(mel)
146+
147+
# Full decode without cache
148+
logits_full, _ = tiny_model.decode(all_tokens, audio_features)
149+
150+
# Incremental: prefill 3, then decode 2
151+
tokens_prefill = mx.array([[1, 2, 3]])
152+
_, kv_cache = tiny_model.decode(tokens_prefill, audio_features)
153+
154+
next_tokens = mx.array([[4, 5]])
155+
logits_cached, _ = tiny_model.decode(next_tokens, audio_features, kv_cache)
156+
157+
# Last 2 logits should match
158+
mx.eval(logits_full, logits_cached)
159+
diff = mx.abs(logits_full[:, 3:, :] - logits_cached).max().item()
160+
assert diff < 1e-4, f"Cached decode diverged from full decode: max diff={diff}"

0 commit comments

Comments
 (0)