Skip to content

Commit 09009e1

Browse files
committed
Update loader tests and fix model caching
1 parent 82a1cbb commit 09009e1

4 files changed

Lines changed: 197 additions & 39 deletions

File tree

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import huggingface_hub as hf
2+
import numpy as np
3+
import onnxruntime as rt
4+
5+
6+
def pytest_report_header() -> str:
7+
return f"onnx-asr deps: numpy-{np.__version__}, onnxruntime-{rt.__version__}, huggingface-hub-{hf.__version__}"

tests/onnx_asr/test_cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ def test_cli_run(args_list: list[str]) -> None:
4545
except ImportError:
4646
pytest.skip("soundfile not available")
4747
args = parse_args(args_list)
48+
args.model_path = None
4849

4950
rng = np.random.default_rng(0)
5051
for file in args.filename:
5152
sf.write(file, rng.random((16_000), dtype=np.float32), 16_000)
5253

53-
run(parse_args(args_list))
54+
run(args)

tests/onnx_asr/test_load_model_errors.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

tests/onnx_asr/test_loader.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import sys
2+
from pathlib import Path
3+
from typing import get_args
4+
5+
import pytest
6+
7+
from onnx_asr.asr import Asr
8+
from onnx_asr.loader import (
9+
AsrLoader,
10+
InvalidModelTypeInConfigError,
11+
ModelFileNotFoundError,
12+
ModelNames,
13+
ModelNotSupportedError,
14+
ModelPathNotDirectoryError,
15+
ModelTypes,
16+
MoreThanOneModelFileFoundError,
17+
NoModelNameOrPathSpecifiedError,
18+
VadLoader,
19+
VadNames,
20+
)
21+
from onnx_asr.models import KaldiTransducer, NemoConformerAED, TOneCtc, WhisperHf
22+
from onnx_asr.models.pyannote import PyAnnoteVad
23+
from onnx_asr.models.silero import SileroVad
24+
from onnx_asr.vad import Vad
25+
26+
27+
@pytest.mark.parametrize("model", get_args(ModelNames))
28+
def test_model_names(model: ModelNames) -> None:
29+
loader = AsrLoader(model)
30+
assert issubclass(loader.model_type, Asr)
31+
assert not loader.offline
32+
assert loader.local_dir is None
33+
assert isinstance(loader.repo_id, str)
34+
35+
36+
@pytest.mark.parametrize("model", get_args(ModelNames))
37+
def test_model_names_with_path(model: ModelNames, tmp_path: Path) -> None:
38+
loader = AsrLoader(model, tmp_path)
39+
assert issubclass(loader.model_type, Asr)
40+
assert loader.offline
41+
assert loader.local_dir == tmp_path
42+
assert isinstance(loader.repo_id, str)
43+
44+
45+
@pytest.mark.parametrize(
46+
("model", "type"),
47+
[
48+
("alphacep/vosk-model-ru", KaldiTransducer),
49+
("alphacep/vosk-model-small-ru", KaldiTransducer),
50+
("t-tech/t-one", TOneCtc),
51+
("onnx-community/whisper-tiny", WhisperHf),
52+
("istupakov/canary-180m-flash-onnx", NemoConformerAED),
53+
],
54+
)
55+
def test_model_repos(model: str, type: type[Asr]) -> None:
56+
loader = AsrLoader(model)
57+
assert loader.model_type == type
58+
assert not loader.offline
59+
assert loader.local_dir is None
60+
assert loader.repo_id == model
61+
62+
63+
@pytest.mark.parametrize(
64+
("model", "type"),
65+
[
66+
("alphacep/vosk-model-ru", KaldiTransducer),
67+
("alphacep/vosk-model-small-ru", KaldiTransducer),
68+
("t-tech/t-one", TOneCtc),
69+
],
70+
)
71+
def test_model_repos_with_path(model: str, tmp_path: Path, type: type[Asr]) -> None:
72+
loader = AsrLoader(model, tmp_path)
73+
assert loader.model_type == type
74+
assert loader.offline
75+
assert loader.local_dir == tmp_path
76+
assert loader.repo_id == model
77+
78+
79+
@pytest.mark.parametrize("model", get_args(ModelTypes))
80+
def test_model_types(model: ModelTypes, tmp_path: Path) -> None:
81+
loader = AsrLoader(model, tmp_path)
82+
assert issubclass(loader.model_type, Asr)
83+
assert loader.offline
84+
assert loader.local_dir == tmp_path
85+
assert loader.repo_id is None
86+
87+
88+
def test_model_not_supported_error(tmp_path: Path) -> None:
89+
with pytest.raises(ModelNotSupportedError):
90+
AsrLoader("xxx", tmp_path)
91+
92+
93+
@pytest.mark.parametrize("model", get_args(ModelTypes))
94+
def test_no_model_name_or_path_specified_error(model: ModelTypes) -> None:
95+
with pytest.raises(NoModelNameOrPathSpecifiedError):
96+
AsrLoader(model)
97+
98+
99+
@pytest.mark.parametrize("model", get_args(ModelTypes))
100+
def test_no_model_name_and_empty_path_specified_error(model: ModelTypes, tmp_path: Path) -> None:
101+
with pytest.raises(NoModelNameOrPathSpecifiedError):
102+
AsrLoader(model, Path(tmp_path, "model"))
103+
104+
105+
@pytest.mark.parametrize("model", get_args(ModelTypes))
106+
def test_model_path_not_found_error(model: ModelTypes, tmp_path: Path) -> None:
107+
Path(tmp_path, "model").write_text("test")
108+
with pytest.raises(ModelPathNotDirectoryError):
109+
AsrLoader(model, Path(tmp_path, "model"))
110+
111+
112+
def test_model_file_not_found_error(tmp_path: Path) -> None:
113+
with pytest.raises(ModelFileNotFoundError):
114+
AsrLoader("onnx-community/whisper-tiny", tmp_path)
115+
116+
117+
def test_offline_model_file_not_found_error() -> None:
118+
with pytest.raises(ModelFileNotFoundError):
119+
AsrLoader("onnx-community/whisper-tiny", offline=True).resolve_model(quantization="fp16")
120+
121+
122+
def test_invalid_model_type_in_config_error(tmp_path: Path) -> None:
123+
Path(tmp_path, "config.json").write_text('{"model_type": "xxx"}')
124+
with pytest.raises(InvalidModelTypeInConfigError):
125+
AsrLoader("onnx-community/whisper-tiny", tmp_path)
126+
127+
128+
def test_remote_config_not_found_error() -> None:
129+
with pytest.raises(IOError): # noqa: PT011
130+
AsrLoader("alphacep/vosk-model-small-ru").resolve_config()
131+
132+
133+
def test_offline_config_not_found_error() -> None:
134+
with pytest.raises(FileNotFoundError):
135+
AsrLoader("alphacep/vosk-model-small-ru", offline=True).resolve_config()
136+
137+
138+
def test_resolve_model_file_not_found_error() -> None:
139+
loader = AsrLoader("onnx-community/whisper-tiny")
140+
with pytest.raises(ModelFileNotFoundError):
141+
loader.resolve_model(quantization="xxx")
142+
143+
144+
def test_more_than_one_model_file_found_error() -> None:
145+
loader = AsrLoader("onnx-community/whisper-tiny")
146+
with pytest.raises(MoreThanOneModelFileFoundError):
147+
loader.resolve_model(quantization="*int8")
148+
149+
150+
def test_with_offline_huggingface_hub() -> None:
151+
AsrLoader("onnx-community/whisper-tiny").resolve_model(quantization="uint8")
152+
153+
AsrLoader("onnx-community/whisper-tiny", offline=True).resolve_model(quantization="uint8")
154+
155+
156+
def test_without_huggingface_hub(monkeypatch: pytest.MonkeyPatch) -> None:
157+
loader = AsrLoader("onnx-community/whisper-tiny")
158+
159+
path = loader._download_model("uint8", local_files_only=False)
160+
161+
monkeypatch.setitem(sys.modules, "huggingface_hub", None)
162+
loader_with_path = AsrLoader("onnx-community/whisper-tiny", path)
163+
assert loader_with_path.offline
164+
loader_with_path.resolve_model(quantization="uint8")
165+
166+
167+
@pytest.mark.parametrize("model", [*get_args(VadNames), "pyannote"])
168+
def test_vad(model: str) -> None:
169+
loader = VadLoader(model)
170+
assert issubclass(loader.model_type, Vad)
171+
assert not loader.offline
172+
assert loader.local_dir is None
173+
assert isinstance(loader.repo_id, str)
174+
175+
176+
@pytest.mark.parametrize("model", [*get_args(VadNames), "pyannote"])
177+
def test_vad_with_path(model: str, tmp_path: Path) -> None:
178+
loader = VadLoader(model, tmp_path)
179+
assert issubclass(loader.model_type, SileroVad | PyAnnoteVad)
180+
assert loader.offline
181+
assert loader.local_dir == tmp_path
182+
assert isinstance(loader.repo_id, str)
183+
184+
185+
def test_resolve_vad_file_not_found_error() -> None:
186+
loader = VadLoader("silero")
187+
with pytest.raises(ModelFileNotFoundError):
188+
loader.resolve_model(quantization="xxx")

0 commit comments

Comments
 (0)