Skip to content

Commit 8912cfe

Browse files
Adding test for flags after loading. (lina-usc#178)
* Adding test for flags after loading. * Fix loading epoch flags. * STY: minor linting --------- Co-authored-by: Scott Huberty <[email protected]>
1 parent 5d09232 commit 8912cfe

File tree

4 files changed

+127
-35
lines changed

4 files changed

+127
-35
lines changed

pylossless/datasets/datasets.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pylossless as ll
77

88

9-
def load_openneuro_bids(subject="pd6"):
9+
def load_openneuro_bids(subject="pd6", timeout=20):
1010
"""Download and Load BIDS dataset ds002778 from OpenNeuro.
1111
1212
Parameters
@@ -72,8 +72,14 @@ def load_openneuro_bids(subject="pd6"):
7272
root=bids_root,
7373
)
7474

75-
while not bids_path.fpath.with_suffix(".bdf").exists():
76-
print(list(bids_path.fpath.glob("*")))
75+
for _ in range(timeout):
76+
if bids_path.fpath.with_suffix(".bdf").exists():
77+
break
78+
print("Waiting for .bdf files to be created. Current files available:",
79+
list(bids_path.fpath.glob("*")))
7780
sleep(1)
81+
else:
82+
raise TimeoutError("OpenNeuro failed to create the .bdf files.")
83+
7884
raw = mne_bids.read_raw_bids(bids_path, verbose="ERROR")
7985
return raw, config, bids_path

pylossless/flagging.py

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,53 @@
1717

1818
from .utils._utils import _icalabel_to_data_frame
1919

20+
IC_LABELS = mne_icalabel.config.ICA_LABELS_TO_MNE
21+
CH_LABELS: dict[str, str] = {
22+
"Noisy": "ch_sd",
23+
"Bridged": "bridge",
24+
"Uncorrelated": "low_r",
25+
"Rank": "rank"
26+
}
27+
EPOCH_LABELS: dict[str, str] = {
28+
"Noisy": "noisy",
29+
"Noisy ICs": "noisy_ICs",
30+
"Uncorrelated": "uncorrelated",
31+
}
32+
33+
34+
class _Flagged(dict):
35+
36+
def __init__(self, key_map, kind_str, ll, *args, **kwargs):
37+
"""Initialize class."""
38+
super().__init__(*args, **kwargs)
39+
self.ll = ll
40+
self._key_map = key_map
41+
self._kind_str = kind_str
42+
43+
@property
44+
def valid_keys(self):
45+
"""Return the valid keys."""
46+
return tuple(self._key_map.values())
47+
48+
def __repr__(self):
49+
"""Return a string representation."""
50+
ret_str = f"Flagged {self._kind_str}s: |\n"
51+
for key, val in self._key_map.items():
52+
ret_str += f" {key}: {self.get(val, None)}\n"
53+
return ret_str
54+
55+
def __eq__(self, other):
56+
for key in self.valid_keys:
57+
if not np.array_equal(self.get(key, np.array([])),
58+
other.get(key, np.array([]))):
59+
return False
60+
return True
61+
62+
def __ne__(self, other):
63+
return not self == other
2064

21-
class FlaggedChs(dict):
65+
66+
class FlaggedChs(_Flagged):
2267
"""Object for handling flagged channels in an instance of mne.io.Raw.
2368
2469
Attributes
@@ -47,28 +92,17 @@ class FlaggedChs(dict):
4792
and methods for python dictionaries.
4893
"""
4994

50-
def __init__(self, ll, *args, **kwargs):
95+
def __init__(self, *args, **kwargs):
5196
"""Initialize class."""
52-
super().__init__(*args, **kwargs)
53-
self.ll = ll
54-
55-
def __repr__(self):
56-
"""Return a string representation of the FlaggedChs object."""
57-
return (
58-
f"Flagged channels: |\n"
59-
f" Noisy: {self.get('ch_sd', None)}\n"
60-
f" Bridged: {self.get('bridge', None)}\n"
61-
f" Uncorrelated: {self.get('low_r', None)}\n"
62-
f" Rank: {self.get('rank', None)}\n"
63-
)
97+
super().__init__(CH_LABELS, "channel", *args, **kwargs)
6498

6599
def add_flag_cat(self, kind, bad_ch_names, *args):
66100
"""Store channel names that have been flagged by pipeline.
67101
68102
Parameters
69103
----------
70104
kind : str
71-
Should be one of ``'outlier'``, ``'ch_sd'``, ``'low_r'``,
105+
Should be one of ``'ch_sd'``, ``'low_r'``,
72106
``'bridge'``, ``'rank'``.
73107
bad_ch_names : list | tuple
74108
Channel names. Will be the values corresponding to the ``kind``
@@ -140,7 +174,7 @@ def load_tsv(self, fname):
140174
self[label] = grp_df.ch_names.values
141175

142176

143-
class FlaggedEpochs(dict):
177+
class FlaggedEpochs(_Flagged):
144178
"""Object for handling flagged Epochs in an instance of mne.Epochs.
145179
146180
Methods
@@ -159,7 +193,7 @@ class FlaggedEpochs(dict):
159193
and methods for python dictionaries.
160194
"""
161195

162-
def __init__(self, ll, *args, **kwargs):
196+
def __init__(self, *args, **kwargs):
163197
"""Initialize class.
164198
165199
Parameters
@@ -171,9 +205,7 @@ def __init__(self, ll, *args, **kwargs):
171205
kwargs : dict
172206
keyword arguments accepted by python's dictionary class.
173207
"""
174-
super().__init__(*args, **kwargs)
175-
176-
self.ll = ll
208+
super().__init__(EPOCH_LABELS, "epoch", *args, **kwargs)
177209

178210
def add_flag_cat(self, kind, bad_epoch_inds, epochs):
179211
"""Add information on time periods flagged by pyLossless.
@@ -194,17 +226,27 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs):
194226
self[kind] = bad_epoch_inds
195227
self.ll.add_pylossless_annotations(bad_epoch_inds, kind, epochs)
196228

197-
def load_from_raw(self, raw):
229+
def load_from_raw(self, raw, events, config):
198230
"""Load pylossless annotations from raw object."""
199231
sfreq = raw.info["sfreq"]
232+
tmax = config["epoching"]["epochs_args"]["tmax"]
233+
tmin = config["epoching"]["epochs_args"]["tmin"]
234+
starts = events[:, 0] / sfreq - tmin
235+
stops = events[:, 0] / sfreq + tmax
200236
for annot in raw.annotations:
201-
if annot["description"].upper().startswith("BAD_LL"):
202-
ind_onset = int(np.round(annot["onset"] * sfreq))
203-
ind_dur = int(np.round(annot["duration"] * sfreq))
204-
inds = np.arange(ind_onset, ind_onset + ind_dur)
205-
if annot["description"] not in self:
206-
self[annot["description"]] = list()
207-
self[annot["description"]].append(inds)
237+
if annot["description"].upper().startswith("BAD_LL_"):
238+
onset = annot["onset"]
239+
offset = annot["onset"] + annot["duration"]
240+
mask = (
241+
(starts >= onset) & (starts < offset)
242+
| (stops > onset) & (stops <= offset)
243+
| (onset <= starts) & (offset >= stops)
244+
)
245+
inds = np.where(mask)[0]
246+
desc = annot["description"].lower().replace("bad_ll_", "")
247+
if desc not in self:
248+
self[desc] = np.array([])
249+
self[desc] = np.concatenate((self[desc], inds))
208250

209251

210252
class FlaggedICs(pd.DataFrame):

pylossless/pipeline.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,9 +1229,20 @@ def run_dataset(self, paths):
12291229
for path in paths:
12301230
self.run(path)
12311231

1232-
# TODO: Finish docstring
12331232
def load_ll_derivative(self, derivatives_path):
1234-
"""Load a completed pylossless derivative state."""
1233+
"""Load a completed pylossless derivative state.
1234+
1235+
Parameters
1236+
----------
1237+
derivatives_path : str | mne_bids.BIDSPath
1238+
Path to a saved pylossless derivatives.
1239+
1240+
Returns
1241+
-------
1242+
:class:`~pylossless.pipeline.LosslessPipeline`
1243+
Returns an instance of :class:`~pylossless.pipeline.LosslessPipeline`
1244+
for the loaded pylossless derivative state.
1245+
"""
12351246
if not isinstance(derivatives_path, BIDSPath):
12361247
derivatives_path = get_bids_path_from_fname(derivatives_path)
12371248
self.raw = mne_bids.read_raw_bids(derivatives_path)
@@ -1260,7 +1271,7 @@ def load_ll_derivative(self, derivatives_path):
12601271
self.flags["ch"].load_tsv(flagged_chs_fpath.fpath)
12611272

12621273
# Load Flagged Epochs
1263-
self.flags["epoch"].load_from_raw(self.raw)
1274+
self.flags["epoch"].load_from_raw(self.raw, self.get_events(), self.config)
12641275

12651276
return self
12661277

pylossless/tests/test_pipeline.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
2+
import mne_bids
33
import pytest
44

55
import pylossless as ll
@@ -54,3 +54,36 @@ def test_deprecation():
5454
# with pytest.raises(DeprecationWarning, match=f"config_fname is deprecated"):
5555
# DeprecationWarning are currently ignored by pytest given our toml file
5656
pipeline.config_fname = pipeline.config_fname
57+
58+
59+
@pytest.mark.filterwarnings("ignore:Converting data files to EDF format")
60+
def test_load_flags(pipeline_fixture, tmp_path):
61+
"""Test running the pipeline."""
62+
bids_root = tmp_path / "derivatives" / "pylossless"
63+
print(bids_root)
64+
65+
subject = "pd6"
66+
datatype = "eeg"
67+
session = "off"
68+
task = "rest"
69+
suffix = "eeg"
70+
bids_path = mne_bids.BIDSPath(
71+
subject=subject,
72+
session=session,
73+
task=task,
74+
suffix=suffix,
75+
datatype=datatype,
76+
root=bids_root
77+
)
78+
79+
pipeline_fixture.save(bids_path,
80+
overwrite=False, format="EDF", event_id=None)
81+
pipeline = ll.LosslessPipeline().load_ll_derivative(bids_path)
82+
83+
assert pipeline_fixture.flags['ch'] == pipeline.flags['ch']
84+
pipeline.flags['ch']["bridge"] = ["xx"]
85+
assert pipeline_fixture.flags['ch'] != pipeline.flags['ch']
86+
87+
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']
88+
pipeline.flags['epoch']["bridge"] = ["noisy"]
89+
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']

0 commit comments

Comments
 (0)