Skip to content

Commit 5772118

Browse files
read only for now
1 parent dc2f1a5 commit 5772118

File tree

2 files changed

+110
-99
lines changed

2 files changed

+110
-99
lines changed

mne_bids/tests/test_read.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import contextlib
77
import json
88
import logging
9+
import multiprocessing as mp
910
import os
1011
import os.path as op
1112
import re
1213
import shutil as sh
1314
from collections import OrderedDict
1415
from contextlib import nullcontext
15-
from datetime import datetime, timezone
16+
from datetime import date, datetime, timezone
1617
from pathlib import Path
1718

1819
import mne
@@ -99,6 +100,49 @@ def fn(fname, *args, **kwargs):
99100
_read_raw_edf = _wrap_read_raw(mne.io.read_raw_edf)
100101

101102

103+
def _make_parallel_raw(subject, *, seed=None):
104+
"""Generate a lightweight Raw instance for parallel-reading tests."""
105+
rng_seed = seed if seed is not None else sum(ord(ch) for ch in subject)
106+
rng = np.random.default_rng(rng_seed)
107+
info = mne.create_info(["MEG0113"], 100, ch_types="mag")
108+
data = rng.standard_normal((1, 100)) * 1e-12
109+
raw = mne.io.RawArray(data, info)
110+
raw.set_meas_date(datetime(2020, 1, 1, tzinfo=timezone.utc))
111+
raw.info["line_freq"] = 60
112+
raw.info["subject_info"] = {
113+
"his_id": subject,
114+
"sex": 1,
115+
"hand": 2,
116+
"birthday": date(1990, 1, 1),
117+
}
118+
return raw
119+
120+
121+
def _write_parallel_dataset(root, *, subject, run):
122+
"""Write a minimal dataset using write_raw_bids."""
123+
root = Path(root)
124+
raw = _make_parallel_raw(subject)
125+
bids_path = BIDSPath(
126+
subject=subject, task="rest", run=run, datatype="meg", root=root
127+
)
128+
write_raw_bids(raw, bids_path, allow_preload=True, format="FIF", verbose=False)
129+
130+
131+
def _parallel_read_participants(root, expected_ids):
132+
"""Read participants.tsv in a multiprocessing worker."""
133+
participants_path = Path(root) / "participants.tsv"
134+
participants = _from_tsv(participants_path)
135+
assert set(participants["participant_id"]) == set(expected_ids)
136+
137+
138+
def _parallel_read_scans(root, expected_filenames):
139+
"""Read scans.tsv in a multiprocessing worker."""
140+
scans_path = BIDSPath(subject="01", root=root, suffix="scans", extension=".tsv")
141+
scans = _from_tsv(scans_path.fpath)
142+
filenames = {str(filename) for filename in scans["filename"]}
143+
assert filenames == set(expected_filenames)
144+
145+
102146
def test_read_raw():
103147
"""Test the raw reading."""
104148
# Use a file ending that does not exist
@@ -133,6 +177,71 @@ def test_read_correct_inputs():
133177
read_raw_bids(bids_path)
134178

135179

180+
@pytest.mark.filterwarnings(
181+
"ignore:No events found or provided:RuntimeWarning",
182+
"ignore:Found no extension for raw file.*:RuntimeWarning",
183+
)
184+
def test_parallel_participants_multiprocess(tmp_path):
185+
"""Ensure parallel reads keep all participants entries visible."""
186+
bids_root = tmp_path / "parallel_multiprocess"
187+
subjects = [f"{i:02d}" for i in range(1, 50)]
188+
189+
for subject in subjects:
190+
_write_parallel_dataset(str(bids_root), subject=subject, run="01")
191+
192+
expected_ids = [f"sub-{subject}" for subject in subjects]
193+
processes = []
194+
for _ in range(len(subjects) // 10): # spawn a few processes
195+
proc = mp.Process(
196+
target=_parallel_read_participants, args=(str(bids_root), expected_ids)
197+
)
198+
proc.start()
199+
processes.append(proc)
200+
201+
for proc in processes:
202+
proc.join()
203+
assert proc.exitcode == 0
204+
205+
participants_path = bids_root / "participants.tsv"
206+
assert participants_path.exists()
207+
participants = _from_tsv(participants_path)
208+
assert set(participants["participant_id"]) == set(expected_ids)
209+
sh.rmtree(bids_root, ignore_errors=True)
210+
211+
212+
@pytest.mark.filterwarnings(
213+
"ignore:No events found or provided:RuntimeWarning",
214+
"ignore:Found no extension for raw file.*:RuntimeWarning",
215+
)
216+
def test_parallel_scans_multiprocessing(tmp_path):
217+
"""Ensure multiprocessing reads see all runs in scans.tsv."""
218+
bids_root = tmp_path / "parallel_multiprocessing"
219+
runs = [f"{i:02d}" for i in range(1, 50)]
220+
221+
for run in runs:
222+
_write_parallel_dataset(str(bids_root), subject="01", run=run)
223+
224+
expected = {f"meg/sub-01_task-rest_run-{run}_meg.fif" for run in runs}
225+
processes = []
226+
for _ in range(4):
227+
proc = mp.Process(target=_parallel_read_scans, args=(str(bids_root), expected))
228+
proc.start()
229+
processes.append(proc)
230+
231+
for proc in processes:
232+
proc.join()
233+
assert proc.exitcode == 0
234+
235+
scans_path = BIDSPath(
236+
subject="01", root=bids_root, suffix="scans", extension=".tsv"
237+
)
238+
assert scans_path.fpath.exists()
239+
scans = _from_tsv(scans_path.fpath)
240+
filenames = {str(filename) for filename in scans["filename"]}
241+
assert filenames == expected
242+
sh.rmtree(bids_root, ignore_errors=True)
243+
244+
136245
@pytest.mark.filterwarnings(warning_str["channel_unit_changed"])
137246
@testing.requires_testing_data
138247
def test_read_participants_data(tmp_path):

mne_bids/tests/test_write.py

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import codecs
1010
import json
11-
import multiprocessing as mp
1211
import os
1312
import os.path as op
1413
import shutil as sh
@@ -100,44 +99,6 @@
10099
)
101100

