@@ -9534,13 +9534,13 @@ def forward(self, x):
95349534            [None , 1 , 3 ], # channels  
95359535            [16 , 32 ], # n_fft  
95369536            [5 , 9 ], # num_frames  
9537-             [None , 4 ,  5 ], # hop_length  
9537+             [None , 5 ], # hop_length  
95389538            [None , 10 , 8 ], # win_length  
95399539            [None , torch .hann_window ], # window  
95409540            [False , True ], # center  
95419541            [False , True ], # normalized  
95429542            [None , False , True ], # onesided  
9543-             [None , 30 ,  40 ], # length  
9543+             [None , "shorter" ,  "larger" ], # length  
95449544            [False , True ], # return_complex  
95459545        ) 
95469546    ) 
@@ -9551,9 +9551,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
95519551        if  hop_length  is  None  and  win_length  is  not   None :
95529552            pytest .skip ("If win_length is set then we must set hop_length and 0 < hop_length <= win_length" )
95539553
9554+         # Compute input_shape to generate test case 
95549555        freq  =  n_fft // 2 + 1  if  onesided  else  n_fft 
95559556        input_shape  =  (channels , freq , num_frames ) if  channels  else  (freq , num_frames )
95569557
9558+         # If not set,c ompute hop_length for capturing errors 
9559+         if  hop_length  is  None :
9560+             hop_length  =  n_fft  //  4 
9561+ 
9562+         if  length  ==  "shorter" :
9563+             length  =  n_fft // 2  +  hop_length  *  (num_frames  -  1 )
9564+         elif  length  ==  "larger" :
9565+             length  =  n_fft * 3 // 2  +  hop_length  *  (num_frames  -  1 )
9566+ 
95579567        class  ISTFTModel (torch .nn .Module ):
95589568            def  forward (self , x ):
95599569                applied_window  =  window (win_length ) if  window  and  win_length  else  None 
@@ -9573,7 +9583,7 @@ def forward(self, x):
95739583                else :
95749584                    return  torch .real (x )
95759585
9576-         if  win_length   and  center  is   False :
9586+         if  ( center   is   False   and  win_length )  or  ( center  and   win_length   and   length ) :
95779587            # For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033 
95789588            with  pytest .raises (RuntimeError , match = "istft\(.*\) window overlap add min: 1" ):
95799589                TorchBaseTest .run_compare_torch (
@@ -9582,7 +9592,7 @@ def forward(self, x):
95829592                    backend = backend ,
95839593                    compute_unit = compute_unit 
95849594                )
9585-         elif  length  is   not   None   and  return_complex   is   True :
9595+         elif  length  and  return_complex :
95869596            with  pytest .raises (ValueError , match = "New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`" ):
95879597                TorchBaseTest .run_compare_torch (
95889598                    input_shape ,
0 commit comments