Skip to content

Commit 0bc1044

Browse files
authored
Refactor STT boundaries to trust upstream vLLM contracts (#220)
This PR is: - To remove local STT compatibility and config glue that duplicates behavior already owned by upstream vLLM, including `hf_config.py`, `config.py`, and the no-op `register_ops` entrypoint. - To delegate Qwen3-ASR config parsing and prompt ownership back to upstream vLLM, so the plugin only adapts the typed upstream config into the local MLX model. - To move Whisper language validation and decoder prompt construction into `WhisperTranscriber`, using upstream Whisper and vLLM language/tokenizer APIs. - To simplify STT tests so they cover the current adapter boundaries and upstream integration contract instead of code paths that were deleted in this refactor. --------- Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
1 parent 1898c19 commit 0bc1044

File tree

13 files changed

+338
-1269
lines changed

13 files changed

+338
-1269
lines changed

pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ Issues = "https://github.com/vllm-project/vllm-metal/issues"
6767
[project.entry-points."vllm.platform_plugins"]
6868
metal = "vllm_metal:register"
6969

70-
[project.entry-points."vllm.general_plugins"]
71-
metal_ops = "vllm_metal:register_ops"
72-
7370
# Maturin configuration for mixed Rust/Python project
7471
[tool.maturin]
7572
# Build both the Rust extension and include Python source

tests/test_qwen3_asr.py

Lines changed: 64 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
"""Tests for Qwen3-ASR model: config, encoder shapes, weight sanitization."""
2+
"""Tests for Qwen3-ASR model behavior and weight mapping."""
33

44
from __future__ import annotations
55

66
import json
77
import os
88
from pathlib import Path
9-
from unittest.mock import MagicMock
9+
from types import SimpleNamespace
10+
from typing import Any, cast
1011

1112
import mlx.core as mx
13+
import numpy as np
1214
import pytest
15+
from transformers import WhisperFeatureExtractor
1316

17+
from vllm_metal.stt.audio import load_audio
1418
from vllm_metal.stt.detection import is_stt_model
1519
from vllm_metal.stt.loader import load_model
1620
from vllm_metal.stt.qwen3_asr.config import (
@@ -26,72 +30,47 @@
2630
)
2731
from vllm_metal.stt.qwen3_asr.transcriber import Qwen3ASRTranscriber
2832

29-
# ===========================================================================
30-
# Configuration
31-
# ===========================================================================
32-
33-
34-
class TestQwen3ASRConfig:
35-
"""Tests for Qwen3ASRConfig.from_dict with 0.6B config."""
36-
37-
def test_from_dict_basic(self) -> None:
38-
"""Config should be parsed from nested thinker_config dict."""
39-
d = {
40-
"model_type": "qwen3_asr",
41-
"thinker_config": {
42-
"audio_config": {
43-
"d_model": 896,
44-
"num_mel_bins": 128,
45-
"encoder_layers": 18,
46-
"encoder_attention_heads": 14,
47-
"encoder_ffn_dim": 3584,
48-
"downsample_hidden_size": 480,
49-
"output_dim": 1024,
50-
"max_source_positions": 1500,
51-
"n_window": 50,
52-
"n_window_infer": 800,
53-
},
54-
"text_config": {
55-
"hidden_size": 1024,
56-
"num_hidden_layers": 28,
57-
"num_attention_heads": 16,
58-
"num_key_value_heads": 8,
59-
"head_dim": 128,
60-
"intermediate_size": 3072,
61-
"vocab_size": 151936,
62-
"rms_norm_eps": 1e-6,
63-
"rope_theta": 1000000.0,
64-
"tie_word_embeddings": True,
65-
},
66-
"audio_token_id": 151676,
67-
"audio_start_token_id": 151669,
68-
"audio_end_token_id": 151670,
69-
},
70-
}
71-
config = Qwen3ASRConfig.from_dict(d)
72-
assert config.audio_config.d_model == 896
73-
assert config.audio_config.encoder_layers == 18
74-
assert config.audio_config.num_mel_bins == 128
75-
assert config.audio_config.n_window == 50
76-
assert config.text_config.hidden_size == 1024
77-
assert config.text_config.num_hidden_layers == 28
78-
assert config.text_config.num_attention_heads == 16
79-
assert config.text_config.num_key_value_heads == 8
80-
assert config.audio_token_id == 151676
81-
assert config.n_mels == 128
82-
assert config.n_audio_ctx == 1500
83-
84-
def test_defaults(self) -> None:
85-
"""Default config should have 0.6B model values."""
86-
config = Qwen3ASRConfig()
87-
assert config.audio_config.d_model == 896
88-
assert config.text_config.vocab_size == 151936
89-
assert config.eos_token_id == 151643
9033

34+
class TestQwen3ASRConfigAdaptation:
35+
"""Tests for adapting the upstream config into the local MLX config."""
36+
37+
def test_from_vllm_config_keeps_local_eos_default_when_upstream_omits_it(
38+
self,
39+
) -> None:
40+
upstream_config = SimpleNamespace(
41+
thinker_config=SimpleNamespace(
42+
audio_config=SimpleNamespace(
43+
d_model=896,
44+
num_mel_bins=128,
45+
encoder_layers=18,
46+
encoder_attention_heads=14,
47+
encoder_ffn_dim=3584,
48+
downsample_hidden_size=480,
49+
output_dim=1024,
50+
max_source_positions=1500,
51+
n_window=50,
52+
n_window_infer=800,
53+
activation_function="gelu",
54+
),
55+
text_config=SimpleNamespace(
56+
hidden_size=1024,
57+
num_hidden_layers=28,
58+
num_attention_heads=16,
59+
num_key_value_heads=8,
60+
head_dim=128,
61+
intermediate_size=3072,
62+
vocab_size=151936,
63+
rms_norm_eps=1e-6,
64+
rope_theta=1000000.0,
65+
tie_word_embeddings=True,
66+
),
67+
audio_token_id=151676,
68+
)
69+
)
70+
config = Qwen3ASRConfig._from_vllm_config(cast(Any, upstream_config))
9171

92-
# ===========================================================================
93-
# CNN output lengths
94-
# ===========================================================================
72+
assert config.audio_token_id == 151676
73+
assert config.eos_token_id == 151643
9574

9675

9776
class TestCNNOutputLengths:
@@ -127,11 +106,6 @@ def test_feat_extract_3000_frames(self) -> None:
127106
assert Qwen3ASRAudioConfig().feat_extract_output_length(3000) == 390
128107

129108

130-
# ===========================================================================
131-
# Audio Encoder shapes
132-
# ===========================================================================
133-
134-
135109
class TestAudioEncoderShapes:
136110
"""Tests for AudioEncoder output dimensions."""
137111

@@ -182,11 +156,6 @@ def test_with_batch_dim(self, tiny_encoder) -> None:
182156
assert out.shape == (13, 48)
183157

184158

185-
# ===========================================================================
186-
# Qwen3 Attention
187-
# ===========================================================================
188-
189-
190159
class TestQwen3Attention:
191160
"""Tests for GQA with QK normalization."""
192161

@@ -243,11 +212,6 @@ def test_cached_decode(self) -> None:
243212
assert cache2[0].shape == (1, 2, 6, 16) # 5 + 1 = 6
244213

245214

246-
# ===========================================================================
247-
# Weight sanitization
248-
# ===========================================================================
249-
250-
251215
class TestWeightSanitize:
252216
"""Tests for Qwen3ASRModel.sanitize() weight mapping."""
253217

@@ -353,11 +317,6 @@ def test_casts_dtype(self, model) -> None:
353317
assert sanitized["audio_tower.ln_post.weight"].dtype == mx.float32
354318

355319

356-
# ===========================================================================
357-
# Qwen3 LM forward
358-
# ===========================================================================
359-
360-
361320
class TestQwen3LM:
362321
"""Tests for Qwen3LM forward pass."""
363322

@@ -393,11 +352,6 @@ def test_decode_step(self, tiny_lm) -> None:
393352
assert logits.shape == (1, 1, 100)
394353

395354

396-
# ===========================================================================
397-
# Full model
398-
# ===========================================================================
399-
400-
401355
class TestQwen3ASRModel:
402356
"""Tests for the full Qwen3ASRModel."""
403357

@@ -430,9 +384,6 @@ def tiny_model(self):
430384
)
431385
return Qwen3ASRModel(config, dtype=mx.float32)
432386

