Skip to content

Commit fa5997a

Browse files
authored
Update mirdata, add local index fixtures (#185)
* Update mirdata, add local index fixtures * update comment * update comment * Move fixtures to conftest, apply to more tests * format fixture metadata * rm unused variable * also move slakh fixtures
1 parent ad83963 commit fa5997a

18 files changed

Lines changed: 977 additions & 40 deletions

basic_pitch/data/datasets/guitarset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def determine_split(index: int) -> str:
144144
return "test"
145145

146146
guitarset = mirdata.initialize("guitarset")
147+
guitarset.download(["index"])
147148
track_ids = guitarset.track_ids
148149
random.shuffle(track_ids)
149150

basic_pitch/data/datasets/ikala.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from basic_pitch.data import commandline, pipeline
3030

3131

32+
# Oct 2025: Ikala remote download is broken on mirdata side # TODO: Re-evaluate later
3233
class IkalaInvalidTracks(beam.DoFn):
3334
def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
3435
track_id, split = element
@@ -142,6 +143,7 @@ def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[
142143
random.seed(seed)
143144

144145
ikala = mirdata.initialize("ikala")
146+
ikala.download(["index"])
145147
track_ids = ikala.track_ids
146148
random.shuffle(track_ids)
147149

basic_pitch/data/datasets/maestro.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import logging
2020
import os
2121
import sys
22-
import tempfile
2322
import time
2423
from typing import Any, Dict, List, TextIO, Tuple
2524

@@ -164,20 +163,10 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str
164163
return [batch]
165164

166165

167-
def create_input_data(source: str) -> List[Tuple[str, str]]:
168-
import apache_beam as beam
169-
170-
filesystem = beam.io.filesystems.FileSystems()
171-
172-
with tempfile.TemporaryDirectory() as tmpdir:
173-
maestro = mirdata.initialize("maestro", data_home=tmpdir)
174-
metadata_path = maestro._index["metadata"]["maestro-v2.0.0"][0]
175-
with filesystem.open(
176-
os.path.join(source, metadata_path),
177-
) as s, open(os.path.join(tmpdir, metadata_path), "wb") as d:
178-
d.write(s.read())
179-
180-
return [(track_id, track.split) for track_id, track in maestro.load_tracks().items()]
166+
def create_input_data() -> List[Tuple[str, str]]:
167+
maestro = mirdata.initialize("maestro")
168+
maestro.download(["metadata"])
169+
return [(track_id, track.split) for track_id, track in maestro.load_tracks().items()]
181170

182171

183172
def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
@@ -198,7 +187,7 @@ def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
198187
"environment_type": "DOCKER",
199188
"environment_config": known_args.sdk_container_image,
200189
}
201-
input_data = create_input_data(known_args.source)
190+
input_data = create_input_data()
202191
pipeline.run(
203192
pipeline_options,
204193
pipeline_args,

basic_pitch/data/datasets/medleydb_pitch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[
140140
random.seed(seed)
141141

142142
medleydb_pitch = mirdata.initialize("medleydb_pitch")
143+
medleydb_pitch.download(["index"])
143144
track_ids = medleydb_pitch.track_ids
144145
random.shuffle(track_ids)
145146

basic_pitch/data/datasets/slakh.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def process(self, element: List[str]) -> List[Any]:
182182

183183
def create_input_data() -> List[Tuple[str, str]]:
184184
slakh = mirdata.initialize("slakh")
185+
slakh.download(["index"])
185186
return [(track_id, track.data_split) for track_id, track in slakh.load_tracks().items()]
186187

187188

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ bp-download = "basic_pitch.data.download:main"
5555
data = [
5656
"basic_pitch[tf,test]",
5757
"apache_beam",
58-
# TODO: mirdata 0.3.9 moves dataset indexes files which breaks our tests
59-
# Adapt our codebase to release that constraint
60-
"mirdata<=0.3.8",
58+
"mirdata>=1.0.0",
6159
"smart_open",
6260
"sox",
6361
"ffmpeg-python"

tests/data/conftest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
import json
3+
import pathlib
4+
from unittest import mock
5+
6+
RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
7+
GUITAR_SET_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "guitarset" / "dummy_index.json"))
8+
IKALA_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "ikala" / "dummy_index.json"))
9+
MAESTRO_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "maestro" / "dummy_index.json"))
10+
METADATA_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "maestro" / "maestro-v2.0.0.json"))
11+
MEDLEYDB_PITCH_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "medleydb_pitch" / "dummy_index.json"))
12+
SLAKH_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "slakh" / "dummy_index.json"))
13+
14+
15+
@pytest.fixture # type: ignore[misc]
16+
def mock_slakh_index() -> None: # type: ignore[misc]
17+
with mock.patch("mirdata.datasets.slakh.Dataset.download"):
18+
with mock.patch("mirdata.datasets.slakh.Dataset._index", new=SLAKH_TEST_INDEX):
19+
yield
20+
21+
22+
@pytest.fixture # type: ignore[misc]
23+
def mock_medleydb_pitch_index() -> None: # type: ignore[misc]
24+
with mock.patch("mirdata.datasets.medleydb_pitch.Dataset.download"):
25+
with mock.patch("mirdata.datasets.medleydb_pitch.Dataset._index", new=MEDLEYDB_PITCH_TEST_INDEX):
26+
yield
27+
28+
29+
@pytest.fixture # type: ignore[misc]
30+
def mock_maestro_index() -> None: # type: ignore[misc]
31+
index_with_metadata = MAESTRO_TEST_INDEX
32+
metadata = {mdata["midi_filename"].split(".")[0]: mdata for mdata in METADATA_TEST_INDEX}
33+
with mock.patch("mirdata.datasets.maestro.Dataset.download"):
34+
with mock.patch("mirdata.datasets.maestro.Dataset._metadata", new=metadata):
35+
with mock.patch("mirdata.datasets.maestro.Dataset._index", new=index_with_metadata):
36+
yield
37+
38+
39+
@pytest.fixture # type: ignore[misc]
40+
def mock_guitarset_index() -> None: # type: ignore[misc]
41+
with mock.patch("mirdata.datasets.guitarset.Dataset.download"):
42+
with mock.patch("mirdata.datasets.guitarset.Dataset._index", new=GUITAR_SET_TEST_INDEX):
43+
yield
44+
45+
46+
@pytest.fixture # type: ignore[misc]
47+
def mock_ikala_index() -> None: # type: ignore[misc]
48+
with mock.patch("mirdata.datasets.ikala.Dataset.download"):
49+
with mock.patch("mirdata.datasets.ikala.Dataset._index", new=IKALA_TEST_INDEX):
50+
yield

