@@ -325,7 +325,7 @@ def _stft(
325325    We can write STFT in terms of convolutions with a DFT kernel. 
326326    At the end: 
327327        * The real part output is: cos_base * input_real + sin_base * input_imag 
328-         * The imaginary part output is: - (sin_base * input_real  - cos_base  * input_imag)  
328+         * The imaginary part output is: cos_base * input_imag  - sin_base  * input_real  
329329    Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py 
330330    """ 
331331    hop_length  =  hop_length  or  mb .floor_div (x = n_fft , y = 4 , before_op = before_op )
@@ -342,7 +342,7 @@ def _stft(
342342
343343    # create a window of centered 1s of the requested size 
344344    if  win_length :
345-         window  =  _get_window (win_length = win_length , n_fft = n_fft , before_op = before_op )
345+         window  =  _get_window (win_length = win_length , n_fft = n_fft , window = window ,  before_op = before_op )
346346
347347    # apply time window 
348348    if  window :
@@ -358,12 +358,13 @@ def _stft(
358358    if  input_imaginary :
359359        signal_imaginary  =  mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
360360
361-     # conv with DFT kernel across the input signal 
362-     # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is: 
363-     # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i) 
364-     # If x is complex then x[n]=(a+i*b) 
365-     # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) 
366-     # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) 
361+     # Convolve the DFT kernel with the input signal 
362+     # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n]) 
363+     #   real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k)) 
364+     #   imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k)) 
365+     # But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k): 
366+     #   real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k)) 
367+     #   imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k)) 
367368    cos_windows_real  =  mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
368369    sin_windows_real  =  mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
369370    if  input_imaginary :
@@ -372,11 +373,11 @@ def _stft(
372373
373374    # add everything together 
374375    if  input_imaginary :
375-         real_result  =  mb .add (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
376-         imag_result  =  mb .sub (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
376+         real_result  =  mb .sub (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
377+         imag_result  =  mb .add (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
377378    else :
378379        real_result  =  cos_windows_real 
379-         imag_result  =  mb . sub ( x = 0. ,  y = sin_windows_real ,  before_op = before_op ) 
380+         imag_result  =  sin_windows_real 
380381
381382    # reduce the rank of the output 
382383    if  should_increase_rank :
@@ -417,17 +418,18 @@ def _istft(
417418    # By default, use the entire frame 
418419    win_length  =  win_length  or  n_fft 
419420
420-     input_shape  =  mb .shape (x = x , before_op = before_op )
421-     n_frames  =  input_shape .val [- 1 ]
422-     fft_size  =  input_shape .val [- 2 ]
423-     # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) 
421+     input_shape  =  mb .shape (x = input_real , before_op = before_op )
422+     channels  =  input_shape .val [0 ]
423+     fft_size  =  input_shape .val [1 ]
424+     n_frames  =  input_shape .val [2 ]
425+     expected_output_signal_len  =  n_fft .val  +  hop_length .val  *  (n_frames  -  1 )
424426
425427    is_onesided  =  onesided .val  if  onesided  else  fft_size  !=  n_fft 
426428    cos_base , sin_base  =  _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
427429
428430    # create a window of centered 1s of the requested size 
429431    if  win_length :
430-         window  =  _get_window (win_length = win_length , n_fft = n_fft , before_op = before_op )
432+         window  =  _get_window (win_length = win_length , n_fft = n_fft , window = window ,  before_op = before_op )
431433
432434    # apply time window 
433435    if  window :
@@ -447,14 +449,13 @@ def _istft(
447449        signal_real  =  mb .mul (x = signal_real , y = multiplier , before_op = before_op )
448450        signal_imaginary  =  mb .mul (x = signal_imaginary , y = multiplier , before_op = before_op )
449451
450-     # Conv with  DFT kernel across  the input signal 
451-     # We can describe the IDFT in terms of DFT just by swapping the input and output 
452+     # Convolve the  DFT kernel with  the input signal 
453+     # We can describe the IDFT in terms of DFT just by swapping the input and output.  
452454    # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT 
453-     # So IDFT(x) = (1/N) * swap(DFT(swap(x))) 
454-     # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i) 
455-     # If x is complex then x[n]=(a+i*b) 
456-     # then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) 
457-     # then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) 
455+     # IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N 
456+     # So using the definition in stft function, we get: 
457+     #   real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n)) 
458+     #   imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n)) 
458459    cos_windows_real  =  mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
459460    sin_windows_real  =  mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
460461    cos_windows_imag  =  mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
@@ -519,6 +520,7 @@ def _overlap_add(
519520def  _get_window (
520521    win_length : Var ,
521522    n_fft : Var ,
523+     window : Optional [Var ],
522524    before_op : Operation ,
523525) ->  Var :
524526    n_left  =  (n_fft .val  -  win_length .val ) //  2 
@@ -750,17 +752,21 @@ def _lower_complex_istft(op: Operation):
750752    is_complex  =  types .is_complex (op .input .dtype )
751753
752754    # check parameters for validity 
755+     if  is_complex :
756+         raise  ValueError ("Only complex inputs are allowed" )
753757    if  op .win_length  and  op .win_length .val  >  op .n_fft .val :
754758        raise  ValueError ("Window length must be less than or equal to n_fft" )
755-     if  is_complex  and  op .onesided  and  op .onesided .val :
756-         raise  ValueError ("Onesided  is only valid for real inputs " )
759+     if  op . return_complex  and  op .onesided  and  op .onesided .val :
760+         raise  ValueError ("Complex output  is not compatible with onesided " )
757761
758762    real , imag  =  _istft (
759-         op .input .real  if  is_complex  else  op .input ,
760-         op .input .imag  if  is_complex  else  None ,
761-         op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , before_op = op )
763+         op .input .real , op .input .imag ,
764+         op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , op .length , before_op = op )
762765
763-     return  _wrap_complex_output (op .outputs [0 ], real , imag )
766+     if  op .return_complex :
767+         return  _wrap_complex_output (op .outputs [0 ], real , imag )
768+     else 
769+         return  real 
764770
765771
766772@LowerComplex .register_lower_func (op_type = "complex_shape" ) 
0 commit comments