102101

103-
def _make_parallel_raw(subject, *, seed=None):
104-
"""Generate a lightweight Raw instance for parallel-writing tests."""
105-
rng_seed = seed if seed is not None else sum(ord(ch) for ch in subject)
106-
rng = np.random.default_rng(rng_seed)
107-
info = mne.create_info(["MEG0113"], sfreq=100.0, ch_types="mag")
108-
data = rng.standard_normal((1, 100)) * 1e-12
109-
raw = mne.io.RawArray(data, info)
110-
raw.set_meas_date(datetime(2025, 1, 1, tzinfo=timezone.utc))
111-
raw.info["line_freq"] = 60
112-
raw.info["subject_info"] = {
113-
"his_id": subject,
114-
"sex": 1,
115-
"hand": 2,
116-
"birthday": date(1990, 1, 1),
117-
}
118-
return raw
119-
120-
121-
def _write_parallel_dataset(root, *, subject, run):
122-
"""Write a minimal dataset using write_raw_bids."""
123-
root = Path(root)
124-
raw = _make_parallel_raw(subject)
125-
bids_path = BIDSPath(
126-
subject=subject, task="rest", run=run, datatype="meg", root=root
127-
)
128-
write_raw_bids(raw, bids_path, allow_preload=True, format="FIF", verbose=False)
129-
130-
131-
def _parallel_write_subject(root, subject):
132-
"""Handle write_raw_bids call in a multiprocessing worker."""
133-
_write_parallel_dataset(root, subject=subject, run="01")
134-
135-
136-
def _multiprocessing_write_run(root, run):
137-
"""Handle write_raw_bids call in a multiprocessing worker."""
138-
_write_parallel_dataset(root, subject="01", run=run)
139-
140-
141102
def _wrap_read_raw(read_raw):
142103
def fn(fname, *args, **kwargs):
143104
if Path(fname).suffix == ".mff":
@@ -356,65 +317,6 @@ def test_write_participants(_bids_validate, tmp_path):
356317
assert participants_tsv["age"][idx] == "n/a"
357318

358319

359-
@pytest.mark.filterwarnings(
360-
"ignore:No events found or provided:RuntimeWarning",
361-
"ignore:Found no extension for raw file.*:RuntimeWarning",
362-
)
363-
def test_parallel_participants_multiprocess(tmp_path):
364-
"""Ensure parallel writes keep all participants entries."""
365-
bids_root = tmp_path / "parallel_multiprocess"
366-
subjects = [f"{i:02d}" for i in range(1, 50)]
367-
368-
processes = []
369-
for subject in subjects:
370-
proc = mp.Process(
371-
target=_parallel_write_subject, args=(str(bids_root), subject)
372-
)
373-
proc.start()
374-
processes.append(proc)
375-
376-
for proc in processes:
377-
proc.join()
378-
assert proc.exitcode == 0
379-
380-
participants_path = bids_root / "participants.tsv"
381-
assert participants_path.exists()
382-
participants = _from_tsv(participants_path)
383-
expected_ids = {f"sub-{subject}" for subject in subjects}
384-
assert set(participants["participant_id"]) == expected_ids
385-
sh.rmtree(bids_root, ignore_errors=True)
386-
387-
388-
@pytest.mark.filterwarnings(
389-
"ignore:No events found or provided:RuntimeWarning",
390-
"ignore:Found no extension for raw file.*:RuntimeWarning",
391-
)
392-
def test_parallel_scans_multiprocessing(tmp_path):
393-
"""Ensure multiprocessing writes add all runs to scans.tsv."""
394-
bids_root = tmp_path / "parallel_multiprocessing"
395-
runs = [f"{i:02d}" for i in range(1, 50)]
396-
397-
processes = []
398-
for run in runs:
399-
proc = mp.Process(target=_multiprocessing_write_run, args=(str(bids_root), run))
400-
proc.start()
401-
processes.append(proc)
402-
403-
for proc in processes:
404-
proc.join()
405-
assert proc.exitcode == 0
406-
407-
scans_path = BIDSPath(
408-
subject="01", root=bids_root, suffix="scans", extension=".tsv"
409-
)
410-
assert scans_path.fpath.exists()
411-
scans = _from_tsv(scans_path.fpath)
412-
filenames = {str(filename) for filename in scans["filename"]}
413-
expected = {f"meg/sub-01_task-rest_run-{run}_meg.fif" for run in runs}
414-
assert filenames == expected
415-
sh.rmtree(bids_root, ignore_errors=True)
416-
417-
418320
@testing.requires_testing_data
419321
def test_write_correct_inputs():
420322
"""Test that inputs of write_raw_bids is correct."""

0 commit comments

Comments
 (0)