Skip to content

Commit fc021cd

Browse files
committed
updating output tests
1 parent 529999b commit fc021cd

3 files changed

Lines changed: 10 additions & 7 deletions

File tree

-16.2 KB
Binary file not shown.

gtcrn_micro/utils/output_tests.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212

1313

1414
def output_test() -> None:
15+
"""Test output of trained (streaming) model in different formats."""
1516
# loading data
1617
mix, fs = sf.read(
1718
"./gtcrn_micro/data/DNS3/noisy_blind_testset_v3_challenge_withSNR_16k/ms_realrec_nonenglish_female_SNR_23.01dB_headset_10_spanish_1.wav",
19+
# "./gtcrn_micro/data/DNS3/noisy_blind_testset_v3_challenge_withSNR_16k/ms_realrec_english_male_SNR_20.77dB_headset_door_near.wav",
1820
dtype="float32",
1921
)
2022
assert fs == 16000, f"Expected fs of 16000, instead got {fs}"
@@ -153,11 +155,10 @@ def output_test() -> None:
153155
# tflite_path = (
154156
# "gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_float32.tflite"
155157
# )
156-
# tflite_path = "gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_dynamic_range_quant.tflite"
157-
tflite_path = (
158-
"gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_float16.tflite"
159-
)
160-
# tflite_path = "./gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_int8.tflite"
158+
tflite_path = "gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_dynamic_range_quant.tflite"
159+
# tflite_path = (
160+
# "gtcrn_micro/streaming/tflite/gtcrn_micro_stream_simple_float16.tflite"
161+
# )
161162
tflite_stft = tflite_stream_infer(x, model_path=tflite_path)
162163
t_stft = tflite_stft
163164
enhanced_tflite = istft(
@@ -168,7 +169,7 @@ def output_test() -> None:
168169
window=np.hanning(512) ** 0.5,
169170
)
170171
sf.write(
171-
"gtcrn_micro/streaming/sample/enh_tflite_f16.wav",
172+
"gtcrn_micro/streaming/sample/enh_tflite_dynamic_range.wav",
172173
enhanced_tflite.squeeze(),
173174
16000,
174175
)
@@ -182,9 +183,11 @@ def output_test() -> None:
182183
print(f"STFT MAE ONNX vs PT: {np.mean(np.abs(o_stft - p_stft))}")
183184
print(f"STFT MAE TFL vs PT: {np.mean(np.abs(t_stft - p_stft))}")
184185
print(f"STFT MAE TFL vs ONNX: {np.mean(np.abs(t_stft - o_stft))}")
186+
185187
m = np.mean(np.abs(t_stft - p_stft), axis=(0, 1, 3))
186188
print(f"TFL vs PT frame MAE start - mid - end: {m[0]} - {m[len(m) // 2]} - {m[-1]}")
187189

190+
print("\nTime-Domain:\n")
188191
print(
189192
f"Onnx outputs error vs pytorch: {np.mean(np.abs(enhanced_onnx - enhanced_pytorch))}"
190193
)
@@ -193,6 +196,7 @@ def output_test() -> None:
193196

194197
print("onnx MAE:", abs_diff.mean())
195198
print("onnx median abs diff:", np.median(abs_diff))
199+
196200
print(
197201
f"Tflite outputs error vs pytorch: {np.mean(np.abs(enhanced_tflite - enhanced_pytorch))}"
198202
)

gtcrn_micro/utils/tflite_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,6 @@ def tflite_stream_infer(x: torch.Tensor, model_path: Path):
363363
print(f"{o}\n")
364364
print("-" * 20)
365365

366-
# prefix = "serving_default_"
367366
# getting the input tensor indexes
368367
audio_in = _pick(in_details, "audio", default_idx=0)
369368
conv_in = _pick(in_details, "conv_cache", default_idx=1)

0 commit comments

Comments
 (0)