Skip to content

Add a variant of basic pitch with global CQT normalization #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions basic_pitch/layers/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,40 @@ def call(self, inputs: tf.Tensor) -> tf.Tensor:
log_power_normalized = tf.math.divide_no_nan(log_power_offset, log_power_offset_max)

return tf.reshape(log_power_normalized, tf.shape(inputs))


class GlobalNormalizedLog(tf.keras.layers.Layer):
"""
Takes an input with shape (batch, freq_bins, time) and rescales to dB, scaled 0 - 1.
The dB range is determined empirically from the CQT implementation by testing frequencies
from 13.8Hz to 8820Hz:
- For unit amplitude input (1.0), max dB is 54.7 (occurs at ~30.3 Hz)
- For unit amplitude input (1.0), min peak dB is -33.9 (occurs at ~8820 Hz)
- For 0.01 amplitude input, min peak dB is -73.9 (occurs at ~8820 Hz)
- The 40 dB difference between 1.0 and 0.01 amplitude inputs is preserved across all frequencies
"""

def __init__(self):
super().__init__()
# Minimum peak dB observed for a 0.01 amplitude sinusoid
# This ensures we don't clip actual signal content while still suppressing noise
self.min_db = -74.0
# Maximum dB value determined empirically from CQT of unit amplitude sinusoid across frequencies
self.max_db = 54.7

def build(self, input_shape: tf.Tensor) -> None:
super().build(input_shape)

def call(self, inputs: tf.Tensor) -> tf.Tensor:
# convert magnitude to power
power = tf.math.square(inputs)
log_power = 10 * log_base_b(power + 1e-10, 10)

# Values should (!) now between self.min_db and self.max_db, assert this and normalize to [0, 1] range
# Don't know how to assert this in a Keras model, so the user of this model has to check this and potentially clip
# tf.debugging.assert_greater_equal(tf.math.reduce_min(log_power), self.min_db)
# tf.debugging.assert_less_equal(tf.math.reduce_max(log_power), self.max_db)
log_power = tf.clip_by_value(log_power, self.min_db, self.max_db)
normalized = (log_power - self.min_db) / (self.max_db - self.min_db)

return normalized
2 changes: 1 addition & 1 deletion basic_pitch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_cqt(inputs: tf.Tensor, n_harmonics: int, use_batchnorm: bool) -> tf.Tens
n_harmonics: The number of harmonics to capture above the maximum output frequency.
Used to calculate the number of semitones for the CQT.
use_batchnorm: If True, applies batch normalization after computing the CQT

