|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import torchaudio |
| 5 | +from onnx import numpy_helper |
| 6 | +from onnxscript import FLOAT, INT64, script |
| 7 | +from onnxscript import opset17 as op |
| 8 | + |
| 9 | + |
| 10 | +def make_kernel(orig_freq: int): |
| 11 | + new_freq = 16_000 |
| 12 | + gcd = math.gcd(orig_freq, new_freq) |
| 13 | + kernel, width = torchaudio.functional.functional._get_sinc_resample_kernel(orig_freq, new_freq, gcd, dtype=torch.float32) |
| 14 | + return kernel.numpy()[:, None], width, orig_freq // gcd, new_freq // gcd |
| 15 | + |
| 16 | + |
| 17 | +kernel08, width08, orig_freq08, new_freq08 = make_kernel(8_000) |
| 18 | +kernel22, width22, orig_freq22, new_freq22 = make_kernel(22_050) |
| 19 | +kernel44, width44, orig_freq44, new_freq44 = make_kernel(44_100) |
| 20 | +kernel48, width48, orig_freq48, new_freq48 = make_kernel(48_000) |
| 21 | + |
| 22 | + |
| 23 | +@script(doc_string="Resampling waveform to 16 kHz") |
| 24 | +def ResamplePreprocessor( |
| 25 | + waveforms: FLOAT["batch_size", "N"], |
| 26 | + waveforms_lens: INT64["batch_size"], |
| 27 | + sample_rate: INT64["1"], |
| 28 | +) -> tuple[FLOAT["batch_size", "M"], INT64["batch_size"]]: |
| 29 | + waveforms = op.Unsqueeze(waveforms, axes=[1, 2]) |
| 30 | + |
| 31 | + if sample_rate[0] == 8_000: |
| 32 | + kernel = op.Constant(value=numpy_helper.from_array(kernel08, "kernel")) |
| 33 | + conv = op.Conv(waveforms, kernel, pads=(0, width08, 0, width08 + orig_freq08), strides=(1, orig_freq08)) |
| 34 | + waveforms_lens = (new_freq08 * waveforms_lens + orig_freq08 - 1) / orig_freq08 |
| 35 | + elif sample_rate[0] == 22_050: |
| 36 | + kernel = op.Constant(value=numpy_helper.from_array(kernel22, "kernel")) |
| 37 | + conv = op.Conv(waveforms, kernel, pads=(0, width22, 0, width22 + orig_freq22), strides=(1, orig_freq22)) |
| 38 | + waveforms_lens = (new_freq22 * waveforms_lens + orig_freq22 - 1) / orig_freq22 |
| 39 | + elif sample_rate[0] == 44_100: |
| 40 | + kernel = op.Constant(value=numpy_helper.from_array(kernel44, "kernel")) |
| 41 | + conv = op.Conv(waveforms, kernel, pads=(0, width44, 0, width44 + orig_freq44), strides=(1, orig_freq44)) |
| 42 | + waveforms_lens = (new_freq44 * waveforms_lens + orig_freq44 - 1) / orig_freq44 |
| 43 | + elif sample_rate[0] == 48_000: |
| 44 | + kernel = op.Constant(value=numpy_helper.from_array(kernel48, "kernel")) |
| 45 | + conv = op.Conv(waveforms, kernel, pads=(0, width48, 0, width48 + orig_freq48), strides=(1, orig_freq48)) |
| 46 | + waveforms_lens = (new_freq48 * waveforms_lens + orig_freq48 - 1) / orig_freq48 |
| 47 | + else: |
| 48 | + conv = waveforms |
| 49 | + |
| 50 | + resampled_lens = op.Identity(waveforms_lens) |
| 51 | + max_len = op.ReduceMax(resampled_lens, keepdims=0) |
| 52 | + mask = op.Unsqueeze(op.Range(0, max_len, 1), [0]) < op.Unsqueeze(resampled_lens, [1]) |
| 53 | + resampled = op.Where(mask, op.Flatten(op.Transpose(conv, perm=(0, 3, 2, 1)))[:, :max_len], 0) |
| 54 | + return resampled, resampled_lens |
0 commit comments