Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions preprocessors/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def save_preprocessor_models(preprocessors_dir: Path, version: str) -> None:
"nemo128.onnx": nemo.NemoPreprocessor128,
"whisper80.onnx": whisper.WhisperPreprocessor80,
"whisper128.onnx": whisper.WhisperPreprocessor128,
"gigaam_v2_conv.onnx": gigaam.GigaamPreprocessorV2Conv,
"gigaam_v3_conv.onnx": gigaam.GigaamPreprocessorV3Conv,
"kaldi_conv.onnx": kaldi.KaldiPreprocessorFastConv,
"nemo80_conv.onnx": nemo.NemoPreprocessor80Conv,
"nemo128_conv.onnx": nemo.NemoPreprocessor128Conv,
"whisper80_conv.onnx": whisper.WhisperPreprocessor80Conv,
"whisper128_conv.onnx": whisper.WhisperPreprocessor128Conv,
}
for filename, model in preprocessors.items():
save_onnx(model, preprocessors_dir.joinpath(filename), version)
Expand Down
40 changes: 40 additions & 0 deletions preprocessors/gigaam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from onnxscript import opset17 as op

from preprocessors.fbanks import melscale_fbanks
from preprocessors.stft import conv_power_spectrogram, stft_conv_weights

sample_rate = 16_000
n_fft_v2 = sample_rate // 40
Expand All @@ -27,6 +28,9 @@
)
hann_window_v3 = np.hanning(win_length_v3 + 1)[:-1].astype(bfloat16).astype(np.float32)

stft_conv_weights_v2 = stft_conv_weights(np.hanning(win_length_v2 + 1)[:-1].astype(np.float32))
stft_conv_weights_v3 = stft_conv_weights(hann_window_v3)


@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v2 models")
def GigaamPreprocessorV2(
Expand Down Expand Up @@ -65,3 +69,39 @@ def GigaamPreprocessorV3(
features_lens = (waveforms_lens - win_length_v3) / hop_length + 1
features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1])
return features, features_lens


@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v2 models (Conv-based STFT)")
def GigaamPreprocessorV2Conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]:
waveforms = op.Pad(
waveforms,
pads=op.Constant(value=[0, n_fft_v2 // 2, 0, n_fft_v2 // 2]),
mode="reflect",
)

spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_v2)

mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks_v2)
log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max))

features_lens = waveforms_lens / hop_length + 1
features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1])
return features, features_lens


@script(doc_string="LogMelSpectrogram feature extractor for GigaAM v3 models (Conv-based STFT)")
def GigaamPreprocessorV3Conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", n_mels, "T"], INT64["batch_size"]]:
spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_v3)

mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks_v3)
log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min, clamp_max))

features_lens = (waveforms_lens - win_length_v3) / hop_length + 1
features = op.Transpose(log_mel_spectrogram, perm=[0, 2, 1])
return features, features_lens
37 changes: 32 additions & 5 deletions preprocessors/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from onnxscript import opset17 as op

from preprocessors.fbanks import melscale_fbanks
from preprocessors.stft import conv_power_spectrogram, stft_conv_weights

sample_rate = 16_000
n_fft = 512
Expand All @@ -29,6 +30,9 @@
np.float32
)
wespeaker_window = np.hamming(win_length).astype(np.float32)
stft_conv_weights_kaldi = stft_conv_weights(
np.pad(np.hanning(win_length) ** 0.85, (0, n_fft - win_length)).astype(np.float32)
)


