Skip to content

Commit 724e7df

Browse files
committed
Added model tests
1 parent 5228031 commit 724e7df

2 files changed

Lines changed: 16 additions & 5 deletions

File tree

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ permissions:
88

99
on:
1010
push:
11-
branches: [ "main" ]
11+
branches: [ "main", "tests" ]
1212
pull_request:
1313
branches: [ "main" ]
1414

tests/onnx_asr/test_recognize.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,24 @@
55
import onnx_asr.utils
66
from onnx_asr.adapters import TextResultsAsrAdapter
77

8+
models = [
9+
"gigaam-v2-ctc",
10+
"gigaam-v2-rnnt",
11+
"nemo-fastconformer-ru-ctc",
12+
"nemo-fastconformer-ru-rnnt",
13+
"alphacep/vosk-model-ru",
14+
"alphacep/vosk-model-small-ru",
15+
"whisper-base",
16+
"onnx-community/whisper-tiny",
17+
]
18+
819

920
@pytest.fixture(scope="module")
1021
def model(request: pytest.FixtureRequest) -> TextResultsAsrAdapter:
11-
return onnx_asr.load_model(request.param)
22+
return onnx_asr.load_model(request.param, quantization="int8" if request.param != "onnx-community/whisper-tiny" else "uint8")
1223

1324

14-
@pytest.mark.parametrize("model", ["alphacep/vosk-model-small-ru", "onnx-community/whisper-tiny", "whisper-base"], indirect=True)
25+
@pytest.mark.parametrize("model", models, indirect=True)
1526
def test_supported_only_mono_audio_error(model: TextResultsAsrAdapter) -> None:
1627
rng = np.random.default_rng(0)
1728
waveform = rng.random((1 * 16_000, 2), dtype=np.float32)
@@ -20,7 +31,7 @@ def test_supported_only_mono_audio_error(model: TextResultsAsrAdapter) -> None:
2031
model.recognize(waveform)
2132

2233

23-
@pytest.mark.parametrize("model", ["alphacep/vosk-model-small-ru", "onnx-community/whisper-tiny", "whisper-base"], indirect=True)
34+
@pytest.mark.parametrize("model", models, indirect=True)
2435
def test_wrong_sample_rate_error(model: TextResultsAsrAdapter) -> None:
2536
rng = np.random.default_rng(0)
2637
waveform = rng.random((1 * 16_000), dtype=np.float32)
@@ -29,7 +40,7 @@ def test_wrong_sample_rate_error(model: TextResultsAsrAdapter) -> None:
2940
model.recognize(waveform, sample_rate=24_000) # type: ignore
3041

3142

32-
@pytest.mark.parametrize("model", ["alphacep/vosk-model-small-ru", "onnx-community/whisper-tiny", "whisper-base"], indirect=True)
43+
@pytest.mark.parametrize("model", models, indirect=True)
3344
def test_recognize(model: TextResultsAsrAdapter) -> None:
3445
rng = np.random.default_rng(0)
3546
waveform = rng.random((1 * 16_000), dtype=np.float32)

0 commit comments

Comments
 (0)