1212
1313
1414def output_test () -> None :
15+ """Test output of trained (streaming) model in different formats."""
1516 # loading data
1617 mix , fs = sf .read (
1718 "./gtcrn_micro/data/DNS3/noisy_blind_testset_v3_challenge_withSNR_16k/ms_realrec_nonenglish_female_SNR_23.01dB_headset_10_spanish_1.wav" ,
19+ # "./gtcrn_micro/data/DNS3/noisy_blind_testset_v3_challenge_withSNR_16k/ms_realrec_english_male_SNR_20.77dB_headset_door_near.wav",
1820 dtype = "float32" ,
1921 )
2022 assert fs == 16000 , f"Expected fs of 16000, instead got { fs } "
@@ -153,11 +155,10 @@ def output_test() -> None:
153155 # tflite_path = (
154156 # "gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_float32.tflite"
155157 # )
156- # tflite_path = "gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_dynamic_range_quant.tflite"
157- tflite_path = (
158- "gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_float16.tflite"
159- )
160- # tflite_path = "./gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_int8.tflite"
158+ tflite_path = "gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_dynamic_range_quant.tflite"
159+ # tflite_path = (
160+ # "gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_float16.tflite"
161+ # )
161162 tflite_stft = tflite_stream_infer (x , model_path = tflite_path )
162163 t_stft = tflite_stft
163164 enhanced_tflite = istft (
@@ -168,7 +169,7 @@ def output_test() -> None:
168169 window = np .hanning (512 ) ** 0.5 ,
169170 )
170171 sf .write (
171- "gtcrn_micro/streaming/sample/enh_tflite_f16 .wav" ,
172+ "gtcrn_micro/streaming/sample/enh_tflite_dynamic_range .wav" ,
172173 enhanced_tflite .squeeze (),
173174 16000 ,
174175 )
@@ -182,9 +183,11 @@ def output_test() -> None:
182183 print (f"STFT MAE ONNX vs PT: { np .mean (np .abs (o_stft - p_stft ))} " )
183184 print (f"STFT MAE TFL vs PT: { np .mean (np .abs (t_stft - p_stft ))} " )
184185 print (f"STFT MAE TFL vs ONNX: { np .mean (np .abs (t_stft - o_stft ))} " )
186+
185187 m = np .mean (np .abs (t_stft - p_stft ), axis = (0 , 1 , 3 ))
186188 print (f"TFL vs PT frame MAE start - mid - end: { m [0 ]} - { m [len (m ) // 2 ]} - { m [- 1 ]} " )
187189
190+ print ("\n Time-Domain:\n " )
188191 print (
189192 f"Onnx outputs error vs pytorch: { np .mean (np .abs (enhanced_onnx - enhanced_pytorch ))} "
190193 )
@@ -193,6 +196,7 @@ def output_test() -> None:
193196
194197 print ("onnx MAE:" , abs_diff .mean ())
195198 print ("onnx median abs diff:" , np .median (abs_diff ))
199+
196200 print (
197201 f"Tflite outputs error vs pytorch: { np .mean (np .abs (enhanced_tflite - enhanced_pytorch ))} "
198202 )
0 commit comments