Skip to content

Commit 0bc794b

Browse files
committed
Add resample preprocessor
1 parent 19a7dd8 commit 0bc794b

11 files changed

Lines changed: 198 additions & 21 deletions

File tree

preprocessors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from .gigaam import GigaamPreprocessor
22
from .kaldi import KaldiPreprocessor
33
from .nemo import NemoPreprocessor
4+
from .resample import ResamplePreprocessor
45
from .whisper import WhisperPreprocessor80, WhisperPreprocessor128
56

67
__all__ = [
78
"GigaamPreprocessor",
89
"KaldiPreprocessor",
910
"NemoPreprocessor",
11+
"ResamplePreprocessor",
1012
"WhisperPreprocessor80",
1113
"WhisperPreprocessor128",
1214
]

preprocessors/build.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ def build():
2424
save_model(preprocessors.NemoPreprocessor, preprocessors_dir.joinpath("nemo.onnx"))
2525
save_model(preprocessors.WhisperPreprocessor80, preprocessors_dir.joinpath("whisper80.onnx"))
2626
save_model(preprocessors.WhisperPreprocessor128, preprocessors_dir.joinpath("whisper128.onnx"))
27+
save_model(preprocessors.ResamplePreprocessor, preprocessors_dir.joinpath("resample.onnx"))

preprocessors/resample.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import math
2+
3+
import torch
4+
import torchaudio
5+
from onnx import numpy_helper
6+
from onnxscript import FLOAT, INT64, script
7+
from onnxscript import opset17 as op
8+
9+
10+
def make_kernel(orig_freq: int):
11+
new_freq = 16_000
12+
gcd = math.gcd(orig_freq, new_freq)
13+
kernel, width = torchaudio.functional.functional._get_sinc_resample_kernel(orig_freq, new_freq, gcd, dtype=torch.float32)
14+
return kernel.numpy()[:, None], width, orig_freq // gcd, new_freq // gcd
15+
16+
17+
kernel08, width08, orig_freq08, new_freq08 = make_kernel(8_000)
18+
kernel22, width22, orig_freq22, new_freq22 = make_kernel(22_050)
19+
kernel44, width44, orig_freq44, new_freq44 = make_kernel(44_100)
20+
kernel48, width48, orig_freq48, new_freq48 = make_kernel(48_000)
21+
22+
23+
@script(doc_string="Resampling waveform to 16 kHz")
24+
def ResamplePreprocessor(
25+
waveforms: FLOAT["batch_size", "N"],
26+
waveforms_lens: INT64["batch_size"],
27+
sample_rate: INT64["1"],
28+
) -> tuple[FLOAT["batch_size", "M"], INT64["batch_size"]]:
29+
waveforms = op.Unsqueeze(waveforms, axes=[1, 2])
30+
31+
if sample_rate[0] == 8_000:
32+
kernel = op.Constant(value=numpy_helper.from_array(kernel08, "kernel"))
33+
conv = op.Conv(waveforms, kernel, pads=(0, width08, 0, width08 + orig_freq08), strides=(1, orig_freq08))
34+
waveforms_lens = (new_freq08 * waveforms_lens + orig_freq08 - 1) / orig_freq08
35+
elif sample_rate[0] == 22_050:
36+
kernel = op.Constant(value=numpy_helper.from_array(kernel22, "kernel"))
37+
conv = op.Conv(waveforms, kernel, pads=(0, width22, 0, width22 + orig_freq22), strides=(1, orig_freq22))
38+
waveforms_lens = (new_freq22 * waveforms_lens + orig_freq22 - 1) / orig_freq22
39+
elif sample_rate[0] == 44_100:
40+
kernel = op.Constant(value=numpy_helper.from_array(kernel44, "kernel"))
41+
conv = op.Conv(waveforms, kernel, pads=(0, width44, 0, width44 + orig_freq44), strides=(1, orig_freq44))
42+
waveforms_lens = (new_freq44 * waveforms_lens + orig_freq44 - 1) / orig_freq44
43+
elif sample_rate[0] == 48_000:
44+
kernel = op.Constant(value=numpy_helper.from_array(kernel48, "kernel"))
45+
conv = op.Conv(waveforms, kernel, pads=(0, width48, 0, width48 + orig_freq48), strides=(1, orig_freq48))
46+
waveforms_lens = (new_freq48 * waveforms_lens + orig_freq48 - 1) / orig_freq48
47+
else:
48+
conv = waveforms
49+
50+
resampled_lens = op.Identity(waveforms_lens)
51+
max_len = op.ReduceMax(resampled_lens, keepdims=0)
52+
mask = op.Unsqueeze(op.Range(0, max_len, 1), [0]) < op.Unsqueeze(resampled_lens, [1])
53+
resampled = op.Where(mask, op.Flatten(op.Transpose(conv, perm=(0, 3, 2, 1)))[:, :max_len], 0)
54+
return resampled, resampled_lens

