Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ Session.vim
/release
/release_models
/separated
/tests
/trash
/misc
/mdx
.mypy_cache
*.onnx
*.ort
*.config
44 changes: 34 additions & 10 deletions demucs/htdemucs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (C) 2025 Mixxx Development Team.
# All rights reserved.
#
# This source code is licensed under the license found in the
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
samplerate=44100,
segment=10,
use_train_segment=True,
onnx_exportable=False,
):
"""
Args:
Expand Down Expand Up @@ -239,6 +241,7 @@ def __init__(
self.wiener_iters = wiener_iters
self.end_iters = end_iters
self.freq_emb = None
self.onnx_exportable = onnx_exportable
assert wiener_iters == end_iters

self.encoder = nn.ModuleList()
Expand Down Expand Up @@ -434,29 +437,46 @@ def _spec(self, x):
pad = hl // 2 * 3
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")

z = spectro(x, nfft, hl)[..., :-1, :]
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
z = z[..., 2: 2 + le]
if self.onnx_exportable:
z = spectro(x, nfft, hl, onnx_exportable=True)[..., :-1, :, :] # adding one more dimension
assert z.shape[-2] == le + 4, (z.shape, x.shape, le) # from -1 to -2
z = z[..., 2: 2 + le, :] # adding one more dimension
else:
z = spectro(x, nfft, hl)[..., :-1, :]
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
z = z[..., 2: 2 + le]
return z

def _ispec(self, z, length=None, scale=0):
hl = self.hop_length // (4**scale)
z = F.pad(z, (0, 0, 0, 1))
z = F.pad(z, (2, 2))
if self.onnx_exportable:
z = F.pad(z, (0, 0, 0, 0, 0, 1)) # add 0 padding for the last dim
z = F.pad(z, (0, 0, 2, 2)) # add 0 padding for the last dim
else:
z = F.pad(z, (0, 0, 0, 1))
z = F.pad(z, (2, 2))
pad = hl // 2 * 3
le = hl * int(math.ceil(length / hl)) + 2 * pad
x = ispectro(z, hl, length=le)
x = ispectro(z, hl, length=le, onnx_exportable=self.onnx_exportable)
x = x[..., pad: pad + length]
return x

def _magnitude(self, z):
# return the magnitude of the spectrogram, except when cac is True,
# in which case we just move the complex dimension to the channel one.
if self.cac:
B, C, Fr, T = z.shape
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
if self.onnx_exportable:
B, C, Fr, T, dim = z.shape # dim should be 2, adding one more dimension
m = z.permute(0, 1, 4, 2, 3) # torch.view_as_real(z) changed to z
else:
B, C, Fr, T = z.shape
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
m = m.reshape(B, C * 2, Fr, T)
else:
elif self.onnx_exportable:
real = z[..., 0]
imag = z[..., 1]
m = torch.sqrt(real**2 + imag**2) # for magnitude
else:
m = z.abs()
return m

Expand All @@ -467,8 +487,12 @@ def _mask(self, z, m):
if self.cac:
B, S, C, Fr, T = m.shape
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
out = torch.view_as_complex(out.contiguous())
if self.onnx_exportable:
out = out.contiguous() # out shape is (B, S, -1, Fr, T, 2) shape
else:
out = torch.view_as_complex(out.contiguous()) # out shape is (B, S, -1, Fr, T) (complex)
return out
# TODO: Modify all the below paths for the new shape
if self.training:
niters = self.end_iters
if niters < 0:
Expand Down
164 changes: 164 additions & 0 deletions demucs/istft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (C) 2025 Mixxx Development Team
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# First author is Anmol Mishra.
"""
This module implements a custom ISTFT process that is compatible with ONNX export.
It uses PyTorch's convolution operations to compute the inverse STFT, avoiding the use of
complex numbers directly, which can be problematic for ONNX export.
"""
import torch
import enum


# Constants set for Demucs
NFFT = 4096 # Number of FFT components for the STFT process
HOP_LENGTH = 1024 # Number of samples between successive frames in the STFT
WINDOW_TYPE = 'hann' # Type of window function used in the STFT
WINDOW_LENGTH = NFFT # Length of the window function
NORMALIZED = True # Whether to normalize the window function
MAX_SIGNAL_LENGTH = int(44100 * 8) # Maximum length of the audio signal WITH padding (8 seconds at 44100 Hz)
MAX_FRAMES = MAX_SIGNAL_LENGTH // HOP_LENGTH + 1 # Maximum number of frames for the audio length after STFT processed.
CENTER = True # Whether to center the input signal before STFT

# Enum for window types
class WindowType(enum.StrEnum):
BARTLETT = 'bartlett'
BLACKMAN = 'blackman'
HAMMING = 'hamming'
HANN = 'hann'
KAISER = 'kaiser'

def __call__(self, window_length):
match self:
case WindowType.BARTLETT:
return torch.bartlett_window(window_length)
case WindowType.BLACKMAN:
return torch.blackman_window(window_length)
case WindowType.HAMMING:
return torch.hamming_window(window_length)
case WindowType.HANN:
return torch.hann_window(window_length)
case WindowType.KAISER:
return torch.kaiser_window(window_length, periodic=True, beta=12.0)
case _:
raise NotImplementedError(f"Window type {self} doesn't yet have a function.")

class ISTFT_Process(torch.nn.Module):
def __init__(self, n_fft=NFFT, hop_len=HOP_LENGTH, window_type=WINDOW_TYPE, window_length=WINDOW_LENGTH, normalized=NORMALIZED, max_frames=MAX_FRAMES, center=CENTER):
super(ISTFT_Process, self).__init__()
self.n_fft = n_fft
self.hop_len = hop_len
self.window_type = window_type
self.window_length = window_length
self.normalized = normalized
self.max_frames = max_frames
self.center = center
self.half_n_fft = n_fft // 2 # Precompute once

# Get window function and compute window once
if self.window_length != self.n_fft:
raise NotImplementedError(f"The case of window length not equal to n_fft is not implemented in {self.__class__.__name__}.")
window = WindowType(window_type)(self.window_length).float()

# Check if center is false
if not self.center:
raise NotImplementedError("No centering is not supported in this implementation.")

# ISTFT forward pass preparation
# Pre-compute fourier basis
fourier_basis = torch.fft.fft(torch.eye(n_fft, dtype=torch.float32))
fourier_basis = torch.vstack([
torch.real(fourier_basis[:self.half_n_fft + 1, :]),
torch.imag(fourier_basis[:self.half_n_fft + 1, :])
]).float()

# Create forward and inverse basis
forward_basis = window * fourier_basis[:, None, :]
inverse_basis = window * torch.linalg.pinv((fourier_basis * n_fft) / hop_len).T[:, None, :]

# Calculate window sum for overlap-add
n = n_fft + hop_len * (max_frames - 1)
window_sum = torch.zeros(n, dtype=torch.float32)
window_normalized = window / window.abs().max()

# Pad window if needed
total_pad = n_fft - window_normalized.shape[0]
if total_pad > 0:
pad_left = total_pad // 2
pad_right = total_pad - pad_left
win_sq = torch.nn.functional.pad(window_normalized ** 2, (pad_left, pad_right), mode='constant', value=0)
else:
win_sq = window_normalized ** 2

# Calculate overlap-add weights
for i in range(max_frames):
sample = i * hop_len
window_sum[sample: min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]

# Normalize window if needed
if normalized:
inverse_basis = inverse_basis * torch.sqrt(torch.tensor([n_fft], dtype=torch.float32))

# Register buffers
self.register_buffer("forward_basis", forward_basis)
self.register_buffer("inverse_basis", inverse_basis)
self.register_buffer("window_sum_inv", n_fft / (window_sum * hop_len + 1e-8)) # Add epsilon to avoid division by zero

def forward(self, real, imag, length=None):
# Calculate magnitude and phase from real and imaginary parts
magnitude = torch.sqrt(real ** 2 + imag ** 2)
phase = torch.atan2(imag, real + torch.finfo(real.dtype).eps) # Add epsilon to avoid division by zero

# Pre-compute trig values
cos_phase = torch.cos(phase)
sin_phase = torch.sin(phase)

# Prepare input for transposed convolution
complex_input = torch.cat((magnitude * cos_phase, magnitude * sin_phase), dim=1)

# Perform transposed convolution
inverse_transform = torch.nn.functional.conv_transpose1d(
complex_input,
self.inverse_basis,
stride=self.hop_len,
padding=0,
)

# Apply window correction
output_len = inverse_transform.size(-1)
start_idx = self.half_n_fft
end_idx = output_len

output = inverse_transform[:, :, start_idx:end_idx] * self.window_sum_inv[start_idx:end_idx]

# If length is specified, trim the output to the desired length
if length:
pad_len = torch.clamp(torch.tensor(length) - output.size(-1), min=0)

# Create a zero pad tensor regardless of need
pad = torch.zeros(
output.size(0), output.size(1), pad_len,
dtype=output.dtype, device=output.device
)

# Always cat, pad_len will be 0 if not needed
output = torch.cat([output, pad], dim=-1)

# Crop in all cases to enforce exact length
output = output[..., :length]

output = output.squeeze(dim=1)
return output

demucs_istft = ISTFT_Process(
n_fft=NFFT,
hop_len=HOP_LENGTH,
window_type=WINDOW_TYPE,
window_length=WINDOW_LENGTH,
normalized=NORMALIZED,
max_frames=MAX_FRAMES,
center=CENTER,
)
28 changes: 26 additions & 2 deletions demucs/spec.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (C) 2025 Mixxx Development Team.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Conveniance wrapper to perform STFT and iSTFT"""

