Skip to content

Commit 50ec38e

Browse files
committed
Add tests for load_model and recognize
1 parent 1f1be16 commit 50ec38e

10 files changed

Lines changed: 68 additions & 0 deletions

tests/onnx_asr/__init__.py

Whitespace-only changes.

tests/onnx_asr/test_load_model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
3+
import onnx_asr
4+
5+
6+
def test_model_not_supported_error():
7+
with pytest.raises(onnx_asr.loader.ModelNotSupportedError):
8+
onnx_asr.load_model("xxx")
9+
10+
11+
def test_model_path_not_found_error():
12+
with pytest.raises(onnx_asr.loader.ModelPathNotFoundError):
13+
onnx_asr.load_model("onnx-community/whisper-tiny", "./xxx")
14+
15+
16+
def test_model_file_not_found_error():
17+
with pytest.raises(onnx_asr.loader.ModelFileNotFoundError):
18+
onnx_asr.load_model("onnx-community/whisper-tiny", quantization="xxx")
19+
20+
21+
def test_more_than_one_model_file_found_error():
22+
with pytest.raises(onnx_asr.loader.MoreThanOneModelFileFoundError):
23+
onnx_asr.load_model("onnx-community/whisper-tiny", quantization="*int8")
24+
25+
26+
def test_no_model_name_or_path_specified_error():
27+
with pytest.raises(onnx_asr.loader.NoModelNameOrPathSpecifiedError):
28+
onnx_asr.load_model("whisper-hf")
29+
30+
31+
@pytest.mark.parametrize("model", ["onnx-community/whisper-tiny", "alphacep/vosk-model-small-ru"])
32+
def test_load_model(model):
33+
onnx_asr.load_model(model)

tests/onnx_asr/test_recognize.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
import pytest
3+
4+
import onnx_asr
5+
import onnx_asr.asr
6+
import onnx_asr.utils
7+
8+
9+
@pytest.fixture(scope="module")
10+
def model():
11+
return onnx_asr.load_model("onnx-community/whisper-tiny")
12+
13+
14+
def test_supported_only_mono_audio_error(model: onnx_asr.asr.Asr):
15+
rng = np.random.default_rng(0)
16+
waveform = rng.random((1 * 16_000, 2), dtype=np.float32)
17+
18+
with pytest.raises(onnx_asr.utils.SupportedOnlyMonoAudioError):
19+
model.recognize(waveform)
20+
21+
22+
def test_wrong_sample_rate_error(model: onnx_asr.asr.Asr):
23+
rng = np.random.default_rng(0)
24+
waveform = rng.random((1 * 16_000), dtype=np.float32)
25+
26+
with pytest.raises(onnx_asr.utils.WrongSampleRateError):
27+
model.recognize(waveform, sample_rate=24_000) # type: ignore
28+
29+
30+
def test_recognize(model: onnx_asr.asr.Asr):
31+
rng = np.random.default_rng(0)
32+
waveform = rng.random((1 * 16_000), dtype=np.float32)
33+
34+
result = model.recognize(waveform)
35+
assert isinstance(result, str)

tests/preprocessors/__init__.py

Whitespace-only changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)