diff --git a/basic_pitch/note_creation.py b/basic_pitch/note_creation.py index 0338b6f..37e123b 100644 --- a/basic_pitch/note_creation.py +++ b/basic_pitch/note_creation.py @@ -55,25 +55,34 @@ def model_output_to_notes( multiple_pitch_bends: bool = False, melodia_trick: bool = True, midi_tempo: float = 120, + pitch_offset_correction: float = 1.2, ) -> Tuple[pretty_midi.PrettyMIDI, List[Tuple[float, float, int, float, Optional[List[int]]]]]: - """Convert model output to MIDI + """Convert model output to MIDI * with corrected pitch mapping * Args: - output: A dictionary with shape - { - 'frame': array of shape (n_times, n_freqs), - 'onset': array of shape (n_times, n_freqs), - 'contour': array of shape (n_times, 3*n_freqs) - } - representing the output of the basic pitch model. + output: A dictionary with shape { + 'frame': array of shape (n_times, n_freqs), + 'onset': array of shape (n_times, n_freqs), + 'contour': array of shape (n_times, 3*n_freqs) + } representing the output of the basic pitch model. + onset_thresh: Minimum amplitude of an onset activation to be considered an onset. + infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes. + min_note_len: The minimum allowed note length in frames. + min_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used. + max_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used. + include_pitch_bends: If True, include pitch bends. + multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends. + melodia_trick: Use the melodia post-processing step. + + *pitch_offset_correction: Correction factor to align visualization with MIDI.* Returns: midi : pretty_midi.PrettyMIDI object @@ -83,6 +92,7 @@ def model_output_to_notes( onsets = output["onset"] contours = output["contour"] + #adjust pitch computation to account for the offset! estimated_notes = output_to_notes_polyphonic( frames, onsets, @@ -94,22 +104,33 @@ def model_output_to_notes( max_freq=max_freq, melodia_trick=melodia_trick, ) + + #apply the correction offset to align pitch (created corrected_notes instead of estimated notes) + corrected_notes = [ + (note[0], note[1], int(note[2] - pitch_offset_correction), note[3], note[4]) + for note in estimated_notes + ] + if include_pitch_bends: - estimated_notes_with_pitch_bend = get_pitch_bends(contours, estimated_notes) + corrected_notes_with_pitch_bend = get_pitch_bends(contours, corrected_notes) else: - estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes] + corrected_notes_with_pitch_bend = [ + (note[0], note[1], note[2], note[3], None) for note in corrected_notes + ] times_s = model_frames_to_time(contours.shape[0]) - estimated_notes_time_seconds = [ - (times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend + corrected_notes_time_seconds = [ + (times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) + for note in corrected_notes_with_pitch_bend ] return ( - note_events_to_midi(estimated_notes_time_seconds, multiple_pitch_bends, midi_tempo), - estimated_notes_time_seconds, + note_events_to_midi(corrected_notes_time_seconds, multiple_pitch_bends, midi_tempo), + corrected_notes_time_seconds, ) + def sonify_midi(midi: pretty_midi.PrettyMIDI, save_path: Union[pathlib.Path, str], sr: Optional[int] = 44100) -> None: """Sonify a pretty_midi midi object and save to a file. diff --git a/tests/test_note_creation.py b/tests/test_note_creation.py old mode 100644 new mode 100755 index 630bec1..1be7772 --- a/tests/test_note_creation.py +++ b/tests/test_note_creation.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from basic_pitch.note_creation import drop_overlapping_pitch_bends @@ -47,4 +48,5 @@ def test_drop_overlapping_pitch_bends() -> None: (4.1, 4.2, 77, 1.0, None), # overlaps w prev ] result = drop_overlapping_pitch_bends(note_events_with_pitch_bends) + print("Test Result: /n", result, "/nExpected: /n", expected) assert sorted(result) == sorted(expected)