Skip to content

Commit a00cba6

Browse files
authored
Refactor STT validation boundaries (#153)
This PR is: - To validate `SpeechToTextConfig` inputs at construction time - To normalize and validate Whisper decode options up front - To fail fast on invalid STT model loading inputs - To add focused tests for the new validation paths Notes: - This keeps the change narrow and behavioral: no broad file moves, no larger refactor mixed in - I also removed a couple of defensive fallbacks so the code stays strict about real model contracts Validation: - `pytest tests/test_stt.py tests/test_transcribe.py tests/test_qwen3_asr.py tests/test_v1_stt_integration.py -q` - Result: `135 passed, 3 skipped` Skipped tests: - `tests/test_qwen3_asr.py::TestModelLoad::test_load_model` - `tests/test_qwen3_asr.py::TestModelLoad::test_encode_dummy_mel` - `tests/test_qwen3_asr.py::TestModelLoad::test_greedy_decode` - Reason: `QWEN3_ASR_MODEL_PATH not set` Next: - Continue the STT cleanup in small slices, likely around transcriber/loader boundaries. --------- Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
1 parent fac064f commit a00cba6

4 files changed

Lines changed: 149 additions & 8 deletions

File tree

tests/test_stt.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,26 @@ def test_custom_values(self) -> None:
4747
assert cfg.overlap_chunk_second == 0.5
4848
assert cfg.min_energy_split_window_size == 800
4949

50+
@pytest.mark.parametrize(
51+
("kwargs", "message"),
52+
[
53+
({"max_audio_clip_s": 0.0}, "max_audio_clip_s"),
54+
({"overlap_chunk_second": -0.1}, "overlap_chunk_second"),
55+
(
56+
{"max_audio_clip_s": 10.0, "overlap_chunk_second": 10.0},
57+
"overlap_chunk_second",
58+
),
59+
({"min_energy_split_window_size": 0}, "min_energy_split_window_size"),
60+
],
61+
)
62+
def test_invalid_values_raise(
63+
self,
64+
kwargs: dict[str, float | int],
65+
message: str,
66+
) -> None:
67+
with pytest.raises(ValueError, match=message):
68+
SpeechToTextConfig(**kwargs)
69+
5070

5171
# ===========================================================================
5272
# validate_language

tests/test_transcribe.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33

44
from __future__ import annotations
55

6+
import json
67
from pathlib import Path
8+
from types import SimpleNamespace
79

10+
import mlx.core as mx
811
import pytest
912

1013
from vllm_metal.stt.transcribe import (
1114
_MAX_PROMPT_TOKENS,
1215
TranscriptionResult,
1316
WhisperTranscriber,
17+
load_model,
1418
)
1519

1620
# ===========================================================================
@@ -166,17 +170,11 @@ class TestLoadModel:
166170

167171
def test_missing_config_json(self, tmp_path: Path) -> None:
168172
"""Should raise FileNotFoundError when config.json is absent."""
169-
from vllm_metal.stt.transcribe import load_model
170-
171173
with pytest.raises(FileNotFoundError, match="config.json not found"):
172174
load_model(tmp_path)
173175

174176
def test_missing_weight_files(self, tmp_path: Path) -> None:
175177
"""Should raise FileNotFoundError when no weight files exist."""
176-
import json
177-
178-
from vllm_metal.stt.transcribe import load_model
179-
180178
config = {
181179
"n_mels": 80,
182180
"n_audio_ctx": 10,
@@ -194,6 +192,65 @@ def test_missing_weight_files(self, tmp_path: Path) -> None:
194192
with pytest.raises(FileNotFoundError, match="No weight files"):
195193
load_model(tmp_path)
196194

195+
def test_empty_model_path_raises(self) -> None:
196+
"""Whitespace-only model paths should fail fast."""
197+
with pytest.raises(ValueError, match="model_path"):
198+
load_model(" ")
199+
200+
def test_invalid_dtype_raises(self, tmp_path: Path) -> None:
201+
"""Non-floating dtypes are rejected before any file I/O."""
202+
with pytest.raises(TypeError, match="Unsupported STT model dtype"):
203+
load_model(tmp_path, dtype=mx.int32)
204+
205+
def test_unknown_model_type_raises(self, tmp_path: Path) -> None:
206+
"""Unknown model_type values should not fall through to Whisper."""
207+
(tmp_path / "config.json").write_text(json.dumps({"model_type": "mystery_stt"}))
208+
209+
with pytest.raises(ValueError, match="Unsupported STT model_type"):
210+
load_model(tmp_path)
211+
212+
213+
class TestResolveDecodeOptions:
214+
"""Tests for WhisperTranscriber task/language validation."""
215+
216+
def test_multilingual_model_normalizes_inputs(self) -> None:
217+
transcriber = WhisperTranscriber(
218+
model=SimpleNamespace(is_multilingual=True),
219+
model_path=None,
220+
)
221+
222+
language, task = transcriber._resolve_decode_options(" EN ", "Transcribe")
223+
224+
assert language == "en"
225+
assert task == "transcribe"
226+
227+
def test_invalid_task_raises(self) -> None:
228+
transcriber = WhisperTranscriber(
229+
model=SimpleNamespace(is_multilingual=True),
230+
model_path=None,
231+
)
232+
233+
with pytest.raises(ValueError, match="Unsupported STT task"):
234+
transcriber._resolve_decode_options("en", "summarize")
235+
236+
def test_english_only_model_rejects_translation(self) -> None:
237+
transcriber = WhisperTranscriber(
238+
model=SimpleNamespace(is_multilingual=False),
239+
model_path=None,
240+
)
241+
242+
with pytest.raises(ValueError, match="do not support translation"):
243+
transcriber._resolve_decode_options(None, "translate")
244+
245+
def test_english_only_model_rejects_non_english_language(self) -> None:
246+
transcriber = WhisperTranscriber(
247+
model=SimpleNamespace(is_multilingual=False),
248+
model_path=None,
249+
)
250+
251+
with pytest.raises(ValueError, match="only support English transcription"):
252+
transcriber._resolve_decode_options("fr", "transcribe")
253+
197254

198255
# ===========================================================================
199256
# Greedy decode and encode chunk (require tiny model)

vllm_metal/stt/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ class SpeechToTextConfig:
7474
# Deprecated: Whisper requires 16kHz; this field is ignored.
7575
sample_rate: int = 16000
7676

77+
def __post_init__(self) -> None:
78+
"""Validate runtime chunking parameters."""
79+
if self.max_audio_clip_s <= 0:
80+
raise ValueError("max_audio_clip_s must be > 0")
81+
if self.overlap_chunk_second < 0:
82+
raise ValueError("overlap_chunk_second must be >= 0")
83+
if self.overlap_chunk_second >= self.max_audio_clip_s:
84+
raise ValueError("overlap_chunk_second must be < max_audio_clip_s")
85+
if self.min_energy_split_window_size <= 0:
86+
raise ValueError("min_energy_split_window_size must be > 0")
87+
7788

7889
def is_stt_model(model_path: str) -> bool:
7990
"""Return ``True`` if *model_path* points to a Speech-to-Text model.

vllm_metal/stt/transcribe.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
QWEN3_ASR_MAX_DECODE_TOKENS,
3232
WHISPER_MAX_DECODE_TOKENS,
3333
SpeechToTextConfig,
34+
validate_language,
3435
)
3536
from vllm_metal.stt.protocol import TranscriptionSegment
3637
from vllm_metal.stt.whisper import WhisperConfig, WhisperModel
@@ -57,6 +58,12 @@
5758
# Regex to detect Whisper timestamp tokens like ``<|0.00|>``.
5859
_TIMESTAMP_RE = re.compile(r"<\|(\d+\.\d+)\|>")
5960

61+
# Supported tasks for Whisper transcription requests.
62+
_WHISPER_TASKS = frozenset({"transcribe", "translate"})
63+
64+
# Supported floating-point dtypes for STT model loading.
65+
_SUPPORTED_LOAD_DTYPES = frozenset({mx.float16, mx.float32, mx.bfloat16})
66+
6067

6168
# ===========================================================================
6269
# Data types
@@ -135,6 +142,8 @@ def transcribe(
135142
Returns:
136143
:class:`TranscriptionResult` with text and optional segments.
137144
"""
145+
language, task = self._resolve_decode_options(language, task)
146+
138147
if isinstance(audio, str):
139148
audio = load_audio(audio, sample_rate=SAMPLE_RATE)
140149
elif isinstance(audio, np.ndarray):
@@ -251,6 +260,31 @@ def _get_token_id(self, token: str) -> int:
251260
"""Resolve a special token string to its integer ID."""
252261
return self.tokenizer.convert_tokens_to_ids(token)
253262

263+
def _resolve_decode_options(
264+
self,
265+
language: str | None,
266+
task: str,
267+
) -> tuple[str | None, str]:
268+
"""Validate and normalize task/language options for Whisper."""
269+
task = task.strip().lower()
270+
if task not in _WHISPER_TASKS:
271+
supported = ", ".join(sorted(_WHISPER_TASKS))
272+
raise ValueError(
273+
f"Unsupported STT task: {task!r}. Must be one of {supported}."
274+
)
275+
276+
if self.model.is_multilingual:
277+
return validate_language(language, default=None), task
278+
279+
resolved_language = validate_language(language, default=None)
280+
if task == "translate":
281+
raise ValueError("English-only Whisper models do not support translation.")
282+
if resolved_language not in (None, "en"):
283+
raise ValueError(
284+
"English-only Whisper models only support English transcription."
285+
)
286+
return resolved_language, task
287+
254288
def _encode_prompt(self, prompt: str | None) -> list[int]:
255289
"""Encode a user prompt into ``<|startofprev|>`` prefix tokens.
256290
@@ -580,6 +614,15 @@ def _resolve_model_path(model_path: str | Path) -> Path:
580614
return model_path
581615

582616

617+
def _validate_load_dtype(dtype: mx.Dtype) -> None:
618+
"""Validate the floating-point dtype used for model loading."""
619+
if dtype not in _SUPPORTED_LOAD_DTYPES:
620+
names = ", ".join(sorted(str(d) for d in _SUPPORTED_LOAD_DTYPES))
621+
raise TypeError(
622+
f"Unsupported STT model dtype: {dtype!r}. Must be one of {names}."
623+
)
624+
625+
583626
def load_model(model_path: str | Path, dtype: mx.Dtype = mx.float16):
584627
"""Load an STT model from a local directory or HuggingFace repo.
585628
@@ -597,14 +640,24 @@ def load_model(model_path: str | Path, dtype: mx.Dtype = mx.float16):
597640
ValueError: If the model type is unsupported or download fails.
598641
FileNotFoundError: If config.json or weight files are missing.
599642
"""
643+
if isinstance(model_path, str) and not model_path.strip():
644+
raise ValueError(
645+
"model_path must be a non-empty local path or HuggingFace repo ID."
646+
)
647+
_validate_load_dtype(dtype)
600648
model_path = _resolve_model_path(model_path)
601649
config_dict = _read_config(model_path)
602650
model_type = config_dict.get("model_type", "").lower()
603651

604652
if model_type == "qwen3_asr":
605653
return _load_qwen3_asr_model(model_path, config_dict, dtype)
606-
# Default to Whisper for backward compatibility
607-
return _load_whisper_model(model_path, config_dict, dtype)
654+
if model_type in ("", "whisper"):
655+
# Default to Whisper for backward compatibility
656+
return _load_whisper_model(model_path, config_dict, dtype)
657+
raise ValueError(
658+
f"Unsupported STT model_type: {model_type!r}. "
659+
"Expected 'whisper' or 'qwen3_asr'."
660+
)
608661

609662

610663
def _load_and_init_model(model, model_path: Path, config_dict: dict):

0 commit comments

Comments
 (0)