diff --git a/basic_pitch/layers/signal.py b/basic_pitch/layers/signal.py index f5f629f..b849209 100644 --- a/basic_pitch/layers/signal.py +++ b/basic_pitch/layers/signal.py @@ -183,3 +183,40 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor: log_power_normalized = tf.math.divide_no_nan(log_power_offset, log_power_offset_max) return tf.reshape(log_power_normalized, tf.shape(inputs)) + + +class GlobalNormalizedLog(tf.keras.layers.Layer): + """ + Takes an input with shape (batch, freq_bins, time) and rescales to dB, scaled 0 - 1. + The dB range is determined empirically from the CQT implementation by testing frequencies + from 13.8Hz to 8820Hz: + - For unit amplitude input (1.0), max dB is 54.7 (occurs at ~30.3 Hz) + - For unit amplitude input (1.0), min peak dB is -33.9 (occurs at ~8820 Hz) + - For 0.01 amplitude input, min peak dB is -73.9 (occurs at ~8820 Hz) + - The 40 dB difference between 1.0 and 0.01 amplitude inputs is preserved across all frequencies + """ + + def __init__(self): + super().__init__() + # Minimum peak dB observed for a 0.01 amplitude sinusoid + # This ensures we don't clip actual signal content while still suppressing noise + self.min_db = -74.0 + # Maximum dB value determined empirically from CQT of unit amplitude sinusoid across frequencies + self.max_db = 54.7 + + def build(self, input_shape: tf.Tensor) -> None: + super().build(input_shape) + + def call(self, inputs: tf.Tensor) -> tf.Tensor: + # convert magnitude to power + power = tf.math.square(inputs) + log_power = 10 * log_base_b(power + 1e-10, 10) + + # Values should (!) now between self.min_db and self.max_db, assert this and normalize to [0, 1] range + # Don't know how to assert this in a Keras model, so the user of this model has to check this and potentially clip + # tf.debugging.assert_greater_equal(tf.math.reduce_min(log_power), self.min_db) + # tf.debugging.assert_less_equal(tf.math.reduce_max(log_power), self.max_db) + log_power = tf.clip_by_value(log_power, self.min_db, self.max_db) + normalized = (log_power - self.min_db) / (self.max_db - self.min_db) + + return normalized diff --git a/basic_pitch/models.py b/basic_pitch/models.py index 0e72651..e667dd2 100644 --- a/basic_pitch/models.py +++ b/basic_pitch/models.py @@ -145,7 +145,7 @@ def get_cqt(inputs: tf.Tensor, n_harmonics: int, use_batchnorm: bool) -> tf.Tens n_harmonics: The number of harmonics to capture above the maximum output frequency. Used to calculate the number of semitones for the CQT. use_batchnorm: If True, applies batch normalization after computing the CQT - + use_fixed_norm: If True, applies fixed normalization after computing the CQT instead of min-max Returns: The log-normalized CQT of the input audio. """ diff --git a/modify_model.py b/modify_model.py new file mode 100644 index 0000000..7ad4d47 --- /dev/null +++ b/modify_model.py @@ -0,0 +1,47 @@ +import tensorflow as tf +import numpy as np +from basic_pitch.layers.signal import GlobalNormalizedLog +from basic_pitch.models import transcription_loss, weighted_transcription_loss +from basic_pitch.nn import FlattenAudioCh, FlattenFreqCh, HarmonicStacking +from basic_pitch.layers import nnaudio, signal +from basic_pitch import ICASSP_2022_MODEL_PATH + +MODIFIED_MODEL_PATH= "icassp_2022_model_modified" + + +def modify_model(input_model_path, output_model_path): + # Register custom objects + custom_objects = { + 'FlattenAudioCh': FlattenAudioCh, + 'FlattenFreqCh': FlattenFreqCh, + 'HarmonicStacking': HarmonicStacking, + 'CQT': nnaudio.CQT, + 'NormalizedLog': signal.NormalizedLog, + 'transcription_loss': transcription_loss, + 'weighted_transcription_loss': weighted_transcription_loss, + '': lambda x, y: transcription_loss(x, y, label_smoothing=0.2), + } + + # Load the original model with custom objects + with tf.keras.utils.custom_object_scope(custom_objects): + original_model = tf.keras.models.load_model(input_model_path) + + # Tail model incorporating everything after CQT normalization + # Check position of normalized log layer + assert isinstance(original_model.layers[3], signal.NormalizedLog) + head_model = tf.keras.Model(inputs=original_model.inputs, outputs=original_model.layers[2].output) + tail_model = tf.keras.Model(inputs=original_model.layers[4].input, outputs=original_model.outputs) + + # Create a new model using head + global normalized log + tail + inputs = tf.keras.Input(shape=head_model.inputs[0].shape[1:]) + x = head_model(inputs) + x = GlobalNormalizedLog()(x) + x = tail_model(x) + x = {"contour": x[0], "note": x[1], "onset": x[2]} + new_model = tf.keras.Model(inputs=inputs, outputs=x) + new_model.save(output_model_path) + + +if __name__ == "__main__": + input_model_path = str(ICASSP_2022_MODEL_PATH) + modify_model(input_model_path, MODIFIED_MODEL_PATH) \ No newline at end of file diff --git a/test_cqt_range.py b/test_cqt_range.py new file mode 100644 index 0000000..de06af1 --- /dev/null +++ b/test_cqt_range.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# encoding: utf-8 + +import numpy as np +import tensorflow as tf +from basic_pitch.layers import nnaudio +from basic_pitch.nn import FlattenAudioCh +from basic_pitch.constants import ( + AUDIO_SAMPLE_RATE, + FFT_HOP, + ANNOTATIONS_BASE_FREQUENCY, + CONTOURS_BINS_PER_SEMITONE, + AUDIO_N_SAMPLES, +) +from basic_pitch.layers.math import log_base_b + +def create_sinusoid(freq=440, duration=None, amplitude=1.0): + """Create a sinusoid with given frequency, duration, and amplitude.""" + if duration is None: + # Use the exact number of samples expected by the model + n_samples = AUDIO_N_SAMPLES + else: + n_samples = int(duration * AUDIO_SAMPLE_RATE) + t = np.linspace(0, n_samples/AUDIO_SAMPLE_RATE, n_samples) + return amplitude * np.sin(2 * np.pi * freq * t) + +def process_audio(audio): + """Process audio through CQT and return dB values.""" + # Reshape audio to match expected format (batch, time, channels) + audio = tf.convert_to_tensor(audio, dtype=tf.float32) + audio = tf.reshape(audio, (1, -1, 1)) + + # Create the processing pipeline + flatten = FlattenAudioCh() + cqt = nnaudio.CQT( + sr=AUDIO_SAMPLE_RATE, + hop_length=FFT_HOP, + fmin=ANNOTATIONS_BASE_FREQUENCY, + n_bins=84, # Standard number of bins + bins_per_octave=12 * CONTOURS_BINS_PER_SEMITONE, + pad_mode="constant", # Use constant padding instead of reflect + ) + + # Process the audio + x = flatten(audio) + x = cqt(x) + + # Convert to power and then to dB + power = tf.math.square(x) + log_power = 10 * log_base_b(power + 1e-10, 10) + + return log_power + +def analyze_frequency_range(min_freq, max_freq, n_freqs, amplitude): + """Analyze CQT response across a range of frequencies.""" + freqs = np.geomspace(min_freq, max_freq, n_freqs) # Use geometric spacing for musical frequencies + min_peak_db = float('inf') # Minimum of the peak responses + max_db = float('-inf') + mean_db = 0 + max_freq_found = 0 + min_peak_freq_found = 0 + + print(f"\nAnalyzing frequencies from {min_freq:.1f}Hz to {max_freq:.1f}Hz at amplitude {amplitude}:") + for freq in freqs: + signal = create_sinusoid(freq=freq, amplitude=amplitude) + db_values = process_audio(signal) + + # Find the peak response for this frequency + curr_peak = float(tf.reduce_max(db_values)) + curr_mean = float(tf.reduce_mean(db_values)) + + if curr_peak > max_db: + max_db = curr_peak + max_freq_found = freq + if curr_peak < min_peak_db: + min_peak_db = curr_peak + min_peak_freq_found = freq + mean_db += curr_mean + + mean_db /= len(freqs) + + print(f" Minimum peak dB: {min_peak_db:.2f} (at {min_peak_freq_found:.1f}Hz)") + print(f" Maximum dB: {max_db:.2f} (at {max_freq_found:.1f}Hz)") + print(f" Mean dB: {mean_db:.2f}") + + return min_peak_db, max_db, mean_db, min_peak_freq_found, max_freq_found + +def main(): + # Test frequencies from just below the base frequency to just below Nyquist + min_freq = ANNOTATIONS_BASE_FREQUENCY / 2 # Test below the base frequency + max_freq = AUDIO_SAMPLE_RATE / 2.5 # Stay comfortably below Nyquist + n_freqs = 50 # Number of frequencies to test + + # Test both low and high amplitude signals + print("\n=== Testing frequency response across the spectrum ===") + low_results = analyze_frequency_range(min_freq, max_freq, n_freqs, amplitude=0.01) + high_results = analyze_frequency_range(min_freq, max_freq, n_freqs, amplitude=1.0) + + # Print the overall extremes + print("\n=== Overall Extremes ===") + print(f"Minimum peak dB: {min(low_results[0], high_results[0]):.2f}") + print(f"Absolute maximum dB: {max(low_results[1], high_results[1]):.2f}") + + # Calculate the maximum dB difference between amplitudes + db_difference = high_results[1] - low_results[1] + print(f"\nMaximum dB difference between amplitudes: {db_difference:.2f}") + print(f"Expected dB difference: {20 * np.log10(1.0/0.01):.2f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_normalization.py b/test_normalization.py new file mode 100644 index 0000000..d7db68c --- /dev/null +++ b/test_normalization.py @@ -0,0 +1,186 @@ +import numpy as np +import tensorflow as tf +import argparse +from basic_pitch.layers.signal import GlobalNormalizedLog, NormalizedLog +from basic_pitch.layers import nnaudio +from basic_pitch.constants import ( + AUDIO_SAMPLE_RATE, FFT_HOP, ANNOTATIONS_BASE_FREQUENCY, + AUDIO_N_SAMPLES, AUDIO_WINDOW_LENGTH, CONTOURS_BINS_PER_SEMITONE +) +from basic_pitch.models import get_cqt, transcription_loss, weighted_transcription_loss +from basic_pitch.nn import FlattenAudioCh, FlattenFreqCh, HarmonicStacking +import matplotlib.pyplot as plt +from modify_model import MODIFIED_MODEL_PATH +from basic_pitch import ICASSP_2022_MODEL_PATH +import os + +def create_increasing_sinusoid(duration_seconds=10, freq=440, sample_rate=AUDIO_SAMPLE_RATE): + """Create a sinusoid with linearly increasing amplitude""" + t = np.linspace(0, duration_seconds, int(duration_seconds * sample_rate)) + # Linear ramp from 0.01 to 1.0 + amplitude = np.linspace(0.01, 1.0, len(t)) + signal = amplitude * np.sin(2 * np.pi * freq * t) + + # Print some diagnostic information + print("\nInput signal properties:") + print(f"Signal shape: {signal.shape}") + print(f"Amplitude range: {np.min(amplitude):.3f} to {np.max(amplitude):.3f}") + print(f"Signal range: {np.min(signal):.3f} to {np.max(signal):.3f}") + print(f"Mean amplitude: {np.mean(amplitude):.3f}") + + return signal + +def process_chunk(chunk, model): + """Process a single chunk through the full model""" + # Ensure chunk has correct size + if len(chunk) != AUDIO_N_SAMPLES: + # Pad or truncate to correct size + if len(chunk) < AUDIO_N_SAMPLES: + chunk = np.pad(chunk, (0, AUDIO_N_SAMPLES - len(chunk))) + else: + chunk = chunk[:AUDIO_N_SAMPLES] + + # Add batch and channel dimensions + chunk = tf.convert_to_tensor(chunk.reshape(1, -1, 1), dtype=tf.float32) + + # Get model outputs + outputs = model(chunk) + + return { + 'contours': outputs['contour'].numpy(), + 'notes': outputs['note'].numpy(), + 'onsets': outputs['onset'].numpy() + } + +def plot_results(results, signal, time, freq, save_path): + """Plot the results comparing original and modified model outputs.""" + # Concatenate contour outputs from all chunks and transpose to (freq_bins, time) + all_contours = np.concatenate([r['contours'][0] for r in results], axis=0).T + all_notes = np.concatenate([r['notes'][0] for r in results], axis=0).T + all_onsets = np.concatenate([r['onsets'][0] for r in results], axis=0).T + + # Calculate time points for the x-axis of the model outputs + hop_time = FFT_HOP / AUDIO_SAMPLE_RATE # time between frames in seconds + output_time = np.arange(all_contours.shape[1]) * hop_time + + # Create the figure and grid + fig = plt.figure(figsize=(15, 12)) + gs = fig.add_gridspec(3, 2, height_ratios=[1, 2, 1]) + fig.suptitle('Comparison of Original vs Modified Model Outputs') + + # Plot input signal + ax = fig.add_subplot(gs[0, 0]) + ax.plot(time, signal) + ax.set_title('Input Signal') + ax.set_xlabel('Time (s)') + ax.set_ylabel('Amplitude') + + # Print shapes for debugging + print("\nConcatenated shapes:") + print(f"All contours shape: {all_contours.shape}") + print(f"All notes shape: {all_notes.shape}") + print(f"All onsets shape: {all_onsets.shape}") + + # Plot contour outputs as spectrograms + ax = fig.add_subplot(gs[1, :]) + im = ax.imshow( + all_contours, + aspect='auto', + origin='lower', + extent=[0, time[-1], 0, all_contours.shape[0]], + cmap='magma' + ) + ax.set_title('Model Contours') + ax.set_xlabel('Time (s)') + ax.set_ylabel('Pitch (bins)') + plt.colorbar(im, ax=ax, label='Contour Value') + + # Plot contour values at the frequency bin closest to our input frequency + ax = fig.add_subplot(gs[2, 0]) + # Calculate the frequency bin more accurately + # ANNOTATIONS_BASE_FREQUENCY is the base frequency (e.g. C0) + # CONTOURS_BINS_PER_SEMITONE determines the resolution (bins per semitone) + # First calculate semitones from base frequency, then multiply by bins per semitone + semitones_from_base = 12 * np.log2(freq / ANNOTATIONS_BASE_FREQUENCY) + freq_bin = int(np.round(semitones_from_base * CONTOURS_BINS_PER_SEMITONE)) + print(f"\nFrequency bin calculation:") + print(f"Input frequency: {freq} Hz") + print(f"Base frequency: {ANNOTATIONS_BASE_FREQUENCY} Hz") + print(f"Semitones from base: {semitones_from_base}") + print(f"Bins per semitone: {CONTOURS_BINS_PER_SEMITONE}") + print(f"Calculated freq_bin: {freq_bin}") + print(f"Total number of bins: {all_contours.shape[0]}") + + # Find bin with highest mean activation + mean_activations = np.mean(all_contours, axis=1) + max_bin = np.argmax(mean_activations) + print(f"Bin with highest mean activation: {max_bin}") + print(f"Mean activation at calculated bin: {mean_activations[freq_bin]}") + print(f"Mean activation at max bin: {mean_activations[max_bin]}") + + # Plot both the calculated bin and the bin with highest activation + ax.plot(output_time, all_contours[max_bin, :], label=f'Bin {max_bin} (highest activation)') + ax.plot(output_time, all_contours[freq_bin, :], '--', label=f'Bin {freq_bin} (calculated)') + ax.set_title('Note contour values at Input Frequency') + ax.set_xlabel('Time (s)') + ax.set_ylabel("Contour Value") + ax.legend() + + plt.tight_layout() + plt.savefig(save_path) + plt.close() + +def main(): + parser = argparse.ArgumentParser(description='Test normalization methods') + parser.add_argument('--normalization', choices=['original', 'global'], default='global', + help='Which normalization method to use (original or global)') + args = parser.parse_args() + + # Create the test signal + duration = 10 # seconds + freq = 440 # Hz + signal = create_increasing_sinusoid(duration, freq) + + # Register custom objects + custom_objects = { + 'FlattenAudioCh': FlattenAudioCh, + 'FlattenFreqCh': FlattenFreqCh, + 'HarmonicStacking': HarmonicStacking, + 'CQT': nnaudio.CQT, + 'NormalizedLog': NormalizedLog, + 'GlobalNormalizedLog': GlobalNormalizedLog, + 'transcription_loss': transcription_loss, + 'weighted_transcription_loss': weighted_transcription_loss, + '': lambda x, y: transcription_loss(x, y, label_smoothing=0.2), + } + + # Load the appropriate model + with tf.keras.utils.custom_object_scope(custom_objects): + if args.normalization == 'global': + # Load original model first to create modified model + model = tf.keras.models.load_model(str(MODIFIED_MODEL_PATH)) + else: + model = tf.keras.models.load_model(str(ICASSP_2022_MODEL_PATH)) + + # Process in 2-second chunks + chunk_size = AUDIO_N_SAMPLES + n_chunks = len(signal) // chunk_size + + # Store results for each chunk + results = [] + for i in range(n_chunks): + start = i * chunk_size + end = start + chunk_size + chunk = signal[start:end] + + chunk_results = process_chunk(chunk, model) + results.append(chunk_results) + + # Create time array for plotting + time = np.linspace(0, duration, len(signal)) + + # Plot results + plot_results(results, signal, time, freq, f'normalization_test_{args.normalization}.png') + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_global_normalized_log.py b/tests/test_global_normalized_log.py new file mode 100644 index 0000000..740cf22 --- /dev/null +++ b/tests/test_global_normalized_log.py @@ -0,0 +1,162 @@ +import unittest +import numpy as np +import tensorflow as tf +from basic_pitch.layers.signal import GlobalNormalizedLog +from basic_pitch.layers import nnaudio +from basic_pitch.nn import FlattenAudioCh +from basic_pitch.constants import ( + AUDIO_SAMPLE_RATE, + FFT_HOP, + ANNOTATIONS_BASE_FREQUENCY, + CONTOURS_BINS_PER_SEMITONE, +) + +class TestGlobalNormalizedLog(unittest.TestCase): + def setUp(self): + # Create a 10-second sinusoid with increasing amplitude + duration = 10 # seconds + freq = 440 # Hz + t = np.linspace(0, duration, int(duration * AUDIO_SAMPLE_RATE)) + amplitude = np.linspace(0.01, 1.0, len(t)) + self.signal = amplitude * np.sin(2 * np.pi * freq * t) + + # Create the processing pipeline + self.flatten = FlattenAudioCh() + self.cqt = nnaudio.CQT( + sr=AUDIO_SAMPLE_RATE, + hop_length=FFT_HOP, + fmin=ANNOTATIONS_BASE_FREQUENCY, + n_bins=84, # Standard number of bins + bins_per_octave=12 * CONTOURS_BINS_PER_SEMITONE, + ) + self.global_norm = GlobalNormalizedLog() + + def test_global_normalized_log(self): + # Process the signal through the pipeline + x = self.flatten(self.signal[np.newaxis, :, np.newaxis]) + cqt_output = self.cqt(x) + + # Print CQT output properties + print("\nCQT output properties:") + print(f"CQT shape: {cqt_output.shape}") + print(f"CQT range: {np.min(cqt_output):.3f} to {np.max(cqt_output):.3f}") + + # Convert to dB for verification + power = np.square(cqt_output) + db = 10 * np.log10(power + 1e-10) + print(f"dB range: {np.min(db):.3f} to {np.max(db):.3f}") + + # Apply GlobalNormalizedLog + normalized_output = self.global_norm(cqt_output) + + # Print normalized output properties + print("\nNormalized output properties:") + print(f"Normalized shape: {normalized_output.shape}") + print(f"Normalized range: {np.min(normalized_output):.3f} to {np.max(normalized_output):.3f}") + + # Verify output range is in [0,1] + self.assertTrue(np.all(normalized_output >= 0)) + self.assertTrue(np.all(normalized_output <= 1)) + + # Verify output shape matches input shape + self.assertEqual(normalized_output.shape, cqt_output.shape) + + # Verify dB values are properly clamped + db_after_norm = self.global_norm.min_db + normalized_output * (self.global_norm.max_db - self.global_norm.min_db) + self.assertTrue(np.all(db_after_norm >= self.global_norm.min_db)) + self.assertTrue(np.all(db_after_norm <= self.global_norm.max_db)) + + def test_global_normalized_log_with_clipping(self): + # Create a test signal with values that would be affected by clipping + # We'll create a signal with some very small values (which will give very negative dB values) + # and some very large values (which will give positive dB values) + duration = 5 # seconds - increased from 1 to 5 to ensure enough samples for CQT + t = np.linspace(0, duration, int(duration * AUDIO_SAMPLE_RATE)) + + # Create a signal with varying amplitudes + signal = np.zeros_like(t) + signal[:len(t)//3] = 0.0001 # Very small values + signal[len(t)//3:2*len(t)//3] = 1.0 # Normal values + signal[2*len(t)//3:] = 10.0 # Very large values + + # Process through pipeline + x = self.flatten(signal[np.newaxis, :, np.newaxis]) + cqt_output = self.cqt(x) + + # Print CQT output properties + print("\nCQT output properties (with clipping test):") + print(f"CQT shape: {cqt_output.shape}") + print(f"CQT range: {np.min(cqt_output):.3f} to {np.max(cqt_output):.3f}") + + # Convert to dB for verification + power = np.square(cqt_output) + db = 10 * np.log10(power + 1e-10) + print(f"dB range: {np.min(db):.3f} to {np.max(db):.3f}") + + # Apply GlobalNormalizedLog + normalized_output = self.global_norm(cqt_output) + + # Print normalized output properties + print("\nNormalized output properties (with clipping test):") + print(f"Normalized shape: {normalized_output.shape}") + print(f"Normalized range: {np.min(normalized_output):.3f} to {np.max(normalized_output):.3f}") + + # Verify output range is in [0,1] + self.assertTrue(np.all(normalized_output >= 0)) + self.assertTrue(np.all(normalized_output <= 1)) + + # Verify output shape matches input shape + self.assertEqual(normalized_output.shape, cqt_output.shape) + + # Verify dB values are properly clamped + db_after_norm = self.global_norm.min_db + normalized_output * (self.global_norm.max_db - self.global_norm.min_db) + self.assertTrue(np.all(db_after_norm >= self.global_norm.min_db)) + self.assertTrue(np.all(db_after_norm <= self.global_norm.max_db)) + + def test_global_normalized_log_direct(self): + """Test GlobalNormalizedLog directly with values that would be affected by clipping""" + # Create test values that would give dB values outside the [-100, 0] range + test_values = np.array([ + [0.00001], # Very quiet: ~-100 dB + [0.1], # Quiet: ~-20 dB + [1.0], # Normal: 0 dB + [10.0], # Loud: +20 dB + [100.0], # Very loud: +40 dB + ]) + + # Convert to dB for verification + power = np.square(test_values) + db = 10 * np.log10(power + 1e-10) + print("\nDirect test dB values:") + print(f"Input values: {test_values.flatten()}") + print(f"Corresponding dB values: {db.flatten()}") + + # Apply GlobalNormalizedLog + normalized_output = self.global_norm(test_values) + normalized_output_np = normalized_output.numpy() + + print("\nNormalized values:") + print(f"Output: {normalized_output_np.flatten()}") + + # Convert back to dB to verify clamping + db_after_norm = self.global_norm.min_db + normalized_output_np * (self.global_norm.max_db - self.global_norm.min_db) + print(f"dB after normalization: {db_after_norm.flatten()}") + + # Verify output range is in [0,1] + self.assertTrue(np.all(normalized_output_np >= 0)) + self.assertTrue(np.all(normalized_output_np <= 1)) + + # Verify dB values are properly clamped + self.assertTrue(np.all(db_after_norm >= self.global_norm.min_db)) + self.assertTrue(np.all(db_after_norm <= self.global_norm.max_db)) + + # Verify that values outside the range are clamped + # Extract single elements properly to avoid deprecation warnings + self.assertAlmostEqual(normalized_output_np[0, 0], 0.0, places=2) # Very quiet (-97 dB) maps to 0.0 + self.assertAlmostEqual(normalized_output_np[1, 0], 0.42, places=2) # -20 dB maps to ~0.42 + self.assertAlmostEqual(normalized_output_np[2, 0], 0.57, places=2) # 0 dB maps to ~0.57 + self.assertAlmostEqual(normalized_output_np[3, 0], 0.73, places=2) # +20 dB maps to ~0.73 + self.assertAlmostEqual(normalized_output_np[4, 0], 0.89, places=2) # +40 dB maps to ~0.89 + +if __name__ == '__main__': + unittest.main() \ No newline at end of file