@@ -232,13 +232,14 @@ def find_bads_by_threshold(epochs, threshold=5e-5):
232232
233233 Parameters
234234 ----------
235- inst : mne.Epochs
236- an instance of mne.Epochs
235+ epochs : mne.Epochs
236+ an instance of mne.Epochs with a single channel type.
237237 threshold : float
238238 the threshold in volts. If the standard deviation of a channel's voltage
239- varianceat a specific epoch is above the threshold, then that channel x epoch
240- will be flagged as an "outlier". Default is 5e-5 (0.00005), i.e.
241- 50 microvolts.
239+ variance at a specific epoch is above the threshold, then that channel x epoch
240+ will be flagged as an "outlier". If more than 20% of epochs are flagged as
241+ outliers for a specific channel, then that channel will be flagged as bad.
242+ Default threshold is 5e-5 (0.00005), i.e. 50 microvolts.
242243
243244 Returns
244245 -------
@@ -256,6 +257,11 @@ def find_bads_by_threshold(epochs, threshold=5e-5):
256257 ... print(threshold)
257258 0.00005
258259
260+ .. seealso::
261+
262+ :func:`~pylossless.LosslessPipeline.flag_channels_fixed_threshold` to use
263+ this function within the lossless pipeline.
264+
259265 Examples
260266 --------
261267 >>> import mne
@@ -266,11 +272,17 @@ def find_bads_by_threshold(epochs, threshold=5e-5):
266272 >>> epochs = mne.make_fixed_length_epochs(raw, preload=True)
267273 >>> bad_chs = ll.pipeline.find_bads_by_threshold(epochs)
268274 """
269- bads = _threshold_volt_std (epochs , flag_dim = "ch" , threshold = threshold )
270- logger .info (
271- f"Found { len (bads )} channels with high voltage variance: { bads } "
275+ # XXX: We should make this function handle multiple channel types.
276+ # XXX: but I'd like to avoid making a copy of the epochs object
277+ ch_types = np .unique (epochs .get_channel_types ()).tolist ()
278+ if len (ch_types ) > 1 :
279+ warn (
280+ f"The epochs object contains multiple channel types: { ch_types } .\n "
281+ " This will likely bias the results of the threshold detection."
282+ " Use the `mne.Epochs.pick` to select a single channel type."
272283 )
273- return _threshold_volt_std (epochs , flag_dim = "ch" , threshold = threshold )
284+ bads = _threshold_volt_std (epochs , flag_dim = "ch" , threshold = threshold )
285+ return bads
274286
275287
276288def _threshold_volt_std (epochs , flag_dim , threshold = 5e-5 ):
@@ -293,7 +305,7 @@ def _threshold_volt_std(epochs, flag_dim, threshold=5e-5):
293305 assert len (threshold ) == 2
294306 l_out , u_out = threshold
295307 init_dir = "both"
296- elif isinstance (threshold , float ):
308+ elif isinstance (threshold , ( float , int ) ):
297309 l_out , u_out = (0 , threshold )
298310 init_dir = "pos"
299311 else :
@@ -814,6 +826,14 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5, picks="eeg"):
814826 # So yes this add a few LOC, but IMO it's worth it for readability
815827 if flag_dim == "ch" :
816828 above_threshold = find_bads_by_threshold (epochs , threshold = threshold )
829+ if above_threshold .any ():
830+ logger .info (
831+ f"🚩 Found { len (above_threshold )} channels with "
832+ f"voltage variance above { threshold } volts: { above_threshold } "
833+ )
834+ else :
835+ msg = f"No channels with standard deviation above { threshold } volts."
836+ logger .info (msg )
817837 else : # TODO: Implement an annotate_bads_by_threshold for epochs
818838 above_threshold = _threshold_volt_std (
819839 epochs , flag_dim = flag_dim , threshold = threshold
@@ -884,6 +904,7 @@ def find_outlier_chs(self, epochs=None, picks="eeg"):
884904
885905 return mean_ch_dist .ch [mean_ch_dist > mdn + 6 * deviation ].values .tolist ()
886906
907+ @lossless_logger
887908 def flag_channels_fixed_threshold (self , threshold = 5e-5 , picks = "eeg" ):
888909 """Flag channels based on the stdev value across the time dimension.
889910
@@ -895,26 +916,50 @@ def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"):
895916 threshold : float
896917 threshold, in volts. If the standard deviation across time in
897918 any channel x epoch indice is above this threshold, then the
898- channel x epoch indices will considered an outlier. Defaults
919+ channel x epoch indices will be considered an outlier. Defaults
899920 to 5e-5, or 50 microvolts. Note that here, 'time' refers to
900921 the samples in an epoch. For each channel, if its std value is
901922 above the given threshold in more than 20% of the epochs, it
902923 is flagged.
903924 picks : str (default "eeg")
904925 Type of channels to pick.
905926
927+ Returns
928+ -------
929+ None
930+ If any channels are flagged, those channel names will be logged
931+ in the `flags` attribute of the `LosslessPipeline` object,
932+ under the key ``'volt_std'``, e.g.
933+ ``my_pipeline.flags["ch"]["volt_std"]``.
934+
906935 Notes
907936 -----
908- WARNING: the default threshold of 50 microvolts may not be appropriate
909- for a particular dataset or data file, as the baseline voltage variance
910- is affected by the impedance of the system that the data was recorded
911- with. You may need to assess a more appropriate value for your own
912- data.
937+ .. warning::
938+
939+ the default threshold of 50 microvolts may not be appropriate
940+ for a particular dataset or data file, as the baseline voltage variance
941+ is affected by the impedance of the system that the data was recorded
942+ with. You may need to assess a more appropriate value for your own
943+ data. You can use the :func:`~pylossless.pipeline.find_bads_by_threshold`
944+ function to quickly assess a more appropriate threshold.
945+
946+ .. seealso::
947+
948+ :func:`~pylossless.pipeline.find_bads_by_threshold`
949+
950+ Examples
951+ --------
952+ >>> import mne
953+ >>> import pylossless as ll
954+ >>> config = ll.Config().load_default()
955+ >>> config["flag_channels_fixed_threshold"] = {"threshold": 5e-5}
956+ >>> pipeline = ll.LosslessPipeline(config=config)
957+ >>> sample_fpath = mne.datasets.sample.data_path()
958+ >>> fpath = sample_fpath / "MEG" / "sample" / "sample_audvis_raw.fif"
959+ >>> raw = mne.io.read_raw(fpath).pick("eeg")
960+ >>> pipeline.raw = raw
961+ >>> pipeline.flag_channels_fixed_threshold()
913962 """
914- if "flag_channels_fixed_threshold" not in self .config :
915- return
916- if "threshold" in self .config ["flag_channels_fixed_threshold" ]:
917- threshold = self .config ["flag_channels_fixed_threshold" ]["threshold" ]
918963 self ._flag_volt_std (flag_dim = "ch" , threshold = threshold , picks = picks )
919964
920965 def flag_epochs_fixed_threshold (self , threshold = 5e-5 , picks = "eeg" ):
@@ -1318,7 +1363,15 @@ def _run(self):
13181363
13191364 # OPTIONAL: Flag chs/epochs based off fixed std threshold of time axis
13201365 self .flag_epochs_fixed_threshold (picks = picks )
1321- self .flag_channels_fixed_threshold (picks = picks )
1366+ if "flag_channels_fixed_threshold" in self .config :
1367+ if "threshold" in self .config ["flag_channels_fixed_threshold" ]:
1368+ thresh = self .config ["flag_channels_fixed_threshold" ]["threshold" ]
1369+ else :
1370+ thresh = 5e-5
1371+ msg = "Flagging Channels by fixed threshold"
1372+ self .flag_channels_fixed_threshold (
1373+ threshold = thresh , picks = picks , message = msg
1374+ )
13221375
13231376 # 3.flag channels based on large Stdev. across time
13241377 msg = "Flagging Noisy Channels"
0 commit comments