@script()
Expand Down Expand Up @@ -75,13 +79,11 @@ def sliding_buffer(prev: FLOAT["batch_size", win_length - hop_length], curr: FLO

@script()
def calc_features(
image: FLOAT["batch_size", "T", n_fft // 2 + 1, 2],
spectrogram: FLOAT["batch_size", "T", n_fft // 2 + 1],
waveforms_lens: INT64["batch_size"],
mel_banks: FLOAT[n_fft // 2 + 1, 2, num_mel_bins],
snip_edges: bool,
):
spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0)

mel_spectrogram = op.MatMul(spectrogram, mel_banks)
log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, min=float_eps))

Expand Down Expand Up @@ -119,7 +121,8 @@ def preprocessor(
frames = frames - preemphasis_coefficient * op.Pad(frames, pads=[0, 0, 1, 0, 0, -1], mode="edge")

image = op.DFT(op.Unsqueeze(window * frames, axes=[-1]), n_fft, axis=-2, onesided=1)
features, features_lens = calc_features(image, waveforms_lens, mel_banks, snip_edges)
spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0)
features, features_lens = calc_features(spectrogram, waveforms_lens, mel_banks, snip_edges)
return features, features_lens


Expand Down Expand Up @@ -155,8 +158,32 @@ def KaldiPreprocessorFast(
pads=op.Constant(value=[0, n_fft - win_length]),
)
image = op.STFT(waveforms, hop_length, povey_window)
spectrogram = op.ReduceSumSquare(image, axes=[-1], keepdims=0)

features, features_lens = calc_features(spectrogram, waveforms_lens, kaldi_mel_banks, snip_edges=snip_edges)
return features, features_lens


@script(doc_string="LogMelSpectrogram feature extractor for Kaldi models (Conv-based STFT)")
def KaldiPreprocessorFastConv(
waveforms: FLOAT["batch_size", "N"], waveforms_lens: INT64["batch_size"]
) -> tuple[FLOAT["batch_size", "T", num_mel_bins], INT64["batch_size"]]:
if dither != 0.0:
waveforms = waveforms + op.RandomNormalLike(waveforms, scale=dither)

if remove_dc_offset:
waveforms = waveforms - op.ReduceMean(waveforms, axes=[-1])

if not snip_edges:
waveforms = symmetric_pad(waveforms, waveforms_lens)

if preemphasis_coefficient != 0.0:
waveforms = waveforms - preemphasis_coefficient * op.Pad(waveforms, pads=[0, 1, 0, -1], mode="edge")

waveforms = op.Pad(waveforms, pads=op.Constant(value=[0, 0, 0, n_fft - win_length]))
spectrogram = conv_power_spectrogram(waveforms, stft_conv_weights_kaldi)

features, features_lens = calc_features(image, waveforms_lens, kaldi_mel_banks, snip_edges=snip_edges)
features, features_lens = calc_features(spectrogram, waveforms_lens, kaldi_mel_banks, snip_edges=snip_edges)
return features, features_lens


Expand Down
59 changes: 59 additions & 0 deletions preprocessors/nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from onnxscript import opset17 as op

from preprocessors.fbanks import melscale_fbanks
from preprocessors.stft import conv_power_spectrogram, stft_conv_weights

sample_rate = 16_000
n_fft = 512
Expand All @@ -20,6 +21,9 @@
melscale_fbanks128 = melscale_fbanks(n_fft // 2 + 1, 0, sample_rate // 2, 128, sample_rate, "slaney", "slaney").astype(
np.float32
)
stft_conv_weights_nemo = stft_conv_weights(
np.pad(np.hanning(win_length), (n_fft // 2 - win_length // 2, n_fft // 2 - win_length // 2)).astype(np.float32)
)


@script()
Expand Down Expand Up @@ -87,3 +91,58 @@ def NemoPreprocessor128(
melscale_fbanks128,
)
return features, features_lens


@script()
def nemo_preprocessor_conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
melscale_fbanks: FLOAT[n_fft // 2 + 1, "M"],
conv_weights: FLOAT["channels", 1, n_fft],
):
if preemph != 0.0:
timemask = op.Range(0, op.Squeeze(op.Shape(waveforms, start=1, end=2)), 1) < op.Unsqueeze(
waveforms_lens, axes=[1]
)
waveforms = op.Concat(waveforms[:, :1], waveforms[:, 1:] - preemph * waveforms[:, :-1], axis=-1)
waveforms = op.Where(timemask, waveforms, 0.0)

waveforms = op.Pad(
waveforms,
pads=op.Constant(value=[0, n_fft // 2, 0, n_fft // 2]),
)
spectrogram = conv_power_spectrogram(waveforms, conv_weights)

mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks)
log_mel_spectrogram = op.Log(mel_spectrogram + log_zero_guard_value)

features_lens = waveforms_lens / hop_length
return normalize(op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]), features_lens), features_lens


@script(doc_string="LogMelSpectrogram feature extractor for Nemo models (Conv-based STFT)", default_opset=op)
def NemoPreprocessor80Conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", 80, "T"], INT64["batch_size"]]:
features, features_lens = nemo_preprocessor_conv(
waveforms,
waveforms_lens,
melscale_fbanks80,
stft_conv_weights_nemo,
)
return features, features_lens


@script(doc_string="LogMelSpectrogram feature extractor for Nemo models (Conv-based STFT)", default_opset=op)
def NemoPreprocessor128Conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", 128, "T"], INT64["batch_size"]]:
features, features_lens = nemo_preprocessor_conv(
waveforms,
waveforms_lens,
melscale_fbanks128,
stft_conv_weights_nemo,
)
return features, features_lens
50 changes: 50 additions & 0 deletions preprocessors/stft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""STFT computed as a fixed 1d convolution.

op.STFT is not supported by the TensorRT and CUDA execution providers, so any
preprocessor graph that uses it falls back to CPU. The discrete Fourier
transform can instead be written as a 1d convolution with a fixed kernel (the
cos/sin Fourier basis multiplied by the analysis window). The convolution and
the ops around it (Reshape, ReduceSumSquare, Transpose) are supported by every
execution provider, so the resulting graph can run fully on GPU / TensorRT.
"""

import numpy as np
import numpy.typing as npt
from onnxscript import FLOAT, script
from onnxscript import opset17 as op

hop_length = 160


def stft_conv_weights(window: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
"""Build Conv weights that compute a windowed DFT.

Args:
window: Analysis window, already zero-padded to the FFT size by the
caller (its length is used as ``n_fft``).

Returns:
Kernel of shape ``[2 * (n_fft // 2 + 1), 1, n_fft]`` stacking the real
(cos) and imaginary (-sin) parts of the Fourier basis. Used with Conv
(stride = ``hop_length``) it reproduces ``op.STFT``.

"""
n_fft = window.shape[0]
indices = np.arange(n_fft // 2 + 1)[:, np.newaxis] * np.arange(n_fft)[np.newaxis, :]
angle = 2 * np.pi * indices / n_fft
basis = np.concatenate([np.cos(angle), -np.sin(angle)]) * window
return basis[:, np.newaxis, :].astype(np.float32)


@script()
def conv_power_spectrogram(waveforms: FLOAT["batch_size", "N"], conv_weights: FLOAT["channels", 1, "n_fft"]):
"""Power spectrogram [batch_size, frames, n_bins] via a Conv-based STFT.

Drop-in replacement for ``op.STFT`` followed by ``ReduceSumSquare`` over the
real/imaginary axis. ``conv_weights`` is built by :func:`stft_conv_weights`.
"""
image = op.Conv(op.Unsqueeze(waveforms, axes=[1]), conv_weights, strides=[hop_length])
n_bins = op.Shape(conv_weights, start=0, end=1) / 2
shape = op.Concat(op.Constant(value=[0, 2]), n_bins, op.Constant(value=[-1]), axis=0)
spectrogram = op.ReduceSumSquare(op.Reshape(image, shape), axes=[1], keepdims=0)
return op.Transpose(spectrogram, perm=[0, 2, 1])
58 changes: 58 additions & 0 deletions preprocessors/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from onnxscript import opset17 as op

from preprocessors.fbanks import melscale_fbanks
from preprocessors.stft import conv_power_spectrogram, stft_conv_weights

chunk_length = 30
sample_rate = 16_000
Expand All @@ -22,6 +23,7 @@
melscale_fbanks128 = melscale_fbanks(n_fft // 2 + 1, 0, sample_rate // 2, 128, sample_rate, "slaney", "slaney").astype(
np.float32
)
stft_conv_weights_whisper = stft_conv_weights(np.hanning(win_length + 1)[:-1].astype(np.float32))


@script()
Expand Down Expand Up @@ -77,3 +79,59 @@ def WhisperPreprocessor128(
melscale_fbanks128,
)
return features, features_lens


@script()
def whisper_preprocessor_conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
melscale_fbanks: FLOAT[n_fft // 2 + 1, "M"],
conv_weights: FLOAT["channels", 1, n_fft],
):
waveforms = op.Pad(
waveforms,
pads=(chunk_length * sample_rate - op.Shape(waveforms, start=1, end=2)) * op.Constant(value=[0, 0, 0, 1]),
)
waveforms = op.Pad(
waveforms,
pads=op.Constant(value=[0, n_fft // 2, 0, n_fft // 2]),
mode="reflect",
)

spectrogram = conv_power_spectrogram(waveforms, conv_weights)[:, :-1]

mel_spectrogram = op.MatMul(spectrogram, melscale_fbanks)
log_mel_spectrogram = op.Log(op.Clip(mel_spectrogram, clamp_min)) / ln10
log_mel_spectrogram = (op.Max(log_mel_spectrogram, op.ReduceMax(log_mel_spectrogram) - 8) + 4) / 4.0

return op.Transpose(log_mel_spectrogram, perm=[0, 2, 1]), op.ConstantOfShape(
op.Shape(waveforms_lens), value=features_length
)


@script(doc_string="LogMelSpectrogram feature extractor for Whisper models (Conv-based STFT)", default_opset=op)
def WhisperPreprocessor80Conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", 80, "T"], INT64["batch_size"]]:
features, features_lens = whisper_preprocessor_conv(
waveforms,
waveforms_lens,
melscale_fbanks80,
stft_conv_weights_whisper,
)
return features, features_lens


@script(doc_string="LogMelSpectrogram feature extractor for Whisper models (Conv-based STFT)", default_opset=op)
def WhisperPreprocessor128Conv(
waveforms: FLOAT["batch_size", "N"],
waveforms_lens: INT64["batch_size"],
) -> tuple[FLOAT["batch_size", 128, "T"], INT64["batch_size"]]:
features, features_lens = whisper_preprocessor_conv(
waveforms,
waveforms_lens,
melscale_fbanks128,
stft_conv_weights_whisper,
)
return features, features_lens
Loading