Skip to content

Commit 95ad433

Browse files
authored
Refactor STT model loading (#168)
This PR is: - To move STT model loading out of `vllm_metal/stt/transcribe.py` into `vllm_metal/stt/loader.py`. - To centralize `model_type -> constructor` mapping in `vllm_metal/stt/registry.py`. - To keep `vllm_metal/stt/transcribe.py` orchestration-only while preserving the public `load_model()` API. Next: - Extract STT runtime glue out of `vllm_metal/v1/model_runner.py` into `vllm_metal/stt/runtime.py`. - Move model-specific request/mm payload interpretation into model-owned adapters under `stt/<model>/`. --------- Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
1 parent 9ba3982 commit 95ad433

3 files changed

Lines changed: 155 additions & 143 deletions

File tree

vllm_metal/stt/loader.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Speech-to-Text model loading."""
3+
4+
from __future__ import annotations
5+
6+
import json
7+
import logging
8+
from pathlib import Path
9+
10+
import mlx.core as mx
11+
import mlx.nn as nn
12+
13+
from vllm_metal.stt.registry import get_stt_model_constructor
14+
15+
logger = logging.getLogger(__name__)
16+
17+
# Supported floating-point dtypes for STT model loading.
18+
_SUPPORTED_LOAD_DTYPES = frozenset({mx.float16, mx.float32, mx.bfloat16})
19+
20+
21+
def load_model(model_path: str | Path, dtype: mx.Dtype = mx.float16):
22+
"""Load an STT model from a local directory or HuggingFace repo."""
23+
if isinstance(model_path, str) and not model_path.strip():
24+
raise ValueError(
25+
"model_path must be a non-empty local path or HuggingFace repo ID."
26+
)
27+
_validate_load_dtype(dtype)
28+
29+
resolved_model_path = _resolve_model_path(model_path)
30+
config_dict = _read_config(resolved_model_path)
31+
model_type = config_dict.get("model_type", "").lower()
32+
33+
model_constructor = get_stt_model_constructor(model_type)
34+
model = model_constructor(config_dict, dtype)
35+
return _load_and_init_model(model, resolved_model_path, config_dict)
36+
37+
38+
def _validate_load_dtype(dtype: mx.Dtype) -> None:
39+
"""Validate the floating-point dtype used for model loading."""
40+
if dtype not in _SUPPORTED_LOAD_DTYPES:
41+
names = ", ".join(sorted(str(d) for d in _SUPPORTED_LOAD_DTYPES))
42+
raise TypeError(
43+
f"Unsupported STT model dtype: {dtype!r}. Must be one of {names}."
44+
)
45+
46+
47+
def _read_config(model_path: Path) -> dict:
48+
"""Read and return config.json from a model directory."""
49+
config_path = model_path / "config.json"
50+
if not config_path.exists():
51+
raise FileNotFoundError(f"config.json not found in {model_path}")
52+
with open(config_path) as f:
53+
return json.load(f)
54+
55+
56+
def _load_weights(model_path: Path) -> dict[str, mx.array]:
57+
"""Load model weights from safetensors or npz files."""
58+
weight_files = sorted(model_path.glob("*.safetensors"))
59+
if not weight_files:
60+
weight_files = sorted(model_path.glob("*.npz"))
61+
if not weight_files:
62+
raise FileNotFoundError(f"No weight files in {model_path}")
63+
64+
weights: dict[str, mx.array] = {}
65+
for wf in weight_files:
66+
weights.update(mx.load(str(wf)))
67+
return weights
68+
69+
70+
def _resolve_model_path(model_path: str | Path) -> Path:
71+
"""Resolve model path, downloading from HF if needed."""
72+
model_path = Path(model_path)
73+
if model_path.exists():
74+
return model_path
75+
76+
try:
77+
from huggingface_hub import snapshot_download
78+
except ImportError as e: # pragma: no cover
79+
raise ValueError(
80+
f"Could not download model {model_path}: huggingface_hub is not installed"
81+
) from e
82+
83+
try:
84+
return Path(snapshot_download(repo_id=str(model_path)))
85+
except OSError as e:
86+
raise ValueError(f"Could not download model: {model_path}") from e
87+
88+
89+
def _load_and_init_model(model, model_path: Path, config_dict: dict):
90+
"""Shared loader: quantize, sanitize, load weights, and eval."""
91+
weights = _load_weights(model_path)
92+
93+
quantization = config_dict.get("quantization")
94+
if quantization is not None:
95+
96+
def class_predicate(p, m):
97+
return isinstance(m, (nn.Linear, nn.Embedding)) and f"{p}.scales" in weights
98+
99+
nn.quantize(model, **quantization, class_predicate=class_predicate)
100+
101+
weights = model.sanitize(weights)
102+
model.load_weights(list(weights.items()), strict=False)
103+
mx.eval(model.parameters())
104+
return model

vllm_metal/stt/registry.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Speech-to-Text model constructor registry."""
3+
4+
from __future__ import annotations
5+
6+
from collections.abc import Callable
7+
8+
import mlx.core as mx
9+
10+
from vllm_metal.stt.qwen3_asr.config import Qwen3ASRConfig
11+
from vllm_metal.stt.qwen3_asr.model import Qwen3ASRModel
12+
from vllm_metal.stt.whisper.config import WhisperConfig
13+
from vllm_metal.stt.whisper.model import WhisperModel
14+
15+
STTModel = WhisperModel | Qwen3ASRModel
16+
STTModelConstructor = Callable[[dict, mx.Dtype], STTModel]
17+
18+
19+
def get_stt_model_constructor(model_type: str) -> STTModelConstructor:
20+
"""Return the model constructor for an STT ``model_type``."""
21+
model_type = model_type.lower()
22+
try:
23+
return _STT_MODEL_CONSTRUCTORS[model_type]
24+
except KeyError:
25+
raise ValueError(
26+
f"Unsupported STT model_type: {model_type!r}. "
27+
"Expected 'whisper' or 'qwen3_asr'."
28+
) from None
29+
30+
31+
def _construct_whisper_model(config_dict: dict, dtype: mx.Dtype) -> WhisperModel:
32+
config = WhisperConfig.from_dict(config_dict)
33+
return WhisperModel(config, dtype)
34+
35+
36+
def _construct_qwen3_asr_model(config_dict: dict, dtype: mx.Dtype) -> Qwen3ASRModel:
37+
config = Qwen3ASRConfig.from_dict(config_dict)
38+
return Qwen3ASRModel(config, dtype)
39+
40+
41+
_STT_MODEL_CONSTRUCTORS: dict[str, STTModelConstructor] = {
42+
# Default to Whisper for backward compatibility.
43+
"": _construct_whisper_model,
44+
"whisper": _construct_whisper_model,
45+
"qwen3_asr": _construct_qwen3_asr_model,
46+
}

vllm_metal/stt/transcribe.py

Lines changed: 5 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,164 +1,26 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
"""Speech-to-Text model loading and orchestration."""
2+
"""Speech-to-Text model orchestration."""
33

44
from __future__ import annotations
55

6-
import json
76
import logging
87
from pathlib import Path
98

109
import mlx.core as mx
11-
import mlx.nn as nn
1210
import numpy as np
1311

1412
from vllm_metal.stt.config import SpeechToTextConfig
13+
from vllm_metal.stt.loader import load_model as _load_model
1514
from vllm_metal.stt.protocol import TranscriptionResult
16-
from vllm_metal.stt.qwen3_asr.config import Qwen3ASRConfig
17-
from vllm_metal.stt.qwen3_asr.model import Qwen3ASRModel
1815
from vllm_metal.stt.qwen3_asr.transcriber import Qwen3ASRTranscriber # noqa: F401
19-
from vllm_metal.stt.whisper import WhisperConfig, WhisperModel, WhisperTranscriber
16+
from vllm_metal.stt.whisper import WhisperModel, WhisperTranscriber
2017

2118
logger = logging.getLogger(__name__)
2219

23-
try:
24-
from huggingface_hub import snapshot_download
25-
except ImportError: # pragma: no cover
26-
snapshot_download = None # type: ignore[assignment]
27-
28-
# Supported floating-point dtypes for STT model loading.
29-
_SUPPORTED_LOAD_DTYPES = frozenset({mx.float16, mx.float32, mx.bfloat16})
30-
31-
32-
# ===========================================================================
33-
# Model loading
34-
# ===========================================================================
35-
36-
37-
def _read_config(model_path: Path) -> dict:
38-
"""Read and return config.json from a model directory."""
39-
config_path = model_path / "config.json"
40-
if not config_path.exists():
41-
raise FileNotFoundError(f"config.json not found in {model_path}")
42-
with open(config_path) as f:
43-
return json.load(f)
44-
45-
46-
def _load_weights(model_path: Path) -> dict[str, mx.array]:
47-
"""Load model weights from safetensors or npz files."""
48-
weight_files = sorted(model_path.glob("*.safetensors"))
49-
if not weight_files:
50-
weight_files = sorted(model_path.glob("*.npz"))
51-
if not weight_files:
52-
raise FileNotFoundError(f"No weight files in {model_path}")
53-
54-
weights: dict[str, mx.array] = {}
55-
for wf in weight_files:
56-
weights.update(mx.load(str(wf)))
57-
return weights
58-
59-
60-
def _resolve_model_path(model_path: str | Path) -> Path:
61-
"""Resolve model path, downloading from HF if needed."""
62-
model_path = Path(model_path)
63-
if not model_path.exists():
64-
if snapshot_download is None:
65-
raise ValueError(
66-
f"Could not download model {model_path}: huggingface_hub is not installed"
67-
)
68-
try:
69-
model_path = Path(snapshot_download(repo_id=str(model_path)))
70-
except OSError as e:
71-
raise ValueError(f"Could not download model: {model_path}") from e
72-
return model_path
73-
74-
75-
def _validate_load_dtype(dtype: mx.Dtype) -> None:
76-
"""Validate the floating-point dtype used for model loading."""
77-
if dtype not in _SUPPORTED_LOAD_DTYPES:
78-
names = ", ".join(sorted(str(d) for d in _SUPPORTED_LOAD_DTYPES))
79-
raise TypeError(
80-
f"Unsupported STT model dtype: {dtype!r}. Must be one of {names}."
81-
)
82-
8320

8421
def load_model(model_path: str | Path, dtype: mx.Dtype = mx.float16):
85-
"""Load an STT model from a local directory or HuggingFace repo.
86-
87-
Auto-detects model type from config.json and dispatches to the
88-
appropriate loader (Whisper or Qwen3-ASR).
89-
90-
Args:
91-
model_path: Local path or HuggingFace repo ID.
92-
dtype: Model dtype (default: float16).
93-
94-
Returns:
95-
Loaded model ready for inference.
96-
97-
Raises:
98-
ValueError: If the model type is unsupported or download fails.
99-
FileNotFoundError: If config.json or weight files are missing.
100-
"""
101-
if isinstance(model_path, str) and not model_path.strip():
102-
raise ValueError(
103-
"model_path must be a non-empty local path or HuggingFace repo ID."
104-
)
105-
_validate_load_dtype(dtype)
106-
model_path = _resolve_model_path(model_path)
107-
config_dict = _read_config(model_path)
108-
model_type = config_dict.get("model_type", "").lower()
109-
110-
if model_type == "qwen3_asr":
111-
return _load_qwen3_asr_model(model_path, config_dict, dtype)
112-
if model_type in ("", "whisper"):
113-
# Default to Whisper for backward compatibility
114-
return _load_whisper_model(model_path, config_dict, dtype)
115-
raise ValueError(
116-
f"Unsupported STT model_type: {model_type!r}. "
117-
"Expected 'whisper' or 'qwen3_asr'."
118-
)
119-
120-
121-
def _load_and_init_model(model, model_path: Path, config_dict: dict):
122-
"""Shared loader: quantize, sanitize, load weights, and eval.
123-
124-
Args:
125-
model: Instantiated model with a ``sanitize`` method.
126-
model_path: Path to weight files.
127-
config_dict: Raw config.json dict (checked for ``quantization``).
128-
129-
Returns:
130-
The model with weights loaded and evaluated.
131-
"""
132-
weights = _load_weights(model_path)
133-
134-
quantization = config_dict.get("quantization")
135-
if quantization is not None:
136-
137-
def class_predicate(p, m):
138-
return isinstance(m, (nn.Linear, nn.Embedding)) and f"{p}.scales" in weights
139-
140-
nn.quantize(model, **quantization, class_predicate=class_predicate)
141-
142-
weights = model.sanitize(weights)
143-
model.load_weights(list(weights.items()), strict=False)
144-
mx.eval(model.parameters())
145-
return model
146-
147-
148-
def _load_whisper_model(
149-
model_path: Path, config_dict: dict, dtype: mx.Dtype
150-
) -> WhisperModel:
151-
"""Load a Whisper model from config and weights."""
152-
config = WhisperConfig.from_dict(config_dict)
153-
model = WhisperModel(config, dtype)
154-
return _load_and_init_model(model, model_path, config_dict)
155-
156-
157-
def _load_qwen3_asr_model(model_path: Path, config_dict: dict, dtype: mx.Dtype):
158-
"""Load a Qwen3-ASR model from config and weights."""
159-
config = Qwen3ASRConfig.from_dict(config_dict)
160-
model = Qwen3ASRModel(config, dtype)
161-
return _load_and_init_model(model, model_path, config_dict)
22+
"""Load an STT model from a local directory or HuggingFace repo."""
23+
return _load_model(model_path, dtype)
16224

16325

16426
# ===========================================================================

0 commit comments

Comments
 (0)