import torch as th
from .stft import demucs_stft
from .istft import demucs_istft


def spectro(x, n_fft=512, hop_length=None, pad=0):
def spectro(x, n_fft=512, hop_length=None, pad=0, onnx_exportable=False):
*other, length = x.shape
x = x.reshape(-1, length)
is_mps_xpu = x.device.type in ['mps', 'xpu']
if is_mps_xpu:
x = x.cpu()

if onnx_exportable:
z = demucs_stft(x.view(-1, 1, length)) # z will return 1 more dimension - z.size(-1) will be 2
_, freqs, frame, dim = z.shape
assert dim == 2, "STFT should return complex numbers"
return z.view(*other, freqs, frame, dim)

z = th.stft(x,
n_fft * (1 + pad),
hop_length or n_fft // 4,
Expand All @@ -27,7 +37,21 @@ def spectro(x, n_fft=512, hop_length=None, pad=0):
return z.view(*other, freqs, frame)


def ispectro(z, hop_length=None, length=None, pad=0):
def ispectro(z, hop_length=None, length=None, pad=0, onnx_exportable=False):
if onnx_exportable:
# B, S, -1, Fr, T (complex) -----> # B, S, -1, Fr, T, 2 shape
*other, freqs, frames, dim = z.shape # dim is 2
assert dim == 2, "iSTFT should receive complex numbers"
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames, dim)
win_length = n_fft // (1 + pad)
is_mps_xpu = z.device.type in ['mps', 'xpu']
if is_mps_xpu:
z = z.cpu()
x = demucs_istft(z[..., 0], z[..., 1], length=length)
_, length = x.shape
return x.view(*other, length)

*other, freqs, frames = z.shape
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
Expand Down
Loading