|
| 1 | +"""ASR preprocessor implementations in NumPy.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from importlib.resources import as_file, files |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import numpy.typing as npt |
| 9 | + |
| 10 | +import onnx_asr.preprocessors |
| 11 | + |
| 12 | + |
| 13 | +class _NumpyPreprocessor: |
| 14 | + def __init__(self, name: str): |
| 15 | + """Create preprocessor. |
| 16 | +
|
| 17 | + Args: |
| 18 | + name: Preprocessor name. |
| 19 | +
|
| 20 | + """ |
| 21 | + with ( |
| 22 | + as_file(files(onnx_asr.preprocessors).joinpath("data").joinpath("fbanks.npz")) as file, |
| 23 | + np.load(file) as data, |
| 24 | + ): |
| 25 | + self._melscale_fbanks = data[name] |
| 26 | + if name == "gigaam_v3": |
| 27 | + self._window = data["gigaam_v3_window"] |
| 28 | + |
| 29 | + |
| 30 | +class GigaamPreprocessorNumpy(_NumpyPreprocessor): |
| 31 | + """GigaAM preprocessor implementation in NumPy.""" |
| 32 | + |
| 33 | + _sample_rate = 16_000 |
| 34 | + _hop_length = _sample_rate // 100 |
| 35 | + _clamp_min = 1e-9 |
| 36 | + _clamp_max = 1e9 |
| 37 | + |
| 38 | + def __init__(self, name: str): # noqa: D107 |
| 39 | + assert name in ("gigaam_v2", "gigaam_v3") |
| 40 | + super().__init__(name) |
| 41 | + self._v2 = name == "gigaam_v2" |
| 42 | + self._n_fft = self._sample_rate // (40 if self._v2 else 50) |
| 43 | + self._win_length = self._n_fft |
| 44 | + if self._v2: |
| 45 | + self._window = np.hanning(self._win_length + 1)[:-1] |
| 46 | + |
| 47 | + def __call__( |
| 48 | + self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64] |
| 49 | + ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: |
| 50 | + """Convert waveforms to model features.""" |
| 51 | + if self._v2: |
| 52 | + waveforms = np.pad(waveforms, ((0, 0), (self._n_fft // 2, self._n_fft // 2)), mode="reflect") |
| 53 | + |
| 54 | + strided_input = np.lib.stride_tricks.sliding_window_view(waveforms, self._win_length, axis=1)[ |
| 55 | + :, :: self._hop_length |
| 56 | + ] |
| 57 | + strided_input = strided_input * self._window |
| 58 | + spectrum = np.abs(np.fft.rfft(strided_input, self._n_fft)) ** 2 |
| 59 | + |
| 60 | + mel_energies = np.matmul(spectrum, self._melscale_fbanks).astype(np.float32) |
| 61 | + |
| 62 | + return np.log(np.clip(mel_energies, self._clamp_min, self._clamp_max)).transpose(0, 2, 1), ( |
| 63 | + waveforms_lens - (0 if self._v2 else self._win_length) |
| 64 | + ) // self._hop_length + 1 |
| 65 | + |
| 66 | + |
| 67 | +class KaldiPreprocessorNumpy(_NumpyPreprocessor): |
| 68 | + """Kaldi preprocessor implementation with NumPy.""" |
| 69 | + |
| 70 | + _n_fft = 512 |
| 71 | + _win_length = 400 |
| 72 | + _hop_length = 160 |
| 73 | + _snip_edges = False |
| 74 | + _dither = 0.0 |
| 75 | + _remove_dc_offset = True |
| 76 | + _preemphasis_coefficient = 0.97 |
| 77 | + _float_eps = float(np.finfo(np.float32).eps) |
| 78 | + |
| 79 | + def __init__(self, name: str): # noqa: D107 |
| 80 | + assert name == "kaldi" |
| 81 | + super().__init__(name) |
| 82 | + |
| 83 | + def _symmetric_pad( |
| 84 | + self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64] |
| 85 | + ) -> npt.NDArray[np.float32]: |
| 86 | + pad_left = self._win_length // 2 - self._hop_length // 2 |
| 87 | + pad_right = self._win_length // 2 |
| 88 | + res = np.pad(waveforms, ((0, 0), (pad_left, pad_right)), mode="symmetric") |
| 89 | + if waveforms.shape[0] == 1: |
| 90 | + return res |
| 91 | + |
| 92 | + for i in range(waveforms.shape[0]): |
| 93 | + tail = res[i, pad_left + waveforms_lens[i] :] |
| 94 | + tail[:pad_right] = waveforms[i, waveforms_lens[i] - pad_right : waveforms_lens[i]][::-1] |
| 95 | + tail[pad_right:] = 0 |
| 96 | + return res |
| 97 | + |
| 98 | + def __call__( |
| 99 | + self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64] |
| 100 | + ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: |
| 101 | + """Convert waveforms to model features.""" |
| 102 | + if not self._snip_edges: |
| 103 | + waveforms = self._symmetric_pad(waveforms, waveforms_lens) |
| 104 | + features_lens = (waveforms_lens + self._hop_length // 2) // self._hop_length |
| 105 | + else: |
| 106 | + features_lens = 1 + (waveforms_lens - self._win_length) // self._hop_length |
| 107 | + |
| 108 | + strided_input = np.lib.stride_tricks.sliding_window_view(waveforms, self._win_length, axis=1)[ |
| 109 | + :, :: self._hop_length |
| 110 | + ] |
| 111 | + |
| 112 | + if self._dither != 0.0: |
| 113 | + rng = np.random.default_rng() |
| 114 | + strided_input = strided_input + self._dither * rng.standard_normal(strided_input.shape).astype(np.float32) |
| 115 | + |
| 116 | + if self._remove_dc_offset: |
| 117 | + strided_input = strided_input - np.mean(strided_input, axis=-1, keepdims=True) |
| 118 | + |
| 119 | + if self._preemphasis_coefficient != 0.0: |
| 120 | + offset_strided_input = np.pad(strided_input, ((0, 0), (0, 0), (1, 0)), mode="edge") |
| 121 | + strided_input = strided_input - self._preemphasis_coefficient * offset_strided_input[..., :-1] |
| 122 | + |
| 123 | + strided_input = strided_input * np.pow(np.hanning(self._win_length), 0.85) |
| 124 | + spectrum = np.abs(np.fft.rfft(strided_input, self._n_fft)) ** 2 |
| 125 | + mel_energies = np.matmul(spectrum, self._melscale_fbanks).astype(np.float32) |
| 126 | + |
| 127 | + features = np.log(np.maximum(mel_energies, np.finfo(np.float32).eps)) |
| 128 | + if features.shape[0] > 0: |
| 129 | + features[np.arange(features.shape[1]) >= features_lens[:, None]] = 0 |
| 130 | + |
| 131 | + return features, features_lens |
| 132 | + |
| 133 | + |
| 134 | +class NemoPreprocessorNumpy(_NumpyPreprocessor): |
| 135 | + """Nemo preprocessor implementation with NumPy.""" |
| 136 | + |
| 137 | + _n_fft = 512 |
| 138 | + _win_length = 400 |
| 139 | + _hop_length = 160 |
| 140 | + _preemph = 0.97 |
| 141 | + _log_zero_guard_value = float(2**-24) |
| 142 | + |
| 143 | + def __init__(self, name: str): # noqa: D107 |
| 144 | + assert name.startswith("nemo") |
| 145 | + super().__init__(name) |
| 146 | + |
| 147 | + def __call__( |
| 148 | + self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64] |
| 149 | + ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: |
| 150 | + """Convert waveforms to model features.""" |
| 151 | + if self._preemph != 0.0: |
| 152 | + waveforms = waveforms - self._preemph * np.pad(waveforms, ((0, 0), (1, 0)))[:, :-1] |
| 153 | + waveforms[np.arange(waveforms.shape[-1]) >= waveforms_lens[:, None]] = 0 |
| 154 | + |
| 155 | + waveforms = np.pad(waveforms, ((0, 0), (self._n_fft // 2, self._n_fft // 2))) |
| 156 | + strided_input = np.lib.stride_tricks.sliding_window_view(waveforms, self._n_fft, axis=1)[:, :: self._hop_length] |
| 157 | + strided_input = strided_input * np.pad( |
| 158 | + np.hanning(self._win_length), ((self._n_fft - self._win_length) // 2, (self._n_fft - self._win_length) // 2) |
| 159 | + ) |
| 160 | + spectrogram = np.abs(np.fft.rfft(strided_input, self._n_fft)) ** 2 |
| 161 | + mel_spectrogram = np.matmul(spectrogram, self._melscale_fbanks) |
| 162 | + log_mel_spectrogram = np.log(mel_spectrogram + self._log_zero_guard_value) |
| 163 | + |
| 164 | + features_lens = waveforms_lens // self._hop_length |
| 165 | + mask = np.arange(log_mel_spectrogram.shape[1])[None, :, None] < features_lens[:, None, None] |
| 166 | + zero = np.float32(0) |
| 167 | + mean = np.where(mask, log_mel_spectrogram, zero).sum(axis=1, keepdims=True) / features_lens[:, None, None] |
| 168 | + var = np.where(mask, (log_mel_spectrogram - mean) ** 2, zero).sum(axis=1, keepdims=True) / ( |
| 169 | + features_lens[:, None, None] - 1 |
| 170 | + ) |
| 171 | + features = np.where(mask, (log_mel_spectrogram - mean) / (np.sqrt(var) + 1e-5), zero) |
| 172 | + return features.transpose(0, 2, 1).astype(np.float32), features_lens |
| 173 | + |
| 174 | + |
| 175 | +class WhisperPreprocessorNumpy(_NumpyPreprocessor): |
| 176 | + """Whisper preprocessor implementation with NumPy.""" |
| 177 | + |
| 178 | + _sample_rate = 16_000 |
| 179 | + _chunk_length = 30 |
| 180 | + _n_fft = 400 |
| 181 | + _win_length = 400 |
| 182 | + _hop_length = 160 |
| 183 | + _clamp_min = 1e-10 |
| 184 | + |
| 185 | + def __init__(self, name: str): # noqa: D107 |
| 186 | + assert name.startswith("whisper") |
| 187 | + super().__init__(name) |
| 188 | + |
| 189 | + def __call__( |
| 190 | + self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64] |
| 191 | + ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: |
| 192 | + """Convert waveforms to model features.""" |
| 193 | + waveforms = waveforms[:, : self._chunk_length * self._sample_rate] |
| 194 | + waveforms = np.pad(waveforms, ((0, 0), (0, self._chunk_length * self._sample_rate - waveforms.shape[-1]))) |
| 195 | + waveforms = np.pad(waveforms, ((0, 0), (self._n_fft // 2, self._n_fft // 2)), mode="reflect") |
| 196 | + |
| 197 | + strided_input = np.lib.stride_tricks.sliding_window_view(waveforms, self._win_length, axis=1)[ |
| 198 | + :, :: self._hop_length |
| 199 | + ] |
| 200 | + strided_input = strided_input * np.hanning(self._win_length + 1)[:-1] |
| 201 | + spectrum = np.abs(np.fft.rfft(strided_input, self._n_fft)[:, :-1]) ** 2 |
| 202 | + |
| 203 | + mel_spectrogram = np.matmul(spectrum, self._melscale_fbanks).astype(np.float32) |
| 204 | + log_mel_spectrogram = np.log10(np.maximum(mel_spectrogram, self._clamp_min)) |
| 205 | + features = (np.maximum(log_mel_spectrogram, log_mel_spectrogram.max() - 8.0) + 4.0) / 4.0 |
| 206 | + return features.transpose(0, 2, 1), np.full_like( |
| 207 | + waveforms_lens, self._chunk_length * self._sample_rate // self._hop_length |
| 208 | + ) |
0 commit comments