|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""Tests for Whisper model: weight sanitization and decoder cache paths.""" |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import mlx.core as mx |
| 7 | +import pytest |
| 8 | + |
| 9 | +from vllm_metal.stt.whisper import WhisperConfig, WhisperModel |
| 10 | + |
| 11 | + |
| 12 | +@pytest.fixture() |
| 13 | +def model(): |
| 14 | + """Create a minimal WhisperModel for testing.""" |
| 15 | + return WhisperModel(WhisperConfig(), dtype=mx.float16) |
| 16 | + |
| 17 | + |
| 18 | +# =========================================================================== |
| 19 | +# Weight sanitization |
| 20 | +# =========================================================================== |
| 21 | + |
| 22 | + |
| 23 | +class TestWeightSanitize: |
| 24 | + """Tests for WhisperModel.sanitize() weight mapping.""" |
| 25 | + |
| 26 | + def test_sanitize_hf_key_rename(self, model) -> None: |
| 27 | + """HuggingFace keys should be renamed to MLX format.""" |
| 28 | + weights = { |
| 29 | + "model.encoder.layers.0.self_attn.q_proj.weight": mx.zeros((512, 512)), |
| 30 | + } |
| 31 | + sanitized = model.sanitize(weights) |
| 32 | + assert "encoder.blocks.0.attn.query.weight" in sanitized |
| 33 | + assert "model.encoder.layers.0.self_attn.q_proj.weight" not in sanitized |
| 34 | + |
| 35 | + def test_sanitize_skips_encoder_positions(self, model) -> None: |
| 36 | + """encoder.embed_positions should be skipped (None mapping).""" |
| 37 | + weights = { |
| 38 | + "model.encoder.embed_positions.weight": mx.zeros((1500, 512)), |
| 39 | + "model.decoder.embed_tokens.weight": mx.zeros((51865, 512)), |
| 40 | + } |
| 41 | + sanitized = model.sanitize(weights) |
| 42 | + assert "encoder.embed_positions.weight" not in sanitized |
| 43 | + assert "decoder.token_embedding.weight" in sanitized |
| 44 | + |
| 45 | + def test_sanitize_transposes_conv_weights(self, model) -> None: |
| 46 | + """Conv1d weights should be transposed from HF format.""" |
| 47 | + hf_conv = mx.zeros((512, 80, 3)) |
| 48 | + weights = {"model.encoder.conv1.weight": hf_conv} |
| 49 | + sanitized = model.sanitize(weights) |
| 50 | + assert sanitized["encoder.conv1.weight"].shape == (512, 3, 80) |
| 51 | + |
| 52 | + def test_sanitize_preserves_mlx_format(self, model) -> None: |
| 53 | + """Already-MLX-format weights pass through unchanged.""" |
| 54 | + weights = { |
| 55 | + "encoder.blocks.0.attn.query.weight": mx.zeros((512, 512)), |
| 56 | + } |
| 57 | + sanitized = model.sanitize(weights) |
| 58 | + assert "encoder.blocks.0.attn.query.weight" in sanitized |
| 59 | + |
| 60 | + def test_sanitize_casts_dtype(self, model) -> None: |
| 61 | + """Weights should be cast to model dtype.""" |
| 62 | + weights = {"encoder.ln_post.weight": mx.ones((512,), dtype=mx.float32)} |
| 63 | + sanitized = model.sanitize(weights) |
| 64 | + assert sanitized["encoder.ln_post.weight"].dtype == mx.float16 |
| 65 | + |
| 66 | + |
| 67 | +# =========================================================================== |
| 68 | +# Decoder cache paths |
| 69 | +# =========================================================================== |
| 70 | + |
| 71 | + |
| 72 | +class TestDecoderCachePaths: |
| 73 | + """Tests for decoder self-attention with and without KV cache.""" |
| 74 | + |
| 75 | + @pytest.fixture() |
| 76 | + def tiny_model(self): |
| 77 | + """Create a tiny model for fast decode tests. |
| 78 | +
|
| 79 | + n_audio_ctx must equal input_frames // 2 because conv2 has stride=2. |
| 80 | + We use input frames = 20 -> conv2 output = 10 -> n_audio_ctx = 10. |
| 81 | + """ |
| 82 | + config = WhisperConfig( |
| 83 | + n_mels=80, |
| 84 | + n_audio_ctx=10, |
| 85 | + n_audio_state=64, |
| 86 | + n_audio_head=2, |
| 87 | + n_audio_layer=1, |
| 88 | + n_vocab=100, |
| 89 | + n_text_ctx=32, |
| 90 | + n_text_state=64, |
| 91 | + n_text_head=2, |
| 92 | + n_text_layer=1, |
| 93 | + ) |
| 94 | + return WhisperModel(config, dtype=mx.float32) |
| 95 | + |
| 96 | + def test_prefill_without_cache(self, tiny_model) -> None: |
| 97 | + """Prefill (no cache) should produce logits without error.""" |
| 98 | + mel = mx.random.normal((1, 20, 80)) |
| 99 | + tokens = mx.array([[1, 2, 3]]) |
| 100 | + |
| 101 | + audio_features = tiny_model.encode(mel) |
| 102 | + logits, kv_cache = tiny_model.decode(tokens, audio_features) |
| 103 | + |
| 104 | + assert logits.shape == (1, 3, 100) |
| 105 | + assert kv_cache is not None |
| 106 | + assert len(kv_cache) == 1 # 1 layer |
| 107 | + |
| 108 | + def test_cached_decode_single_token(self, tiny_model) -> None: |
| 109 | + """Decode a single token with cache should work.""" |
| 110 | + mel = mx.random.normal((1, 20, 80)) |
| 111 | + tokens_prefill = mx.array([[1, 2, 3]]) |
| 112 | + |
| 113 | + audio_features = tiny_model.encode(mel) |
| 114 | + _, kv_cache = tiny_model.decode(tokens_prefill, audio_features) |
| 115 | + |
| 116 | + # Decode 1 new token |
| 117 | + next_token = mx.array([[4]]) |
| 118 | + logits, kv_cache2 = tiny_model.decode(next_token, audio_features, kv_cache) |
| 119 | + |
| 120 | + assert logits.shape == (1, 1, 100) |
| 121 | + # Self-attn cache k should now have 4 tokens (3 prefill + 1 new) |
| 122 | + assert kv_cache2[0][0][0].shape[1] == 4 |
| 123 | + |
| 124 | + def test_cached_decode_multiple_tokens(self, tiny_model) -> None: |
| 125 | + """Decode q_len > 1 with cache — the mask bug repro case.""" |
| 126 | + mel = mx.random.normal((1, 20, 80)) |
| 127 | + tokens_prefill = mx.array([[1, 2, 3]]) |
| 128 | + |
| 129 | + audio_features = tiny_model.encode(mel) |
| 130 | + _, kv_cache = tiny_model.decode(tokens_prefill, audio_features) |
| 131 | + |
| 132 | + # Decode 2 tokens at once with cache (q_len=2, k_len=5) |
| 133 | + next_tokens = mx.array([[4, 5]]) |
| 134 | + logits, kv_cache2 = tiny_model.decode(next_tokens, audio_features, kv_cache) |
| 135 | + |
| 136 | + assert logits.shape == (1, 2, 100) |
| 137 | + assert kv_cache2[0][0][0].shape[1] == 5 |
| 138 | + |
| 139 | + def test_cached_vs_full_decode_match(self, tiny_model) -> None: |
| 140 | + """Cached decode should produce same logits as full non-cached decode.""" |
| 141 | + mx.random.seed(42) |
| 142 | + mel = mx.random.normal((1, 20, 80)) |
| 143 | + all_tokens = mx.array([[1, 2, 3, 4, 5]]) |
| 144 | + |
| 145 | + audio_features = tiny_model.encode(mel) |
| 146 | + |
| 147 | + # Full decode without cache |
| 148 | + logits_full, _ = tiny_model.decode(all_tokens, audio_features) |
| 149 | + |
| 150 | + # Incremental: prefill 3, then decode 2 |
| 151 | + tokens_prefill = mx.array([[1, 2, 3]]) |
| 152 | + _, kv_cache = tiny_model.decode(tokens_prefill, audio_features) |
| 153 | + |
| 154 | + next_tokens = mx.array([[4, 5]]) |
| 155 | + logits_cached, _ = tiny_model.decode(next_tokens, audio_features, kv_cache) |
| 156 | + |
| 157 | + # Last 2 logits should match |
| 158 | + mx.eval(logits_full, logits_cached) |
| 159 | + diff = mx.abs(logits_full[:, 3:, :] - logits_cached).max().item() |
| 160 | + assert diff < 1e-4, f"Cached decode diverged from full decode: max diff={diff}" |
0 commit comments