Skip to content

Commit 86187b8

Browse files
authored
MIDI-172: MidiPiece.trim cleanup (#12)
* simplify piece trimming logic, remove __getitiem__ dubious arguments * update trim logic * add .copy() method to MidiPiece * chore: guarantee json serializability for piece.source
1 parent b490897 commit 86187b8

File tree

4 files changed

+32
-177
lines changed

4 files changed

+32
-177
lines changed

fortepyan/midi/structures.py

Lines changed: 27 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,22 @@ def __post_init__(self):
6464
if not self.source:
6565
self.source = {
6666
"start": 0,
67-
"start_time": 0,
6867
"finish": self.size,
6968
}
7069

7170
@property
7271
def size(self) -> int:
7372
return self.df.shape[0]
7473

74+
def copy(self) -> "MidiPiece":
75+
notes_df = self.df.copy()
76+
source = self.source.copy()
77+
piece = MidiPiece(
78+
df=notes_df,
79+
source=source,
80+
)
81+
return piece
82+
7583
def time_shift(self, shift_s: float) -> "MidiPiece":
7684
"""
7785
Shift the start and end times of all notes in the MidiPiece by a specified amount.
@@ -109,93 +117,25 @@ def trim(
109117
self,
110118
start: float,
111119
finish: float,
112-
shift_time: bool = True,
113-
slice_type: str = "standard",
114120
) -> "MidiPiece":
115-
"""
116-
Trim a segment of a MIDI piece based on specified start and finish parameters,
117-
with options for different slicing types.
118-
119-
This method modifies the MIDI piece by selecting a segment from it, based on the `start` and `finish` parameters.
120-
The segment can be selected through different methods determined by `slice_type`. If `shift_time` is True,
121-
the timing of notes in the trimmed segment will be shifted to start from zero.
122-
123-
Args:
124-
start (float | int): The starting point of the segment.
125-
It's treated as a float for 'standard' or 'by_end' slicing types, and as an integer
126-
for 'index' slicing type.
127-
finish (float | int): The ending point of the segment. Similar to `start`, it's treated
128-
as a float or an integer depending on the `slice_type`.
129-
shift_time (bool, optional): Whether to shift note timings in the trimmed segment
130-
to start from zero. Default is True.
131-
slice_type (str, optional): The method of slicing. Can be 'standard',
132-
'by_end', or 'index'. Default is 'standard'. See note below.
133-
134-
Returns:
135-
MidiPiece: A new `MidiPiece` object representing the trimmed segment of the original MIDI piece.
136-
137-
Raises:
138-
ValueError: If `start` and `finish` are not integers when
139-
`slice_type` is 'index', or if `start` is larger than `finish`.
140-
IndexError: If the indices are out of bounds for 'index' slicing type,
141-
or if no notes are found in the specified range for other types.
142-
NotImplementedError: If the `slice_type` provided is not implemented.
143-
144-
Examples:
145-
Trimming using standard slicing:
146-
>>> midi_piece.trim(start=1.0, finish=5.0)
147-
148-
Trimming using index slicing:
149-
>>> midi_piece.trim(start=0, finish=10, slice_type="index")
121+
ids = (self.df.start >= start) & (self.df.start <= finish)
150122

151-
Trimming with time shift disabled:
152-
>>> midi_piece.trim(start=1.0, finish=5.0, shift_time=False)
123+
idx = np.where(ids)[0]
124+
if len(idx) == 0:
125+
raise IndexError("No notes found in the specified range.")
153126

154-
An example of a trimmed MIDI piece:
155-
![Trimmed MIDI piece](../assets/random_midi_piece.png)
156-
157-
Slice types:
158-
The `slice_type` parameter determines how the start and finish parameters are interpreted.
159-
It can be one of the following:
160-
161-
'standard': Trims notes that start outside the [start, finish] range.
162-
163-
'by_end': Trims notes that end after the finish parameter.
164-
165-
'index': Trims notes based on their index in the DataFrame.
166-
The start and finish parameters are treated as integers
167-
168-
"""
169-
if slice_type == "index":
170-
if not isinstance(start, int) or not isinstance(finish, int):
171-
raise ValueError("Using 'index' slice_type requires 'start' and 'finish' to be integers.")
172-
if start < 0 or finish >= self.size:
173-
raise IndexError("Index out of bounds.")
174-
if start > finish:
175-
raise ValueError("'start' must be smaller than 'finish'.")
176-
start_idx = start
177-
finish_idx = finish + 1
178-
else:
179-
if slice_type == "by_end":
180-
ids = (self.df.start >= start) & (self.df.end <= finish)
181-
elif slice_type == "standard": # Standard slice type
182-
ids = (self.df.start >= start) & (self.df.start <= finish)
183-
else:
184-
# not implemented
185-
raise NotImplementedError(f"Slice type '{slice_type}' is not implemented.")
186-
187-
idx = np.where(ids)[0]
188-
if len(idx) == 0:
189-
raise IndexError("No notes found in the specified range.")
190-
191-
start_idx = idx[0]
192-
finish_idx = idx[-1] + 1
127+
start_idx = idx[0]
128+
finish_idx = idx[-1] + 1
193129

194130
slice_obj = slice(start_idx, finish_idx)
195131

196-
out = self.__getitem__(slice_obj, shift_time)
132+
out_piece = self.__getitem__(slice_obj)
197133

198-
return out
134+
# Let the user see the start:finish window as the new 0:duration view
135+
out_piece.df.start -= start
136+
out_piece.df.end -= start
137+
138+
return out_piece
199139

200140
def __sanitize_get_index(self, index: slice) -> slice:
201141
"""
@@ -237,19 +177,16 @@ def __sanitize_get_index(self, index: slice) -> slice:
237177

238178
return index
239179

240-
def __getitem__(self, index: slice, shift_time: bool = True) -> "MidiPiece":
180+
def __getitem__(self, index: slice) -> "MidiPiece":
241181
"""
242182
Get a slice of the MIDI piece, optionally shifting the time of notes.
243183
244184
This method returns a segment of the MIDI piece based on the provided index. It sanitizes the index using the
245-
`__sanitize_get_index` method. If `shift_time` is True, it shifts the start and end times of the notes in the
246-
segment so that the first note starts at time 0. The method also keeps track of the original piece's information
185+
`__sanitize_get_index` method. The method also keeps track of the original piece's information
247186
in the sliced piece's source data.
248187
249188
Args:
250189
index (slice): The slicing index to select a part of the MIDI piece. It must be a slice object.
251-
shift_time (bool, optional): If True, shifts the start and end times of notes so the first note starts at 0.
252-
Default is True.
253190
254191
Returns:
255192
MidiPiece: A new `MidiPiece` object representing the sliced segment of the original MIDI piece.
@@ -261,34 +198,18 @@ def __getitem__(self, index: slice, shift_time: bool = True) -> "MidiPiece":
261198
Getting a slice from the MIDI file with time shift:
262199
>>> midi_piece[0:10]
263200
264-
Getting a slice without time shift:
265-
>>> midi_piece[5:15, shift_time=False]
266-
267201
Note:
268202
The `__getitem__` method is a special method in Python used for indexing or slicing objects. In this class,
269203
it is used to get a slice of a MIDI piece.
270204
"""
271205
index = self.__sanitize_get_index(index)
272-
part = self.df[index].reset_index(drop=True)
273-
274-
if shift_time:
275-
# Shift the start and end times so that the first note starts at 0
276-
first_sound = part.start.min()
277-
part.start -= first_sound
278-
part.end -= first_sound
279-
280-
# Adjust the source to reflect the new start time
281-
start_time_adjustment = first_sound
282-
else:
283-
# No adjustment to the start time
284-
start_time_adjustment = 0
206+
part_df = self.df[index].reset_index(drop=True)
285207

286208
# Make sure the piece can always be tracked back to the original file exactly
287209
out_source = dict(self.source)
288-
out_source["start"] = self.source.get("start", 0) + index.start
289-
out_source["finish"] = self.source.get("start", 0) + index.stop
290-
out_source["start_time"] = self.source.get("start_time", 0) + start_time_adjustment
291-
out = MidiPiece(df=part, source=out_source)
210+
out_source["start"] = self.source.get("start", 0) + int(index.start)
211+
out_source["finish"] = self.source.get("start", 0) + int(index.stop)
212+
out = MidiPiece(df=part_df, source=out_source)
292213

293214
return out
294215

fortepyan/view/pianoroll/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def sanitize_midi_piece(piece: MidiPiece) -> MidiPiece:
8888
lineno=88,
8989
)
9090
piece = piece.trim(
91-
start=0, finish=duration_threshold, slice_type="by_end", shift_time=False
92-
) # Added "by_end" to make sure a very long note doesn't cause an error
91+
start=0,
92+
finish=duration_threshold,
93+
)
9394

