1717
1818from .utils ._utils import _icalabel_to_data_frame
1919
20+ IC_LABELS = mne_icalabel .config .ICA_LABELS_TO_MNE
21+ CH_LABELS : dict [str , str ] = {
22+ "Noisy" : "ch_sd" ,
23+ "Bridged" : "bridge" ,
24+ "Uncorrelated" : "low_r" ,
25+ "Rank" : "rank"
26+ }
27+ EPOCH_LABELS : dict [str , str ] = {
28+ "Noisy" : "noisy" ,
29+ "Noisy ICs" : "noisy_ICs" ,
30+ "Uncorrelated" : "uncorrelated" ,
31+ }
32+
33+
34+ class _Flagged (dict ):
35+
36+ def __init__ (self , key_map , kind_str , ll , * args , ** kwargs ):
37+ """Initialize class."""
38+ super ().__init__ (* args , ** kwargs )
39+ self .ll = ll
40+ self ._key_map = key_map
41+ self ._kind_str = kind_str
2042
21- class FlaggedChs (dict ):
43+ @property
44+ def valid_keys (self ):
45+ """Return the valid keys."""
46+ return tuple (self ._key_map .values ())
47+
48+ def __repr__ (self ):
49+ """Return a string representation."""
50+ ret_str = f"Flagged { self ._kind_str } s: |\n "
51+ for key , val in self ._key_map .items ():
52+ ret_str += f" { key } : { self .get (val , None )} \n "
53+ return ret_str
54+
55+ def __eq__ (self , other ):
56+ for key in self .valid_keys :
57+ if not np .array_equal (self .get (key , np .array ([])),
58+ other .get (key , np .array ([]))):
59+ return False
60+ return True
61+
62+ def __ne__ (self , other ):
63+ return not self == other
64+
65+
66+ class FlaggedChs (_Flagged ):
2267 """Object for handling flagged channels in an instance of mne.io.Raw.
2368
2469 Attributes
@@ -47,32 +92,9 @@ class FlaggedChs(dict):
4792 and methods for python dictionaries.
4893 """
4994
50- def __init__ (self , ll , * args , ** kwargs ):
95+ def __init__ (self , * args , ** kwargs ):
5196 """Initialize class."""
52- super ().__init__ (* args , ** kwargs )
53- self .ll = ll
54-
55- @property
56- def valid_keys (self ):
57- """Return the valid keys for FlaggedChs objects."""
58- return ('ch_sd' , 'bridge' , 'low_r' , 'rank' )
59-
60- def __repr__ (self ):
61- """Return a string representation of the FlaggedChs object."""
62- return (
63- f"Flagged channels: |\n "
64- f" Noisy: { self .get ('ch_sd' , None )} \n "
65- f" Bridged: { self .get ('bridge' , None )} \n "
66- f" Uncorrelated: { self .get ('low_r' , None )} \n "
67- f" Rank: { self .get ('rank' , None )} \n "
68- )
69-
70- def __eq__ (self , other ):
71- for key in self .valid_keys :
72- if not np .array_equal (self .get (key , np .array ([])),
73- other .get (key , np .array ([]))):
74- return False
75- return True
97+ super ().__init__ (CH_LABELS , "channel" , * args , ** kwargs )
7698
7799 def add_flag_cat (self , kind , bad_ch_names , * args ):
78100 """Store channel names that have been flagged by pipeline.
@@ -152,7 +174,7 @@ def load_tsv(self, fname):
152174 self [label ] = grp_df .ch_names .values
153175
154176
155- class FlaggedEpochs (dict ):
177+ class FlaggedEpochs (_Flagged ):
156178 """Object for handling flagged Epochs in an instance of mne.Epochs.
157179
158180 Methods
@@ -171,7 +193,7 @@ class FlaggedEpochs(dict):
171193 and methods for python dictionaries.
172194 """
173195
174- def __init__ (self , ll , * args , ** kwargs ):
196+ def __init__ (self , * args , ** kwargs ):
175197 """Initialize class.
176198
177199 Parameters
@@ -183,30 +205,7 @@ def __init__(self, ll, *args, **kwargs):
183205 kwargs : dict
184206 keyword arguments accepted by python's dictionary class.
185207 """
186- super ().__init__ (* args , ** kwargs )
187-
188- self .ll = ll
189-
190- @property
191- def valid_keys (self ):
192- """Return the valid keys for FlaggedEpochs objects."""
193- return ('noisy' , 'uncorrelated' , 'noisy_ICs' )
194-
195- def __repr__ (self ):
196- """Return a string representation of the FlaggedEpochs object."""
197- return (
198- f"Flagged channels: |\n "
199- f" Noisy: { self .get ('noisy' , None )} \n "
200- f" Noisy ICs: { self .get ('noisy_ICs' , None )} \n "
201- f" Uncorrelated: { self .get ('uncorrelated' , None )} \n "
202- )
203-
204- def __eq__ (self , other ):
205- for key in self .valid_keys :
206- if not np .array_equal (self .get (key , np .array ([])),
207- other .get (key , np .array ([]))):
208- return False
209- return True
208+ super ().__init__ (EPOCH_LABELS , "epoch" , * args , ** kwargs )
210209
211210 def add_flag_cat (self , kind , bad_epoch_inds , epochs ):
212211 """Add information on time periods flagged by pyLossless.
@@ -227,17 +226,25 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs):
227226 self [kind ] = bad_epoch_inds
228227 self .ll .add_pylossless_annotations (bad_epoch_inds , kind , epochs )
229228
230- def load_from_raw (self , raw ):
229+ def load_from_raw (self , raw , events , config ):
231230 """Load pylossless annotations from raw object."""
232231 sfreq = raw .info ["sfreq" ]
232+ tmax = config ["epoching" ]["epochs_args" ]["tmax" ]
233+ tmin = config ["epoching" ]["epochs_args" ]["tmin" ]
234+ starts = events [:, 0 ]/ sfreq - tmin
235+ stops = events [:, 0 ]/ sfreq + tmax
233236 for annot in raw .annotations :
234- if annot ["description" ].upper ().startswith ("BAD_LL" ):
235- ind_onset = int (np .round (annot ["onset" ] * sfreq ))
236- ind_dur = int (np .round (annot ["duration" ] * sfreq ))
237- inds = np .arange (ind_onset , ind_onset + ind_dur )
238- if annot ["description" ] not in self :
239- self [annot ["description" ]] = list ()
240- self [annot ["description" ]].append (inds )
237+ if annot ["description" ].upper ().startswith ("BAD_LL_" ):
238+ onset = annot ["onset" ]
239+ offset = annot ["onset" ]+ annot ["duration" ]
240+ mask = ((starts >= onset ) & (starts < offset ) |
241+ (stops > onset ) & (stops <= offset ) |
242+ (onset <= starts ) & (offset >= stops ))
243+ inds = np .where (mask )[0 ]
244+ desc = annot ["description" ].lower ().replace ("bad_ll_" , "" )
245+ if desc not in self :
246+ self [desc ] = np .array ([])
247+ self [desc ] = np .concatenate ((self [desc ], inds ))
241248
242249
243250class FlaggedICs (pd .DataFrame ):
0 commit comments