|
| 1 | +"""Wrapper for sherpa-onnx models.""" |
| 2 | + |
| 3 | +from collections.abc import Iterator |
| 4 | +from pathlib import Path |
| 5 | +from typing import Any, Literal |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import numpy.typing as npt |
| 9 | + |
| 10 | +from onnx_asr.asr import TimestampedResult |
| 11 | +from onnx_asr.resolver import Resolver |
| 12 | + |
| 13 | + |
| 14 | +class SherpaASR: |
| 15 | + """Wrapper model for sherpa-onnx ASR.""" |
| 16 | + |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + repo_id: str | None = None, |
| 20 | + local_dir: Path | None = None, |
| 21 | + *, |
| 22 | + offline: bool | None = None, |
| 23 | + quantization: str | None = None, |
| 24 | + **kwargs: Any, |
| 25 | + ): |
| 26 | + """Create wrapper.""" |
| 27 | + resolver = Resolver(SherpaASR, repo_id, local_dir, offline=offline) |
| 28 | + model_files = resolver.resolve_model(quantization=quantization) |
| 29 | + |
| 30 | + import sherpa_onnx |
| 31 | + |
| 32 | + self._recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( |
| 33 | + str(model_files["encoder"]), |
| 34 | + str(model_files["decoder"]), |
| 35 | + str(model_files["joiner"]), |
| 36 | + str(model_files["tokens"]), |
| 37 | + bpe_vocab=str(model_files["bpe_vocab"]), |
| 38 | + modeling_unit="bpe", |
| 39 | + **kwargs, |
| 40 | + ) |
| 41 | + |
| 42 | + @staticmethod |
| 43 | + def _get_model_files(quantization: str | None = None) -> dict[str, str]: |
| 44 | + suffix = "?" + quantization if quantization else "" |
| 45 | + return { |
| 46 | + "encoder": f"*/encoder{suffix}.onnx", |
| 47 | + "decoder": f"*/decoder{suffix}.onnx", |
| 48 | + "joiner": f"*/joiner{suffix}.onnx", |
| 49 | + "tokens": "*/tokens.txt", |
| 50 | + "bpe_vocab": "*/unigram_500.vocab", |
| 51 | + } |
| 52 | + |
| 53 | + @staticmethod |
| 54 | + def _get_sample_rate() -> Literal[8_000, 16_000]: |
| 55 | + return 16_000 |
| 56 | + |
| 57 | + def recognize_batch( |
| 58 | + self, waveforms: npt.NDArray[np.float32], waveforms_len: npt.NDArray[np.int64], /, **kwargs: object | None |
| 59 | + ) -> Iterator[TimestampedResult]: |
| 60 | + """Recognize waveforms batch.""" |
| 61 | + streams = [] |
| 62 | + for waveform, waveform_len in zip(waveforms, waveforms_len, strict=True): |
| 63 | + stream = self._recognizer.create_stream() |
| 64 | + stream.accept_waveform(self._get_sample_rate(), waveform[:waveform_len]) |
| 65 | + streams.append(stream) |
| 66 | + self._recognizer.decode_streams(streams) |
| 67 | + return ( |
| 68 | + TimestampedResult( |
| 69 | + stream.result.text, stream.result.timestamps, stream.result.tokens, stream.result.ys_log_probs |
| 70 | + ) |
| 71 | + for stream in streams |
| 72 | + ) |
0 commit comments