From d69940ef365d2e9410a19f967771e2231a24a707 Mon Sep 17 00:00:00 2001 From: BO-NIAL Date: Fri, 28 Jul 2023 11:09:27 +0200 Subject: [PATCH] amplitude correction --- DTLN_model.py | 6 ++++-- real_time_dtln_audio.py | 5 +++++ real_time_processing.py | 5 +++++ real_time_processing_onnx.py | 5 +++++ real_time_processing_tf_lite.py | 5 ++++- 5 files changed, 23 insertions(+), 3 deletions(-) diff --git a/DTLN_model.py b/DTLN_model.py index a52584a..920a329 100644 --- a/DTLN_model.py +++ b/DTLN_model.py @@ -248,14 +248,16 @@ def ifftLayer(self, x): tf.exp( (1j * tf.cast(x[1], tf.complex64)))) # returning the time domain frames return tf.signal.irfft(s1_stft) - - + def overlapAddLayer(self, x): ''' Method for an overlap and add helper layer used with a Lambda layer. This layer reconstructs the waveform from a framed signal. ''' + #if more than 50% overlap, add scale factor to keep same amplitude as the input signal + if self.block_shift/self.blockLen < 1/2: + x *= (self.block_shift/self.blockLen) # calculating and returning the reconstructed waveform return tf.signal.overlap_and_add(x, self.block_shift) diff --git a/real_time_dtln_audio.py b/real_time_dtln_audio.py index b02ced5..2fba85a 100644 --- a/real_time_dtln_audio.py +++ b/real_time_dtln_audio.py @@ -122,6 +122,11 @@ def callback(indata, outdata, frames, time, status): # get output tensors out_block = interpreter_2.get_tensor(output_details_2[0]['index']) states_2 = interpreter_2.get_tensor(output_details_2[1]['index']) + + #if more than 50% overlap, add scale factor to keep same amplitude as the input signal + if block_shift/block_len < 1/2: + out_block *= (block_shift/block_len) + # write to buffer out_buffer[:-block_shift] = out_buffer[block_shift:] out_buffer[-block_shift:] = np.zeros((block_shift)) diff --git a/real_time_processing.py b/real_time_processing.py index aa089a9..86229c7 100644 --- a/real_time_processing.py +++ b/real_time_processing.py @@ -41,6 +41,11 @@ in_block = np.expand_dims(in_buffer, axis=0).astype('float32') # process one block out_block= infer(tf.constant(in_block))['conv1d_1'] + + #if more than 50% overlap, add scale factor to keep same amplitude as the input signal + if block_shift/block_len < 1/2: + out_block *= (block_shift/block_len) + # shift values and write to buffer out_buffer[:-block_shift] = out_buffer[block_shift:] out_buffer[-block_shift:] = np.zeros((block_shift)) diff --git a/real_time_processing_onnx.py b/real_time_processing_onnx.py index 27ee207..864094c 100644 --- a/real_time_processing_onnx.py +++ b/real_time_processing_onnx.py @@ -91,6 +91,11 @@ out_block = model_outputs_2[0] # set out states back to input model_inputs_2[model_input_names_2[1]] = model_outputs_2[1] + + #if more than 50% overlap, add scale factor to keep same amplitude as the input signal + if block_shift/block_len < 1/2: + out_block *= (block_shift/block_len) + # shift values and write to buffer out_buffer[:-block_shift] = out_buffer[block_shift:] out_buffer[-block_shift:] = np.zeros((block_shift)) diff --git a/real_time_processing_tf_lite.py b/real_time_processing_tf_lite.py index da63ec9..0e89c08 100644 --- a/real_time_processing_tf_lite.py +++ b/real_time_processing_tf_lite.py @@ -85,7 +85,10 @@ # get output tensors out_block = interpreter_2.get_tensor(output_details_2[0]['index']) states_2 = interpreter_2.get_tensor(output_details_2[1]['index']) - + + #if more than 50% overlap, add scale factor to keep same amplitude as the input signal + if block_shift/block_len < 1/2: + out_block *= (block_shift/block_len) # shift values and write to buffer out_buffer[:-block_shift] = out_buffer[block_shift:]