Skip to content

Commit cde4b85

Browse files
Merge branch 'main' into test_save
2 parents 835f1f9 + 7ed857f commit cde4b85

File tree

6 files changed

+135
-34
lines changed

6 files changed

+135
-34
lines changed

.readthedocs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ python:
2121
- requirements: requirements.txt
2222
- requirements: docs/requirements_doc.txt
2323
- requirements: requirements_testing.txt
24+
- requirements: requirements_rtd.txt

pylossless/datasets/datasets.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pylossless as ll
77

88

9-
def load_openneuro_bids(subject="pd6"):
9+
def load_openneuro_bids(subject="pd6", timeout=20):
1010
"""Download and Load BIDS dataset ds002778 from OpenNeuro.
1111
1212
Parameters
@@ -72,8 +72,14 @@ def load_openneuro_bids(subject="pd6"):
7272
root=bids_root,
7373
)
7474

75-
while not bids_path.fpath.with_suffix(".bdf").exists():
76-
print(list(bids_path.fpath.glob("*")))
75+
for _ in range(timeout):
76+
if bids_path.fpath.with_suffix(".bdf").exists():
77+
break
78+
print("Waiting for .bdf files to be created. Current files available:",
79+
list(bids_path.fpath.glob("*")))
7780
sleep(1)
81+
else:
82+
raise TimeoutError("OpenNeuro failed to create the .bdf files.")
83+
7884
raw = mne_bids.read_raw_bids(bids_path, verbose="ERROR")
7985
return raw, config, bids_path

pylossless/flagging.py

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,53 @@
1717

1818
from .utils._utils import _icalabel_to_data_frame
1919

20+
IC_LABELS = mne_icalabel.config.ICA_LABELS_TO_MNE
21+
CH_LABELS: dict[str, str] = {
22+
"Noisy": "ch_sd",
23+
"Bridged": "bridge",
24+
"Uncorrelated": "low_r",
25+
"Rank": "rank"
26+
}
27+
EPOCH_LABELS: dict[str, str] = {
28+
"Noisy": "noisy",
29+
"Noisy ICs": "noisy_ICs",
30+
"Uncorrelated": "uncorrelated",
31+
}
32+
33+
34+
class _Flagged(dict):
35+
36+
def __init__(self, key_map, kind_str, ll, *args, **kwargs):
37+
"""Initialize class."""
38+
super().__init__(*args, **kwargs)
39+
self.ll = ll
40+
self._key_map = key_map
41+
self._kind_str = kind_str
42+
43+
@property
44+
def valid_keys(self):
45+
"""Return the valid keys."""
46+
return tuple(self._key_map.values())
47+
48+
def __repr__(self):
49+
"""Return a string representation."""
50+
ret_str = f"Flagged {self._kind_str}s: |\n"
51+
for key, val in self._key_map.items():
52+
ret_str += f" {key}: {self.get(val, None)}\n"
53+
return ret_str
54+
55+
def __eq__(self, other):
56+
for key in self.valid_keys:
57+
if not np.array_equal(self.get(key, np.array([])),
58+
other.get(key, np.array([]))):
59+
return False
60+
return True
61+
62+
def __ne__(self, other):
63+
return not self == other
2064

21-
class FlaggedChs(dict):
65+
66+
class FlaggedChs(_Flagged):
2267
"""Object for handling flagged channels in an instance of mne.io.Raw.
2368
2469
Attributes
@@ -47,28 +92,17 @@ class FlaggedChs(dict):
4792
and methods for python dictionaries.
4893
"""
4994

50-
def __init__(self, ll, *args, **kwargs):
95+
def __init__(self, *args, **kwargs):
5196
"""Initialize class."""
52-
super().__init__(*args, **kwargs)
53-
self.ll = ll
54-
55-
def __repr__(self):
56-
"""Return a string representation of the FlaggedChs object."""
57-
return (
58-
f"Flagged channels: |\n"
59-
f" Noisy: {self.get('ch_sd', None)}\n"
60-
f" Bridged: {self.get('bridge', None)}\n"
61-
f" Uncorrelated: {self.get('low_r', None)}\n"
62-
f" Rank: {self.get('rank', None)}\n"
63-
)
97+
super().__init__(CH_LABELS, "channel", *args, **kwargs)
6498

6599
def add_flag_cat(self, kind, bad_ch_names, *args):
66100
"""Store channel names that have been flagged by pipeline.
67101
68102
Parameters
69103
----------
70104
kind : str
71-
Should be one of ``'outlier'``, ``'ch_sd'``, ``'low_r'``,
105+
Should be one of ``'ch_sd'``, ``'low_r'``,
72106
``'bridge'``, ``'rank'``.
73107
bad_ch_names : list | tuple
74108
Channel names. Will be the values corresponding to the ``kind``
@@ -140,7 +174,7 @@ def load_tsv(self, fname):
140174
self[label] = grp_df.ch_names.values
141175

