|
9 | 9 | import shutil as sh |
10 | 10 | import timeit |
11 | 11 | from datetime import datetime, timezone |
| 12 | +from functools import partial |
12 | 13 | from pathlib import Path |
13 | 14 |
|
14 | 15 | import mne |
|
32 | 33 | _filter_fnames, |
33 | 34 | _find_best_candidates, |
34 | 35 | _parse_ext, |
| 36 | + _return_root_paths, |
35 | 37 | find_matching_paths, |
36 | 38 | get_bids_path_from_fname, |
37 | 39 | get_entities_from_fname, |
@@ -259,6 +261,68 @@ def test_path_benchmark(tmp_path_factory): |
259 | 261 | assert out_4 == [] |
260 | 262 |
|
261 | 263 |
|
| 264 | +def _scan_targeted_meg(root, entities=None): |
| 265 | + return _return_root_paths( |
| 266 | + root, |
| 267 | + datatype="meg", |
| 268 | + ignore_json=True, |
| 269 | + ignore_nosub=False, |
| 270 | + entities=entities, |
| 271 | + ) |
| 272 | + |
| 273 | + |
| 274 | +def test_entity_targeted_scan_benchmark(tmp_path_factory, benchmark): |
| 275 | + """Benchmark the entity-aware root scan optimisation.""" |
| 276 | + bids_root = Path(tmp_path_factory.mktemp("mnebids_entity_scan")) |
| 277 | + |
| 278 | + n_subjects = 60 |
| 279 | + n_sessions = 4 |
| 280 | + n_runs = 12 |
| 281 | + |
| 282 | + target_entities = {"subject": "01", "session": "02"} |
| 283 | + target_sub = f"sub-{target_entities['subject']}" |
| 284 | + target_ses = f"ses-{target_entities['session']}" |
| 285 | + |
| 286 | + for subj_idx in range(1, n_subjects + 1): |
| 287 | + sub_label = f"{subj_idx:02d}" |
| 288 | + for ses_idx in range(1, n_sessions + 1): |
| 289 | + ses_label = f"{ses_idx:02d}" |
| 290 | + meg_dir = bids_root / f"sub-{sub_label}" / f"ses-{ses_label}" / "meg" |
| 291 | + meg_dir.mkdir(parents=True, exist_ok=True) |
| 292 | + for run_idx in range(1, n_runs + 1): |
| 293 | + fname = ( |
| 294 | + f"sub-{sub_label}_ses-{ses_label}_task-" |
| 295 | + f"task_run-{run_idx:02d}_meg.fif" |
| 296 | + ) |
| 297 | + (meg_dir / fname).touch() |
| 298 | + |
| 299 | + timer = timeit.default_timer |
| 300 | + # Warm-up to mitigate cold-cache effects. |
| 301 | + _scan_targeted_meg(bids_root) |
| 302 | + baseline_durations = [] |
| 303 | + for _ in range(3): |
| 304 | + start = timer() |
| 305 | + _scan_targeted_meg(bids_root) |
| 306 | + baseline_durations.append(timer() - start) |
| 307 | + baseline_mean = sum(baseline_durations) / len(baseline_durations) |
| 308 | + |
| 309 | + expected_len = n_runs |
| 310 | + optimized_paths = benchmark( |
| 311 | + partial(_scan_targeted_meg, bids_root, entities=target_entities) |
| 312 | + ) |
| 313 | + optimized_mean = benchmark.stats.stats.mean |
| 314 | + benchmark.extra_info["baseline_mean"] = baseline_mean |
| 315 | + benchmark.extra_info["optimized_mean"] = optimized_mean |
| 316 | + |
| 317 | + assert all( |
| 318 | + target_sub in path.as_posix() and target_ses in path.as_posix() |
| 319 | + for path in optimized_paths |
| 320 | + ) |
| 321 | + assert len(optimized_paths) == expected_len |
| 322 | + # Require a substantial speed-up to guard against regressions. |
| 323 | + assert optimized_mean < baseline_mean * 0.5 |
| 324 | + |
| 325 | + |
262 | 326 | def test_search_folder_for_text(capsys): |
263 | 327 | """Test finding entries.""" |
264 | 328 | with pytest.raises(ValueError, match="is not a directory"): |
|
0 commit comments