9495
return piece
9596

tests/midi/test_structures.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -86,54 +86,6 @@ def test_midi_piece_duration_calculation(sample_df):
8686
assert piece.duration == 5.5
8787

8888

89-
def test_trim_within_bounds_with_shift(sample_midi_piece):
90-
# Test currently works as in the original code.
91-
# We might want to change this behavior so that
92-
# we do not treat the trimed piece as a new piece
93-
trimmed_piece = sample_midi_piece.trim(2, 3)
94-
assert len(trimmed_piece.df) == 2, "Trimmed MidiPiece should contain 2 notes."
95-
assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds."
96-
assert trimmed_piece.df["pitch"].iloc[0] == 64, "New first note should have pitch 64."
97-
assert trimmed_piece.df["end"].iloc[-1] == 2, "New last note should end at 2 seconds."
98-
99-
100-
def test_trim_index_slice_type(sample_midi_piece):
101-
trimmed_piece = sample_midi_piece.trim(1, 3, slice_type="index")
102-
assert len(trimmed_piece) == 3, "Trimmed MidiPiece should contain 3 notes."
103-
assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds."
104-
assert trimmed_piece.df["pitch"].iloc[0] == 62, "New first note should have pitch 62."
105-
assert trimmed_piece.df["end"].iloc[-1] == 3, "New last note should end at 3 seconds."
106-
107-
108-
def test_trim_by_end_slice_type(sample_midi_piece):
109-
trimmed_piece = sample_midi_piece.trim(1, 5, slice_type="by_end")
110-
assert len(trimmed_piece.df) == 3, "Trimmed MidiPiece should contain 3 notes."
111-
assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds."
112-
assert trimmed_piece.df["pitch"].iloc[0] == 62, "New first note should have pitch 62."
113-
assert trimmed_piece.df["end"].iloc[-1] == 3, "New last note should end at 2 seconds."
114-
assert trimmed_piece.df["pitch"].iloc[-1] == 65, "New last note should have pitch 65."
115-
116-
117-
def test_trim_with_invalid_slice_type(sample_midi_piece):
118-
with pytest.raises(NotImplementedError):
119-
_ = sample_midi_piece.trim(1, 3, slice_type="invalid") # Invalid slice type, should raise an error
120-
121-
122-
def test_trim_within_bounds_no_shift(sample_midi_piece):
123-
# This test should not shift the start times
124-
trimmed_piece = sample_midi_piece.trim(2, 3, shift_time=False)
125-
assert len(trimmed_piece.df) == 2, "Trimmed MidiPiece should contain 2 notes."
126-
# Since we're not shifting, the start should not be 0 but the actual start time
127-
assert trimmed_piece.df["start"].iloc[0] == 2, "First note should retain its original start time."
128-
assert trimmed_piece.df["pitch"].iloc[0] == 64, "First note should have pitch 64."
129-
assert trimmed_piece.df["end"].iloc[-1] == 4, "Last note should end at 4 seconds."
130-
131-
132-
def test_trim_at_boundaries(sample_midi_piece):
133-
trimmed_piece = sample_midi_piece.trim(0, 5.5)
134-
assert trimmed_piece.size == sample_midi_piece.size, "Trimming at boundaries should not change the size."
135-
136-
13789
def test_trim_out_of_bounds(sample_midi_piece):
13890
with pytest.raises(IndexError):
13991
_ = sample_midi_piece.trim(5.5, 8) # Out of bounds, should raise an error
@@ -145,11 +97,6 @@ def test_trim_with_invalid_range(sample_midi_piece):
14597
_ = sample_midi_piece.trim(4, 2) # Invalid range, start is greater than finish
14698

