Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 27 additions & 106 deletions fortepyan/midi/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,22 @@ def __post_init__(self):
if not self.source:
self.source = {
"start": 0,
"start_time": 0,
"finish": self.size,
}

@property
def size(self) -> int:
return self.df.shape[0]

def copy(self) -> "MidiPiece":
notes_df = self.df.copy()
source = self.source.copy()
piece = MidiPiece(
df=notes_df,
source=source,
)
return piece

def time_shift(self, shift_s: float) -> "MidiPiece":
"""
Shift the start and end times of all notes in the MidiPiece by a specified amount.
Expand Down Expand Up @@ -109,93 +117,25 @@ def trim(
self,
start: float,
finish: float,
shift_time: bool = True,
slice_type: str = "standard",
) -> "MidiPiece":
"""
Trim a segment of a MIDI piece based on specified start and finish parameters,
with options for different slicing types.

This method modifies the MIDI piece by selecting a segment from it, based on the `start` and `finish` parameters.
The segment can be selected through different methods determined by `slice_type`. If `shift_time` is True,
the timing of notes in the trimmed segment will be shifted to start from zero.

Args:
start (float | int): The starting point of the segment.
It's treated as a float for 'standard' or 'by_end' slicing types, and as an integer
for 'index' slicing type.
finish (float | int): The ending point of the segment. Similar to `start`, it's treated
as a float or an integer depending on the `slice_type`.
shift_time (bool, optional): Whether to shift note timings in the trimmed segment
to start from zero. Default is True.
slice_type (str, optional): The method of slicing. Can be 'standard',
'by_end', or 'index'. Default is 'standard'. See note below.

Returns:
MidiPiece: A new `MidiPiece` object representing the trimmed segment of the original MIDI piece.

Raises:
ValueError: If `start` and `finish` are not integers when
`slice_type` is 'index', or if `start` is larger than `finish`.
IndexError: If the indices are out of bounds for 'index' slicing type,
or if no notes are found in the specified range for other types.
NotImplementedError: If the `slice_type` provided is not implemented.

Examples:
Trimming using standard slicing:
>>> midi_piece.trim(start=1.0, finish=5.0)

Trimming using index slicing:
>>> midi_piece.trim(start=0, finish=10, slice_type="index")
ids = (self.df.start >= start) & (self.df.start <= finish)

Trimming with time shift disabled:
>>> midi_piece.trim(start=1.0, finish=5.0, shift_time=False)
idx = np.where(ids)[0]
if len(idx) == 0:
raise IndexError("No notes found in the specified range.")

An example of a trimmed MIDI piece:
![Trimmed MIDI piece](../assets/random_midi_piece.png)

Slice types:
The `slice_type` parameter determines how the start and finish parameters are interpreted.
It can be one of the following:

'standard': Trims notes that start outside the [start, finish] range.

'by_end': Trims notes that end after the finish parameter.

'index': Trims notes based on their index in the DataFrame.
The start and finish parameters are treated as integers

"""
if slice_type == "index":
if not isinstance(start, int) or not isinstance(finish, int):
raise ValueError("Using 'index' slice_type requires 'start' and 'finish' to be integers.")
if start < 0 or finish >= self.size:
raise IndexError("Index out of bounds.")
if start > finish:
raise ValueError("'start' must be smaller than 'finish'.")
start_idx = start
finish_idx = finish + 1
else:
if slice_type == "by_end":
ids = (self.df.start >= start) & (self.df.end <= finish)
elif slice_type == "standard": # Standard slice type
ids = (self.df.start >= start) & (self.df.start <= finish)
else:
# not implemented
raise NotImplementedError(f"Slice type '{slice_type}' is not implemented.")

idx = np.where(ids)[0]
if len(idx) == 0:
raise IndexError("No notes found in the specified range.")

start_idx = idx[0]
finish_idx = idx[-1] + 1
start_idx = idx[0]
finish_idx = idx[-1] + 1

slice_obj = slice(start_idx, finish_idx)

out = self.__getitem__(slice_obj, shift_time)
out_piece = self.__getitem__(slice_obj)

return out
# Let the user see the start:finish window as the new 0:duration view
out_piece.df.start -= start
out_piece.df.end -= start

return out_piece

def __sanitize_get_index(self, index: slice) -> slice:
"""
Expand Down Expand Up @@ -237,19 +177,16 @@ def __sanitize_get_index(self, index: slice) -> slice:

return index

def __getitem__(self, index: slice, shift_time: bool = True) -> "MidiPiece":
def __getitem__(self, index: slice) -> "MidiPiece":
"""
Get a slice of the MIDI piece, optionally shifting the time of notes.

This method returns a segment of the MIDI piece based on the provided index. It sanitizes the index using the
`__sanitize_get_index` method. If `shift_time` is True, it shifts the start and end times of the notes in the
segment so that the first note starts at time 0. The method also keeps track of the original piece's information
`__sanitize_get_index` method. The method also keeps track of the original piece's information
in the sliced piece's source data.

Args:
index (slice): The slicing index to select a part of the MIDI piece. It must be a slice object.
shift_time (bool, optional): If True, shifts the start and end times of notes so the first note starts at 0.
Default is True.

Returns:
MidiPiece: A new `MidiPiece` object representing the sliced segment of the original MIDI piece.
Expand All @@ -261,34 +198,18 @@ def __getitem__(self, index: slice, shift_time: bool = True) -> "MidiPiece":
Getting a slice from the MIDI file with time shift:
>>> midi_piece[0:10]

Getting a slice without time shift:
>>> midi_piece[5:15, shift_time=False]

Note:
The `__getitem__` method is a special method in Python used for indexing or slicing objects. In this class,
it is used to get a slice of a MIDI piece.
"""
index = self.__sanitize_get_index(index)
part = self.df[index].reset_index(drop=True)

if shift_time:
# Shift the start and end times so that the first note starts at 0
first_sound = part.start.min()
part.start -= first_sound
part.end -= first_sound

# Adjust the source to reflect the new start time
start_time_adjustment = first_sound
else:
# No adjustment to the start time
start_time_adjustment = 0
part_df = self.df[index].reset_index(drop=True)

# Make sure the piece can always be tracked back to the original file exactly
out_source = dict(self.source)
out_source["start"] = self.source.get("start", 0) + index.start
out_source["finish"] = self.source.get("start", 0) + index.stop
out_source["start_time"] = self.source.get("start_time", 0) + start_time_adjustment
out = MidiPiece(df=part, source=out_source)
out_source["start"] = self.source.get("start", 0) + int(index.start)
out_source["finish"] = self.source.get("start", 0) + int(index.stop)
out = MidiPiece(df=part_df, source=out_source)

return out

Expand Down
5 changes: 3 additions & 2 deletions fortepyan/view/pianoroll/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def sanitize_midi_piece(piece: MidiPiece) -> MidiPiece:
lineno=88,
)
piece = piece.trim(
start=0, finish=duration_threshold, slice_type="by_end", shift_time=False
) # Added "by_end" to make sure a very long note doesn't cause an error
start=0,
finish=duration_threshold,
)

return piece

Expand Down
53 changes: 0 additions & 53 deletions tests/midi/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,54 +86,6 @@ def test_midi_piece_duration_calculation(sample_df):
assert piece.duration == 5.5


def test_trim_within_bounds_with_shift(sample_midi_piece):
# Test currently works as in the original code.
# We might want to change this behavior so that
# we do not treat the trimed piece as a new piece
trimmed_piece = sample_midi_piece.trim(2, 3)
assert len(trimmed_piece.df) == 2, "Trimmed MidiPiece should contain 2 notes."
assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds."
assert trimmed_piece.df["pitch"].iloc[0] == 64, "New first note should have pitch 64."
assert trimmed_piece.df["end"].iloc[-1] == 2, "New last note should end at 2 seconds."


def test_trim_index_slice_type(sample_midi_piece):
trimmed_piece = sample_midi_piece.trim(1, 3, slice_type="index")
assert len(trimmed_piece) == 3, "Trimmed MidiPiece should contain 3 notes."
assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds."
assert trimmed_piece.df["pitch"].iloc[0] == 62, "New first note should have pitch 62."
assert trimmed_piece.df["end"].iloc[-1] == 3, "New last note should end at 3 seconds."