142176

143-
class FlaggedEpochs(dict):
177+
class FlaggedEpochs(_Flagged):
144178
"""Object for handling flagged Epochs in an instance of mne.Epochs.
145179
146180
Methods
@@ -159,7 +193,7 @@ class FlaggedEpochs(dict):
159193
and methods for python dictionaries.
160194
"""
161195

162-
def __init__(self, ll, *args, **kwargs):
196+
def __init__(self, *args, **kwargs):
163197
"""Initialize class.
164198
165199
Parameters
@@ -171,9 +205,7 @@ def __init__(self, ll, *args, **kwargs):
171205
kwargs : dict
172206
keyword arguments accepted by python's dictionary class.
173207
"""
174-
super().__init__(*args, **kwargs)
175-
176-
self.ll = ll
208+
super().__init__(EPOCH_LABELS, "epoch", *args, **kwargs)
177209

178210
def add_flag_cat(self, kind, bad_epoch_inds, epochs):
179211
"""Add information on time periods flagged by pyLossless.
@@ -194,17 +226,27 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs):
194226
self[kind] = bad_epoch_inds
195227
self.ll.add_pylossless_annotations(bad_epoch_inds, kind, epochs)
196228

197-
def load_from_raw(self, raw):
229+
def load_from_raw(self, raw, events, config):
198230
"""Load pylossless annotations from raw object."""
199231
sfreq = raw.info["sfreq"]
232+
tmax = config["epoching"]["epochs_args"]["tmax"]
233+
tmin = config["epoching"]["epochs_args"]["tmin"]
234+
starts = events[:, 0] / sfreq - tmin
235+
stops = events[:, 0] / sfreq + tmax
200236
for annot in raw.annotations:
201-
if annot["description"].upper().startswith("BAD_LL"):
202-
ind_onset = int(np.round(annot["onset"] * sfreq))
203-
ind_dur = int(np.round(annot["duration"] * sfreq))
204-
inds = np.arange(ind_onset, ind_onset + ind_dur)
205-
if annot["description"] not in self:
206-
self[annot["description"]] = list()
207-
self[annot["description"]].append(inds)
237+
if annot["description"].upper().startswith("BAD_LL_"):
238+
onset = annot["onset"]
239+
offset = annot["onset"] + annot["duration"]
240+
mask = (
241+
(starts >= onset) & (starts < offset)
242+
| (stops > onset) & (stops <= offset)
243+
| (onset <= starts) & (offset >= stops)
244+
)
245+
inds = np.where(mask)[0]
246+
desc = annot["description"].lower().replace("bad_ll_", "")
247+
if desc not in self:
248+
self[desc] = np.array([])
249+
self[desc] = np.concatenate((self[desc], inds))
208250

209251

210252
class FlaggedICs(pd.DataFrame):

pylossless/pipeline.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,9 @@ def find_breaks(self):
719719
"""
720720
if "find_breaks" not in self.config or not self.config["find_breaks"]:
721721
return
722+
if not self.raw.annotations:
723+
logger.debug("No annotations found in raw object. Skipping find_breaks.")
724+
return
722725
breaks = annotate_break(self.raw, **self.config["find_breaks"])
723726
self.raw.set_annotations(breaks + self.raw.annotations)
724727

