Skip to content

Commit cc88579

Browse files
Fix loading epoch flags.
1 parent 159a148 commit cc88579

File tree

3 files changed

+69
-63
lines changed

3 files changed

+69
-63
lines changed

pylossless/flagging.py

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,53 @@
1717

1818
from .utils._utils import _icalabel_to_data_frame
1919

20+
IC_LABELS = mne_icalabel.config.ICA_LABELS_TO_MNE
21+
CH_LABELS: dict[str, str] = {
22+
"Noisy": "ch_sd",
23+
"Bridged": "bridge",
24+
"Uncorrelated": "low_r",
25+
"Rank": "rank"
26+
}
27+
EPOCH_LABELS: dict[str, str] = {
28+
"Noisy": "noisy",
29+
"Noisy ICs": "noisy_ICs",
30+
"Uncorrelated": "uncorrelated",
31+
}
32+
33+
34+
class _Flagged(dict):
35+
36+
def __init__(self, key_map, kind_str, ll, *args, **kwargs):
37+
"""Initialize class."""
38+
super().__init__(*args, **kwargs)
39+
self.ll = ll
40+
self._key_map = key_map
41+
self._kind_str = kind_str
2042

21-
class FlaggedChs(dict):
43+
@property
44+
def valid_keys(self):
45+
"""Return the valid keys."""
46+
return tuple(self._key_map.values())
47+
48+
def __repr__(self):
49+
"""Return a string representation."""
50+
ret_str = f"Flagged {self._kind_str}s: |\n"
51+
for key, val in self._key_map.items():
52+
ret_str += f" {key}: {self.get(val, None)}\n"
53+
return ret_str
54+
55+
def __eq__(self, other):
56+
for key in self.valid_keys:
57+
if not np.array_equal(self.get(key, np.array([])),
58+
other.get(key, np.array([]))):
59+
return False
60+
return True
61+
62+
def __ne__(self, other):
63+
return not self == other
64+
65+
66+
class FlaggedChs(_Flagged):
2267
"""Object for handling flagged channels in an instance of mne.io.Raw.
2368
2469
Attributes
@@ -47,32 +92,9 @@ class FlaggedChs(dict):
4792
and methods for python dictionaries.
4893
"""
4994

50-
def __init__(self, ll, *args, **kwargs):
95+
def __init__(self, *args, **kwargs):
5196
"""Initialize class."""
52-
super().__init__(*args, **kwargs)
53-
self.ll = ll
54-
55-
@property
56-
def valid_keys(self):
57-
"""Return the valid keys for FlaggedChs objects."""
58-
return ('ch_sd', 'bridge', 'low_r', 'rank')
59-
60-
def __repr__(self):
61-
"""Return a string representation of the FlaggedChs object."""
62-
return (
63-
f"Flagged channels: |\n"
64-
f" Noisy: {self.get('ch_sd', None)}\n"
65-
f" Bridged: {self.get('bridge', None)}\n"
66-
f" Uncorrelated: {self.get('low_r', None)}\n"
67-
f" Rank: {self.get('rank', None)}\n"
68-
)
69-
70-
def __eq__(self, other):
71-
for key in self.valid_keys:
72-
if not np.array_equal(self.get(key, np.array([])),
73-
other.get(key, np.array([]))):
74-
return False
75-
return True
97+
super().__init__(CH_LABELS, "channel", *args, **kwargs)
7698

