Skip to content

Commit 76bc3f5

Browse files
committed
Add NumPy preprocessors
1 parent f6238b7 commit 76bc3f5

11 files changed

Lines changed: 301 additions & 20 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ wheels/
1616
site/
1717
.cache/
1818

19-
# ML models
19+
# ML models and data
2020
*.onnx
21+
*.npz
2122

2223
/models/
2324
/*.ipynb

hatch_build.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ def initialize(self, version: str, build_data: dict[str, Any]) -> None:
1818

1919
self.artifacts_path.mkdir(exist_ok=True)
2020
build(self.artifacts_path, self.metadata.version)
21-
build_data["artifacts"] = [str(self.artifacts_path.joinpath("*.onnx"))]
21+
build_data["artifacts"] = [
22+
str(self.artifacts_path.joinpath("*.onnx")),
23+
str(self.artifacts_path.joinpath("*.npz")),
24+
]
2225

2326
def dependencies(self) -> list[str]:
2427
return self.metadata.config["dependency-groups"]["build"] # type: ignore[no-any-return]

preprocessors/build.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
"""Build ONNX preprocessors."""
1+
"""Build ONNX and NumPy preprocessors."""
22

33
from pathlib import Path
44

5+
import numpy as np
56
import onnx
67
import onnxscript
78

@@ -52,6 +53,21 @@ def save_resampler_models(preprocessors_dir: Path, version: str) -> None:
5253
)
5354

5455

56+
def save_fbanks(preprocessors_dir: Path) -> None:
57+
fbanks = {
58+
"gigaam_v2": gigaam.melscale_fbanks_v2,
59+
"gigaam_v3": gigaam.melscale_fbanks_v3,
60+
"gigaam_v3_window": gigaam.hann_window_v3,
61+
"kaldi": kaldi.mel_banks,
62+
"nemo80": nemo.melscale_fbanks80,
63+
"nemo128": nemo.melscale_fbanks128,
64+
"whisper80": whisper.melscale_fbanks80,
65+
"whisper128": whisper.melscale_fbanks128,
66+
}
67+
np.savez_compressed(Path(preprocessors_dir, "fbanks"), allow_pickle=False, **fbanks)
68+
69+
5570
def build(preprocessors_dir: Path, version: str) -> None:
5671
save_preprocessor_models(preprocessors_dir, version)
5772
save_resampler_models(preprocessors_dir, version)
73+
save_fbanks(preprocessors_dir)

src/onnx_asr/loader.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
from onnx_asr.models.tone import TOneCtc
2020
from onnx_asr.models.whisper import WhisperHf, WhisperOrt
2121
from onnx_asr.onnx import OnnxSessionOptions, get_onnx_providers, update_onnx_providers
22+
from onnx_asr.preprocessors.numpy_preprocessor import (
23+
GigaamPreprocessorNumpy,
24+
KaldiPreprocessorNumpy,
25+
NemoPreprocessorNumpy,
26+
WhisperPreprocessorNumpy,
27+
)
2228
from onnx_asr.preprocessors.preprocessor import ConcurrentPreprocessor, IdentityPreprocessor, OnnxPreprocessor
2329
from onnx_asr.preprocessors.resampler import Resampler
2430
from onnx_asr.utils import (
@@ -138,8 +144,8 @@ def _download_config(self, *, local_files_only: bool) -> Path:
138144

139145
assert self.repo_id is not None
140146
return Path(
141-
hf_hub_download(self.repo_id, "config.json", local_dir=self.local_dir, local_files_only=local_files_only)
142-
) # nosec
147+
hf_hub_download(self.repo_id, "config.json", local_dir=self.local_dir, local_files_only=local_files_only) # nosec
148+
)
143149

144150
def _download_model(self, quantization: str | None, *, local_files_only: bool) -> Path:
145151
from huggingface_hub import snapshot_download # noqa: PLC0415
@@ -266,11 +272,25 @@ def create_preprocessor(name: str) -> Preprocessor:
266272
return IdentityPreprocessor()
267273

268274
providers = get_onnx_providers(preprocessor_config)
269-
if name == "kaldi" and providers and providers != ["CPUExecutionProvider"]:
270-
name = "kaldi_fast"
271-
272275
max_concurrent_workers = preprocessor_config.pop("max_concurrent_workers", 1)
273-
preprocessor = OnnxPreprocessor(name, preprocessor_config)
276+
277+
preprocessor: Preprocessor
278+
if not providers or providers == ["CPUExecutionProvider"]:
279+
if name.startswith("gigaam"):
280+
preprocessor = GigaamPreprocessorNumpy(name)
281+
elif name == "kaldi":
282+
preprocessor = KaldiPreprocessorNumpy(name)
283+
elif name.startswith("nemo"):
284+
preprocessor = NemoPreprocessorNumpy(name)
285+
elif name.startswith("whisper"):
286+
preprocessor = WhisperPreprocessorNumpy(name)
287+
else:
288+
raise ModelNotSupportedError(name)
289+
else:
290+
if name == "kaldi" and providers and providers != ["CPUExecutionProvider"]:
291+
name = "kaldi_fast"
292+
preprocessor = OnnxPreprocessor(name, preprocessor_config)
293+
274294
if max_concurrent_workers == 1:
275295
return preprocessor
276296
return ConcurrentPreprocessor(preprocessor, max_concurrent_workers)
@@ -360,9 +380,11 @@ def load_model(
360380

361381
loader = AsrLoader(model, path)
362382

363-
default_onnx_config: OnnxSessionOptions = {
383+
default_onnx_config = update_onnx_providers(
384+
{"providers": rt.get_available_providers()}, excluded_providers=["AzureExecutionProvider"]
385+
) | {
364386
"sess_options": sess_options,
365-
"providers": providers or rt.get_available_providers(),
387+
"providers": providers,
366388
"provider_options": provider_options,
367389
}
368390

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
"""ASR preprocessor implementations in NumPy."""
2+
3+
from __future__ import annotations
4+
5+
from importlib.resources import as_file, files
6+
7+
import numpy as np
8+
import numpy.typing as npt
9+
10+
import onnx_asr.preprocessors
11+
12+
13+
class _NumpyPreprocessor:
14+
def __init__(self, name: str):
15+
"""Create preprocessor.
16+
17+
Args:
18+
name: Preprocessor name.
19+
20+
"""
21+
with (
22+
as_file(files(onnx_asr.preprocessors).joinpath("data").joinpath("fbanks.npz")) as file,
23+
np.load(file) as data,
24+
):
25+
self._melscale_fbanks = data[name]
26+
if name == "gigaam_v3":
27+
self._window = data["gigaam_v3_window"]
28+
29+
30+
class GigaamPreprocessorNumpy(_NumpyPreprocessor):
31+
"""GigaAM preprocessor implementation in NumPy."""
32+
33+
_sample_rate = 16_000
34+
_hop_length = _sample_rate // 100
35+
_clamp_min = 1e-9
36+
_clamp_max = 1e9
37+
38+
def __init__(self, name: str): # noqa: D107
39+
assert name in ("gigaam_v2", "gigaam_v3")
40+
super().__init__(name)
41+
self._v2 = name == "gigaam_v2"
42+
self._n_fft = self._sample_rate // (40 if self._v2 else 50)
43+
self._win_length = self._n_fft
44+
if self._v2:
45+
self._window = np.hanning(self._win_length + 1)[:-1]
46+
47+
def __call__(
48+
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64]
49+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
50+
"""Convert waveforms to model features."""
51+
if self._v2:
52+
waveforms = np.pad(waveforms, ((0, 0), (self._n_fft // 2, self._n_fft // 2)), mode="reflect")
53+
54+
strided_input = np.lib.stride_tricks.sliding_window_view(waveforms, self._win_length, axis=1)[
55+
:, :: self._hop_length
56+
]
57+
strided_input = strided_input * self._window
58+
spectrum = np.abs(np.fft.rfft(strided_input, self._n_fft)) ** 2
59+
60+
mel_energies = np.matmul(spectrum, self._melscale_fbanks).astype(np.float32)
61+
62+
return np.log(np.clip(mel_energies, self._clamp_min, self._clamp_max)).transpose(0, 2, 1), (
63+
waveforms_lens - (0 if self._v2 else self._win_length)
64+
) // self._hop_length + 1
65+
66+
67+
class KaldiPreprocessorNumpy(_NumpyPreprocessor):
68+
"""Kaldi preprocessor implementation with NumPy."""
69+
70+
_n_fft = 512
71+
_win_length = 400
72+
_hop_length = 160
73+
_snip_edges = False
74+
_dither = 0.0
75+
_remove_dc_offset = True
76+
_preemphasis_coefficient = 0.97
77+
_float_eps = float(np.finfo(np.float32).eps)
78+
79+
def __init__(self, name: str): # noqa: D107
80+
assert name == "kaldi"
81+
super().__init__(name)
82+
83+
def _symmetric_pad(
84+
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64]
85+
) -> npt.NDArray[np.float32]:
86+
pad_left = self._win_length // 2 - self._hop_length // 2
87+
pad_right = self._win_length // 2
88+
res = np.pad(waveforms, ((0, 0), (pad_left, pad_right)), mode="symmetric")
89+
if waveforms.shape[0] == 1:
90+
return res
91+
92+
for i in range(waveforms.shape[0]):
93+
tail = res[i, pad_left + waveforms_lens[i] :]
94+
tail[:pad_right] = waveforms[i, waveforms_lens[i] - pad_right : waveforms_lens[i]][::-1]
95+
tail[pad_right:] = 0
96+
return res
97+
98+
def __call__(
99+
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64]
100+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
101+
"""Convert waveforms to model features."""
102+
if not self._snip_edges:
103+
waveforms = self._symmetric_pad(waveforms, waveforms_lens)
104+
features_lens = (waveforms_lens + self._hop_length // 2) // self._hop_length
105+
else:
106+
features_lens = 1 + (waveforms_lens - self._win_length) // self._hop_length
107+
108+
strided_input = np.lib.stride_tricks.sliding_window_view(waveforms, self._win_length, axis=1)[
109+
:, :: self._hop_length
110+
]
111+
112+
if self._dither != 0.0:
113+
rng = np.random.default_rng()
114+
strided_input = strided_input + self._dither * rng.standard_normal(strided_input.shape).astype(np.float32)
115+
116+
if self._remove_dc_offset:
117+
strided_input = strided_input - np.mean(strided_input, axis=-1, keepdims=True)
118+
119+
if self._preemphasis_coefficient != 0.0:
120+
offset_strided_input = np.pad(strided_input, ((0, 0), (0, 0), (1, 0)), mode="edge")
121+
strided_input = strided_input - self._preemphasis_coefficient * offset_strided_input[..., :-1]
122+
123+
strided_input = strided_input * np.pow(np.hanning(self._win_length), 0.85)
124+
spectrum = np.abs(np.fft.rfft(strided_input, self._n_fft)) ** 2
125+
mel_energies = np.matmul(spectrum, self._melscale_fbanks).astype(np.float32)
126+
127+
features = np.log(np.maximum(mel_energies, np.finfo(np.float32).eps))
128+
if features.shape[0] > 0:
129+
features[np.arange(features.shape[1]) >= features_lens[:, None]] = 0
130+
131+
return features, features_lens
132+
133+
134+
class NemoPreprocessorNumpy(_NumpyPreprocessor):
135+
"""Nemo preprocessor implementation with NumPy."""
136+
137+
_n_fft = 512
138+
_win_length = 400
139+
_hop_length = 160
140+
_preemph = 0.97
141+
_log_zero_guard_value = float(2**-24)
142+
143+
def __init__(self, name: str): # noqa: D107
144+
assert name.startswith("nemo")
145+
super().__init__(name)
146+
147+
def __call__(
148+
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64]
149+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
150+
"""Convert waveforms to model features."""
151+
if self._preemph != 0.0:
152+
waveforms = waveforms - self._preemph * np.pad(waveforms, ((0, 0), (1, 0)))[:, :-1]
153+
waveforms[np.arange(waveforms.shape[-1]) >= waveforms_lens[:, None]] = 0
154+
155+
waveforms = np.pad(waveforms, ((0, 0), (self._n_fft // 2, self._n_fft // 2)))
156+
strided_input = np.lib.stride_tricks.sliding_window_view(waveforms, self._n_fft, axis=1)[:, :: self._hop_length]
157+
strided_input = strided_input * np.pad(
158+
np.hanning(self._win_length), ((self._n_fft - self._win_length) // 2, (self._n_fft - self._win_length) // 2)
159+
)
160+
spectrogram = np.abs(np.fft.rfft(strided_input, self._n_fft)) ** 2
161+
mel_spectrogram = np.matmul(spectrogram, self._melscale_fbanks)
162+
log_mel_spectrogram = np.log(mel_spectrogram + self._log_zero_guard_value)
163+
164+
features_lens = waveforms_lens // self._hop_length
165+
mask = np.arange(log_mel_spectrogram.shape[1])[None, :, None] < features_lens[:, None, None]
166+
zero = np.float32(0)
167+
mean = np.where(mask, log_mel_spectrogram, zero).sum(axis=1, keepdims=True) / features_lens[:, None, None]
168+
var = np.where(mask, (log_mel_spectrogram - mean) ** 2, zero).sum(axis=1, keepdims=True) / (
169+
features_lens[:, None, None] - 1
170+
)
171+
features = np.where(mask, (log_mel_spectrogram - mean) / (np.sqrt(var) + 1e-5), zero)
172+
return features.transpose(0, 2, 1).astype(np.float32), features_lens
173+
174+
175+
class WhisperPreprocessorNumpy(_NumpyPreprocessor):
176+
"""Whisper preprocessor implementation with NumPy."""
177+
178+
_sample_rate = 16_000
179+
_chunk_length = 30
180+
_n_fft = 400
181+
_win_length = 400
182+
_hop_length = 160
183+
_clamp_min = 1e-10
184+
185+
def __init__(self, name: str): # noqa: D107
186+
assert name.startswith("whisper")
187+
super().__init__(name)
188+
189+
def __call__(
190+
self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64]
191+
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
192+
"""Convert waveforms to model features."""
193+
waveforms = waveforms[:, : self._chunk_length * self._sample_rate]
194+
waveforms = np.pad(waveforms, ((0, 0), (0, self._chunk_length * self._sample_rate - waveforms.shape[-1])))
195+
waveforms = np.pad(waveforms, ((0, 0), (self._n_fft // 2, self._n_fft // 2)), mode="reflect")
196+
197+
strided_input = np.lib.stride_tricks.sliding_window_view(waveforms, self._win_length, axis=1)[
198+
:, :: self._hop_length
199+
]
200+
strided_input = strided_input * np.hanning(self._win_length + 1)[:-1]
201+
spectrum = np.abs(np.fft.rfft(strided_input, self._n_fft)[:, :-1]) ** 2
202+
203+
mel_spectrogram = np.matmul(spectrum, self._melscale_fbanks).astype(np.float32)
204+
log_mel_spectrogram = np.log10(np.maximum(mel_spectrogram, self._clamp_min))
205+
features = (np.maximum(log_mel_spectrogram, log_mel_spectrogram.max() - 8.0) + 4.0) / 4.0
206+
return features.transpose(0, 2, 1), np.full_like(
207+
waveforms_lens, self._chunk_length * self._sample_rate // self._hop_length
208+
)

src/onnx_asr/preprocessors/preprocessor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def __call__(
2525

2626

2727
class OnnxPreprocessor:
28-
"""ASR preprocessor implementation."""
28+
"""ONNX preprocessor implementation."""
2929

3030
def __init__(self, name: str, onnx_options: OnnxSessionOptions):
31-
"""Create ASR preprocessor.
31+
"""Create preprocessor.
3232
3333
Args:
3434
name: Preprocessor name.
@@ -61,7 +61,7 @@ def __call__(
6161

6262

6363
class ConcurrentPreprocessor:
64-
"""Concurrent ASR preprocessor implementation."""
64+
"""Concurrent preprocessor implementation."""
6565

6666
def __init__(self, preprocessor: Preprocessor, max_concurrent_workers: int | None = None):
6767
"""Create preprocessor.

tests/preprocessors/test_build.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22

3+
import numpy as np
34
import onnx
45

56
from preprocessors import build
@@ -8,6 +9,7 @@
89
def test_build(tmp_path: Path):
910
build.build(tmp_path, "tests")
1011
assert len(list(tmp_path.glob("*.onnx"))) == 22
12+
assert len(list(tmp_path.glob("*.npz"))) == 1
1113

1214

1315
def test_save_preprocessor_models(tmp_path: Path):
@@ -44,3 +46,12 @@ def test_save_resampler_models(tmp_path: Path):
4446
assert len(model.graph.output) == 2
4547
assert model.graph.output[0].name == "resampled"
4648
assert model.graph.output[1].name == "resampled_lens"
49+
50+
51+
def test_save_fbanks(tmp_path: Path):
52+
build.save_fbanks(tmp_path)
53+
filename = Path(tmp_path, "fbanks.npz")
54+
55+
assert filename.exists()
56+
with np.load(filename) as data:
57+
assert len(data.keys()) == 8

0 commit comments

Comments
 (0)