Skip to content

Commit 835f1f9

Browse files
Add test for pipeline saving.
1 parent 47858f6 commit 835f1f9

File tree

3 files changed

+65
-17
lines changed

3 files changed

+65
-17
lines changed

pylossless/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Pytest fixtures that can be reused across our unit tests."""
22
# Author: Scott Huberty <[email protected]>
3+
# Christian O'Reilly <[email protected]>
34
#
45
# License: MIT
56

67
from pathlib import Path
78
import shutil
89

10+
import mne
11+
import numpy as np
912
from mne import Annotations
1013

1114
import pylossless as ll
@@ -48,3 +51,30 @@ def pipeline_fixture():
4851
Path("test_config.yaml").unlink() # delete config file
4952
shutil.rmtree(bids_path.root)
5053
return pipeline
54+
55+
56+
@pytest.fixture(scope="session")
57+
@pytest.mark.filterwarnings("ignore:Converting data files to EDF format")
58+
def bids_dataset_fixture(tmpdir_factory):
59+
"""Make sure MNE's annotate_break function can run."""
60+
def edf_import_fct(path_in):
61+
# read in a file
62+
raw = mne.io.read_raw_edf(path_in, preload=True)
63+
match_alias = {ch_name: ch_name.strip(".") for ch_name in raw.ch_names}
64+
raw.set_montage("standard_1005", match_alias=match_alias, match_case=False)
65+
return raw, np.array([[0, 0, 0]]), {"test": 0, "T0": 1, "T1": 2, "T2": 3}
66+
67+
tmp_path = tmpdir_factory.mktemp('bids_dataset')
68+
testing_path = mne.datasets.testing.data_path()
69+
fname = testing_path / "EDF" / "test_edf_overlapping_annotations.edf"
70+
import_args = [{"path_in": fname}]
71+
bids_path_args = [{'subject': '001', 'run': '01', 'session': '01',
72+
"task": "test"}]
73+
bids_path = ll.bids.convert_dataset_to_bids(
74+
edf_import_fct,
75+
import_args,
76+
bids_path_args,
77+
bids_root=tmp_path,
78+
overwrite=True
79+
)[0]
80+
return bids_path

pylossless/pipeline.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def __init__(self, config_path=None, config=None):
457457
config : pylossless.config.Config | None
458458
:class:`pylossless.config.Config` object for the pipeline.
459459
"""
460+
self.bids_path = None
460461
self.flags = {
461462
"ch": FlaggedChs(self),
462463
"epoch": FlaggedEpochs(self),
@@ -1060,7 +1061,7 @@ def flag_noisy_ics(self):
10601061

10611062
# icsd_epoch_flags=padflags(raw, icsd_epoch_flags,1,'value',.5);
10621063

1063-
def save(self, derivatives_path, overwrite=False, format="EDF", event_id=None):
1064+
def save(self, derivatives_path=None, overwrite=False, format="EDF", event_id=None):
10641065
"""Save the file at the end of the pipeline.
10651066
10661067
Parameters
@@ -1075,6 +1076,10 @@ def save(self, derivatives_path, overwrite=False, format="EDF", event_id=None):
10751076
event_id : dict | None (default None)
10761077
Dictionary mapping annotation descriptions to event codes.
10771078
"""
1079+
1080+
if derivatives_path is None:
1081+
derivatives_path = self.get_derivative_path(self.bids_path)
1082+
10781083
mne_bids.write_raw_bids(
10791084
self.raw,
10801085
derivatives_path,
@@ -1121,22 +1126,17 @@ def filter(self):
11211126
# 5.a. Filter lowpass/highpass
11221127
self.raw.filter(**self.config["filtering"]["filter_args"])
11231128

1129+
# 5.b. Filter notch
11241130
if "notch_filter_args" in self.config["filtering"]:
11251131
notch_args = self.config["filtering"]["notch_filter_args"]
1126-
# in raw.notch_filter, freqs=None is ok if method=spectrum_fit
1127-
if not notch_args["freqs"] and "method" not in notch_args:
1128-
logger.info("No notch filter arguments provided. Skipping")
1129-
else:
1132+
spectrum_fit_method = (
1133+
"method" in notch_args and notch_args["method"] == "spectrum_fit"
1134+
)
1135+
if notch_args["freqs"] or spectrum_fit_method:
1136+
# in raw.notch_filter, freqs=None is ok if method=='spectrum_fit'
11301137
self.raw.notch_filter(**notch_args)
1131-
1132-
# 5.b. Filter notch
1133-
notch_args = self.config["filtering"]["notch_filter_args"]
1134-
spectrum_fit_method = (
1135-
"method" in notch_args and notch_args["method"] == "spectrum_fit"
1136-
)
1137-
if notch_args["freqs"] or spectrum_fit_method:
1138-
# in raw.notch_filter, freqs=None is ok if method=='spectrum_fit'
1139-
self.raw.notch_filter(**notch_args)
1138+
else:
1139+
logger.info("No notch filter arguments provided. Skipping")
11401140
else:
11411141
logger.info("No notch filter arguments provided. Skipping")
11421142

pylossless/tests/test_pipeline.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from pathlib import Path
2-
2+
import mne
33
import pytest
4+
import shutil
45

56
import pylossless as ll
67

7-
import mne
8-
98

109
def test_empty_repr(tmp_path):
1110
"""Test the __repr__ method for a pipeline that hasn't run."""
@@ -25,6 +24,25 @@ def test_pipeline_run(pipeline_fixture):
2524
assert pipeline_fixture.flags["ch"].__repr__()
2625

2726

27+
@pytest.mark.filterwarnings("ignore:Converting data files to EDF format")
28+
@pytest.mark.filterwarnings("ignore:The provided Epochs instance is not filtered between 1 and 100 Hz.")
29+
def test_pipeline_save(bids_dataset_fixture, tmp_path):
30+
"""Test running the pipeline."""
31+
config = ll.config.Config()
32+
config.load_default()
33+
config["filtering"]["filter_args"]["h_freq"] = 40
34+
del config["filtering"]["notch_filter_args"]
35+
36+
pipeline = ll.LosslessPipeline(config=config)
37+
pipeline.run(bids_dataset_fixture, save=True)
38+
39+
with pytest.raises(FileExistsError):
40+
pipeline.save(overwrite=False, format="EDF")
41+
pipeline.save(overwrite=True, format="EDF")
42+
43+
shutil.rmtree(bids_dataset_fixture.root)
44+
45+
2846
@pytest.mark.parametrize("logging", [True, False])
2947
def test_find_breaks(logging):
3048
"""Make sure MNE's annotate_break function can run."""

0 commit comments

Comments
 (0)