def test_trim_by_end_slice_type(sample_midi_piece):
trimmed_piece = sample_midi_piece.trim(1, 5, slice_type="by_end")
assert len(trimmed_piece.df) == 3, "Trimmed MidiPiece should contain 3 notes."
assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds."
assert trimmed_piece.df["pitch"].iloc[0] == 62, "New first note should have pitch 62."
assert trimmed_piece.df["end"].iloc[-1] == 3, "New last note should end at 2 seconds."
assert trimmed_piece.df["pitch"].iloc[-1] == 65, "New last note should have pitch 65."


def test_trim_with_invalid_slice_type(sample_midi_piece):
with pytest.raises(NotImplementedError):
_ = sample_midi_piece.trim(1, 3, slice_type="invalid") # Invalid slice type, should raise an error


def test_trim_within_bounds_no_shift(sample_midi_piece):
# This test should not shift the start times
trimmed_piece = sample_midi_piece.trim(2, 3, shift_time=False)
assert len(trimmed_piece.df) == 2, "Trimmed MidiPiece should contain 2 notes."
# Since we're not shifting, the start should not be 0 but the actual start time
assert trimmed_piece.df["start"].iloc[0] == 2, "First note should retain its original start time."
assert trimmed_piece.df["pitch"].iloc[0] == 64, "First note should have pitch 64."
assert trimmed_piece.df["end"].iloc[-1] == 4, "Last note should end at 4 seconds."


def test_trim_at_boundaries(sample_midi_piece):
trimmed_piece = sample_midi_piece.trim(0, 5.5)
assert trimmed_piece.size == sample_midi_piece.size, "Trimming at boundaries should not change the size."


def test_trim_out_of_bounds(sample_midi_piece):
with pytest.raises(IndexError):
_ = sample_midi_piece.trim(5.5, 8) # Out of bounds, should raise an error
Expand All @@ -145,11 +97,6 @@ def test_trim_with_invalid_range(sample_midi_piece):
_ = sample_midi_piece.trim(4, 2) # Invalid range, start is greater than finish


def test_source_update_after_trimming(sample_midi_piece):
trimmed_piece = sample_midi_piece.trim(1, 3)
assert trimmed_piece.source["start_time"] == 1, "Source start_time should be updated to reflect trimming."


def test_to_midi(sample_midi_piece):
# Create the MIDI track
midi_file = sample_midi_piece.to_midi()
Expand Down
18 changes: 2 additions & 16 deletions tests/view/pianoroll/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def midi_piece_long():
def test_sanitize_midi_piece(midi_piece):
sanitized_piece = sanitize_midi_piece(midi_piece)
assert isinstance(sanitized_piece, MidiPiece)
assert sanitized_piece.duration < 1200
assert sanitized_piece.duration < 1300


def test_sanitize_midi_piece_long(midi_piece_long):
with pytest.warns(RuntimeWarning, match="playtime too long! Showing after trim"):
sanitized_piece = sanitize_midi_piece(midi_piece_long)
assert isinstance(sanitized_piece, MidiPiece)
assert sanitized_piece.duration < 1200
assert sanitized_piece.duration < 1300


def test_draw_pianoroll_with_velocities(midi_piece):
Expand All @@ -66,17 +66,3 @@ def test_draw_pianoroll_with_velocities_long(midi_piece_long):
with pytest.warns(RuntimeWarning, match="playtime too long! Showing after trim"):
fig = draw_pianoroll_with_velocities(midi_piece_long)
assert isinstance(fig, plt.Figure)

# Accessing the axes of the figure
ax1, ax2 = fig.axes
assert ax1.get_title() == ""
xticks = ax1.get_xticks()
assert len(xticks) == 13 # Number of ticks with default resolution in the long test midi file

# Verify label
assert ax1.get_xlabel() == "Time [s]"

yticks = ax2.get_yticks()
assert len(yticks) == 4 # 0 50 100 150
yticks = ax1.get_yticks()
assert len(yticks) == 11
Loading