433-
def test_model_type(self, tiny_model) -> None:
434-
assert tiny_model.model_type == "qwen3_asr"
435-
436387
def test_encode(self, tiny_model) -> None:
437388
"""Encode should produce audio embeddings."""
438389
mel = mx.random.normal((16, 100))
@@ -464,11 +415,6 @@ def test_prefill_and_decode(self, tiny_model) -> None:
464415
assert logits2.shape == (1, 1, 100)
465416

466417

467-
# ===========================================================================
468-
# Post-process output
469-
# ===========================================================================
470-
471-
472418
class TestPostProcessOutput:
473419
"""Tests for Qwen3ASRTranscriber.post_process_output."""
474420

@@ -484,11 +430,6 @@ def test_empty_string(self) -> None:
484430
assert Qwen3ASRTranscriber.post_process_output("") == ""
485431

486432

487-
# ===========================================================================
488-
# Config detection
489-
# ===========================================================================
490-
491-
492433
class TestPostProcessOutputTruncation:
493434
"""Tests for special token truncation in post_process_output."""
494435

@@ -513,88 +454,6 @@ def test_strips_whitespace(self) -> None:
513454
assert Qwen3ASRTranscriber.post_process_output(text) == "Hello world"
514455

515456

516-
# ===========================================================================
517-
# Build prompt tokens
518-
# ===========================================================================
519-
520-
521-
class TestBuildPromptTokens:
522-
"""Tests for Qwen3ASRTranscriber.build_prompt_tokens structure."""
523-
524-
@pytest.fixture()
525-
def transcriber(self, tmp_path):
526-
"""Create a transcriber with a mock tokenizer for prompt tests."""
527-
config = Qwen3ASRConfig(
528-
audio_token_id=99,
529-
audio_start_token_id=97,
530-
audio_end_token_id=98,
531-
eos_token_id=0,
532-
)
533-
model = MagicMock()
534-
model.config = config
535-
536-
# Inject mock tokenizer with deterministic encode
537-
mock_tok = MagicMock()
538-
_encode_map = {
539-
"<|im_start|>": [10],
540-
"<|im_end|>": [11],
541-
"user\n": [20],
542-
"assistant\n": [30],
543-
"\n": [40],
544-
}
545-
mock_tok.encode = MagicMock(
546-
side_effect=lambda s, add_special_tokens=False: _encode_map.get(s, [0])
547-
)
548-
t = Qwen3ASRTranscriber(model, tokenizer=mock_tok)
549-
return t
550-
551-
def test_audio_pad_count_matches_frames(self, transcriber) -> None:
552-
"""Number of audio_pad tokens should equal n_audio_frames."""
553-
# Act
554-
prompt = transcriber.build_prompt_tokens(50)
555-
556-
# Assert
557-
audio_pad_count = prompt.count(99) # audio_token_id
558-
assert audio_pad_count == 50
559-
560-
def test_audio_pad_count_zero(self, transcriber) -> None:
561-
"""Zero audio frames should produce no audio_pad tokens."""
562-
# Act
563-
prompt = transcriber.build_prompt_tokens(0)
564-
565-
# Assert
566-
assert prompt.count(99) == 0
567-
568-
def test_prompt_contains_structural_tokens(self, transcriber) -> None:
569-
"""Prompt should contain audio_start, audio_end, im_start, user, assistant."""
570-
# Act
571-
prompt = transcriber.build_prompt_tokens(5)
572-
573-
# Assert
574-
assert 97 in prompt # audio_start
575-
assert 98 in prompt # audio_end
576-
assert 10 in prompt # im_start
577-
assert 20 in prompt # user
578-
assert 30 in prompt # assistant
579-
580-
def test_prompt_structure_order(self, transcriber) -> None:
581-
"""Audio tokens should be between audio_start and audio_end."""
582-
# Act
583-
prompt = transcriber.build_prompt_tokens(3)
584-
585-
# Assert
586-
start_idx = prompt.index(97) # audio_start
587-
end_idx = prompt.index(98) # audio_end
588-
for i, tok in enumerate(prompt):
589-
if tok == 99:
590-
assert start_idx < i < end_idx
591-
592-
593-
# ===========================================================================
594-
# Config detection
595-
# ===========================================================================
596-
597-
598457
class TestConfigDetection:
599458
"""Tests for is_stt_model with Qwen3-ASR config."""
600459

