-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Describe the issue
Hi, thanks for the great library :)
Asteroid-filterbank (https://github.com/asteroid-team/asteroid-filterbanks) provides an onnx exportable implementation of stft and istft operations that i am using in a model for speech separation. The stft and istft is intergrated into the model for easier end to end inference.
Exporting to onnx has some warnings (shown below) and on exporting the model generates artifacts that make the audio seem to have extra noise which is not ideal.
I am seeking help in case this is an issue on asteroid or onnx and would appreciate someone looking into it. Thanks
Error/Warning Output logs from onnx
warnings.warn(
/miniconda3/envs/rizumu/lib/python3.11/site-packages/asteroid_filterbanks/enc_dec.py:294: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
length = min(length, wav.shape[-1])
miniconda3/envs/rizumu/lib/python3.11/site-packages/torch/onnx/_internal/jit_utils.py:308: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp:180.)
_C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
miniconda3/envs/rizumu/lib/python3.11/site-packages/torch/onnx/utils.py:663: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp:180.)
_C._jit_pass_onnx_graph_shape_type_inference(
miniconda3/envs/rizumu/lib/python3.11/site-packages/torch/onnx/utils.py:1186: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp:180.)
_C._jit_pass_onnx_graph_shape_type_inference(
To reproduce
Colabarotory Link: https://colab.research.google.com/drive/1mNCwjGqMWLSAIZOIi1FJOJgfJqOmjxWn#scrollTo=H2c-2PWuNxxg
Installing dependencies
!pip install onnxruntime onnx asteroid-filterbanksCode
from typing import Optional
import onnxruntime
import torch
from torch import nn, Tensor
from asteroid_filterbanks.enc_dec import Encoder, Decoder
from asteroid_filterbanks.transforms import to_torchaudio, from_torchaudio
from asteroid_filterbanks import torch_stft_fb
class AsteroidSTFT(nn.Module):
def __init__(self, fb):
super(AsteroidSTFT, self).__init__()
self.enc = Encoder(fb)
def forward(self, x):
aux = self.enc(x)
return to_torchaudio(aux)
class AsteroidISTFT(nn.Module):
def __init__(self, fb):
super(AsteroidISTFT, self).__init__()
self.dec = Decoder(fb)
def forward(self, x: Tensor, length: Optional[int] = None) -> Tensor:
aux = from_torchaudio(x)
x = self.dec(aux, length=length)
return x
def make_filterbanks(n_fft=4096, n_hop=1024, center=True, sample_rate=44100.0):
window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False)
fb = torch_stft_fb.TorchSTFTFB.from_torch_args(
n_fft=n_fft,
hop_length=n_hop,
win_length=n_fft,
window=window,
center=center,
sample_rate=sample_rate,
)
encoder = AsteroidSTFT(fb)
decoder = AsteroidISTFT(fb)
return encoder, decoder
class TempTest(nn.Module):
def __init__(self):
super(TempTest, self).__init__()
self.stft,self.istft = make_filterbanks()
def forward(self, x: Tensor) -> Tensor:
initial_size = x.shape[-1]
was_unsqueezed = False
if x.ndim == 2:
# stft expects (batch, audio,channel) while model takes audio,channel
# so fake a third dimension
x = x.unsqueeze(0)
was_unsqueezed = True
prev_device = x.device
x_cpu = x.to("cpu")
self.stft = self.stft.to("cpu")
x = self.stft(x_cpu)
x = self.istft(x,initial_size)
# return back to previous device
x = x.to(prev_device)
if was_unsqueezed:
# remove the fake dimension squeeze
x = x.squeeze(dim=0)
return x
if __name__ == '__main__':
model = TempTest()
audio = torch.randn((1,20000))
c = model(audio)
torch.testing.assert_close(c,audio)
# export to onnx
torch.onnx.export(model,audio,"./temp_test.onnx",
dynamo_export=True,
external_data=False,
report=True,
verify=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "channels", 1: "length"},
"output": {0: "channels", 1: "length"}})
sess = onnxruntime.InferenceSession("./temp_test.onnx")
output = sess.run(["output"], {"input": audio.detach().numpy()})[0]
torch.testing.assert_close(torch.from_numpy(output),audio)Urgency
No response
Platform
Mac
OS Version
15.0 (24A335)
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
onnx==1.17.0 onnxruntime==1.20.1
ONNX Runtime API
Python
Architecture
ARM64
Execution Provider
Default CPU
Execution Provider Library Version
No response