use_fixed_norm: If True, applies fixed normalization after computing the CQT instead of min-max
Returns:
The log-normalized CQT of the input audio.
"""
Expand Down
47 changes: 47 additions & 0 deletions modify_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import tensorflow as tf
import numpy as np
from basic_pitch.layers.signal import GlobalNormalizedLog
from basic_pitch.models import transcription_loss, weighted_transcription_loss
from basic_pitch.nn import FlattenAudioCh, FlattenFreqCh, HarmonicStacking
from basic_pitch.layers import nnaudio, signal
from basic_pitch import ICASSP_2022_MODEL_PATH

MODIFIED_MODEL_PATH= "icassp_2022_model_modified"


def modify_model(input_model_path, output_model_path):
# Register custom objects
custom_objects = {
'FlattenAudioCh': FlattenAudioCh,
'FlattenFreqCh': FlattenFreqCh,
'HarmonicStacking': HarmonicStacking,
'CQT': nnaudio.CQT,
'NormalizedLog': signal.NormalizedLog,
'transcription_loss': transcription_loss,
'weighted_transcription_loss': weighted_transcription_loss,
'<lambda>': lambda x, y: transcription_loss(x, y, label_smoothing=0.2),
}

# Load the original model with custom objects
with tf.keras.utils.custom_object_scope(custom_objects):
original_model = tf.keras.models.load_model(input_model_path)

# Tail model incorporating everything after CQT normalization
# Check position of normalized log layer
assert isinstance(original_model.layers[3], signal.NormalizedLog)
head_model = tf.keras.Model(inputs=original_model.inputs, outputs=original_model.layers[2].output)
tail_model = tf.keras.Model(inputs=original_model.layers[4].input, outputs=original_model.outputs)

# Create a new model using head + global normalized log + tail
inputs = tf.keras.Input(shape=head_model.inputs[0].shape[1:])
x = head_model(inputs)
x = GlobalNormalizedLog()(x)
x = tail_model(x)
x = {"contour": x[0], "note": x[1], "onset": x[2]}
new_model = tf.keras.Model(inputs=inputs, outputs=x)
new_model.save(output_model_path)


if __name__ == "__main__":
input_model_path = str(ICASSP_2022_MODEL_PATH)
modify_model(input_model_path, MODIFIED_MODEL_PATH)
110 changes: 110 additions & 0 deletions test_cqt_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/usr/bin/env python
# encoding: utf-8

import numpy as np
import tensorflow as tf
from basic_pitch.layers import nnaudio
from basic_pitch.nn import FlattenAudioCh
from basic_pitch.constants import (
AUDIO_SAMPLE_RATE,
FFT_HOP,
ANNOTATIONS_BASE_FREQUENCY,
CONTOURS_BINS_PER_SEMITONE,
AUDIO_N_SAMPLES,
)
from basic_pitch.layers.math import log_base_b

def create_sinusoid(freq=440, duration=None, amplitude=1.0):
"""Create a sinusoid with given frequency, duration, and amplitude."""
if duration is None:
# Use the exact number of samples expected by the model
n_samples = AUDIO_N_SAMPLES
else:
n_samples = int(duration * AUDIO_SAMPLE_RATE)
t = np.linspace(0, n_samples/AUDIO_SAMPLE_RATE, n_samples)
return amplitude * np.sin(2 * np.pi * freq * t)

def process_audio(audio):
"""Process audio through CQT and return dB values."""
# Reshape audio to match expected format (batch, time, channels)
audio = tf.convert_to_tensor(audio, dtype=tf.float32)
audio = tf.reshape(audio, (1, -1, 1))

# Create the processing pipeline
flatten = FlattenAudioCh()
cqt = nnaudio.CQT(
sr=AUDIO_SAMPLE_RATE,
hop_length=FFT_HOP,
fmin=ANNOTATIONS_BASE_FREQUENCY,
n_bins=84, # Standard number of bins
bins_per_octave=12 * CONTOURS_BINS_PER_SEMITONE,
pad_mode="constant", # Use constant padding instead of reflect
)

# Process the audio
x = flatten(audio)
x = cqt(x)

# Convert to power and then to dB
power = tf.math.square(x)
log_power = 10 * log_base_b(power + 1e-10, 10)

return log_power

def analyze_frequency_range(min_freq, max_freq, n_freqs, amplitude):
"""Analyze CQT response across a range of frequencies."""
freqs = np.geomspace(min_freq, max_freq, n_freqs) # Use geometric spacing for musical frequencies
min_peak_db = float('inf') # Minimum of the peak responses
max_db = float('-inf')
mean_db = 0
max_freq_found = 0
min_peak_freq_found = 0

print(f"\nAnalyzing frequencies from {min_freq:.1f}Hz to {max_freq:.1f}Hz at amplitude {amplitude}:")
for freq in freqs:
signal = create_sinusoid(freq=freq, amplitude=amplitude)
db_values = process_audio(signal)

# Find the peak response for this frequency
curr_peak = float(tf.reduce_max(db_values))
curr_mean = float(tf.reduce_mean(db_values))

if curr_peak > max_db:
max_db = curr_peak
max_freq_found = freq
if curr_peak < min_peak_db:
min_peak_db = curr_peak
min_peak_freq_found = freq
mean_db += curr_mean

mean_db /= len(freqs)

print(f" Minimum peak dB: {min_peak_db:.2f} (at {min_peak_freq_found:.1f}Hz)")
print(f" Maximum dB: {max_db:.2f} (at {max_freq_found:.1f}Hz)")
print(f" Mean dB: {mean_db:.2f}")

return min_peak_db, max_db, mean_db, min_peak_freq_found, max_freq_found

def main():
# Test frequencies from just below the base frequency to just below Nyquist
min_freq = ANNOTATIONS_BASE_FREQUENCY / 2 # Test below the base frequency
max_freq = AUDIO_SAMPLE_RATE / 2.5 # Stay comfortably below Nyquist
n_freqs = 50 # Number of frequencies to test

# Test both low and high amplitude signals
print("\n=== Testing frequency response across the spectrum ===")
low_results = analyze_frequency_range(min_freq, max_freq, n_freqs, amplitude=0.01)
high_results = analyze_frequency_range(min_freq, max_freq, n_freqs, amplitude=1.0)

# Print the overall extremes
print("\n=== Overall Extremes ===")
print(f"Minimum peak dB: {min(low_results[0], high_results[0]):.2f}")
print(f"Absolute maximum dB: {max(low_results[1], high_results[1]):.2f}")

# Calculate the maximum dB difference between amplitudes
db_difference = high_results[1] - low_results[1]
print(f"\nMaximum dB difference between amplitudes: {db_difference:.2f}")
print(f"Expected dB difference: {20 * np.log10(1.0/0.01):.2f}")

if __name__ == "__main__":
main()
186 changes: 186 additions & 0 deletions test_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import numpy as np
import tensorflow as tf
import argparse
from basic_pitch.layers.signal import GlobalNormalizedLog, NormalizedLog
from basic_pitch.layers import nnaudio
from basic_pitch.constants import (
AUDIO_SAMPLE_RATE, FFT_HOP, ANNOTATIONS_BASE_FREQUENCY,
AUDIO_N_SAMPLES, AUDIO_WINDOW_LENGTH, CONTOURS_BINS_PER_SEMITONE
)
from basic_pitch.models import get_cqt, transcription_loss, weighted_transcription_loss
from basic_pitch.nn import FlattenAudioCh, FlattenFreqCh, HarmonicStacking
import matplotlib.pyplot as plt
from modify_model import MODIFIED_MODEL_PATH
from basic_pitch import ICASSP_2022_MODEL_PATH
import os

def create_increasing_sinusoid(duration_seconds=10, freq=440, sample_rate=AUDIO_SAMPLE_RATE):
"""Create a sinusoid with linearly increasing amplitude"""
t = np.linspace(0, duration_seconds, int(duration_seconds * sample_rate))
# Linear ramp from 0.01 to 1.0
amplitude = np.linspace(0.01, 1.0, len(t))
signal = amplitude * np.sin(2 * np.pi * freq * t)

# Print some diagnostic information
print("\nInput signal properties:")
print(f"Signal shape: {signal.shape}")
print(f"Amplitude range: {np.min(amplitude):.3f} to {np.max(amplitude):.3f}")
print(f"Signal range: {np.min(signal):.3f} to {np.max(signal):.3f}")
print(f"Mean amplitude: {np.mean(amplitude):.3f}")

return signal

def process_chunk(chunk, model):
"""Process a single chunk through the full model"""
# Ensure chunk has correct size
if len(chunk) != AUDIO_N_SAMPLES:
# Pad or truncate to correct size
if len(chunk) < AUDIO_N_SAMPLES:
chunk = np.pad(chunk, (0, AUDIO_N_SAMPLES - len(chunk)))
else:
chunk = chunk[:AUDIO_N_SAMPLES]

# Add batch and channel dimensions
chunk = tf.convert_to_tensor(chunk.reshape(1, -1, 1), dtype=tf.float32)

# Get model outputs
outputs = model(chunk)

return {
'contours': outputs['contour'].numpy(),
'notes': outputs['note'].numpy(),
'onsets': outputs['onset'].numpy()
}

def plot_results(results, signal, time, freq, save_path):
"""Plot the results comparing original and modified model outputs."""
# Concatenate contour outputs from all chunks and transpose to (freq_bins, time)
all_contours = np.concatenate([r['contours'][0] for r in results], axis=0).T
all_notes = np.concatenate([r['notes'][0] for r in results], axis=0).T
all_onsets = np.concatenate([r['onsets'][0] for r in results], axis=0).T

# Calculate time points for the x-axis of the model outputs
hop_time = FFT_HOP / AUDIO_SAMPLE_RATE # time between frames in seconds
output_time = np.arange(all_contours.shape[1]) * hop_time

# Create the figure and grid
fig = plt.figure(figsize=(15, 12))
gs = fig.add_gridspec(3, 2, height_ratios=[1, 2, 1])
fig.suptitle('Comparison of Original vs Modified Model Outputs')

# Plot input signal
ax = fig.add_subplot(gs[0, 0])
ax.plot(time, signal)
ax.set_title('Input Signal')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude')

# Print shapes for debugging
print("\nConcatenated shapes:")
print(f"All contours shape: {all_contours.shape}")
print(f"All notes shape: {all_notes.shape}")
print(f"All onsets shape: {all_onsets.shape}")

# Plot contour outputs as spectrograms
ax = fig.add_subplot(gs[1, :])
im = ax.imshow(
all_contours,
aspect='auto',
origin='lower',
extent=[0, time[-1], 0, all_contours.shape[0]],
cmap='magma'
)
ax.set_title('Model Contours')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Pitch (bins)')
plt.colorbar(im, ax=ax, label='Contour Value')

# Plot contour values at the frequency bin closest to our input frequency
ax = fig.add_subplot(gs[2, 0])
# Calculate the frequency bin more accurately
# ANNOTATIONS_BASE_FREQUENCY is the base frequency (e.g. C0)
# CONTOURS_BINS_PER_SEMITONE determines the resolution (bins per semitone)
# First calculate semitones from base frequency, then multiply by bins per semitone
semitones_from_base = 12 * np.log2(freq / ANNOTATIONS_BASE_FREQUENCY)
freq_bin = int(np.round(semitones_from_base * CONTOURS_BINS_PER_SEMITONE))
print(f"\nFrequency bin calculation:")
print(f"Input frequency: {freq} Hz")
print(f"Base frequency: {ANNOTATIONS_BASE_FREQUENCY} Hz")
print(f"Semitones from base: {semitones_from_base}")
print(f"Bins per semitone: {CONTOURS_BINS_PER_SEMITONE}")
print(f"Calculated freq_bin: {freq_bin}")
print(f"Total number of bins: {all_contours.shape[0]}")

# Find bin with highest mean activation
mean_activations = np.mean(all_contours, axis=1)
max_bin = np.argmax(mean_activations)
print(f"Bin with highest mean activation: {max_bin}")
print(f"Mean activation at calculated bin: {mean_activations[freq_bin]}")
print(f"Mean activation at max bin: {mean_activations[max_bin]}")

# Plot both the calculated bin and the bin with highest activation
ax.plot(output_time, all_contours[max_bin, :], label=f'Bin {max_bin} (highest activation)')
ax.plot(output_time, all_contours[freq_bin, :], '--', label=f'Bin {freq_bin} (calculated)')
ax.set_title('Note contour values at Input Frequency')
ax.set_xlabel('Time (s)')
ax.set_ylabel("Contour Value")
ax.legend()

plt.tight_layout()
plt.savefig(save_path)
plt.close()

def main():
parser = argparse.ArgumentParser(description='Test normalization methods')
parser.add_argument('--normalization', choices=['original', 'global'], default='global',
help='Which normalization method to use (original or global)')
args = parser.parse_args()

# Create the test signal
duration = 10 # seconds
freq = 440 # Hz
signal = create_increasing_sinusoid(duration, freq)

# Register custom objects
custom_objects = {
'FlattenAudioCh': FlattenAudioCh,
'FlattenFreqCh': FlattenFreqCh,
'HarmonicStacking': HarmonicStacking,
'CQT': nnaudio.CQT,
'NormalizedLog': NormalizedLog,
'GlobalNormalizedLog': GlobalNormalizedLog,
'transcription_loss': transcription_loss,
'weighted_transcription_loss': weighted_transcription_loss,
'<lambda>': lambda x, y: transcription_loss(x, y, label_smoothing=0.2),
}

# Load the appropriate model
with tf.keras.utils.custom_object_scope(custom_objects):
if args.normalization == 'global':
# Load original model first to create modified model
model = tf.keras.models.load_model(str(MODIFIED_MODEL_PATH))
else:
model = tf.keras.models.load_model(str(ICASSP_2022_MODEL_PATH))

# Process in 2-second chunks
chunk_size = AUDIO_N_SAMPLES
n_chunks = len(signal) // chunk_size

# Store results for each chunk
results = []
for i in range(n_chunks):
start = i * chunk_size
end = start + chunk_size
chunk = signal[start:end]

chunk_results = process_chunk(chunk, model)
results.append(chunk_results)

# Create time array for plotting
time = np.linspace(0, duration, len(signal))

# Plot results
plot_results(results, signal, time, freq, f'normalization_test_{args.normalization}.png')

if __name__ == "__main__":
main()
Loading
Loading