|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | | -"""Speech-to-Text model loading and orchestration.""" |
| 2 | +"""Speech-to-Text model orchestration.""" |
3 | 3 |
|
4 | 4 | from __future__ import annotations |
5 | 5 |
|
6 | | -import json |
7 | 6 | import logging |
8 | 7 | from pathlib import Path |
9 | 8 |
|
10 | 9 | import mlx.core as mx |
11 | | -import mlx.nn as nn |
12 | 10 | import numpy as np |
13 | 11 |
|
14 | 12 | from vllm_metal.stt.config import SpeechToTextConfig |
| 13 | +from vllm_metal.stt.loader import load_model as _load_model |
15 | 14 | from vllm_metal.stt.protocol import TranscriptionResult |
16 | | -from vllm_metal.stt.qwen3_asr.config import Qwen3ASRConfig |
17 | | -from vllm_metal.stt.qwen3_asr.model import Qwen3ASRModel |
18 | 15 | from vllm_metal.stt.qwen3_asr.transcriber import Qwen3ASRTranscriber # noqa: F401 |
19 | | -from vllm_metal.stt.whisper import WhisperConfig, WhisperModel, WhisperTranscriber |
| 16 | +from vllm_metal.stt.whisper import WhisperModel, WhisperTranscriber |
20 | 17 |
|
21 | 18 | logger = logging.getLogger(__name__) |
22 | 19 |
|
23 | | -try: |
24 | | - from huggingface_hub import snapshot_download |
25 | | -except ImportError: # pragma: no cover |
26 | | - snapshot_download = None # type: ignore[assignment] |
27 | | - |
28 | | -# Supported floating-point dtypes for STT model loading. |
29 | | -_SUPPORTED_LOAD_DTYPES = frozenset({mx.float16, mx.float32, mx.bfloat16}) |
30 | | - |
31 | | - |
32 | | -# =========================================================================== |
33 | | -# Model loading |
34 | | -# =========================================================================== |
35 | | - |
36 | | - |
37 | | -def _read_config(model_path: Path) -> dict: |
38 | | - """Read and return config.json from a model directory.""" |
39 | | - config_path = model_path / "config.json" |
40 | | - if not config_path.exists(): |
41 | | - raise FileNotFoundError(f"config.json not found in {model_path}") |
42 | | - with open(config_path) as f: |
43 | | - return json.load(f) |
44 | | - |
45 | | - |
46 | | -def _load_weights(model_path: Path) -> dict[str, mx.array]: |
47 | | - """Load model weights from safetensors or npz files.""" |
48 | | - weight_files = sorted(model_path.glob("*.safetensors")) |
49 | | - if not weight_files: |
50 | | - weight_files = sorted(model_path.glob("*.npz")) |
51 | | - if not weight_files: |
52 | | - raise FileNotFoundError(f"No weight files in {model_path}") |
53 | | - |
54 | | - weights: dict[str, mx.array] = {} |
55 | | - for wf in weight_files: |
56 | | - weights.update(mx.load(str(wf))) |
57 | | - return weights |
58 | | - |
59 | | - |
60 | | -def _resolve_model_path(model_path: str | Path) -> Path: |
61 | | - """Resolve model path, downloading from HF if needed.""" |
62 | | - model_path = Path(model_path) |
63 | | - if not model_path.exists(): |
64 | | - if snapshot_download is None: |
65 | | - raise ValueError( |
66 | | - f"Could not download model {model_path}: huggingface_hub is not installed" |
67 | | - ) |
68 | | - try: |
69 | | - model_path = Path(snapshot_download(repo_id=str(model_path))) |
70 | | - except OSError as e: |
71 | | - raise ValueError(f"Could not download model: {model_path}") from e |
72 | | - return model_path |
73 | | - |
74 | | - |
75 | | -def _validate_load_dtype(dtype: mx.Dtype) -> None: |
76 | | - """Validate the floating-point dtype used for model loading.""" |
77 | | - if dtype not in _SUPPORTED_LOAD_DTYPES: |
78 | | - names = ", ".join(sorted(str(d) for d in _SUPPORTED_LOAD_DTYPES)) |
79 | | - raise TypeError( |
80 | | - f"Unsupported STT model dtype: {dtype!r}. Must be one of {names}." |
81 | | - ) |
82 | | - |
83 | 20 |
|
84 | 21 | def load_model(model_path: str | Path, dtype: mx.Dtype = mx.float16): |
85 | | - """Load an STT model from a local directory or HuggingFace repo. |
86 | | -
|
87 | | - Auto-detects model type from config.json and dispatches to the |
88 | | - appropriate loader (Whisper or Qwen3-ASR). |
89 | | -
|
90 | | - Args: |
91 | | - model_path: Local path or HuggingFace repo ID. |
92 | | - dtype: Model dtype (default: float16). |
93 | | -
|
94 | | - Returns: |
95 | | - Loaded model ready for inference. |
96 | | -
|
97 | | - Raises: |
98 | | - ValueError: If the model type is unsupported or download fails. |
99 | | - FileNotFoundError: If config.json or weight files are missing. |
100 | | - """ |
101 | | - if isinstance(model_path, str) and not model_path.strip(): |
102 | | - raise ValueError( |
103 | | - "model_path must be a non-empty local path or HuggingFace repo ID." |
104 | | - ) |
105 | | - _validate_load_dtype(dtype) |
106 | | - model_path = _resolve_model_path(model_path) |
107 | | - config_dict = _read_config(model_path) |
108 | | - model_type = config_dict.get("model_type", "").lower() |
109 | | - |
110 | | - if model_type == "qwen3_asr": |
111 | | - return _load_qwen3_asr_model(model_path, config_dict, dtype) |
112 | | - if model_type in ("", "whisper"): |
113 | | - # Default to Whisper for backward compatibility |
114 | | - return _load_whisper_model(model_path, config_dict, dtype) |
115 | | - raise ValueError( |
116 | | - f"Unsupported STT model_type: {model_type!r}. " |
117 | | - "Expected 'whisper' or 'qwen3_asr'." |
118 | | - ) |
119 | | - |
120 | | - |
121 | | -def _load_and_init_model(model, model_path: Path, config_dict: dict): |
122 | | - """Shared loader: quantize, sanitize, load weights, and eval. |
123 | | -
|
124 | | - Args: |
125 | | - model: Instantiated model with a ``sanitize`` method. |
126 | | - model_path: Path to weight files. |
127 | | - config_dict: Raw config.json dict (checked for ``quantization``). |
128 | | -
|
129 | | - Returns: |
130 | | - The model with weights loaded and evaluated. |
131 | | - """ |
132 | | - weights = _load_weights(model_path) |
133 | | - |
134 | | - quantization = config_dict.get("quantization") |
135 | | - if quantization is not None: |
136 | | - |
137 | | - def class_predicate(p, m): |
138 | | - return isinstance(m, (nn.Linear, nn.Embedding)) and f"{p}.scales" in weights |
139 | | - |
140 | | - nn.quantize(model, **quantization, class_predicate=class_predicate) |
141 | | - |
142 | | - weights = model.sanitize(weights) |
143 | | - model.load_weights(list(weights.items()), strict=False) |
144 | | - mx.eval(model.parameters()) |
145 | | - return model |
146 | | - |
147 | | - |
148 | | -def _load_whisper_model( |
149 | | - model_path: Path, config_dict: dict, dtype: mx.Dtype |
150 | | -) -> WhisperModel: |
151 | | - """Load a Whisper model from config and weights.""" |
152 | | - config = WhisperConfig.from_dict(config_dict) |
153 | | - model = WhisperModel(config, dtype) |
154 | | - return _load_and_init_model(model, model_path, config_dict) |
155 | | - |
156 | | - |
157 | | -def _load_qwen3_asr_model(model_path: Path, config_dict: dict, dtype: mx.Dtype): |
158 | | - """Load a Qwen3-ASR model from config and weights.""" |
159 | | - config = Qwen3ASRConfig.from_dict(config_dict) |
160 | | - model = Qwen3ASRModel(config, dtype) |
161 | | - return _load_and_init_model(model, model_path, config_dict) |
| 22 | + """Load an STT model from a local directory or HuggingFace repo.""" |
| 23 | + return _load_model(model_path, dtype) |
162 | 24 |
|
163 | 25 |
|
164 | 26 | # =========================================================================== |
|
0 commit comments