Skip to content

Commit 6f21f27

Browse files
Test save (lina-usc#186)
* Add test for pipeline saving. * Removing some code duplication. * Fix test issue due to file deletion.
1 parent 675aa2d commit 6f21f27

File tree

4 files changed

+67
-46
lines changed

4 files changed

+67
-46
lines changed

pylossless/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Pytest fixtures that can be reused across our unit tests."""
22
# Author: Scott Huberty <[email protected]>
3+
# Christian O'Reilly <[email protected]>
34
#
45
# License: MIT
56

67
from pathlib import Path
78
import shutil
89

10+
import mne
11+
import numpy as np
912
from mne import Annotations
1013

1114
import pylossless as ll
@@ -48,3 +51,30 @@ def pipeline_fixture():
4851
Path("test_config.yaml").unlink() # delete config file
4952
shutil.rmtree(bids_path.root)
5053
return pipeline
54+
55+
56+
@pytest.fixture(scope="session")
57+
@pytest.mark.filterwarnings("ignore:Converting data files to EDF format")
58+
def bids_dataset_fixture(tmpdir_factory):
59+
"""Return a BIDS path for a test recording."""
60+
def edf_import_fct(path_in):
61+
# read in a file
62+
raw = mne.io.read_raw_edf(path_in, preload=True)
63+
match_alias = {ch_name: ch_name.strip(".") for ch_name in raw.ch_names}
64+
raw.set_montage("standard_1005", match_alias=match_alias, match_case=False)
65+
return raw, np.array([[0, 0, 0]]), {"test": 0, "T0": 1, "T1": 2, "T2": 3}
66+
67+
tmp_path = tmpdir_factory.mktemp('bids_dataset')
68+
testing_path = mne.datasets.testing.data_path()
69+
fname = testing_path / "EDF" / "test_edf_overlapping_annotations.edf"
70+
import_args = [{"path_in": fname}]
71+
bids_path_args = [{'subject': '001', 'run': '01', 'session': '01',
72+
"task": "test"}]
73+
bids_path = ll.bids.convert_dataset_to_bids(
74+
edf_import_fct,
75+
import_args,
76+
bids_path_args,
77+
bids_root=tmp_path,
78+
overwrite=True
79+
)[0]
80+
return bids_path

pylossless/pipeline.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def __init__(self, config_path=None, config=None):
457457
config : pylossless.config.Config | None
458458
:class:`pylossless.config.Config` object for the pipeline.
459459
"""
460+
self.bids_path = None
460461
self.flags = {
461462
"ch": FlaggedChs(self),
462463
"epoch": FlaggedEpochs(self),
@@ -1063,12 +1064,12 @@ def flag_noisy_ics(self):
10631064

10641065
# icsd_epoch_flags=padflags(raw, icsd_epoch_flags,1,'value',.5);
10651066

1066-
def save(self, derivatives_path, overwrite=False, format="EDF", event_id=None):
1067+
def save(self, derivatives_path=None, overwrite=False, format="EDF", event_id=None):
10671068
"""Save the file at the end of the pipeline.
10681069
10691070
Parameters
10701071
----------
1071-
derivatives_path : mne_bids.BIDSPath
1072+
derivatives_path : None | mne_bids.BIDSPath
10721073
path of the derivatives folder to save the file to.
10731074
overwrite : bool (default False)
10741075
whether to overwrite existing files with the same name.
@@ -1078,6 +1079,9 @@ def save(self, derivatives_path, overwrite=False, format="EDF", event_id=None):
10781079
event_id : dict | None (default None)
10791080
Dictionary mapping annotation descriptions to event codes.
10801081
"""
1082+
if derivatives_path is None:
1083+
derivatives_path = self.get_derivative_path(self.bids_path)
1084+
10811085
mne_bids.write_raw_bids(
10821086
self.raw,
10831087
derivatives_path,
@@ -1124,22 +1128,17 @@ def filter(self):
11241128
# 5.a. Filter lowpass/highpass
11251129
self.raw.filter(**self.config["filtering"]["filter_args"])
11261130

1131+
# 5.b. Filter notch
11271132
if "notch_filter_args" in self.config["filtering"]:
11281133
notch_args = self.config["filtering"]["notch_filter_args"]
1129-
# in raw.notch_filter, freqs=None is ok if method=spectrum_fit
1130-
if not notch_args["freqs"] and "method" not in notch_args:
1131-
logger.info("No notch filter arguments provided. Skipping")
1132-
else:
1134+
spectrum_fit_method = (
1135+
"method" in notch_args and notch_args["method"] == "spectrum_fit"
1136+
)
1137+
if notch_args["freqs"] or spectrum_fit_method:
1138+
# in raw.notch_filter, freqs=None is ok if method=='spectrum_fit'
11331139
self.raw.notch_filter(**notch_args)
1134-
1135-
# 5.b. Filter notch
1136-
notch_args = self.config["filtering"]["notch_filter_args"]
1137-
spectrum_fit_method = (
1138-
"method" in notch_args and notch_args["method"] == "spectrum_fit"
1139-
)
1140-
if notch_args["freqs"] or spectrum_fit_method:
1141-
# in raw.notch_filter, freqs=None is ok if method=='spectrum_fit'
1142-
self.raw.notch_filter(**notch_args)
1140+
else:
1141+
logger.info("No notch filter arguments provided. Skipping")
11431142
else:
11441143
logger.info("No notch filter arguments provided. Skipping")
11451144

pylossless/tests/test_bids.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +0,0 @@
1-
import pylossless as ll
2-
import mne
3-
import numpy as np
4-
import pytest
5-
import shutil
6-
7-
8-
@pytest.mark.filterwarnings("ignore:Converting data files to EDF format")
9-
def test_convert_dataset_to_bids(tmp_path):
10-
"""Make sure MNE's annotate_break function can run."""
11-
def edf_import_fct(path_in):
12-
# read in a file
13-
raw = mne.io.read_raw_edf(path_in, preload=True)
14-
print(raw.annotations)
15-
return raw, np.array([[0, 0, 0]]), {"test": 0, "T0": 1, "T1": 2, "T2": 3}
16-
17-
testing_path = mne.datasets.testing.data_path()
18-
fname = testing_path / "EDF" / "test_edf_overlapping_annotations.edf"
19-
import_args = [{"path_in": fname}]
20-
bids_path_args = [{'subject': '001', 'run': '01', 'session': '01',
21-
"task": "test"}]
22-
ll.bids.convert_dataset_to_bids(
23-
edf_import_fct,
24-
import_args,
25-
bids_path_args,
26-
bids_root=tmp_path / "bids_dataset",
27-
overwrite=True
28-
)
29-
shutil.rmtree(tmp_path / "bids_dataset")

pylossless/tests/test_pipeline.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from pathlib import Path
2+
import mne
23
import mne_bids
34
import pytest
45

56
import pylossless as ll
67

7-
import mne
8-
98

109
def test_empty_repr(tmp_path):
1110
"""Test the __repr__ method for a pipeline that hasn't run."""
@@ -25,6 +24,28 @@ def test_pipeline_run(pipeline_fixture):
2524
assert pipeline_fixture.flags["ch"].__repr__()
2625

2726

27+
@pytest.mark.filterwarnings("ignore:Converting data files to EDF format")
28+
@pytest.mark.filterwarnings("ignore:The provided Epochs instance is not"
29+
" filtered between 1 and 100 Hz.")
30+
def test_pipeline_save(bids_dataset_fixture):
31+
"""Test running the pipeline."""
32+
config = ll.config.Config()
33+
config.load_default()
34+
config["filtering"]["filter_args"]["h_freq"] = 40
35+
del config["filtering"]["notch_filter_args"]
36+
37+
pipeline = ll.LosslessPipeline(config=config)
38+
pipeline.run(bids_dataset_fixture, save=True)
39+
40+
with pytest.raises(FileExistsError):
41+
pipeline.save(overwrite=False, format="EDF")
42+
pipeline.save(overwrite=True, format="EDF")
43+
44+
# Files are created in a tmp folder so no need
45+
# to clean up...
46+
# shutil.rmtree(bids_dataset_fixture.root)
47+
48+
2849
@pytest.mark.parametrize("logging", [True, False])
2950
def test_find_breaks(logging):
3051
"""Make sure MNE's annotate_break function can run."""

0 commit comments

Comments
 (0)