diff --git a/preprocessors/build.py b/preprocessors/build.py index c9bb04b..3f10878 100644 --- a/preprocessors/build.py +++ b/preprocessors/build.py @@ -34,6 +34,13 @@ def save_preprocessor_models(preprocessors_dir: Path, version: str) -> None: "nemo128.onnx": nemo.NemoPreprocessor128, "whisper80.onnx": whisper.WhisperPreprocessor80, "whisper128.onnx": whisper.WhisperPreprocessor128, + "gigaam_v2_conv.onnx": gigaam.GigaamPreprocessorV2Conv, + "gigaam_v3_conv.onnx": gigaam.GigaamPreprocessorV3Conv, + "kaldi_conv.onnx": kaldi.KaldiPreprocessorFastConv, + "nemo80_conv.onnx": nemo.NemoPreprocessor80Conv, + "nemo128_conv.onnx": nemo.NemoPreprocessor128Conv, + "whisper80_conv.onnx": whisper.WhisperPreprocessor80Conv, + "whisper128_conv.onnx": whisper.WhisperPreprocessor128Conv, } for filename, model in preprocessors.items(): save_onnx(model, preprocessors_dir.joinpath(filename), version) diff --git a/preprocessors/gigaam.py b/preprocessors/gigaam.py index ced63da..80b3831 100644 --- a/preprocessors/gigaam.py +++ b/preprocessors/gigaam.py @@ -6,6 +6,7 @@ from onnxscript import opset17 as op from preprocessors.fbanks import melscale_fbanks +from preprocessors.stft import conv_power_spectrogram, stft_conv_weights sample_rate = 16_000 n_fft_v2 = sample_rate // 40 @@ -27,6 +28,9 @@ ) hann_window_v3 = np.hanning(win_length_v3 + 1)[:-1].astype(bfloat16).astype(np.float32) +stft_conv_weights_v2 = stft_conv_weights(np.hanning(win_length_v2 + 1)[:-1].astype(np.float32)) +stft_conv_weights_v3 = stft_conv_weights(hann_window_v3) + @script(doc_string="LogMelSpectrogram feature extractor for GigaAM v2 models") def GigaamPreprocessorV2( @@ -65,3 +69,39 @@ def GigaamPreprocessorV3( features_lens = (waveforms_lens - win_length_v3) / hop_length + 1 features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]) return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v2 models (Conv-based STFT)") +def GigaamPreprocessorV2Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]: + waveforms = op.Pad( + waveforms, + pads=op.Constant(value=[0, n_fft_v2 // 2, 0, n_fft_v2 // 2]), + mode="reflect", + ) + + spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_v2) + + mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks_v2) + log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max)) + + features_lens = waveforms_lens / hop_length + 1 + features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]) + return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v3 models (Conv-based STFT)") +def GigaamPreprocessorV3Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]: + spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_v3) + + mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks_v3) + log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max)) + + features_lens = (waveforms_lens - win_length_v3) / hop_length + 1 + features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]) + return features, features_lens diff --git a/preprocessors/kaldi.py b/preprocessors/kaldi.py index 3bb8e7a..ea433f4 100644 --- a/preprocessors/kaldi.py +++ b/preprocessors/kaldi.py @@ -5,6 +5,7 @@ from onnxscript import opset17 as op from preprocessors.fbanks import melscale_fbanks +from preprocessors.stft import conv_power_spectrogram, stft_conv_weights sample_rate = 16_000 n_fft = 512 @@ -29,6 +30,9 @@ np.float32 ) wespeaker_window = np.hamming(win_length).astype(np.float32) +stft_conv_weights_kaldi = stft_conv_weights( + np.pad(np.hanning(win_length) ** 0.85, (0, n_fft - win_length)).astype(np.float32) +) @script() @@ -75,13 +79,11 @@ def sliding_buffer(prev: FLOAT["batch_size", win_length - hop_length], curr: FLO @script() def calc_features( - image: FLOAT["batch_size", "T", n_fft // 2 + 1, 2], + spectrogram: FLOAT["batch_size", "T", n_fft // 2 + 1], waveforms_lens: INT64["batch_size"], mel_banks: FLOAT[n_fft // 2 + 1, 2, num_mel_bins], snip_edges: bool, ): - spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0) - mel_spectrogram = op.MatMul(spectrogram, mel_banks) log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, min=float_eps)) @@ -119,7 +121,8 @@ def preprocessor( frames = frames - preemphasis_coefficient * op.Pad(frames, pads=[0, 0, 1, 0, 0, -1], mode="edge") image = op.DFT(op.Unsqueeze(window * frames, axes=[-1]), n_fft, axis=-2, onesided=1) - features, features_lens = calc_features(image, waveforms_lens, mel_banks, snip_edges) + spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0) + features, features_lens = calc_features(spectrogram, waveforms_lens, mel_banks, snip_edges) return features, features_lens @@ -155,8 +158,32 @@ def KaldiPreprocessorFast( pads=op.Constant(value=[0, n_fft - win_length]), ) image = op.STFT(waveforms, hop_length, povey_window) + spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0) + + features, features_lens = calc_features(spectrogram, waveforms_lens, kaldi_mel_banks, snip_edges=snip_edges) + return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for Kaldi models (Conv-based STFT)") +def KaldiPreprocessorFastConv( + waveforms: FLOAT["batch_size", "N"], waveforms_lens: INT64["batch_size"] +) -> tuple[FLOAT["batch_size", "T", num_mel_bins], INT64["batch_size"]]: + if dither != 0.0: + waveforms = waveforms + op.RandomNormalLike(waveforms, scale=dither) + + if remove_dc_offset: + waveforms = waveforms - op.ReduceMean(waveforms, axes=[-1]) + + if not snip_edges: + waveforms = symmetric_pad(waveforms, waveforms_lens) + + if preemphasis_coefficient != 0.0: + waveforms = waveforms - preemphasis_coefficient * op.Pad(waveforms, pads=[0, 1, 0, -1], mode="edge") + + waveforms = op.Pad(waveforms, pads=op.Constant(value=[0, 0, 0, n_fft - win_length])) + spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_kaldi) - features, features_lens = calc_features(image, waveforms_lens, kaldi_mel_banks, snip_edges=snip_edges) + features, features_lens = calc_features(spectrogram, waveforms_lens, kaldi_mel_banks, snip_edges=snip_edges) return features, features_lens diff --git a/preprocessors/nemo.py b/preprocessors/nemo.py index 3f04b06..5bb619b 100644 --- a/preprocessors/nemo.py +++ b/preprocessors/nemo.py @@ -5,6 +5,7 @@ from onnxscript import opset17 as op from preprocessors.fbanks import melscale_fbanks +from preprocessors.stft import conv_power_spectrogram, stft_conv_weights sample_rate = 16_000 n_fft = 512 @@ -20,6 +21,9 @@ melscale_fbanks128 = melscale_fbanks(n_fft // 2 + 1, 0, sample_rate // 2, 128, sample_rate, "slaney", "slaney").astype( np.float32 ) +stft_conv_weights_nemo = stft_conv_weights( + np.pad(np.hanning(win_length), (n_fft // 2 - win_length // 2, n_fft // 2 - win_length // 2)).astype(np.float32) +) @script() @@ -87,3 +91,58 @@ def NemoPreprocessor128( melscale_fbanks128, ) return features, features_lens + + +@script() +def nemo_preprocessor_conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], + melscale_fbanks: FLOAT[n_fft // 2 + 1, "M"], + conv_weights: FLOAT["channels", 1, n_fft], +): + if preemph != 0.0: + timemask = op.Range(0, op.Squeeze(op.Shape(waveforms, start=1, end=2)), 1) < op.Unsqueeze( + waveforms_lens, axes=[1] + ) + waveforms = op.Concat(waveforms[:, :1], waveforms[:, 1:] - preemph * waveforms[:, :-1], axis=-1) + waveforms = op.Where(timemask, waveforms, 0.0) + + waveforms = op.Pad( + waveforms, + pads=op.Constant(value=[0, n_fft // 2, 0, n_fft // 2]), + ) + spectrogram = conv_power_spectrogram(waveforms, conv_weights) + + mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks) + log_mel_spectrogram = op.Log(mel_spectrogram + log_zero_guard_value) + + features_lens = waveforms_lens / hop_length + return normalize(op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]), features_lens), features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for Nemo models (Conv-based STFT)", default_opset=op) +def NemoPreprocessor80Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", 80, "T"], INT64["batch_size"]]: + features, features_lens = nemo_preprocessor_conv( + waveforms, + waveforms_lens, + melscale_fbanks80, + stft_conv_weights_nemo, + ) + return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for Nemo models (Conv-based STFT)", default_opset=op) +def NemoPreprocessor128Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", 128, "T"], INT64["batch_size"]]: + features, features_lens = nemo_preprocessor_conv( + waveforms, + waveforms_lens, + melscale_fbanks128, + stft_conv_weights_nemo, + ) + return features, features_lens diff --git a/preprocessors/stft.py b/preprocessors/stft.py new file mode 100644 index 0000000..c8be6e6 --- /dev/null +++ b/preprocessors/stft.py @@ -0,0 +1,50 @@ +"""STFT computed as a fixed 1d convolution. + +op.STFT is not supported by the TensorRT and CUDA execution providers, so any +preprocessor graph that uses it falls back to CPU. The discrete Fourier +transform can instead be written as a 1d convolution with a fixed kernel (the +cos/sin Fourier basis multiplied by the analysis window). The convolution and +the ops around it (Reshape, ReduceSumSquare, Transpose) are supported by every +execution provider, so the resulting graph can run fully on GPU / TensorRT. +""" + +import numpy as np +import numpy.typing as npt +from onnxscript import FLOAT, script +from onnxscript import opset17 as op + +hop_length = 160 + + +def stft_conv_weights(window: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + """Build Conv weights that compute a windowed DFT. + + Args: + window: Analysis window, already zero-padded to the FFT size by the + caller (its length is used as ``n_fft``). + + Returns: + Kernel of shape ``[2 * (n_fft // 2 + 1), 1, n_fft]`` stacking the real + (cos) and imaginary (-sin) parts of the Fourier basis. Used with Conv + (stride = ``hop_length``) it reproduces ``op.STFT``. + + """ + n_fft = window.shape[0] + indices = np.arange(n_fft // 2 + 1)[:, np.newaxis] * np.arange(n_fft)[np.newaxis, :] + angle = 2 * np.pi * indices / n_fft + basis = np.concatenate([np.cos(angle), -np.sin(angle)]) * window + return basis[:, np.newaxis, :].astype(np.float32) + + +@script() +def conv_power_spectrogram(waveforms: FLOAT["batch_size", "N"], conv_weights: FLOAT["channels", 1, "n_fft"]): + """Power spectrogram [batch_size, frames, n_bins] via a Conv-based STFT. + + Drop-in replacement for ``op.STFT`` followed by ``ReduceSumSquare`` over the + real/imaginary axis. ``conv_weights`` is built by :func:`stft_conv_weights`. + """ + image = op.Conv(op.Unsqueeze(waveforms, axes=[1]), conv_weights, strides=[hop_length]) + n_bins = op.Shape(conv_weights, start=0, end=1) / 2 + shape = op.Concat(op.Constant(value=[0, 2]), n_bins, op.Constant(value=[-1]), axis=0) + spectrogram = op.ReduceSumSquare(op.Reshape(image, shape), axes=[1], keepdims=0) + return op.Transpose(spectrogram, perm=[0, 2, 1]) diff --git a/preprocessors/whisper.py b/preprocessors/whisper.py index 69ab434..5f5734f 100644 --- a/preprocessors/whisper.py +++ b/preprocessors/whisper.py @@ -5,6 +5,7 @@ from onnxscript import opset17 as op from preprocessors.fbanks import melscale_fbanks +from preprocessors.stft import conv_power_spectrogram, stft_conv_weights chunk_length = 30 sample_rate = 16_000 @@ -22,6 +23,7 @@ melscale_fbanks128 = melscale_fbanks(n_fft // 2 + 1, 0, sample_rate // 2, 128, sample_rate, "slaney", "slaney").astype( np.float32 ) +stft_conv_weights_whisper = stft_conv_weights(np.hanning(win_length + 1)[:-1].astype(np.float32)) @script() @@ -77,3 +79,59 @@ def WhisperPreprocessor128( melscale_fbanks128, ) return features, features_lens + + +@script() +def whisper_preprocessor_conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], + melscale_fbanks: FLOAT[n_fft // 2 + 1, "M"], + conv_weights: FLOAT["channels", 1, n_fft], +): + waveforms = op.Pad( + waveforms, + pads=(chunk_length * sample_rate - op.Shape(waveforms, start=1, end=2)) * op.Constant(value=[0, 0, 0, 1]), + ) + waveforms = op.Pad( + waveforms, + pads=op.Constant(value=[0, n_fft // 2, 0, n_fft // 2]), + mode="reflect", + ) + + spectrogram = conv_power_spectrogram(waveforms, conv_weights)[:, :-1] + + mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks) + log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min)) / ln10 + log_mel_spectrogram = (op.Max(log_mel_spectrogram, op.ReduceMax(log_mel_spectrogram) - 8) + 4) / 4.0 + + return op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]), op.ConstantOfShape( + op.Shape(waveforms_lens), value=features_length + ) + + +@script(doc_string="LogMelSpectrogram feature extractor for Whisper models (Conv-based STFT)", default_opset=op) +def WhisperPreprocessor80Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", 80, "T"], INT64["batch_size"]]: + features, features_lens = whisper_preprocessor_conv( + waveforms, + waveforms_lens, + melscale_fbanks80, + stft_conv_weights_whisper, + ) + return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for Whisper models (Conv-based STFT)", default_opset=op) +def WhisperPreprocessor128Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", 128, "T"], INT64["batch_size"]]: + features, features_lens = whisper_preprocessor_conv( + waveforms, + waveforms_lens, + melscale_fbanks128, + stft_conv_weights_whisper, + ) + return features, features_lens diff --git a/src/onnx_asr/loader.py b/src/onnx_asr/loader.py index 7f38f83..c5aaf2e 100644 --- a/src/onnx_asr/loader.py +++ b/src/onnx_asr/loader.py @@ -17,7 +17,7 @@ from onnx_asr.models.tone import TOneCtc from onnx_asr.models.wespeaker import WespeakerEmbeddings from onnx_asr.models.whisper import WhisperHf, WhisperOrt -from onnx_asr.onnx import OnnxSessionOptions, Provider, get_onnx_providers, update_onnx_providers +from onnx_asr.onnx import OnnxSessionOptions, Provider, TensorRtOptions, get_onnx_providers, update_onnx_providers from onnx_asr.preprocessors.numpy_preprocessor import ( GigaamPreprocessorNumpy, KaldiPreprocessorNumpy, @@ -142,6 +142,13 @@ def create_se_resolver( return Resolver(WespeakerEmbeddings, model, local_dir, offline=offline) +CONV_PREPROCESSORS = frozenset({"gigaam_v2", "gigaam_v3", "kaldi", "nemo80", "nemo128", "whisper80", "whisper128"}) +"""ONNX preprocessors that have a Conv-based STFT variant (named ``_conv``).""" + +CONV_DEFAULT_PROVIDERS = (*TensorRtOptions.get_provider_names(), "CUDAExecutionProvider") +"""Execution providers for which the Conv-based preprocessors are enabled by default.""" + + class PreprocessorRuntimeConfig(OnnxSessionOptions, total=False): """Preprocessor runtime config.""" @@ -151,6 +158,15 @@ class PreprocessorRuntimeConfig(OnnxSessionOptions, total=False): use_numpy_preprocessors: bool | None """Use NumPy preprocessors backend instead of ONNX.""" + use_conv_preprocessors: bool | None + """Use Conv-based STFT in ONNX preprocessors. + + op.STFT has no kernel in the CUDA execution provider (the STFT node falls + back to CPU with host/device copies around it). The Conv-based graph runs + natively on every provider and is faster on CPU for non-power-of-2 sizes. + None - auto (enabled when a CUDA or TensorRT execution provider is used). + """ + class Manager: """Manager for models creation.""" @@ -175,17 +191,31 @@ def __init__( } if preprocessor_config is None: - self.preprocessor_config = update_onnx_providers( - self.default_onnx_config, - new_options={"TensorrtExecutionProvider": {"trt_fp16_enable": False}}, - excluded_providers=OnnxPreprocessor._get_excluded_providers(), - ) self.preprocessor_max_workers: int | None = 1 self.use_numpy_preprocessors = None + self.use_conv_preprocessors: bool | None = None + base_config: OnnxSessionOptions = self.default_onnx_config else: self.preprocessor_max_workers = preprocessor_config.pop("max_concurrent_workers", 1) self.use_numpy_preprocessors = preprocessor_config.pop("use_numpy_preprocessors") - self.preprocessor_config = preprocessor_config + self.use_conv_preprocessors = preprocessor_config.pop("use_conv_preprocessors", None) + base_config = preprocessor_config + + if self.use_conv_preprocessors is None: + self.use_conv_preprocessors = any(p in CONV_DEFAULT_PROVIDERS for p in get_onnx_providers(base_config)) + + if preprocessor_config is None: + # op.STFT has no kernel in the CUDA execution provider, so the STFT + # preprocessors exclude it (the node would otherwise fall back to CPU + # with host/device copies). The Conv variant runs on every provider, + # so when it is used the CUDA provider is kept. + self.preprocessor_config = update_onnx_providers( + base_config, + new_options={"TensorrtExecutionProvider": {"trt_fp16_enable": False}}, + excluded_providers=[] if self.use_conv_preprocessors else OnnxPreprocessor._get_excluded_providers(), + ) + else: + self.preprocessor_config = base_config providers = get_onnx_providers(self.preprocessor_config) if self.use_numpy_preprocessors is None: @@ -213,6 +243,8 @@ def _create_preprocessor(self, name: str) -> Preprocessor: preprocessor = WhisperPreprocessorNumpy(name) else: raise ModelNotSupportedError(name) + elif self.use_conv_preprocessors and name in CONV_PREPROCESSORS: + preprocessor = OnnxPreprocessor(f"{name}_conv", self.preprocessor_config) else: preprocessor = OnnxPreprocessor(name, self.preprocessor_config) diff --git a/tests/onnx_asr/test_manager.py b/tests/onnx_asr/test_manager.py index 2328c5a..b289fc1 100644 --- a/tests/onnx_asr/test_manager.py +++ b/tests/onnx_asr/test_manager.py @@ -8,6 +8,7 @@ def test_with_cpu_provider() -> None: assert manager.default_onnx_config.get("providers") == providers assert manager.preprocessor_max_workers == 1 assert manager.use_numpy_preprocessors is True + assert manager.use_conv_preprocessors is False assert manager.preprocessor_config.get("providers") == providers assert manager.resampler_config.get("providers") == providers @@ -18,8 +19,10 @@ def test_with_cuda_provider() -> None: assert manager.default_onnx_config.get("providers") == providers assert manager.preprocessor_max_workers == 1 - assert manager.use_numpy_preprocessors is True - assert manager.preprocessor_config.get("providers") == [] + # op.STFT has no CUDA kernel, but the Conv preprocessors do run on CUDA. + assert manager.use_numpy_preprocessors is False + assert manager.use_conv_preprocessors is True + assert manager.preprocessor_config.get("providers") == providers assert manager.resampler_config.get("providers") == providers @@ -30,5 +33,20 @@ def test_with_tensorrt_provider() -> None: assert manager.default_onnx_config.get("providers") == providers assert manager.preprocessor_max_workers == 1 assert manager.use_numpy_preprocessors is False + assert manager.use_conv_preprocessors is True assert manager.preprocessor_config.get("providers") == providers assert manager.resampler_config.get("providers") == [] + + +def test_use_conv_preprocessors_override() -> None: + manager = Manager( + providers=["TensorrtExecutionProvider"], + preprocessor_config={"use_numpy_preprocessors": False, "use_conv_preprocessors": False}, + ) + assert manager.use_conv_preprocessors is False + + manager = Manager( + providers=["CPUExecutionProvider"], + preprocessor_config={"use_numpy_preprocessors": False, "use_conv_preprocessors": True}, + ) + assert manager.use_conv_preprocessors is True diff --git a/tests/onnx_asr/test_recognize.py b/tests/onnx_asr/test_recognize.py index 8a067e9..ef6897d 100644 --- a/tests/onnx_asr/test_recognize.py +++ b/tests/onnx_asr/test_recognize.py @@ -89,15 +89,19 @@ def test_recognize_batch(model: TextResultsAsrAdapter) -> None: assert all(isinstance(item, str) for item in result) +@pytest.mark.parametrize("use_conv_preprocessors", [False, True]) @pytest.mark.parametrize("max_concurrent_workers", [None, 1, 2]) @pytest.mark.parametrize("use_numpy_preprocessors", [None, True, False]) -def test_preprocessor_options(max_concurrent_workers: int | None, use_numpy_preprocessors: bool | None) -> None: +def test_preprocessor_options( + max_concurrent_workers: int | None, use_numpy_preprocessors: bool | None, use_conv_preprocessors: bool +) -> None: model = onnx_asr.load_model( "alphacep/vosk-model-small-ru", quantization="int8", preprocessor_config={ "max_concurrent_workers": max_concurrent_workers, "use_numpy_preprocessors": use_numpy_preprocessors, + "use_conv_preprocessors": use_conv_preprocessors, }, ) rng = np.random.default_rng(0) diff --git a/tests/preprocessors/test_build.py b/tests/preprocessors/test_build.py index b26de8e..ee0dfe1 100644 --- a/tests/preprocessors/test_build.py +++ b/tests/preprocessors/test_build.py @@ -8,7 +8,7 @@ def test_build(tmp_path: Path): build.build(tmp_path, "tests") - assert len(list(tmp_path.glob("*.onnx"))) == 22 + assert len(list(tmp_path.glob("*.onnx"))) == 29 assert len(list(tmp_path.glob("*.npz"))) == 1 @@ -16,7 +16,7 @@ def test_save_preprocessor_models(tmp_path: Path): build.save_preprocessor_models(tmp_path, "tests") files = list(tmp_path.glob("*.onnx")) - assert len(files) == 8 + assert len(files) == 15 for filename in files: onnx.checker.check_model(filename, full_check=True) model = onnx.load_model(filename) diff --git a/tests/preprocessors/test_gigaam.py b/tests/preprocessors/test_gigaam.py index 70dc620..d6f4b6c 100644 --- a/tests/preprocessors/test_gigaam.py +++ b/tests/preprocessors/test_gigaam.py @@ -95,7 +95,10 @@ def preprocessor_torch_v3(waveforms, lens): ) // gigaam.hop_length + 1 -@pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt"]) +@pytest.fixture( + scope="module", + params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt", "onnx_func_conv", "onnx_model_conv"], +) def preprocessor_v2(request): match request.param: case "torch": @@ -108,9 +111,16 @@ def preprocessor_v2(request): return (OnnxPreprocessor("gigaam_v2", {}), 1e-3) case "onnx_model_mt": return (ConcurrentPreprocessor(OnnxPreprocessor("gigaam_v2", {}), 2), 1e-3) + case "onnx_func_conv": + return (gigaam.GigaamPreprocessorV2Conv, 1e-3) + case "onnx_model_conv": + return (OnnxPreprocessor("gigaam_v2_conv", {}), 1e-3) -@pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt"]) +@pytest.fixture( + scope="module", + params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt", "onnx_func_conv", "onnx_model_conv"], +) def preprocessor_v3(request): match request.param: case "torch": @@ -123,6 +133,10 @@ def preprocessor_v3(request): return (OnnxPreprocessor("gigaam_v3", {}), 1e-3) case "onnx_model_mt": return (ConcurrentPreprocessor(OnnxPreprocessor("gigaam_v3", {}), 2), 1e-3) + case "onnx_func_conv": + return (gigaam.GigaamPreprocessorV3Conv, 1e-3) + case "onnx_model_conv": + return (OnnxPreprocessor("gigaam_v3_conv", {}), 1e-3) def test_gigaam_preprocessor_v2(preprocessor_v2, waveforms): diff --git a/tests/preprocessors/test_kaldi.py b/tests/preprocessors/test_kaldi.py index 8189253..68a2d30 100644 --- a/tests/preprocessors/test_kaldi.py +++ b/tests/preprocessors/test_kaldi.py @@ -83,7 +83,9 @@ def preprocessor(request): return kaldi.KaldiPreprocessor -@pytest.fixture(scope="module", params=["torch", "onnx_func", "onnx_model", "onnx_model_mt"]) +@pytest.fixture( + scope="module", params=["torch", "onnx_func", "onnx_model", "onnx_model_mt", "onnx_func_conv", "onnx_model_conv"] +) def preprocessor_fast(request): match request.param: case "torch": @@ -94,6 +96,10 @@ def preprocessor_fast(request): return OnnxPreprocessor("kaldi", {}) case "onnx_model_mt": return ConcurrentPreprocessor(OnnxPreprocessor("kaldi", {}), 2) + case "onnx_func_conv": + return kaldi.KaldiPreprocessorFastConv + case "onnx_model_conv": + return OnnxPreprocessor("kaldi_conv", {}) @pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt"]) diff --git a/tests/preprocessors/test_nemo.py b/tests/preprocessors/test_nemo.py index ac9dbfa..573e1c3 100644 --- a/tests/preprocessors/test_nemo.py +++ b/tests/preprocessors/test_nemo.py @@ -62,7 +62,10 @@ def preprocessor_torch(waveforms, lens, n_mels): return features, features_lens.numpy() -@pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt"]) +@pytest.fixture( + scope="module", + params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt", "onnx_func_conv", "onnx_model_conv"], +) def preprocessor(request, n_mels): match request.param: case "torch": @@ -75,6 +78,10 @@ def preprocessor(request, n_mels): return OnnxPreprocessor(f"nemo{n_mels}", {}) case "onnx_model_mt": return ConcurrentPreprocessor(OnnxPreprocessor(f"nemo{n_mels}", {}), 2) + case "onnx_func_conv": + return nemo.NemoPreprocessor80Conv if n_mels == 80 else nemo.NemoPreprocessor128Conv + case "onnx_model_conv": + return OnnxPreprocessor(f"nemo{n_mels}_conv", {}) def test_nemo_preprocessor(preprocessor_origin, preprocessor, waveforms): diff --git a/tests/preprocessors/test_whisper.py b/tests/preprocessors/test_whisper.py index 2a91338..e20f9fa 100644 --- a/tests/preprocessors/test_whisper.py +++ b/tests/preprocessors/test_whisper.py @@ -50,7 +50,9 @@ def preprocessor_torch(waveforms, lens, n_mels): return features.numpy(), np.full_like(lens, whisper.chunk_length * whisper.sample_rate // whisper.hop_length) -@pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model"]) +@pytest.fixture( + scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_func_conv", "onnx_model_conv"] +) def preprocessor_tol(request, n_mels): match request.param: case "torch": @@ -61,6 +63,10 @@ def preprocessor_tol(request, n_mels): return whisper.WhisperPreprocessor80 if n_mels == 80 else whisper.WhisperPreprocessor128, 5e-3 case "onnx_model": return OnnxPreprocessor(f"whisper{n_mels}", {}), 5e-3 + case "onnx_func_conv": + return (whisper.WhisperPreprocessor80Conv if n_mels == 80 else whisper.WhisperPreprocessor128Conv), 5e-3 + case "onnx_model_conv": + return OnnxPreprocessor(f"whisper{n_mels}_conv", {}), 5e-3 def test_whisper_preprocessor(n_mels, preprocessor_tol, waveforms):