tests/data/test_guitarset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import os
2020
import pathlib
2121
import shutil
22-
2322
from apache_beam.testing.test_pipeline import TestPipeline
2423
from typing import List
2524

@@ -36,7 +35,7 @@
3635
TRACK_ID = "00_BN1-129-Eb_comp"
3736

3837

39-
def test_guitarset_to_tf_example(tmp_path: pathlib.Path) -> None:
38+
def test_guitarset_to_tf_example(tmp_path: pathlib.Path, mock_guitarset_index: None) -> None:
4039
mock_guitarset_home = tmp_path / "guitarset"
4140
mock_guitarset_audio = mock_guitarset_home / "audio_mono-mic"
4241
mock_guitarset_annotations = mock_guitarset_home / "annotation"
@@ -91,7 +90,7 @@ def test_guitarset_invalid_tracks(tmpdir: str) -> None:
9190
assert fp.read().strip() == str(i)
9291

9392

94-
def test_guitarset_create_input_data() -> None:
93+
def test_guitarset_create_input_data(mock_guitarset_index: None) -> None:
9594
data = create_input_data(train_percent=0.33, validation_percent=0.33)
9695
data.sort(key=lambda el: el[1]) # sort by split
9796
tolerance = 0.1

tests/data/test_ikala.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717
import apache_beam as beam
1818
import itertools
1919
import os
20-
2120
from apache_beam.testing.test_pipeline import TestPipeline
2221

2322
from basic_pitch.data.datasets.ikala import (
2423
IkalaInvalidTracks,
2524
create_input_data,
2625
)
2726

28-
2927
# TODO: Create test_ikala_to_tf_example
3028

3129

@@ -51,7 +49,7 @@ def test_ikala_invalid_tracks(tmpdir: str) -> None:
5149
assert fp.read().strip() == str(i)
5250

5351

54-
def test_ikala_create_input_data() -> None:
52+
def test_ikala_create_input_data(mock_ikala_index: None) -> None:
5553
data = create_input_data(train_percent=0.5)
5654
data.sort(key=lambda el: el[1]) # sort by split
5755
tolerance = 0.1

tests/data/test_maestro.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# limitations under the License.
1717
import os
1818
import pathlib
19-
2019
from typing import List
2120

2221
import apache_beam as beam
@@ -40,7 +39,7 @@
4039
GT_15M_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav"
4140

4241

43-
def test_maestro_to_tf_example(tmp_path: pathlib.Path) -> None:
42+
def test_maestro_to_tf_example(tmp_path: pathlib.Path, mock_maestro_index: None) -> None:
4443
mock_maestro_home = tmp_path / "maestro"
4544
mock_maestro_ext = mock_maestro_home / "2004"
4645
mock_maestro_ext.mkdir(parents=True, exist_ok=True)
@@ -68,7 +67,7 @@ def test_maestro_to_tf_example(tmp_path: pathlib.Path) -> None:
6867
assert len(data) != 0
6968

7069

71-
def test_maestro_invalid_tracks(tmp_path: pathlib.Path) -> None:
70+
def test_maestro_invalid_tracks(tmp_path: pathlib.Path, mock_maestro_index: None) -> None:
7271
mock_maestro_home = tmp_path / "maestro"
7372
mock_maestro_ext = mock_maestro_home / "2004"
7473
mock_maestro_ext.mkdir(parents=True, exist_ok=True)
@@ -98,7 +97,7 @@ def test_maestro_invalid_tracks(tmp_path: pathlib.Path) -> None:
9897
assert fp.read().strip() == track_id
9998

10099

101-
def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path) -> None:
100+
def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path, mock_maestro_index: None) -> None:
102101
"""
103102
The track id used here is a real track id in maestro, and it is part of the train split, but we mock the data so as
104103
not to store a large file in git, hence the variable name.
@@ -131,13 +130,13 @@ def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path) -> None:
131130
assert fp.read().strip() == ""
132131

133132

134-
def test_maestro_create_input_data() -> None:
133+
def test_maestro_create_input_data(mock_maestro_index: None) -> None:
135134
"""
136135
A commuted metadata file is included in the repo for testing. mirdata references the metadata file to
137136
populate the tracklist with metadata. Since the file is commuted to only the filenames referenced here,
138137
we only consider these when testing the metadata.
139138
"""
140-
data = create_input_data(str(MAESTRO_TEST_DATA_PATH))
139+
data = create_input_data()
141140
assert len(data)
142141

143142
test_fnames = {TRAIN_TRACK_ID, VALID_TRACK_ID, TEST_TRACK_ID, GT_15M_TRACK_ID}

0 commit comments

Comments
 (0)