Skip to content

Inconsistent outputs when running onnx and pytorch (stft and istft) #23219

@etemesi254

Description

@etemesi254

Describe the issue

Hi, thanks for the great library :)

Asteroid-filterbank (https://github.com/asteroid-team/asteroid-filterbanks) provides an onnx exportable implementation of stft and istft operations that i am using in a model for speech separation. The stft and istft is intergrated into the model for easier end to end inference.

Exporting to onnx has some warnings (shown below) and on exporting the model generates artifacts that make the audio seem to have extra noise which is not ideal.

I am seeking help in case this is an issue on asteroid or onnx and would appreciate someone looking into it. Thanks

Error/Warning Output logs from onnx


  warnings.warn(
/miniconda3/envs/rizumu/lib/python3.11/site-packages/asteroid_filterbanks/enc_dec.py:294: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  length = min(length, wav.shape[-1])
miniconda3/envs/rizumu/lib/python3.11/site-packages/torch/onnx/_internal/jit_utils.py:308: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp:180.)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
miniconda3/envs/rizumu/lib/python3.11/site-packages/torch/onnx/utils.py:663: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp:180.)
  _C._jit_pass_onnx_graph_shape_type_inference(
miniconda3/envs/rizumu/lib/python3.11/site-packages/torch/onnx/utils.py:1186: UserWarning: Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. Constant folding not applied. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/jit/passes/onnx/constant_fold.cpp:180.)
  _C._jit_pass_onnx_graph_shape_type_inference(

To reproduce

Colabarotory Link: https://colab.research.google.com/drive/1mNCwjGqMWLSAIZOIi1FJOJgfJqOmjxWn#scrollTo=H2c-2PWuNxxg

Installing dependencies

!pip install onnxruntime onnx asteroid-filterbanks

Code

from typing import Optional

import onnxruntime
import torch
from torch import nn, Tensor

from asteroid_filterbanks.enc_dec import Encoder, Decoder
from asteroid_filterbanks.transforms import to_torchaudio, from_torchaudio
from asteroid_filterbanks import torch_stft_fb


class AsteroidSTFT(nn.Module):
    def __init__(self, fb):
        super(AsteroidSTFT, self).__init__()
        self.enc = Encoder(fb)

    def forward(self, x):
        aux = self.enc(x)
        return to_torchaudio(aux)


class AsteroidISTFT(nn.Module):
    def __init__(self, fb):
        super(AsteroidISTFT, self).__init__()
        self.dec = Decoder(fb)

    def forward(self, x: Tensor, length: Optional[int] = None) -> Tensor:
        aux = from_torchaudio(x)
        x = self.dec(aux, length=length)
        return x


def make_filterbanks(n_fft=4096, n_hop=1024, center=True, sample_rate=44100.0):
    window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False)

    fb = torch_stft_fb.TorchSTFTFB.from_torch_args(
        n_fft=n_fft,
        hop_length=n_hop,
        win_length=n_fft,
        window=window,
        center=center,
        sample_rate=sample_rate,
    )
    encoder = AsteroidSTFT(fb)
    decoder = AsteroidISTFT(fb)

    return encoder, decoder



class TempTest(nn.Module):
    def __init__(self):
        super(TempTest, self).__init__()

        self.stft,self.istft = make_filterbanks()

    def forward(self, x: Tensor) -> Tensor:
        initial_size = x.shape[-1]
        was_unsqueezed = False

        if x.ndim == 2:
            # stft expects (batch, audio,channel) while model takes audio,channel
            # so fake a third dimension
            x = x.unsqueeze(0)
            was_unsqueezed = True
        prev_device = x.device
        x_cpu = x.to("cpu")
        self.stft = self.stft.to("cpu")
        x = self.stft(x_cpu)
        x = self.istft(x,initial_size)
        # return back to previous device
        x = x.to(prev_device)

        if was_unsqueezed:
            # remove the fake dimension squeeze
            x = x.squeeze(dim=0)
        return x

if __name__ == '__main__':
    model = TempTest()
    audio = torch.randn((1,20000))
    c = model(audio)
    torch.testing.assert_close(c,audio)
    # export to onnx
    torch.onnx.export(model,audio,"./temp_test.onnx",
                  dynamo_export=True,
                  external_data=False,
                  report=True,
                  verify=True,

                  input_names=["input"],
                  output_names=["output"],
                  dynamic_axes={"input": {0: "channels", 1: "length"},
                                "output": {0: "channels", 1: "length"}})
    sess = onnxruntime.InferenceSession("./temp_test.onnx")
    output = sess.run(["output"], {"input": audio.detach().numpy()})[0]
    torch.testing.assert_close(torch.from_numpy(output),audio)

Urgency

No response

Platform

Mac

OS Version

15.0 (24A335)

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

onnx==1.17.0 onnxruntime==1.20.1

ONNX Runtime API

Python

Architecture

ARM64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleissues that have not been addressed in a while; categorized by a bot

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions