Skip to content

Commit 6180470

Browse files
committed
Add wrapper models for sherpa-onnx and nemo
1 parent 8dc27be commit 6180470

4 files changed

Lines changed: 113 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ python_version = "3.10"
115115
strict = true
116116
pretty = true
117117
untyped_calls_exclude = "onnxruntime"
118-
exclude = ["^preprocessors.", "^tests.preprocessors."]
118+
exclude = ["^preprocessors.", "^tests.preprocessors.", "^wrappers."]
119119

120120
[[tool.mypy.overrides]]
121121
module = ["onnxruntime.*"]
@@ -153,6 +153,7 @@ ignore = [
153153
[tool.ruff.lint.per-file-ignores]
154154
"tests/*" = ["ANN", "D", "FBT001", "PGH003", "PLR0911", "PLR2004"]
155155
"preprocessors/*" = ["ANN", "D103", "F821", "N802", "N806"]
156+
"wrappers/*" = ["ANN401", "PLC0415"]
156157
"*.ipynb" = ["ANN", "D", "ERA", "RUF001", "T"]
157158

158159
[tool.pytest]

wrappers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Wrapper models for testing and comparison."""

wrappers/nemo.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Wrapper for sherpa-onnx models."""
2+
3+
from collections.abc import Iterator
4+
from typing import Any, Literal
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
9+
from onnx_asr.asr import TimestampedResult
10+
11+
12+
class NemoASR:
13+
"""Wrapper model for NeMo Toolkit ASR."""
14+
15+
def __init__(self, model_name: str):
16+
"""Create wrapper."""
17+
from nemo.utils.nemo_logging import Logger
18+
19+
self.logger = Logger()
20+
self.logger.setLevel(Logger.ERROR)
21+
22+
import nemo.collections.asr as nemo_asr
23+
24+
self.model: Any = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
25+
self.model.change_decoding_strategy({"strategy": "greedy_batch"})
26+
self.model.eval()
27+
28+
@staticmethod
29+
def _get_sample_rate() -> Literal[8_000, 16_000]:
30+
return 16_000
31+
32+
def recognize_batch(
33+
self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], /, **kwargs: object | None
34+
) -> Iterator[TimestampedResult]:
35+
"""Recognize waveforms batch."""
36+
for waveform, waveform_len in zip(waveforms, waveforms_len, strict=True):
37+
hypot = self.model.transcribe(waveform[:waveform_len], verbose=False)
38+
yield TimestampedResult(hypot[0].text)

wrappers/sherpa.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Wrapper for sherpa-onnx models."""
2+
3+
from collections.abc import Iterator
4+
from pathlib import Path
5+
from typing import Any, Literal
6+
7+
import numpy as np
8+
import numpy.typing as npt
9+
10+
from onnx_asr.asr import TimestampedResult
11+
from onnx_asr.resolver import Resolver
12+
13+
14+
class SherpaASR:
15+
"""Wrapper model for sherpa-onnx ASR."""
16+
17+
def __init__(
18+
self,
19+
repo_id: str | None = None,
20+
local_dir: Path | None = None,
21+
*,
22+
offline: bool | None = None,
23+
quantization: str | None = None,
24+
**kwargs: Any,
25+
):
26+
"""Create wrapper."""
27+
resolver = Resolver(SherpaASR, repo_id, local_dir, offline=offline)
28+
model_files = resolver.resolve_model(quantization=quantization)
29+
30+
import sherpa_onnx
31+
32+
self._recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
33+
str(model_files["encoder"]),
34+
str(model_files["decoder"]),
35+
str(model_files["joiner"]),
36+
str(model_files["tokens"]),
37+
bpe_vocab=str(model_files["bpe_vocab"]),
38+
modeling_unit="bpe",
39+
**kwargs,
40+
)
41+
42+
@staticmethod
43+
def _get_model_files(quantization: str | None = None) -> dict[str, str]:
44+
suffix = "?" + quantization if quantization else ""
45+
return {
46+
"encoder": f"*/encoder{suffix}.onnx",
47+
"decoder": f"*/decoder{suffix}.onnx",
48+
"joiner": f"*/joiner{suffix}.onnx",
49+
"tokens": "*/tokens.txt",
50+
"bpe_vocab": "*/unigram_500.vocab",
51+
}
52+
53+
@staticmethod
54+
def _get_sample_rate() -> Literal[8_000, 16_000]:
55+
return 16_000
56+
57+
def recognize_batch(
58+
self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], /, **kwargs: object | None
59+
) -> Iterator[TimestampedResult]:
60+
"""Recognize waveforms batch."""
61+
streams = []
62+
for waveform, waveform_len in zip(waveforms, waveforms_len, strict=True):
63+
stream = self._recognizer.create_stream()
64+
stream.accept_waveform(self._get_sample_rate(), waveform[:waveform_len])
65+
streams.append(stream)
66+
self._recognizer.decode_streams(streams)
67+
return (
68+
TimestampedResult(
69+
stream.result.text, stream.result.timestamps, stream.result.tokens, stream.result.ys_log_probs
70+
)
71+
for stream in streams
72+
)

0 commit comments

Comments
 (0)