src/onnx_asr/asr.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy.typing as npt
1111

1212
from .preprocessors import Preprocessor
13-
from .utils import pad_list, read_wav_files
13+
from .utils import SampleRates, pad_list, read_wav_files
1414

1515

1616
class Asr(ABC):
@@ -21,26 +21,35 @@ class Asr(ABC):
2121
def _get_model_files(quantization: str | None = None) -> dict[str, str]: ...
2222

2323
@abstractmethod
24-
def _recognize_batch(self, waveforms: list[npt.NDArray[np.float32]], language: str | None = None) -> list[str]: ...
24+
def _recognize_batch(
25+
self, waveforms: list[npt.NDArray[np.float32]], sample_rate: SampleRates, language: str | None
26+
) -> list[str]: ...
2527

2628
def recognize(
27-
self, waveform: str | npt.NDArray[np.float32] | list[str | npt.NDArray[np.float32]], language: str | None = None
29+
self,
30+
waveform: str | npt.NDArray[np.float32] | list[str | npt.NDArray[np.float32]],
31+
*,
32+
sample_rate: SampleRates = 16_000,
33+
language: str | None = None,
2834
) -> str | list[str]:
2935
"""Recognize speech (single or batch).
3036
3137
Args:
3238
waveform: Path to wav file (only PCM_U8, PCM_16, PCM_24 and PCM_32 formats with 16 kHz sample rate are supported)
3339
or Numpy array with PCM waveform.
3440
A list of file paths or numpy arrays for batch recognition are also supported.
41+
sample_rate: Sample rate for Numpy arrays in waveform.
3542
language: Speech language (only for Whisper models).
3643
3744
Returns:
3845
Speech recognition results (single string or list for batch recognition).
3946
4047
"""
4148
if isinstance(waveform, list):
42-
return self._recognize_batch(read_wav_files(waveform), language)
43-
return self._recognize_batch(read_wav_files([waveform]), language)[0]
49+
if not waveform:
50+
return []
51+
return self._recognize_batch(*read_wav_files(waveform, sample_rate), language)
52+
return self._recognize_batch(*read_wav_files([waveform], sample_rate), language)[0]
4453

4554

4655
class _AsrWithDecoding(Asr):
@@ -65,8 +74,12 @@ def _decode_tokens(self, tokens: list[int]) -> str:
6574
text = "".join([self._vocab[i] for i in tokens])
6675
return re.sub(self.DECODE_SPACE_PATTERN, lambda x: " " if x.group(1) else "", text)
6776

68-
def _recognize_batch(self, waveforms: list[npt.NDArray[np.float32]], language: str | None = None) -> list[str]:
69-
return list(map(self._decode_tokens, self._decoding(*self._encode(*self._preprocessor(*pad_list(waveforms))))))
77+
def _recognize_batch(
78+
self, waveforms: list[npt.NDArray[np.float32]], sample_rate: SampleRates, language: str | None = None
79+
) -> list[str]:
80+
return list(
81+
map(self._decode_tokens, self._decoding(*self._encode(*self._preprocessor(*pad_list(waveforms), sample_rate))))
82+
)
7083

7184

7285
class _AsrWithCtcDecoding(_AsrWithDecoding):

src/onnx_asr/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _download_model(model: str, files: list[str]) -> str:
146146
def load_model(
147147
model: str | ModelNames | ModelTypes,
148148
path: str | Path | None = None,
149+
*,
149150
quantization: str | None = None,
150151
providers: Sequence[str | tuple[str, dict]] | None = None,
151152
) -> Asr:

src/onnx_asr/models/whisper.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from onnx_asr.asr import Asr
1313
from onnx_asr.preprocessors.preprocessor import Preprocessor
14-
from onnx_asr.utils import pad_list
14+
from onnx_asr.utils import SampleRates, pad_list
1515

1616

1717
@typing.no_type_check
@@ -66,8 +66,8 @@ def _get_model_files(quantization: str | None = None) -> dict[str, str]:
6666
"added_tokens": "added_tokens.json",
6767
}
6868

69-
def _preprocess(self, waveforms: list[npt.NDArray[np.float32]]) -> npt.NDArray[np.float32]:
70-
input_features, _ = self._preprocessor(*pad_list(waveforms))
69+
def _preprocess(self, waveforms: list[npt.NDArray[np.float32]], sample_rate: SampleRates) -> npt.NDArray[np.float32]:
70+
input_features, _ = self._preprocessor(*pad_list(waveforms), sample_rate)
7171
return input_features
7272

7373
@abstractmethod
@@ -77,8 +77,10 @@ def _decode_tokens(self, tokens: npt.NDArray) -> str:
7777
text = "".join(token for id in tokens if (token := self._vocab[id]) and not token.startswith("<|"))
7878
return bytearray([self._byte_decoder[c] for c in text]).decode("utf-8", errors="replace").removeprefix(" ")
7979

80-
def _recognize_batch(self, waveforms: list[npt.NDArray[np.float32]], language: str | None = None) -> list[str]:
81-
input_features = self._preprocess(waveforms)
80+
def _recognize_batch(
81+
self, waveforms: list[npt.NDArray[np.float32]], sample_rate: SampleRates, language: str | None = None
82+
) -> list[str]:
83+
input_features = self._preprocess(waveforms, sample_rate)
8284
input_tokens = np.repeat(self._decoder_input, len(waveforms), axis=0)
8385

8486
if language:
@@ -149,8 +151,8 @@ def _get_model_files(quantization: str | None = None) -> dict[str, str]:
149151
"decoder": f"**/decoder_model{suffix}.onnx",
150152
} | _Whisper._get_model_files(suffix)
151153

