From bb75aef6ed0c7ea4bc26eb17e2c32f49bf2f7ef5 Mon Sep 17 00:00:00 2001 From: ivan Date: Sun, 17 May 2026 15:25:31 +0300 Subject: [PATCH 1/3] Add Conv-based STFT variants for ONNX preprocessors op.STFT has no kernel in the onnxruntime CUDA execution provider: a preprocessor graph that uses it gets split, and the STFT node runs on CPU with host/device copies around it. Accelerators such as CoreML do not support it either, and for non-power-of-2 FFT sizes it is slow on CPU. Add a shared preprocessors/stft.py helper that expresses the windowed DFT as a 1d convolution with a fixed kernel, plus Conv-based variants of every STFT-using preprocessor: gigaam_v2/v3, nemo80/128, whisper80/128 and kaldi. The new graphs use only operators with kernels on every execution provider, so they run fully on GPU; they are numerically equivalent to the STFT graphs. --- preprocessors/build.py | 7 +++++ preprocessors/gigaam.py | 40 +++++++++++++++++++++++++++ preprocessors/kaldi.py | 37 +++++++++++++++++++++---- preprocessors/nemo.py | 59 ++++++++++++++++++++++++++++++++++++++++ preprocessors/stft.py | 50 ++++++++++++++++++++++++++++++++++ preprocessors/whisper.py | 58 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 246 insertions(+), 5 deletions(-) create mode 100644 preprocessors/stft.py diff --git a/preprocessors/build.py b/preprocessors/build.py index c9bb04b..3f10878 100644 --- a/preprocessors/build.py +++ b/preprocessors/build.py @@ -34,6 +34,13 @@ def save_preprocessor_models(preprocessors_dir: Path, version: str) -> None: "nemo128.onnx": nemo.NemoPreprocessor128, "whisper80.onnx": whisper.WhisperPreprocessor80, "whisper128.onnx": whisper.WhisperPreprocessor128, + "gigaam_v2_conv.onnx": gigaam.GigaamPreprocessorV2Conv, + "gigaam_v3_conv.onnx": gigaam.GigaamPreprocessorV3Conv, + "kaldi_conv.onnx": kaldi.KaldiPreprocessorFastConv, + "nemo80_conv.onnx": nemo.NemoPreprocessor80Conv, + "nemo128_conv.onnx": nemo.NemoPreprocessor128Conv, + "whisper80_conv.onnx": whisper.WhisperPreprocessor80Conv, + "whisper128_conv.onnx": whisper.WhisperPreprocessor128Conv, } for filename, model in preprocessors.items(): save_onnx(model, preprocessors_dir.joinpath(filename), version) diff --git a/preprocessors/gigaam.py b/preprocessors/gigaam.py index ced63da..80b3831 100644 --- a/preprocessors/gigaam.py +++ b/preprocessors/gigaam.py @@ -6,6 +6,7 @@ from onnxscript import opset17 as op from preprocessors.fbanks import melscale_fbanks +from preprocessors.stft import conv_power_spectrogram, stft_conv_weights sample_rate = 16_000 n_fft_v2 = sample_rate // 40 @@ -27,6 +28,9 @@ ) hann_window_v3 = np.hanning(win_length_v3 + 1)[:-1].astype(bfloat16).astype(np.float32) +stft_conv_weights_v2 = stft_conv_weights(np.hanning(win_length_v2 + 1)[:-1].astype(np.float32)) +stft_conv_weights_v3 = stft_conv_weights(hann_window_v3) + @script(doc_string="LogMelSpectrogram feature extractor for GigaAM v2 models") def GigaamPreprocessorV2( @@ -65,3 +69,39 @@ def GigaamPreprocessorV3( features_lens = (waveforms_lens - win_length_v3) / hop_length + 1 features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]) return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v2 models (Conv-based STFT)") +def GigaamPreprocessorV2Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]: + waveforms = op.Pad( + waveforms, + pads=op.Constant(value=[0, n_fft_v2 // 2, 0, n_fft_v2 // 2]), + mode="reflect", + ) + + spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_v2) + + mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks_v2) + log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max)) + + features_lens = waveforms_lens / hop_length + 1 + features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]) + return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v3 models (Conv-based STFT)") +def GigaamPreprocessorV3Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]: + spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_v3) + + mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks_v3) + log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max)) + + features_lens = (waveforms_lens - win_length_v3) / hop_length + 1 + features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]) + return features, features_lens diff --git a/preprocessors/kaldi.py b/preprocessors/kaldi.py index 3bb8e7a..ea433f4 100644 --- a/preprocessors/kaldi.py +++ b/preprocessors/kaldi.py @@ -5,6 +5,7 @@ from onnxscript import opset17 as op from preprocessors.fbanks import melscale_fbanks +from preprocessors.stft import conv_power_spectrogram, stft_conv_weights sample_rate = 16_000 n_fft = 512 @@ -29,6 +30,9 @@ np.float32 ) wespeaker_window = np.hamming(win_length).astype(np.float32) +stft_conv_weights_kaldi = stft_conv_weights( + np.pad(np.hanning(win_length) ** 0.85, (0, n_fft - win_length)).astype(np.float32) +) @script() @@ -75,13 +79,11 @@ def sliding_buffer(prev: FLOAT["batch_size", win_length - hop_length], curr: FLO @script() def calc_features( - image: FLOAT["batch_size", "T", n_fft // 2 + 1, 2], + spectrogram: FLOAT["batch_size", "T", n_fft // 2 + 1], waveforms_lens: INT64["batch_size"], mel_banks: FLOAT[n_fft // 2 + 1, 2, num_mel_bins], snip_edges: bool, ): - spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0) - mel_spectrogram = op.MatMul(spectrogram, mel_banks) log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, min=float_eps)) @@ -119,7 +121,8 @@ def preprocessor( frames = frames - preemphasis_coefficient * op.Pad(frames, pads=[0, 0, 1, 0, 0, -1], mode="edge") image = op.DFT(op.Unsqueeze(window * frames, axes=[-1]), n_fft, axis=-2, onesided=1) - features, features_lens = calc_features(image, waveforms_lens, mel_banks, snip_edges) + spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0) + features, features_lens = calc_features(spectrogram, waveforms_lens, mel_banks, snip_edges) return features, features_lens @@ -155,8 +158,32 @@ def KaldiPreprocessorFast( pads=op.Constant(value=[0, n_fft - win_length]), ) image = op.STFT(waveforms, hop_length, povey_window) + spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0) + + features, features_lens = calc_features(spectrogram, waveforms_lens, kaldi_mel_banks, snip_edges=snip_edges) + return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for Kaldi models (Conv-based STFT)") +def KaldiPreprocessorFastConv( + waveforms: FLOAT["batch_size", "N"], waveforms_lens: INT64["batch_size"] +) -> tuple[FLOAT["batch_size", "T", num_mel_bins], INT64["batch_size"]]: + if dither != 0.0: + waveforms = waveforms + op.RandomNormalLike(waveforms, scale=dither) + + if remove_dc_offset: + waveforms = waveforms - op.ReduceMean(waveforms, axes=[-1]) + + if not snip_edges: + waveforms = symmetric_pad(waveforms, waveforms_lens) + + if preemphasis_coefficient != 0.0: + waveforms = waveforms - preemphasis_coefficient * op.Pad(waveforms, pads=[0, 1, 0, -1], mode="edge") + + waveforms = op.Pad(waveforms, pads=op.Constant(value=[0, 0, 0, n_fft - win_length])) + spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_kaldi) - features, features_lens = calc_features(image, waveforms_lens, kaldi_mel_banks, snip_edges=snip_edges) + features, features_lens = calc_features(spectrogram, waveforms_lens, kaldi_mel_banks, snip_edges=snip_edges) return features, features_lens diff --git a/preprocessors/nemo.py b/preprocessors/nemo.py index 3f04b06..5bb619b 100644 --- a/preprocessors/nemo.py +++ b/preprocessors/nemo.py @@ -5,6 +5,7 @@ from onnxscript import opset17 as op from preprocessors.fbanks import melscale_fbanks +from preprocessors.stft import conv_power_spectrogram, stft_conv_weights sample_rate = 16_000 n_fft = 512 @@ -20,6 +21,9 @@ melscale_fbanks128 = melscale_fbanks(n_fft // 2 + 1, 0, sample_rate // 2, 128, sample_rate, "slaney", "slaney").astype( np.float32 ) +stft_conv_weights_nemo = stft_conv_weights( + np.pad(np.hanning(win_length), (n_fft // 2 - win_length // 2, n_fft // 2 - win_length // 2)).astype(np.float32) +) @script() @@ -87,3 +91,58 @@ def NemoPreprocessor128( melscale_fbanks128, ) return features, features_lens + + +@script() +def nemo_preprocessor_conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], + melscale_fbanks: FLOAT[n_fft // 2 + 1, "M"], + conv_weights: FLOAT["channels", 1, n_fft], +): + if preemph != 0.0: + timemask = op.Range(0, op.Squeeze(op.Shape(waveforms, start=1, end=2)), 1) < op.Unsqueeze( + waveforms_lens, axes=[1] + ) + waveforms = op.Concat(waveforms[:, :1], waveforms[:, 1:] - preemph * waveforms[:, :-1], axis=-1) + waveforms = op.Where(timemask, waveforms, 0.0) + + waveforms = op.Pad( + waveforms, + pads=op.Constant(value=[0, n_fft // 2, 0, n_fft // 2]), + ) + spectrogram = conv_power_spectrogram(waveforms, conv_weights) + + mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks) + log_mel_spectrogram = op.Log(mel_spectrogram + log_zero_guard_value) + + features_lens = waveforms_lens / hop_length + return normalize(op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]), features_lens), features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for Nemo models (Conv-based STFT)", default_opset=op) +def NemoPreprocessor80Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", 80, "T"], INT64["batch_size"]]: + features, features_lens = nemo_preprocessor_conv( + waveforms, + waveforms_lens, + melscale_fbanks80, + stft_conv_weights_nemo, + ) + return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for Nemo models (Conv-based STFT)", default_opset=op) +def NemoPreprocessor128Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", 128, "T"], INT64["batch_size"]]: + features, features_lens = nemo_preprocessor_conv( + waveforms, + waveforms_lens, + melscale_fbanks128, + stft_conv_weights_nemo, + ) + return features, features_lens diff --git a/preprocessors/stft.py b/preprocessors/stft.py new file mode 100644 index 0000000..c8be6e6 --- /dev/null +++ b/preprocessors/stft.py @@ -0,0 +1,50 @@ +"""STFT computed as a fixed 1d convolution. + +op.STFT is not supported by the TensorRT and CUDA execution providers, so any +preprocessor graph that uses it falls back to CPU. The discrete Fourier +transform can instead be written as a 1d convolution with a fixed kernel (the +cos/sin Fourier basis multiplied by the analysis window). The convolution and +the ops around it (Reshape, ReduceSumSquare, Transpose) are supported by every +execution provider, so the resulting graph can run fully on GPU / TensorRT. +""" + +import numpy as np +import numpy.typing as npt +from onnxscript import FLOAT, script +from onnxscript import opset17 as op + +hop_length = 160 + + +def stft_conv_weights(window: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + """Build Conv weights that compute a windowed DFT. + + Args: + window: Analysis window, already zero-padded to the FFT size by the + caller (its length is used as ``n_fft``). + + Returns: + Kernel of shape ``[2 * (n_fft // 2 + 1), 1, n_fft]`` stacking the real + (cos) and imaginary (-sin) parts of the Fourier basis. Used with Conv + (stride = ``hop_length``) it reproduces ``op.STFT``. + + """ + n_fft = window.shape[0] + indices = np.arange(n_fft // 2 + 1)[:, np.newaxis] * np.arange(n_fft)[np.newaxis, :] + angle = 2 * np.pi * indices / n_fft + basis = np.concatenate([np.cos(angle), -np.sin(angle)]) * window + return basis[:, np.newaxis, :].astype(np.float32) + + +@script() +def conv_power_spectrogram(waveforms: FLOAT["batch_size", "N"], conv_weights: FLOAT["channels", 1, "n_fft"]): + """Power spectrogram [batch_size, frames, n_bins] via a Conv-based STFT. + + Drop-in replacement for ``op.STFT`` followed by ``ReduceSumSquare`` over the + real/imaginary axis. ``conv_weights`` is built by :func:`stft_conv_weights`. + """ + image = op.Conv(op.Unsqueeze(waveforms, axes=[1]), conv_weights, strides=[hop_length]) + n_bins = op.Shape(conv_weights, start=0, end=1) / 2 + shape = op.Concat(op.Constant(value=[0, 2]), n_bins, op.Constant(value=[-1]), axis=0) + spectrogram = op.ReduceSumSquare(op.Reshape(image, shape), axes=[1], keepdims=0) + return op.Transpose(spectrogram, perm=[0, 2, 1]) diff --git a/preprocessors/whisper.py b/preprocessors/whisper.py index 69ab434..5f5734f 100644 --- a/preprocessors/whisper.py +++ b/preprocessors/whisper.py @@ -5,6 +5,7 @@ from onnxscript import opset17 as op from preprocessors.fbanks import melscale_fbanks +from preprocessors.stft import conv_power_spectrogram, stft_conv_weights chunk_length = 30 sample_rate = 16_000 @@ -22,6 +23,7 @@ melscale_fbanks128 = melscale_fbanks(n_fft // 2 + 1, 0, sample_rate // 2, 128, sample_rate, "slaney", "slaney").astype( np.float32 ) +stft_conv_weights_whisper = stft_conv_weights(np.hanning(win_length + 1)[:-1].astype(np.float32)) @script() @@ -77,3 +79,59 @@ def WhisperPreprocessor128( melscale_fbanks128, ) return features, features_lens + + +@script() +def whisper_preprocessor_conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], + melscale_fbanks: FLOAT[n_fft // 2 + 1, "M"], + conv_weights: FLOAT["channels", 1, n_fft], +): + waveforms = op.Pad( + waveforms, + pads=(chunk_length * sample_rate - op.Shape(waveforms, start=1, end=2)) * op.Constant(value=[0, 0, 0, 1]), + ) + waveforms = op.Pad( + waveforms, + pads=op.Constant(value=[0, n_fft // 2, 0, n_fft // 2]), + mode="reflect", + ) + + spectrogram = conv_power_spectrogram(waveforms, conv_weights)[:, :-1] + + mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks) + log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min)) / ln10 + log_mel_spectrogram = (op.Max(log_mel_spectrogram, op.ReduceMax(log_mel_spectrogram) - 8) + 4) / 4.0 + + return op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]), op.ConstantOfShape( + op.Shape(waveforms_lens), value=features_length + ) + + +@script(doc_string="LogMelSpectrogram feature extractor for Whisper models (Conv-based STFT)", default_opset=op) +def WhisperPreprocessor80Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", 80, "T"], INT64["batch_size"]]: + features, features_lens = whisper_preprocessor_conv( + waveforms, + waveforms_lens, + melscale_fbanks80, + stft_conv_weights_whisper, + ) + return features, features_lens + + +@script(doc_string="LogMelSpectrogram feature extractor for Whisper models (Conv-based STFT)", default_opset=op) +def WhisperPreprocessor128Conv( + waveforms: FLOAT["batch_size", "N"], + waveforms_lens: INT64["batch_size"], +) -> tuple[FLOAT["batch_size", 128, "T"], INT64["batch_size"]]: + features, features_lens = whisper_preprocessor_conv( + waveforms, + waveforms_lens, + melscale_fbanks128, + stft_conv_weights_whisper, + ) + return features, features_lens From b6258b55a6b1f91a4d26bee41feffef94fea9449 Mon Sep 17 00:00:00 2001 From: ivan Date: Sun, 17 May 2026 15:25:31 +0300 Subject: [PATCH 2/3] Add use_conv_preprocessors runtime option PreprocessorRuntimeConfig gains a use_conv_preprocessors flag that selects the Conv-based ONNX preprocessor variants. It defaults to auto: enabled when a CUDA or TensorRT execution provider is used, disabled otherwise. When the Conv preprocessors are used the CUDA provider is no longer excluded from the preprocessor session (op.STFT has no CUDA kernel, the Conv graph does), so preprocessing runs on the GPU instead of falling back to a NumPy/CPU implementation. --- src/onnx_asr/loader.py | 46 +++++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/src/onnx_asr/loader.py b/src/onnx_asr/loader.py index 7f38f83..c5aaf2e 100644 --- a/src/onnx_asr/loader.py +++ b/src/onnx_asr/loader.py @@ -17,7 +17,7 @@ from onnx_asr.models.tone import TOneCtc from onnx_asr.models.wespeaker import WespeakerEmbeddings from onnx_asr.models.whisper import WhisperHf, WhisperOrt -from onnx_asr.onnx import OnnxSessionOptions, Provider, get_onnx_providers, update_onnx_providers +from onnx_asr.onnx import OnnxSessionOptions, Provider, TensorRtOptions, get_onnx_providers, update_onnx_providers from onnx_asr.preprocessors.numpy_preprocessor import ( GigaamPreprocessorNumpy, KaldiPreprocessorNumpy, @@ -142,6 +142,13 @@ def create_se_resolver( return Resolver(WespeakerEmbeddings, model, local_dir, offline=offline) +CONV_PREPROCESSORS = frozenset({"gigaam_v2", "gigaam_v3", "kaldi", "nemo80", "nemo128", "whisper80", "whisper128"}) +"""ONNX preprocessors that have a Conv-based STFT variant (named ``_conv``).""" + +CONV_DEFAULT_PROVIDERS = (*TensorRtOptions.get_provider_names(), "CUDAExecutionProvider") +"""Execution providers for which the Conv-based preprocessors are enabled by default.""" + + class PreprocessorRuntimeConfig(OnnxSessionOptions, total=False): """Preprocessor runtime config.""" @@ -151,6 +158,15 @@ class PreprocessorRuntimeConfig(OnnxSessionOptions, total=False): use_numpy_preprocessors: bool | None """Use NumPy preprocessors backend instead of ONNX.""" + use_conv_preprocessors: bool | None + """Use Conv-based STFT in ONNX preprocessors. + + op.STFT has no kernel in the CUDA execution provider (the STFT node falls + back to CPU with host/device copies around it). The Conv-based graph runs + natively on every provider and is faster on CPU for non-power-of-2 sizes. + None - auto (enabled when a CUDA or TensorRT execution provider is used). + """ + class Manager: """Manager for models creation.""" @@ -175,17 +191,31 @@ def __init__( } if preprocessor_config is None: - self.preprocessor_config = update_onnx_providers( - self.default_onnx_config, - new_options={"TensorrtExecutionProvider": {"trt_fp16_enable": False}}, - excluded_providers=OnnxPreprocessor._get_excluded_providers(), - ) self.preprocessor_max_workers: int | None = 1 self.use_numpy_preprocessors = None + self.use_conv_preprocessors: bool | None = None + base_config: OnnxSessionOptions = self.default_onnx_config else: self.preprocessor_max_workers = preprocessor_config.pop("max_concurrent_workers", 1) self.use_numpy_preprocessors = preprocessor_config.pop("use_numpy_preprocessors") - self.preprocessor_config = preprocessor_config + self.use_conv_preprocessors = preprocessor_config.pop("use_conv_preprocessors", None) + base_config = preprocessor_config + + if self.use_conv_preprocessors is None: + self.use_conv_preprocessors = any(p in CONV_DEFAULT_PROVIDERS for p in get_onnx_providers(base_config)) + + if preprocessor_config is None: + # op.STFT has no kernel in the CUDA execution provider, so the STFT + # preprocessors exclude it (the node would otherwise fall back to CPU + # with host/device copies). The Conv variant runs on every provider, + # so when it is used the CUDA provider is kept. + self.preprocessor_config = update_onnx_providers( + base_config, + new_options={"TensorrtExecutionProvider": {"trt_fp16_enable": False}}, + excluded_providers=[] if self.use_conv_preprocessors else OnnxPreprocessor._get_excluded_providers(), + ) + else: + self.preprocessor_config = base_config providers = get_onnx_providers(self.preprocessor_config) if self.use_numpy_preprocessors is None: @@ -213,6 +243,8 @@ def _create_preprocessor(self, name: str) -> Preprocessor: preprocessor = WhisperPreprocessorNumpy(name) else: raise ModelNotSupportedError(name) + elif self.use_conv_preprocessors and name in CONV_PREPROCESSORS: + preprocessor = OnnxPreprocessor(f"{name}_conv", self.preprocessor_config) else: preprocessor = OnnxPreprocessor(name, self.preprocessor_config) From ffc1ed0b170e658eba6f5b1aa8960ef7eb38e4dc Mon Sep 17 00:00:00 2001 From: ivan Date: Sun, 17 May 2026 15:25:31 +0300 Subject: [PATCH 3/3] Test Conv-based preprocessor variants Parametrize the preprocessor tests over the Conv variants, update the build file counts, and cover use_conv_preprocessors selection in the Manager and preprocessor-option tests. --- tests/onnx_asr/test_manager.py | 22 ++++++++++++++++++++-- tests/onnx_asr/test_recognize.py | 6 +++++- tests/preprocessors/test_build.py | 4 ++-- tests/preprocessors/test_gigaam.py | 18 ++++++++++++++++-- tests/preprocessors/test_kaldi.py | 8 +++++++- tests/preprocessors/test_nemo.py | 9 ++++++++- tests/preprocessors/test_whisper.py | 8 +++++++- 7 files changed, 65 insertions(+), 10 deletions(-) diff --git a/tests/onnx_asr/test_manager.py b/tests/onnx_asr/test_manager.py index 2328c5a..b289fc1 100644 --- a/tests/onnx_asr/test_manager.py +++ b/tests/onnx_asr/test_manager.py @@ -8,6 +8,7 @@ def test_with_cpu_provider() -> None: assert manager.default_onnx_config.get("providers") == providers assert manager.preprocessor_max_workers == 1 assert manager.use_numpy_preprocessors is True + assert manager.use_conv_preprocessors is False assert manager.preprocessor_config.get("providers") == providers assert manager.resampler_config.get("providers") == providers @@ -18,8 +19,10 @@ def test_with_cuda_provider() -> None: assert manager.default_onnx_config.get("providers") == providers assert manager.preprocessor_max_workers == 1 - assert manager.use_numpy_preprocessors is True - assert manager.preprocessor_config.get("providers") == [] + # op.STFT has no CUDA kernel, but the Conv preprocessors do run on CUDA. + assert manager.use_numpy_preprocessors is False + assert manager.use_conv_preprocessors is True + assert manager.preprocessor_config.get("providers") == providers assert manager.resampler_config.get("providers") == providers @@ -30,5 +33,20 @@ def test_with_tensorrt_provider() -> None: assert manager.default_onnx_config.get("providers") == providers assert manager.preprocessor_max_workers == 1 assert manager.use_numpy_preprocessors is False + assert manager.use_conv_preprocessors is True assert manager.preprocessor_config.get("providers") == providers assert manager.resampler_config.get("providers") == [] + + +def test_use_conv_preprocessors_override() -> None: + manager = Manager( + providers=["TensorrtExecutionProvider"], + preprocessor_config={"use_numpy_preprocessors": False, "use_conv_preprocessors": False}, + ) + assert manager.use_conv_preprocessors is False + + manager = Manager( + providers=["CPUExecutionProvider"], + preprocessor_config={"use_numpy_preprocessors": False, "use_conv_preprocessors": True}, + ) + assert manager.use_conv_preprocessors is True diff --git a/tests/onnx_asr/test_recognize.py b/tests/onnx_asr/test_recognize.py index 8a067e9..ef6897d 100644 --- a/tests/onnx_asr/test_recognize.py +++ b/tests/onnx_asr/test_recognize.py @@ -89,15 +89,19 @@ def test_recognize_batch(model: TextResultsAsrAdapter) -> None: assert all(isinstance(item, str) for item in result) +@pytest.mark.parametrize("use_conv_preprocessors", [False, True]) @pytest.mark.parametrize("max_concurrent_workers", [None, 1, 2]) @pytest.mark.parametrize("use_numpy_preprocessors", [None, True, False]) -def test_preprocessor_options(max_concurrent_workers: int | None, use_numpy_preprocessors: bool | None) -> None: +def test_preprocessor_options( + max_concurrent_workers: int | None, use_numpy_preprocessors: bool | None, use_conv_preprocessors: bool +) -> None: model = onnx_asr.load_model( "alphacep/vosk-model-small-ru", quantization="int8", preprocessor_config={ "max_concurrent_workers": max_concurrent_workers, "use_numpy_preprocessors": use_numpy_preprocessors, + "use_conv_preprocessors": use_conv_preprocessors, }, ) rng = np.random.default_rng(0) diff --git a/tests/preprocessors/test_build.py b/tests/preprocessors/test_build.py index b26de8e..ee0dfe1 100644 --- a/tests/preprocessors/test_build.py +++ b/tests/preprocessors/test_build.py @@ -8,7 +8,7 @@ def test_build(tmp_path: Path): build.build(tmp_path, "tests") - assert len(list(tmp_path.glob("*.onnx"))) == 22 + assert len(list(tmp_path.glob("*.onnx"))) == 29 assert len(list(tmp_path.glob("*.npz"))) == 1 @@ -16,7 +16,7 @@ def test_save_preprocessor_models(tmp_path: Path): build.save_preprocessor_models(tmp_path, "tests") files = list(tmp_path.glob("*.onnx")) - assert len(files) == 8 + assert len(files) == 15 for filename in files: onnx.checker.check_model(filename, full_check=True) model = onnx.load_model(filename) diff --git a/tests/preprocessors/test_gigaam.py b/tests/preprocessors/test_gigaam.py index 70dc620..d6f4b6c 100644 --- a/tests/preprocessors/test_gigaam.py +++ b/tests/preprocessors/test_gigaam.py @@ -95,7 +95,10 @@ def preprocessor_torch_v3(waveforms, lens): ) // gigaam.hop_length + 1 -@pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt"]) +@pytest.fixture( + scope="module", + params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt", "onnx_func_conv", "onnx_model_conv"], +) def preprocessor_v2(request): match request.param: case "torch": @@ -108,9 +111,16 @@ def preprocessor_v2(request): return (OnnxPreprocessor("gigaam_v2", {}), 1e-3) case "onnx_model_mt": return (ConcurrentPreprocessor(OnnxPreprocessor("gigaam_v2", {}), 2), 1e-3) + case "onnx_func_conv": + return (gigaam.GigaamPreprocessorV2Conv, 1e-3) + case "onnx_model_conv": + return (OnnxPreprocessor("gigaam_v2_conv", {}), 1e-3) -@pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt"]) +@pytest.fixture( + scope="module", + params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt", "onnx_func_conv", "onnx_model_conv"], +) def preprocessor_v3(request): match request.param: case "torch": @@ -123,6 +133,10 @@ def preprocessor_v3(request): return (OnnxPreprocessor("gigaam_v3", {}), 1e-3) case "onnx_model_mt": return (ConcurrentPreprocessor(OnnxPreprocessor("gigaam_v3", {}), 2), 1e-3) + case "onnx_func_conv": + return (gigaam.GigaamPreprocessorV3Conv, 1e-3) + case "onnx_model_conv": + return (OnnxPreprocessor("gigaam_v3_conv", {}), 1e-3) def test_gigaam_preprocessor_v2(preprocessor_v2, waveforms): diff --git a/tests/preprocessors/test_kaldi.py b/tests/preprocessors/test_kaldi.py index 8189253..68a2d30 100644 --- a/tests/preprocessors/test_kaldi.py +++ b/tests/preprocessors/test_kaldi.py @@ -83,7 +83,9 @@ def preprocessor(request): return kaldi.KaldiPreprocessor -@pytest.fixture(scope="module", params=["torch", "onnx_func", "onnx_model", "onnx_model_mt"]) +@pytest.fixture( + scope="module", params=["torch", "onnx_func", "onnx_model", "onnx_model_mt", "onnx_func_conv", "onnx_model_conv"] +) def preprocessor_fast(request): match request.param: case "torch": @@ -94,6 +96,10 @@ def preprocessor_fast(request): return OnnxPreprocessor("kaldi", {}) case "onnx_model_mt": return ConcurrentPreprocessor(OnnxPreprocessor("kaldi", {}), 2) + case "onnx_func_conv": + return kaldi.KaldiPreprocessorFastConv + case "onnx_model_conv": + return OnnxPreprocessor("kaldi_conv", {}) @pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt"]) diff --git a/tests/preprocessors/test_nemo.py b/tests/preprocessors/test_nemo.py index ac9dbfa..573e1c3 100644 --- a/tests/preprocessors/test_nemo.py +++ b/tests/preprocessors/test_nemo.py @@ -62,7 +62,10 @@ def preprocessor_torch(waveforms, lens, n_mels): return features, features_lens.numpy() -@pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt"]) +@pytest.fixture( + scope="module", + params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_model_mt", "onnx_func_conv", "onnx_model_conv"], +) def preprocessor(request, n_mels): match request.param: case "torch": @@ -75,6 +78,10 @@ def preprocessor(request, n_mels): return OnnxPreprocessor(f"nemo{n_mels}", {}) case "onnx_model_mt": return ConcurrentPreprocessor(OnnxPreprocessor(f"nemo{n_mels}", {}), 2) + case "onnx_func_conv": + return nemo.NemoPreprocessor80Conv if n_mels == 80 else nemo.NemoPreprocessor128Conv + case "onnx_model_conv": + return OnnxPreprocessor(f"nemo{n_mels}_conv", {}) def test_nemo_preprocessor(preprocessor_origin, preprocessor, waveforms): diff --git a/tests/preprocessors/test_whisper.py b/tests/preprocessors/test_whisper.py index 2a91338..e20f9fa 100644 --- a/tests/preprocessors/test_whisper.py +++ b/tests/preprocessors/test_whisper.py @@ -50,7 +50,9 @@ def preprocessor_torch(waveforms, lens, n_mels): return features.numpy(), np.full_like(lens, whisper.chunk_length * whisper.sample_rate // whisper.hop_length) -@pytest.fixture(scope="module", params=["torch", "numpy", "onnx_func", "onnx_model"]) +@pytest.fixture( + scope="module", params=["torch", "numpy", "onnx_func", "onnx_model", "onnx_func_conv", "onnx_model_conv"] +) def preprocessor_tol(request, n_mels): match request.param: case "torch": @@ -61,6 +63,10 @@ def preprocessor_tol(request, n_mels): return whisper.WhisperPreprocessor80 if n_mels == 80 else whisper.WhisperPreprocessor128, 5e-3 case "onnx_model": return OnnxPreprocessor(f"whisper{n_mels}", {}), 5e-3 + case "onnx_func_conv": + return (whisper.WhisperPreprocessor80Conv if n_mels == 80 else whisper.WhisperPreprocessor128Conv), 5e-3 + case "onnx_model_conv": + return OnnxPreprocessor(f"whisper{n_mels}_conv", {}), 5e-3 def test_whisper_preprocessor(n_mels, preprocessor_tol, waveforms):