Skip to content

Commit 3b6c257

Browse files
authored
Refactor STT: model-owned packages (#164)
This PR is: - To move Whisper and Qwen3-ASR implementations into model-owned packages (`stt/whisper/*`, `stt/qwen3_asr/*`) and keep `stt/transcribe.py` focused on orchestration. - To keep `stt/hf_config.py` importable without pulling in MLX by splitting Qwen3-ASR configs into an MLX-free module (with feature-length math on the config). - To avoid tokenizer side effects at construction time (Whisper tokenizer lazy-load; tokenizer loading is transcriber-owned). Verification: - `source .venv-vllm-metal/bin/activate && ruff check .` - `source .venv-vllm-metal/bin/activate && ruff format --check .` - `source .venv-vllm-metal/bin/activate && pytest -m "not slow"` Next: - Refactor STT: add `stt/registry.py` + `stt/loader.py` so model dispatch doesn’t grow shared branching. - Refactor STT: extract STT runtime glue from `v1/model_runner.py` into `stt/runtime.py` to reduce upstream-diff surface. --------- Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
1 parent ecc11ff commit 3b6c257

13 files changed

Lines changed: 716 additions & 867 deletions

File tree

tests/test_qwen3_asr.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
import pytest
1313

1414
from vllm_metal.stt.config import is_stt_model
15-
from vllm_metal.stt.qwen3_asr import (
16-
AudioEncoder,
15+
from vllm_metal.stt.qwen3_asr.config import (
1716
Qwen3ASRAudioConfig,
1817
Qwen3ASRConfig,
19-
Qwen3ASRModel,
2018
Qwen3ASRTextConfig,
19+
)
20+
from vllm_metal.stt.qwen3_asr.model import (
21+
AudioEncoder,
22+
Qwen3ASRModel,
2123
Qwen3Attention,
2224
Qwen3LM,
23-
_get_cnn_output_lengths,
24-
_get_feat_extract_output_lengths,
2525
)
2626
from vllm_metal.stt.transcribe import Qwen3ASRTranscriber, load_model
2727

@@ -94,34 +94,36 @@ def test_defaults(self) -> None:
9494

9595

9696
class TestCNNOutputLengths:
97-
"""Tests for _get_cnn_output_lengths and _get_feat_extract_output_lengths."""
97+
"""Tests for Qwen3ASRAudioConfig shape helpers."""
9898

9999
def test_single_conv_stride(self) -> None:
100100
"""3x Conv2d stride-2 on 100 frames → 13 output frames."""
101-
assert _get_cnn_output_lengths(100) == 13
101+
assert Qwen3ASRAudioConfig.cnn_output_length(100) == 13
102102

103103
def test_small_inputs(self) -> None:
104104
"""Edge cases for small input lengths."""
105-
assert _get_cnn_output_lengths(1) == 1
106-
assert _get_cnn_output_lengths(2) == 1
105+
assert Qwen3ASRAudioConfig.cnn_output_length(1) == 1
106+
assert Qwen3ASRAudioConfig.cnn_output_length(2) == 1
107107
# 3 -> 2 -> 1 -> 1 after 3x stride-2
108-
assert _get_cnn_output_lengths(3) == 1
108+
assert Qwen3ASRAudioConfig.cnn_output_length(3) == 1
109109

110110
def test_feat_extract_full_chunks(self) -> None:
111111
"""Full chunks of 100 frames each produce 13 frames per chunk."""
112-
assert _get_feat_extract_output_lengths(100) == 13
113-
assert _get_feat_extract_output_lengths(200) == 26
114-
assert _get_feat_extract_output_lengths(300) == 39
112+
cfg = Qwen3ASRAudioConfig()
113+
assert cfg.feat_extract_output_length(100) == 13
114+
assert cfg.feat_extract_output_length(200) == 26
115+
assert cfg.feat_extract_output_length(300) == 39
115116

116117
def test_feat_extract_with_remainder(self) -> None:
117118
"""Partial chunk adds its CNN output to full chunks."""
118119
# 150 = 1 full chunk (13) + 50 remainder
119-
remainder_out = _get_cnn_output_lengths(50)
120-
assert _get_feat_extract_output_lengths(150) == 13 + remainder_out
120+
cfg = Qwen3ASRAudioConfig()
121+
remainder_out = Qwen3ASRAudioConfig.cnn_output_length(50)
122+
assert cfg.feat_extract_output_length(150) == 13 + remainder_out
121123

122124
def test_feat_extract_3000_frames(self) -> None:
123125
"""30 seconds at 16kHz/hop160 = 3000 frames → 30 * 13 = 390."""
124-
assert _get_feat_extract_output_lengths(3000) == 390
126+
assert Qwen3ASRAudioConfig().feat_extract_output_length(3000) == 390
125127

126128

127129
# ===========================================================================
@@ -154,7 +156,7 @@ def test_single_chunk(self, tiny_encoder) -> None:
154156
mel = mx.random.normal((16, 80)) # 80 < 100 frames
155157
out = tiny_encoder(mel)
156158
mx.eval(out)
157-
expected_frames = _get_cnn_output_lengths(80)
159+
expected_frames = Qwen3ASRAudioConfig.cnn_output_length(80)
158160
assert out.shape == (expected_frames, 48)
159161

160162
def test_exact_chunk(self, tiny_encoder) -> None:
@@ -529,7 +531,6 @@ def transcriber(self, tmp_path):
529531
)
530532
model = MagicMock()
531533
model.config = config
532-
t = Qwen3ASRTranscriber(model, model_path=str(tmp_path))
533534

534535
# Inject mock tokenizer with deterministic encode
535536
mock_tok = MagicMock()
@@ -543,7 +544,7 @@ def transcriber(self, tmp_path):
543544
mock_tok.encode = MagicMock(
544545
side_effect=lambda s, add_special_tokens=False: _encode_map.get(s, [0])
545546
)
546-
t._tokenizer = mock_tok
547+
t = Qwen3ASRTranscriber(model, tokenizer=mock_tok)
547548
return t
548549

549550
def test_audio_pad_count_matches_frames(self, transcriber) -> None:
@@ -635,7 +636,7 @@ def test_encode_dummy_mel(self) -> None:
635636
mel = mx.random.normal((128, 300))
636637
embeddings = model.encode(mel)
637638
mx.eval(embeddings)
638-
expected = _get_feat_extract_output_lengths(300)
639+
expected = model.config.audio_config.feat_extract_output_length(300)
639640
assert embeddings.shape == (expected, 1024)
640641

641642
def test_greedy_decode(self) -> None:

tests/test_transcribe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import pytest
1212

1313
from vllm_metal.stt.transcribe import (
14-
_MAX_PROMPT_TOKENS,
1514
TranscriptionResult,
1615
WhisperTranscriber,
1716
load_model,
1817
)
18+
from vllm_metal.stt.whisper.transcriber import MAX_PROMPT_TOKENS
1919

2020
# ===========================================================================
2121
# Fixtures
@@ -33,8 +33,8 @@ def transcriber():
3333
except ImportError:
3434
pytest.skip("transformers not available")
3535

36-
t = WhisperTranscriber(model=None, model_path=None) # type: ignore[arg-type]
37-
t._tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small")
36+
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small")
37+
t = WhisperTranscriber(model=None, tokenizer=tokenizer) # type: ignore[arg-type]
3838
return t
3939

4040

@@ -153,11 +153,11 @@ def test_prompt_contains_text_tokens(self, transcriber: WhisperTranscriber) -> N
153153
assert len(result) >= 2
154154

155155
def test_long_prompt_truncated(self, transcriber: WhisperTranscriber) -> None:
156-
"""Very long prompt should be truncated to _MAX_PROMPT_TOKENS + 1."""
156+
"""Very long prompt should be truncated to MAX_PROMPT_TOKENS + 1."""
157157
long_text = "word " * 500
158158
result = transcriber._encode_prompt(long_text)
159-
# startofprev (1) + at most _MAX_PROMPT_TOKENS text tokens
160-
assert len(result) <= _MAX_PROMPT_TOKENS + 1
159+
# startofprev (1) + at most MAX_PROMPT_TOKENS text tokens
160+
assert len(result) <= MAX_PROMPT_TOKENS + 1
161161

162162

163163
# ===========================================================================

vllm_metal/stt/hf_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from transformers.configuration_utils import PretrainedConfig
1818

1919
from vllm_metal.stt.config import get_whisper_languages
20+
from vllm_metal.stt.qwen3_asr.config import Qwen3ASRAudioConfig
2021

2122
logger = logging.getLogger(__name__)
2223

@@ -449,8 +450,6 @@ def get_num_audio_tokens(
449450
stt_config: SpeechToTextConfig,
450451
model_config: ModelConfig,
451452
) -> int | None:
452-
from vllm_metal.stt.qwen3_asr import _get_feat_extract_output_lengths
453-
454453
# Derive hop_length from WhisperFeatureExtractor defaults
455454
hop_length = WhisperFeatureExtractor().hop_length
456455

@@ -460,7 +459,9 @@ def get_num_audio_tokens(
460459
mel_frames = math.ceil(
461460
audio_duration_s * stt_config.sample_rate / hop_length
462461
)
463-
return _get_feat_extract_output_lengths(mel_frames, n_window=n_window)
462+
return Qwen3ASRAudioConfig(n_window=n_window).feat_extract_output_length(
463+
mel_frames
464+
)
464465

465466
# Attach multimodal processor factory to the stub class
466467
Qwen3ASRStub._processor_factory = _ProcessorFactories(

vllm_metal/stt/protocol.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Response types for Speech-to-Text."""
33

4+
from dataclasses import dataclass, field
5+
46
from pydantic import BaseModel
57

68

@@ -16,3 +18,13 @@ class TranscriptionSegment(BaseModel):
1618
avg_logprob: float = 0.0
1719
compression_ratio: float = 0.0
1820
no_speech_prob: float = 0.0
21+
22+
23+
@dataclass
24+
class TranscriptionResult:
25+
"""Result of a transcription operation."""
26+
27+
text: str
28+
language: str | None = None
29+
segments: list[TranscriptionSegment] = field(default_factory=list)
30+
duration: float = 0.0
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Qwen3-ASR configuration and entrypoints.
3+
4+
Keep this module MLX-free so ``vllm_metal.stt.hf_config`` can import
5+
``vllm_metal.stt.qwen3_asr.config`` without pulling in the model stack.
6+
"""
7+
8+
from .config import Qwen3ASRAudioConfig, Qwen3ASRConfig, Qwen3ASRTextConfig
9+
10+
__all__ = [
11+
"Qwen3ASRAudioConfig",
12+
"Qwen3ASRConfig",
13+
"Qwen3ASRTextConfig",
14+
]

vllm_metal/stt/qwen3_asr/config.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Qwen3-ASR configuration (MLX-free).
3+
4+
Keep this module free of MLX imports so vLLM compat code can import config and
5+
shape helpers during planning/registration without pulling in the model stack.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from dataclasses import dataclass, field
11+
12+
13+
@dataclass
14+
class Qwen3ASRAudioConfig:
15+
"""Audio encoder configuration."""
16+
17+
num_mel_bins: int = 128
18+
d_model: int = 896
19+
encoder_layers: int = 18
20+
encoder_attention_heads: int = 14
21+
encoder_ffn_dim: int = 3584
22+
downsample_hidden_size: int = 480
23+
output_dim: int = 1024
24+
max_source_positions: int = 1500
25+
n_window: int = 50
26+
n_window_infer: int = 800
27+
activation_function: str = "gelu"
28+
29+
@staticmethod
30+
def cnn_output_length(num_frames: int) -> int:
31+
"""Return time length after 3x Conv2d(stride=2) downsampling."""
32+
length = num_frames
33+
for _ in range(3):
34+
length = (length - 1) // 2 + 1
35+
return int(length)
36+
37+
def feat_extract_output_length(self, num_mel_frames: int) -> int:
38+
"""Return number of audio tokens produced from a mel with N time frames."""
39+
chunk_size = self.n_window * 2
40+
frames_per_full_chunk = self.cnn_output_length(chunk_size)
41+
full_chunks, remainder = divmod(num_mel_frames, chunk_size)
42+
if remainder == 0:
43+
return int(full_chunks * frames_per_full_chunk)
44+
return int(
45+
full_chunks * frames_per_full_chunk + self.cnn_output_length(remainder)
46+
)
47+
48+
49+
@dataclass
50+
class Qwen3ASRTextConfig:
51+
"""Text decoder (Qwen3 LM) configuration."""
52+
53+
hidden_size: int = 1024
54+
num_hidden_layers: int = 28
55+
num_attention_heads: int = 16
56+
num_key_value_heads: int = 8
57+
head_dim: int = 128
58+
intermediate_size: int = 3072
59+
vocab_size: int = 151936
60+
rms_norm_eps: float = 1e-6
61+
rope_theta: float = 1000000.0
62+
tie_word_embeddings: bool = True
63+
64+
65+
@dataclass
66+
class Qwen3ASRConfig:
67+
"""Top-level Qwen3-ASR model configuration."""
68+
69+
audio_config: Qwen3ASRAudioConfig = field(default_factory=Qwen3ASRAudioConfig)
70+
text_config: Qwen3ASRTextConfig = field(default_factory=Qwen3ASRTextConfig)
71+
audio_token_id: int = 151676
72+
audio_start_token_id: int = 151669
73+
audio_end_token_id: int = 151670
74+
eos_token_id: int = 151643
75+
# Compatibility with Whisper interface for load_model dispatching
76+
n_mels: int = 128
77+
n_audio_ctx: int = 1500
78+
79+
@classmethod
80+
def from_dict(cls, d: dict) -> Qwen3ASRConfig:
81+
"""Create config from config.json dictionary."""
82+
thinker = d.get("thinker_config", d)
83+
84+
audio_dict = thinker.get("audio_config", {})
85+
audio_cfg = Qwen3ASRAudioConfig(
86+
num_mel_bins=audio_dict.get("num_mel_bins", 128),
87+
d_model=audio_dict.get("d_model", 896),
88+
encoder_layers=audio_dict.get("encoder_layers", 18),
89+
encoder_attention_heads=audio_dict.get("encoder_attention_heads", 14),
90+
encoder_ffn_dim=audio_dict.get("encoder_ffn_dim", 3584),
91+
downsample_hidden_size=audio_dict.get("downsample_hidden_size", 480),
92+
output_dim=audio_dict.get("output_dim", 1024),
93+
max_source_positions=audio_dict.get("max_source_positions", 1500),
94+
n_window=audio_dict.get("n_window", 50),
95+
n_window_infer=audio_dict.get("n_window_infer", 800),
96+
activation_function=audio_dict.get("activation_function", "gelu"),
97+
)
98+
99+
text_dict = thinker.get("text_config", {})
100+
text_cfg = Qwen3ASRTextConfig(
101+
hidden_size=text_dict.get("hidden_size", 1024),
102+
num_hidden_layers=text_dict.get("num_hidden_layers", 28),
103+
num_attention_heads=text_dict.get("num_attention_heads", 16),
104+
num_key_value_heads=text_dict.get("num_key_value_heads", 8),
105+
head_dim=text_dict.get("head_dim", 128),
106+
intermediate_size=text_dict.get("intermediate_size", 3072),
107+
vocab_size=text_dict.get("vocab_size", 151936),
108+
rms_norm_eps=text_dict.get("rms_norm_eps", 1e-6),
109+
rope_theta=text_dict.get("rope_theta", 1000000.0),
110+
tie_word_embeddings=text_dict.get("tie_word_embeddings", True),
111+
)
112+
113+
return cls(
114+
audio_config=audio_cfg,
115+
text_config=text_cfg,
116+
audio_token_id=thinker.get("audio_token_id", 151676),
117+
audio_start_token_id=thinker.get("audio_start_token_id", 151669),
118+
audio_end_token_id=thinker.get("audio_end_token_id", 151670),
119+
n_mels=audio_cfg.num_mel_bins,
120+
n_audio_ctx=audio_cfg.max_source_positions,
121+
)

0 commit comments

Comments
 (0)