1717
1818from .utils ._utils import _icalabel_to_data_frame
1919
20+ IC_LABELS = mne_icalabel .config .ICA_LABELS_TO_MNE
21+ CH_LABELS : dict [str , str ] = {
22+ "Noisy" : "ch_sd" ,
23+ "Bridged" : "bridge" ,
24+ "Uncorrelated" : "low_r" ,
25+ "Rank" : "rank"
26+ }
27+ EPOCH_LABELS : dict [str , str ] = {
28+ "Noisy" : "noisy" ,
29+ "Noisy ICs" : "noisy_ICs" ,
30+ "Uncorrelated" : "uncorrelated" ,
31+ }
32+
33+
34+ class _Flagged (dict ):
35+
36+ def __init__ (self , key_map , kind_str , ll , * args , ** kwargs ):
37+ """Initialize class."""
38+ super ().__init__ (* args , ** kwargs )
39+ self .ll = ll
40+ self ._key_map = key_map
41+ self ._kind_str = kind_str
42+
43+ @property
44+ def valid_keys (self ):
45+ """Return the valid keys."""
46+ return tuple (self ._key_map .values ())
47+
48+ def __repr__ (self ):
49+ """Return a string representation."""
50+ ret_str = f"Flagged { self ._kind_str } s: |\n "
51+ for key , val in self ._key_map .items ():
52+ ret_str += f" { key } : { self .get (val , None )} \n "
53+ return ret_str
54+
55+ def __eq__ (self , other ):
56+ for key in self .valid_keys :
57+ if not np .array_equal (self .get (key , np .array ([])),
58+ other .get (key , np .array ([]))):
59+ return False
60+ return True
61+
62+ def __ne__ (self , other ):
63+ return not self == other
2064
21- class FlaggedChs (dict ):
65+
66+ class FlaggedChs (_Flagged ):
2267 """Object for handling flagged channels in an instance of mne.io.Raw.
2368
2469 Attributes
@@ -47,28 +92,17 @@ class FlaggedChs(dict):
4792 and methods for python dictionaries.
4893 """
4994
50- def __init__ (self , ll , * args , ** kwargs ):
95+ def __init__ (self , * args , ** kwargs ):
5196 """Initialize class."""
52- super ().__init__ (* args , ** kwargs )
53- self .ll = ll
54-
55- def __repr__ (self ):
56- """Return a string representation of the FlaggedChs object."""
57- return (
58- f"Flagged channels: |\n "
59- f" Noisy: { self .get ('ch_sd' , None )} \n "
60- f" Bridged: { self .get ('bridge' , None )} \n "
61- f" Uncorrelated: { self .get ('low_r' , None )} \n "
62- f" Rank: { self .get ('rank' , None )} \n "
63- )
97+ super ().__init__ (CH_LABELS , "channel" , * args , ** kwargs )
6498
6599 def add_flag_cat (self , kind , bad_ch_names , * args ):
66100 """Store channel names that have been flagged by pipeline.
67101
68102 Parameters
69103 ----------
70104 kind : str
71- Should be one of ``'outlier'``, ``' ch_sd'``, ``'low_r'``,
105+ Should be one of ``'ch_sd'``, ``'low_r'``,
72106 ``'bridge'``, ``'rank'``.
73107 bad_ch_names : list | tuple
74108 Channel names. Will be the values corresponding to the ``kind``
@@ -140,7 +174,7 @@ def load_tsv(self, fname):
140174 self [label ] = grp_df .ch_names .values
141175
142176
143- class FlaggedEpochs (dict ):
177+ class FlaggedEpochs (_Flagged ):
144178 """Object for handling flagged Epochs in an instance of mne.Epochs.
145179
146180 Methods
@@ -159,7 +193,7 @@ class FlaggedEpochs(dict):
159193 and methods for python dictionaries.
160194 """
161195
162- def __init__ (self , ll , * args , ** kwargs ):
196+ def __init__ (self , * args , ** kwargs ):
163197 """Initialize class.
164198
165199 Parameters
@@ -171,9 +205,7 @@ def __init__(self, ll, *args, **kwargs):
171205 kwargs : dict
172206 keyword arguments accepted by python's dictionary class.
173207 """
174- super ().__init__ (* args , ** kwargs )
175-
176- self .ll = ll
208+ super ().__init__ (EPOCH_LABELS , "epoch" , * args , ** kwargs )
177209
178210 def add_flag_cat (self , kind , bad_epoch_inds , epochs ):
179211 """Add information on time periods flagged by pyLossless.
@@ -194,17 +226,27 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs):
194226 self [kind ] = bad_epoch_inds
195227 self .ll .add_pylossless_annotations (bad_epoch_inds , kind , epochs )
196228
197- def load_from_raw (self , raw ):
229+ def load_from_raw (self , raw , events , config ):
198230 """Load pylossless annotations from raw object."""
199231 sfreq = raw .info ["sfreq" ]
232+ tmax = config ["epoching" ]["epochs_args" ]["tmax" ]
233+ tmin = config ["epoching" ]["epochs_args" ]["tmin" ]
234+ starts = events [:, 0 ] / sfreq - tmin
235+ stops = events [:, 0 ] / sfreq + tmax
200236 for annot in raw .annotations :
201- if annot ["description" ].upper ().startswith ("BAD_LL" ):
202- ind_onset = int (np .round (annot ["onset" ] * sfreq ))
203- ind_dur = int (np .round (annot ["duration" ] * sfreq ))
204- inds = np .arange (ind_onset , ind_onset + ind_dur )
205- if annot ["description" ] not in self :
206- self [annot ["description" ]] = list ()
207- self [annot ["description" ]].append (inds )
237+ if annot ["description" ].upper ().startswith ("BAD_LL_" ):
238+ onset = annot ["onset" ]
239+ offset = annot ["onset" ] + annot ["duration" ]
240+ mask = (
241+ (starts >= onset ) & (starts < offset )
242+ | (stops > onset ) & (stops <= offset )
243+ | (onset <= starts ) & (offset >= stops )
244+ )
245+ inds = np .where (mask )[0 ]
246+ desc = annot ["description" ].lower ().replace ("bad_ll_" , "" )
247+ if desc not in self :
248+ self [desc ] = np .array ([])
249+ self [desc ] = np .concatenate ((self [desc ], inds ))
208250
209251
210252class FlaggedICs (pd .DataFrame ):
0 commit comments