@@ -605,11 +464,6 @@ def test_qwen3_asr_detected(self, tmp_path) -> None:
605464
assert is_stt_model(str(tmp_path)) is True
606465

607466

608-
# ===========================================================================
609-
# Slow tests (require real model)
610-
# ===========================================================================
611-
612-
613467
@pytest.mark.slow
614468
class TestModelLoad:
615469
"""Tests that load the real Qwen3-ASR-0.6B model.
@@ -629,7 +483,7 @@ def _model_path(self):
629483
def test_load_model(self) -> None:
630484
"""Should load model without errors."""
631485
model = load_model(self._MODEL_PATH)
632-
assert model.model_type == "qwen3_asr"
486+
assert isinstance(model, Qwen3ASRModel)
633487

634488
def test_encode_dummy_mel(self) -> None:
635489
"""Should encode a dummy mel spectrogram."""
@@ -642,11 +496,6 @@ def test_encode_dummy_mel(self) -> None:
642496

643497
def test_greedy_decode(self) -> None:
644498
"""Should encode + decode a real audio file using WhisperFeatureExtractor."""
645-
import numpy as np
646-
from transformers import WhisperFeatureExtractor
647-
648-
from vllm_metal.stt.audio import load_audio
649-
650499
audio_path = os.environ.get("QWEN3_ASR_AUDIO_PATH")
651500
if not audio_path or not Path(audio_path).exists():
652501
pytest.skip("QWEN3_ASR_AUDIO_PATH not set or file not found")
@@ -669,7 +518,24 @@ def test_greedy_decode(self) -> None:
669518

670519
# Build prompt
671520
n_audio = audio_emb.shape[0]
672-
prompt = transcriber.build_prompt_tokens(n_audio)
521+
tokenizer = transcriber.tokenizer
522+
audio_start_token_id = tokenizer.convert_tokens_to_ids(
523+
tokenizer.audio_bos_token
524+
)
525+
audio_token_id = tokenizer.convert_tokens_to_ids(tokenizer.audio_token)
526+
audio_end_token_id = tokenizer.convert_tokens_to_ids(tokenizer.audio_eos_token)
527+
prompt = (
528+
tokenizer.encode("<|im_start|>", add_special_tokens=False)
529+
+ tokenizer.encode("user\n", add_special_tokens=False)
530+
+ [audio_start_token_id]
531+
+ [audio_token_id] * n_audio
532+
+ [audio_end_token_id]
533+
+ tokenizer.encode("\n", add_special_tokens=False)
534+
+ tokenizer.encode("<|im_end|>", add_special_tokens=False)
535+
+ tokenizer.encode("\n", add_special_tokens=False)
536+
+ tokenizer.encode("<|im_start|>", add_special_tokens=False)
537+
+ tokenizer.encode("assistant\n", add_special_tokens=False)
538+
)
673539

674540
# Decode
675541
tokens = transcriber.greedy_decode_tokens(audio_emb, prompt, max_tokens=100)

0 commit comments

Comments
 (0)