7799
def add_flag_cat(self, kind, bad_ch_names, *args):
78100
"""Store channel names that have been flagged by pipeline.
@@ -152,7 +174,7 @@ def load_tsv(self, fname):
152174
self[label] = grp_df.ch_names.values
153175

154176

155-
class FlaggedEpochs(dict):
177+
class FlaggedEpochs(_Flagged):
156178
"""Object for handling flagged Epochs in an instance of mne.Epochs.
157179
158180
Methods
@@ -171,7 +193,7 @@ class FlaggedEpochs(dict):
171193
and methods for python dictionaries.
172194
"""
173195

174-
def __init__(self, ll, *args, **kwargs):
196+
def __init__(self, *args, **kwargs):
175197
"""Initialize class.
176198
177199
Parameters
@@ -183,30 +205,7 @@ def __init__(self, ll, *args, **kwargs):
183205
kwargs : dict
184206
keyword arguments accepted by python's dictionary class.
185207
"""
186-
super().__init__(*args, **kwargs)
187-
188-
self.ll = ll
189-
190-
@property
191-
def valid_keys(self):
192-
"""Return the valid keys for FlaggedEpochs objects."""
193-
return ('noisy', 'uncorrelated', 'noisy_ICs')
194-
195-
def __repr__(self):
196-
"""Return a string representation of the FlaggedEpochs object."""
197-
return (
198-
f"Flagged channels: |\n"
199-
f" Noisy: {self.get('noisy', None)}\n"
200-
f" Noisy ICs: {self.get('noisy_ICs', None)}\n"
201-
f" Uncorrelated: {self.get('uncorrelated', None)}\n"
202-
)
203-
204-
def __eq__(self, other):
205-
for key in self.valid_keys:
206-
if not np.array_equal(self.get(key, np.array([])),
207-
other.get(key, np.array([]))):
208-
return False
209-
return True
208+
super().__init__(EPOCH_LABELS, "epoch", *args, **kwargs)
210209

211210
def add_flag_cat(self, kind, bad_epoch_inds, epochs):
212211
"""Add information on time periods flagged by pyLossless.
@@ -227,17 +226,25 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs):
227226
self[kind] = bad_epoch_inds
228227
self.ll.add_pylossless_annotations(bad_epoch_inds, kind, epochs)
229228

230-
def load_from_raw(self, raw):
229+
def load_from_raw(self, raw, events, config):
231230
"""Load pylossless annotations from raw object."""
232231
sfreq = raw.info["sfreq"]
232+
tmax = config["epoching"]["epochs_args"]["tmax"]
233+
tmin = config["epoching"]["epochs_args"]["tmin"]
234+
starts = events[:, 0]/sfreq - tmin
235+
stops = events[:, 0]/sfreq + tmax
233236
for annot in raw.annotations:
234-
if annot["description"].upper().startswith("BAD_LL"):
235-
ind_onset = int(np.round(annot["onset"] * sfreq))
236-
ind_dur = int(np.round(annot["duration"] * sfreq))
237-
inds = np.arange(ind_onset, ind_onset + ind_dur)
238-
if annot["description"] not in self:
239-
self[annot["description"]] = list()
240-
self[annot["description"]].append(inds)
237+
if annot["description"].upper().startswith("BAD_LL_"):
238+
onset = annot["onset"]
239+
offset = annot["onset"]+annot["duration"]
240+
mask = ((starts >= onset) & (starts < offset) |
241+
(stops > onset) & (stops <= offset) |
242+
(onset <= starts) & (offset >= stops))
243+
inds = np.where(mask)[0]
244+
desc = annot["description"].lower().replace("bad_ll_", "")
245+
if desc not in self:
246+
self[desc] = np.array([])
247+
self[desc] = np.concatenate((self[desc], inds))
241248

242249

243250
class FlaggedICs(pd.DataFrame):

pylossless/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,7 @@ def load_ll_derivative(self, derivatives_path):
12581258
self.flags["ch"].load_tsv(flagged_chs_fpath.fpath)
12591259

12601260
# Load Flagged Epochs
1261-
self.flags["epoch"].load_from_raw(self.raw)
1261+
self.flags["epoch"].load_from_raw(self.raw, self.get_events(), self.config)
12621262

12631263
return self
12641264

pylossless/tests/test_pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,11 @@ def test_load_flags(pipeline_fixture, tmp_path):
7979
pipeline_fixture.save(bids_path,
8080
overwrite=False, format="EDF", event_id=None)
8181
pipeline = ll.LosslessPipeline().load_ll_derivative(bids_path)
82+
8283
assert pipeline_fixture.flags['ch'] == pipeline.flags['ch']
8384
pipeline.flags['ch']["bridge"] = ["xx"]
8485
assert pipeline_fixture.flags['ch'] != pipeline.flags['ch']
8586

8687
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']
8788
pipeline.flags['epoch']["bridge"] = ["noisy"]
8889
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']
89-
90-
assert pipeline_fixture.flags['ic'] == pipeline.flags['ic']

0 commit comments

Comments
 (0)