|
6 | 6 | import contextlib |
7 | 7 | import json |
8 | 8 | import logging |
| 9 | +import multiprocessing as mp |
9 | 10 | import os |
10 | 11 | import os.path as op |
11 | 12 | import re |
12 | 13 | import shutil as sh |
13 | 14 | from collections import OrderedDict |
14 | 15 | from contextlib import nullcontext |
15 | | -from datetime import datetime, timezone |
| 16 | +from datetime import date, datetime, timezone |
16 | 17 | from pathlib import Path |
17 | 18 |
|
18 | 19 | import mne |
@@ -99,6 +100,49 @@ def fn(fname, *args, **kwargs): |
99 | 100 | _read_raw_edf = _wrap_read_raw(mne.io.read_raw_edf) |
100 | 101 |
|
101 | 102 |
|
| 103 | +def _make_parallel_raw(subject, *, seed=None): |
| 104 | + """Generate a lightweight Raw instance for parallel-reading tests.""" |
| 105 | + rng_seed = seed if seed is not None else sum(ord(ch) for ch in subject) |
| 106 | + rng = np.random.default_rng(rng_seed) |
| 107 | + info = mne.create_info(["MEG0113"], 100, ch_types="mag") |
| 108 | + data = rng.standard_normal((1, 100)) * 1e-12 |
| 109 | + raw = mne.io.RawArray(data, info) |
| 110 | + raw.set_meas_date(datetime(2020, 1, 1, tzinfo=timezone.utc)) |
| 111 | + raw.info["line_freq"] = 60 |
| 112 | + raw.info["subject_info"] = { |
| 113 | + "his_id": subject, |
| 114 | + "sex": 1, |
| 115 | + "hand": 2, |
| 116 | + "birthday": date(1990, 1, 1), |
| 117 | + } |
| 118 | + return raw |
| 119 | + |
| 120 | + |
| 121 | +def _write_parallel_dataset(root, *, subject, run): |
| 122 | + """Write a minimal dataset using write_raw_bids.""" |
| 123 | + root = Path(root) |
| 124 | + raw = _make_parallel_raw(subject) |
| 125 | + bids_path = BIDSPath( |
| 126 | + subject=subject, task="rest", run=run, datatype="meg", root=root |
| 127 | + ) |
| 128 | + write_raw_bids(raw, bids_path, allow_preload=True, format="FIF", verbose=False) |
| 129 | + |
| 130 | + |
| 131 | +def _parallel_read_participants(root, expected_ids): |
| 132 | + """Read participants.tsv in a multiprocessing worker.""" |
| 133 | + participants_path = Path(root) / "participants.tsv" |
| 134 | + participants = _from_tsv(participants_path) |
| 135 | + assert set(participants["participant_id"]) == set(expected_ids) |
| 136 | + |
| 137 | + |
| 138 | +def _parallel_read_scans(root, expected_filenames): |
| 139 | + """Read scans.tsv in a multiprocessing worker.""" |
| 140 | + scans_path = BIDSPath(subject="01", root=root, suffix="scans", extension=".tsv") |
| 141 | + scans = _from_tsv(scans_path.fpath) |
| 142 | + filenames = {str(filename) for filename in scans["filename"]} |
| 143 | + assert filenames == set(expected_filenames) |
| 144 | + |
| 145 | + |
102 | 146 | def test_read_raw(): |
103 | 147 | """Test the raw reading.""" |
104 | 148 | # Use a file ending that does not exist |
@@ -133,6 +177,71 @@ def test_read_correct_inputs(): |
133 | 177 | read_raw_bids(bids_path) |
134 | 178 |
|
135 | 179 |
|
| 180 | +@pytest.mark.filterwarnings( |
| 181 | + "ignore:No events found or provided:RuntimeWarning", |
| 182 | + "ignore:Found no extension for raw file.*:RuntimeWarning", |
| 183 | +) |
| 184 | +def test_parallel_participants_multiprocess(tmp_path): |
| 185 | + """Ensure parallel reads keep all participants entries visible.""" |
| 186 | + bids_root = tmp_path / "parallel_multiprocess" |
| 187 | + subjects = [f"{i:02d}" for i in range(1, 50)] |
| 188 | + |
| 189 | + for subject in subjects: |
| 190 | + _write_parallel_dataset(str(bids_root), subject=subject, run="01") |
| 191 | + |
| 192 | + expected_ids = [f"sub-{subject}" for subject in subjects] |
| 193 | + processes = [] |
| 194 | + for _ in range(len(subjects) // 10): # spawn a few processes |
| 195 | + proc = mp.Process( |
| 196 | + target=_parallel_read_participants, args=(str(bids_root), expected_ids) |
| 197 | + ) |
| 198 | + proc.start() |
| 199 | + processes.append(proc) |
| 200 | + |
| 201 | + for proc in processes: |
| 202 | + proc.join() |
| 203 | + assert proc.exitcode == 0 |
| 204 | + |
| 205 | + participants_path = bids_root / "participants.tsv" |
| 206 | + assert participants_path.exists() |
| 207 | + participants = _from_tsv(participants_path) |
| 208 | + assert set(participants["participant_id"]) == set(expected_ids) |
| 209 | + sh.rmtree(bids_root, ignore_errors=True) |
| 210 | + |
| 211 | + |
| 212 | +@pytest.mark.filterwarnings( |
| 213 | + "ignore:No events found or provided:RuntimeWarning", |
| 214 | + "ignore:Found no extension for raw file.*:RuntimeWarning", |
| 215 | +) |
| 216 | +def test_parallel_scans_multiprocessing(tmp_path): |
| 217 | + """Ensure multiprocessing reads see all runs in scans.tsv.""" |
| 218 | + bids_root = tmp_path / "parallel_multiprocessing" |
| 219 | + runs = [f"{i:02d}" for i in range(1, 50)] |
| 220 | + |
| 221 | + for run in runs: |
| 222 | + _write_parallel_dataset(str(bids_root), subject="01", run=run) |
| 223 | + |
| 224 | + expected = {f"meg/sub-01_task-rest_run-{run}_meg.fif" for run in runs} |
| 225 | + processes = [] |
| 226 | + for _ in range(4): |
| 227 | + proc = mp.Process(target=_parallel_read_scans, args=(str(bids_root), expected)) |
| 228 | + proc.start() |
| 229 | + processes.append(proc) |
| 230 | + |
| 231 | + for proc in processes: |
| 232 | + proc.join() |
| 233 | + assert proc.exitcode == 0 |
| 234 | + |
| 235 | + scans_path = BIDSPath( |
| 236 | + subject="01", root=bids_root, suffix="scans", extension=".tsv" |
| 237 | + ) |
| 238 | + assert scans_path.fpath.exists() |
| 239 | + scans = _from_tsv(scans_path.fpath) |
| 240 | + filenames = {str(filename) for filename in scans["filename"]} |
| 241 | + assert filenames == expected |
| 242 | + sh.rmtree(bids_root, ignore_errors=True) |
| 243 | + |
| 244 | + |
136 | 245 | @pytest.mark.filterwarnings(warning_str["channel_unit_changed"]) |
137 | 246 | @testing.requires_testing_data |
138 | 247 | def test_read_participants_data(tmp_path): |
|
0 commit comments