Skip to content

Commit 159a148

Browse files
Adding test for flags after loading.
1 parent ec1cedb commit 159a148

File tree

4 files changed

+91
-7
lines changed

4 files changed

+91
-7
lines changed

pylossless/datasets/datasets.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pylossless as ll
77

88

9-
def load_openneuro_bids(subject="pd6"):
9+
def load_openneuro_bids(subject="pd6", timeout=20):
1010
"""Download and Load BIDS dataset ds002778 from OpenNeuro.
1111
1212
Parameters
@@ -72,8 +72,14 @@ def load_openneuro_bids(subject="pd6"):
7272
root=bids_root,
7373
)
7474

75-
while not bids_path.fpath.with_suffix(".bdf").exists():
76-
print(list(bids_path.fpath.glob("*")))
75+
for _ in range(timeout):
76+
if bids_path.fpath.with_suffix(".bdf").exists():
77+
break
78+
print("Waiting for .bdf files to be created. Current files available:",
79+
list(bids_path.fpath.glob("*")))
7780
sleep(1)
81+
else:
82+
raise TimeoutError("OpenNeuro failed to create the .bdf files.")
83+
7884
raw = mne_bids.read_raw_bids(bids_path, verbose="ERROR")
7985
return raw, config, bids_path

pylossless/flagging.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def __init__(self, ll, *args, **kwargs):
5252
super().__init__(*args, **kwargs)
5353
self.ll = ll
5454

55+
@property
56+
def valid_keys(self):
57+
"""Return the valid keys for FlaggedChs objects."""
58+
return ('ch_sd', 'bridge', 'low_r', 'rank')
59+
5560
def __repr__(self):
5661
"""Return a string representation of the FlaggedChs object."""
5762
return (
@@ -62,13 +67,20 @@ def __repr__(self):
6267
f" Rank: {self.get('rank', None)}\n"
6368
)
6469

70+
def __eq__(self, other):
71+
for key in self.valid_keys:
72+
if not np.array_equal(self.get(key, np.array([])),
73+
other.get(key, np.array([]))):
74+
return False
75+
return True
76+
6577
def add_flag_cat(self, kind, bad_ch_names, *args):
6678
"""Store channel names that have been flagged by pipeline.
6779
6880
Parameters
6981
----------
7082
kind : str
71-
Should be one of ``'outlier'``, ``'ch_sd'``, ``'low_r'``,
83+
Should be one of ``'ch_sd'``, ``'low_r'``,
7284
``'bridge'``, ``'rank'``.
7385
bad_ch_names : list | tuple
7486
Channel names. Will be the values corresponding to the ``kind``
@@ -175,6 +187,27 @@ def __init__(self, ll, *args, **kwargs):
175187

176188
self.ll = ll
177189

190+
@property
191+
def valid_keys(self):
192+
"""Return the valid keys for FlaggedEpochs objects."""
193+
return ('noisy', 'uncorrelated', 'noisy_ICs')
194+
195+
def __repr__(self):
196+
"""Return a string representation of the FlaggedEpochs object."""
197+
return (
198+
f"Flagged channels: |\n"
199+
f" Noisy: {self.get('noisy', None)}\n"
200+
f" Noisy ICs: {self.get('noisy_ICs', None)}\n"
201+
f" Uncorrelated: {self.get('uncorrelated', None)}\n"
202+
)
203+
204+
def __eq__(self, other):
205+
for key in self.valid_keys:
206+
if not np.array_equal(self.get(key, np.array([])),
207+
other.get(key, np.array([]))):
208+
return False
209+
return True
210+
178211
def add_flag_cat(self, kind, bad_epoch_inds, epochs):
179212
"""Add information on time periods flagged by pyLossless.
180213

pylossless/pipeline.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,9 +1216,20 @@ def run_dataset(self, paths):
12161216
for path in paths:
12171217
self.run(path)
12181218

1219-
# TODO: Finish docstring
12201219
def load_ll_derivative(self, derivatives_path):
1221-
"""Load a completed pylossless derivative state."""
1220+
"""Load a completed pylossless derivative state.
1221+
1222+
Parameters
1223+
----------
1224+
derivatives_path : str | mne_bids.BIDSPath
1225+
Path to a saved pylossless derivatives.
1226+
1227+
Returns
1228+
-------
1229+
:class:`~pylossless.pipeline.LosslessPipeline`
1230+
Returns an instance of :class:`~pylossless.pipeline.LosslessPipeline`
1231+
for the loaded pylossless derivative state.
1232+
"""
12221233
if not isinstance(derivatives_path, BIDSPath):
12231234
derivatives_path = get_bids_path_from_fname(derivatives_path)
12241235
self.raw = mne_bids.read_raw_bids(derivatives_path)

pylossless/tests/test_pipeline.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
2+
import mne_bids
33
import pytest
44

55
import pylossless as ll
@@ -54,3 +54,37 @@ def test_deprecation():
5454
# with pytest.raises(DeprecationWarning, match=f"config_fname is deprecated"):
5555
# DeprecationWarning are currently ignored by pytest given our toml file
5656
pipeline.config_fname = pipeline.config_fname
57+
58+
59+
@pytest.mark.filterwarnings("ignore:Converting data files to EDF format")
60+
def test_load_flags(pipeline_fixture, tmp_path):
61+
"""Test running the pipeline."""
62+
bids_root = tmp_path / "derivatives" / "pylossless"
63+
print(bids_root)
64+
65+
subject = "pd6"
66+
datatype = "eeg"
67+
session = "off"
68+
task = "rest"
69+
suffix = "eeg"
70+
bids_path = mne_bids.BIDSPath(
71+
subject=subject,
72+
session=session,
73+
task=task,
74+
suffix=suffix,
75+
datatype=datatype,
76+
root=bids_root
77+
)
78+
79+
pipeline_fixture.save(bids_path,
80+
overwrite=False, format="EDF", event_id=None)
81+
pipeline = ll.LosslessPipeline().load_ll_derivative(bids_path)
82+
assert pipeline_fixture.flags['ch'] == pipeline.flags['ch']
83+
pipeline.flags['ch']["bridge"] = ["xx"]
84+
assert pipeline_fixture.flags['ch'] != pipeline.flags['ch']
85+
86+
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']
87+
pipeline.flags['epoch']["bridge"] = ["noisy"]
88+
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']
89+
90+
assert pipeline_fixture.flags['ic'] == pipeline.flags['ic']

0 commit comments

Comments
 (0)