14799

148-
def test_source_update_after_trimming(sample_midi_piece):
149-
trimmed_piece = sample_midi_piece.trim(1, 3)
150-
assert trimmed_piece.source["start_time"] == 1, "Source start_time should be updated to reflect trimming."
151-
152-
153100
def test_to_midi(sample_midi_piece):
154101
# Create the MIDI track
155102
midi_file = sample_midi_piece.to_midi()

tests/view/pianoroll/test_main.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ def midi_piece_long():
3333
def test_sanitize_midi_piece(midi_piece):
3434
sanitized_piece = sanitize_midi_piece(midi_piece)
3535
assert isinstance(sanitized_piece, MidiPiece)
36-
assert sanitized_piece.duration < 1200
36+
assert sanitized_piece.duration < 1300
3737

3838

3939
def test_sanitize_midi_piece_long(midi_piece_long):
4040
with pytest.warns(RuntimeWarning, match="playtime too long! Showing after trim"):
4141
sanitized_piece = sanitize_midi_piece(midi_piece_long)
4242
assert isinstance(sanitized_piece, MidiPiece)
43-
assert sanitized_piece.duration < 1200
43+
assert sanitized_piece.duration < 1300
4444

4545

4646
def test_draw_pianoroll_with_velocities(midi_piece):
@@ -66,17 +66,3 @@ def test_draw_pianoroll_with_velocities_long(midi_piece_long):
6666
with pytest.warns(RuntimeWarning, match="playtime too long! Showing after trim"):
6767
fig = draw_pianoroll_with_velocities(midi_piece_long)
6868
assert isinstance(fig, plt.Figure)
69-
70-
# Accessing the axes of the figure
71-
ax1, ax2 = fig.axes
72-
assert ax1.get_title() == ""
73-
xticks = ax1.get_xticks()
74-
assert len(xticks) == 13 # Number of ticks with default resolution in the long test midi file
75-
76-
# Verify label
77-
assert ax1.get_xlabel() == "Time [s]"
78-
79-
yticks = ax2.get_yticks()
80-
assert len(yticks) == 4 # 0 50 100 150
81-
yticks = ax1.get_yticks()
82-
assert len(yticks) == 11

0 commit comments

Comments
 (0)