152-
def _preprocess(self, waveforms: list[npt.NDArray[np.float32]]) -> npt.NDArray[np.float32]:
153-
input_features = super()._preprocess(waveforms)
154+
def _preprocess(self, waveforms: list[npt.NDArray[np.float32]], sample_rate: SampleRates) -> npt.NDArray[np.float32]:
155+
input_features = super()._preprocess(waveforms, sample_rate)
154156
(last_hidden_state,) = self._encoder.run(
155157
["last_hidden_state"],
156158
{"input_features": input_features},
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""ASR preprocessor implementations."""
22

33
from .preprocessor import Preprocessor
4+
from .resampler import Resampler
45

5-
__all__ = ["Preprocessor"]
6+
__all__ = ["Preprocessor", "Resampler"]

src/onnx_asr/preprocessors/preprocessor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
import numpy.typing as npt
99
import onnxruntime as rt
1010

11+
from onnx_asr.utils import SampleRates
12+
13+
from .resampler import Resampler
14+
1115

1216
class Preprocessor:
1317
"""ASR preprocessor implementation."""
@@ -22,11 +26,15 @@ def __init__(self, name: str, **kwargs: Any):
2226
"""
2327
filename = str(Path(name).with_suffix(".onnx"))
2428
self._preprocessor = rt.InferenceSession(files(__package__).joinpath(filename).read_bytes(), **kwargs)
29+
self._resampler = Resampler(**kwargs)
2530

2631
def __call__(
27-
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64]
32+
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64], sample_rate: SampleRates = 16_000
2833
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
2934
"""Convert waveforms to model features."""
35+
if sample_rate != 16_000:
36+
waveforms, waveforms_lens = self._resampler(waveforms, waveforms_lens, sample_rate)
37+
3038
features, features_lens = self._preprocessor.run(
3139
["features", "features_lens"], {"waveforms": waveforms, "waveforms_lens": waveforms_lens}
3240
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Waveform resampler implementations."""
2+
3+
from importlib.resources import files
4+
from typing import Any
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
import onnxruntime as rt
9+
10+
from onnx_asr.utils import SampleRates
11+
12+
13+
class Resampler:
14+
"""Waveform resampler to 16 kHz implementation."""
15+
16+
def __init__(self, **kwargs: Any):
17+
"""Create waveform resampler.
18+
19+
Args:
20+
kwargs: Additional parameters for onnxruntime.InferenceSession.
21+
22+
"""
23+
self._preprocessor = rt.InferenceSession(files(__package__).joinpath("resample.onnx").read_bytes(), **kwargs)
24+
25+
def __call__(
26+
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64], sample_rate: SampleRates
27+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
28+
"""Resample waveform to 16 kHz."""
29+
resampled, resampled_lens = self._preprocessor.run(
30+
["resampled", "resampled_lens"],
31+
{"waveforms": waveforms, "waveforms_lens": waveforms_lens, "sample_rate": [sample_rate]},
32+
)
33+
return resampled, resampled_lens

src/onnx_asr/utils.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
"""Utils for ASR."""
22

33
import wave
4+
from typing import Literal, TypeGuard, get_args
45

56
import numpy as np
67
import numpy.typing as npt
78

9+
SampleRates = Literal[8_000, 16_000, 22_050, 44_100, 48_000]
10+
11+
12+
def is_supported_sample_rate(sample_rate: int) -> TypeGuard[SampleRates]:
13+
"""Sample rate is supported."""
14+
return sample_rate in get_args(SampleRates)
15+
816

917
class SupportedOnlyMonoAudioError(ValueError):
1018
"""Supported only mono audio error."""
@@ -19,7 +27,15 @@ class WrongSampleRateError(ValueError):
1927

2028
def __init__(self) -> None:
2129
"""Create error."""
22-
super().__init__("Supported only 16 kHz sample rate.")
30+
super().__init__("Supported only 8, 16, 22.05, 44.1 and 48 kHz sample rate.")
31+
32+
33+
class DifferentSampleRatesError(ValueError):
34+
"""Different sample rates error."""
35+
36+
def __init__(self) -> None:
37+
"""Create error."""
38+
super().__init__("All sample rates in a batch must be the same.")
2339

2440

2541
def read_wav(filename: str) -> tuple[npt.NDArray[np.float32], int]:
@@ -41,23 +57,31 @@ def read_wav(filename: str) -> tuple[npt.NDArray[np.float32], int]:
4157
return buffer.reshape(f.getnframes(), f.getnchannels()).astype(np.float32) / max_value - zero_value, f.getframerate()
4258

4359

44-
def read_wav_files(waveforms: list[npt.NDArray[np.float32] | str]) -> list[npt.NDArray[np.float32]]:
60+
def read_wav_files(
61+
waveforms: list[npt.NDArray[np.float32] | str], numpy_sample_rate: SampleRates
62+
) -> tuple[list[npt.NDArray[np.float32]], SampleRates]:
4563
"""Convert list of waveform or filenames to list of waveforms."""
4664
results = []
65+
sample_rates = []
4766
for x in waveforms:
4867
if isinstance(x, str):
4968
waveform, sample_rate = read_wav(x)
50-
if sample_rate != 16_000:
51-
raise WrongSampleRateError()
5269
if waveform.shape[1] != 1:
5370
raise SupportedOnlyMonoAudioError()
5471
results.append(waveform[:, 0])
72+
sample_rates.append(sample_rate)
5573
else:
5674
if x.ndim != 1:
5775
raise SupportedOnlyMonoAudioError()
5876
results.append(x)
77+
sample_rates.append(numpy_sample_rate)
78+
79+
if len(set(sample_rates)) > 1:
80+
raise DifferentSampleRatesError()
5981

60-
return results
82+
if is_supported_sample_rate(sample_rates[0]):
83+
return results, sample_rates[0]
84+
raise WrongSampleRateError()
6185

6286

6387
def pad_list(arrays: list[npt.NDArray[np.float32]], axis: int = 0) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:

0 commit comments

Comments
 (0)