@@ -1229,9 +1232,20 @@ def run_dataset(self, paths):
12291232
for path in paths:
12301233
self.run(path)
12311234

1232-
# TODO: Finish docstring
12331235
def load_ll_derivative(self, derivatives_path):
1234-
"""Load a completed pylossless derivative state."""
1236+
"""Load a completed pylossless derivative state.
1237+
1238+
Parameters
1239+
----------
1240+
derivatives_path : str | mne_bids.BIDSPath
1241+
Path to a saved pylossless derivatives.
1242+
1243+
Returns
1244+
-------
1245+
:class:`~pylossless.pipeline.LosslessPipeline`
1246+
Returns an instance of :class:`~pylossless.pipeline.LosslessPipeline`
1247+
for the loaded pylossless derivative state.
1248+
"""
12351249
if not isinstance(derivatives_path, BIDSPath):
12361250
derivatives_path = get_bids_path_from_fname(derivatives_path)
12371251
self.raw = mne_bids.read_raw_bids(derivatives_path)
@@ -1260,7 +1274,7 @@ def load_ll_derivative(self, derivatives_path):
12601274
self.flags["ch"].load_tsv(flagged_chs_fpath.fpath)
12611275

12621276
# Load Flagged Epochs
1263-
self.flags["epoch"].load_from_raw(self.raw)
1277+
self.flags["epoch"].load_from_raw(self.raw, self.get_events(), self.config)
12641278

12651279
return self
12661280

pylossless/tests/test_pipeline.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22
import mne
3+
import mne_bids
34
import pytest
45
import shutil
56

@@ -61,6 +62,9 @@ def test_find_breaks(logging):
6162
pipeline.find_breaks(message="Looking for break periods between tasks")
6263
else:
6364
pipeline.find_breaks()
65+
# Now explicitly remove annotations and make sure we avoid MNE's error.
66+
pipeline.raw.set_annotations(None)
67+
pipeline.find_breaks()
6468
Path(config_fname).unlink() # delete config file
6569

6670

@@ -72,3 +76,36 @@ def test_deprecation():
7276
# with pytest.raises(DeprecationWarning, match=f"config_fname is deprecated"):
7377
# DeprecationWarning are currently ignored by pytest given our toml file
7478
pipeline.config_fname = pipeline.config_fname
79+
80+
81+
@pytest.mark.filterwarnings("ignore:Converting data files to EDF format")
82+
def test_load_flags(pipeline_fixture, tmp_path):
83+
"""Test running the pipeline."""
84+
bids_root = tmp_path / "derivatives" / "pylossless"
85+
print(bids_root)
86+
87+
subject = "pd6"
88+
datatype = "eeg"
89+
session = "off"
90+
task = "rest"
91+
suffix = "eeg"
92+
bids_path = mne_bids.BIDSPath(
93+
subject=subject,
94+
session=session,
95+
task=task,
96+
suffix=suffix,
97+
datatype=datatype,
98+
root=bids_root
99+
)
100+
101+
pipeline_fixture.save(bids_path,
102+
overwrite=False, format="EDF", event_id=None)
103+
pipeline = ll.LosslessPipeline().load_ll_derivative(bids_path)
104+
105+
assert pipeline_fixture.flags['ch'] == pipeline.flags['ch']
106+
pipeline.flags['ch']["bridge"] = ["xx"]
107+
assert pipeline_fixture.flags['ch'] != pipeline.flags['ch']
108+
109+
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']
110+
pipeline.flags['epoch']["bridge"] = ["noisy"]
111+
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']

requirements_rtd.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
-e .

0 commit comments

Comments
 (0)