@@ -227,6 +227,64 @@ def _detect_outliers(
227227 return prop_outliers [prop_outliers > flag_crit ].coords .to_index ().values
228228
229229
230+ def find_bads_by_threshold (epochs , threshold = 5e-5 ):
231+ """Return channels with a standard deviation consistently above a fixed threshold.
232+
233+ Parameters
234+ ----------
235+ epochs : mne.Epochs
236+ an instance of mne.Epochs with a single channel type.
237+ threshold : float
238+ the threshold in volts. If the standard deviation of a channel's voltage
239+ variance at a specific epoch is above the threshold, then that channel x epoch
240+ will be flagged as an "outlier". If more than 20% of epochs are flagged as
241+ outliers for a specific channel, then that channel will be flagged as bad.
242+ Default threshold is 5e-5 (0.00005), i.e. 50 microvolts.
243+
244+ Returns
245+ -------
246+ list
247+ a list of channel names that are considered outliers.
248+
249+ Notes
250+ -----
251+ If you are having trouble converting between exponential notation and
252+ decimal notation, you can use the following code to convert between the two:
253+
254+ >>> import numpy as np
255+ >>> threshold = 5e-5
256+ >>> with np.printoptions(suppress=True):
257+ ... print(threshold)
258+ 0.00005
259+
260+ .. seealso::
261+
262+ :func:`~pylossless.LosslessPipeline.flag_channels_fixed_threshold` to use
263+ this function within the lossless pipeline.
264+
265+ Examples
266+ --------
267+ >>> import mne
268+ >>> import pylossless as ll
269+ >>> fname = mne.datasets.sample.data_path() / "MEG/sample/sample_audvis_raw.fif"
270+ >>> raw = mne.io.read_raw(fname, preload=True).pick("eeg")
271+ >>> raw.apply_function(lambda x: x * 3, picks=["EEG 001"]) # Make a noisy channel
272+ >>> epochs = mne.make_fixed_length_epochs(raw, preload=True)
273+ >>> bad_chs = ll.pipeline.find_bads_by_threshold(epochs)
274+ """
275+ # TODO: We should make this function handle multiple channel types.
276+ # TODO: but I'd like to avoid making a copy of the epochs object
277+ ch_types = np .unique (epochs .get_channel_types ()).tolist ()
278+ if len (ch_types ) > 1 :
279+ warn (
280+ f"The epochs object contains multiple channel types: { ch_types } .\n "
281+ " This will likely bias the results of the threshold detection."
282+ " Use the `mne.Epochs.pick` to select a single channel type."
283+ )
284+ bads = _threshold_volt_std (epochs , flag_dim = "ch" , threshold = threshold )
285+ return bads
286+
287+
230288def _threshold_volt_std (epochs , flag_dim , threshold = 5e-5 ):
231289 """Detect epochs or channels whose voltage std is above threshold.
232290
@@ -247,7 +305,7 @@ def _threshold_volt_std(epochs, flag_dim, threshold=5e-5):
247305 assert len (threshold ) == 2
248306 l_out , u_out = threshold
249307 init_dir = "both"
250- elif isinstance (threshold , float ):
308+ elif isinstance (threshold , ( float , int ) ):
251309 l_out , u_out = (0 , threshold )
252310 init_dir = "pos"
253311 else :
@@ -765,9 +823,20 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5, picks="eeg"):
765823 on. You may need to assess a more appropriate value for your own data.
766824 """
767825 epochs = self .get_epochs (picks = picks )
768- above_threshold = _threshold_volt_std (
769- epochs , flag_dim = flag_dim , threshold = threshold
770- )
826+ if flag_dim == "ch" :
827+ above_threshold = find_bads_by_threshold (epochs , threshold = threshold )
828+ if above_threshold .any ():
829+ logger .info (
830+ f"🚩 Found { len (above_threshold )} channels with "
831+ f"voltage variance above { threshold } volts: { above_threshold } "
832+ )
833+ else :
834+ msg = f"No channels with standard deviation above { threshold } volts."
835+ logger .info (msg )
836+ else : # TODO: Implement an annotate_bads_by_threshold for epochs
837+ above_threshold = _threshold_volt_std (
838+ epochs , flag_dim = flag_dim , threshold = threshold
839+ )
771840 self .flags [flag_dim ].add_flag_cat ("volt_std" , above_threshold , epochs )
772841
773842 def find_outlier_chs (self , epochs = None , picks = "eeg" ):
@@ -834,6 +903,7 @@ def find_outlier_chs(self, epochs=None, picks="eeg"):
834903
835904 return mean_ch_dist .ch [mean_ch_dist > mdn + 6 * deviation ].values .tolist ()
836905
906+ @lossless_logger
837907 def flag_channels_fixed_threshold (self , threshold = 5e-5 , picks = "eeg" ):
838908 """Flag channels based on the stdev value across the time dimension.
839909
@@ -845,26 +915,50 @@ def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"):
845915 threshold : float
846916 threshold, in volts. If the standard deviation across time in
847917 any channel x epoch indice is above this threshold, then the
848- channel x epoch indices will considered an outlier. Defaults
918+ channel x epoch indices will be considered an outlier. Defaults
849919 to 5e-5, or 50 microvolts. Note that here, 'time' refers to
850920 the samples in an epoch. For each channel, if its std value is
851921 above the given threshold in more than 20% of the epochs, it
852922 is flagged.
853923 picks : str (default "eeg")
854924 Type of channels to pick.
855925
926+ Returns
927+ -------
928+ None
929+ If any channels are flagged, those channel names will be logged
930+ in the `flags` attribute of the `LosslessPipeline` object,
931+ under the key ``'volt_std'``, e.g.
932+ ``my_pipeline.flags["ch"]["volt_std"]``.
933+
856934 Notes
857935 -----
858- WARNING: the default threshold of 50 microvolts may not be appropriate
859- for a particular dataset or data file, as the baseline voltage variance
860- is affected by the impedance of the system that the data was recorded
861- with. You may need to assess a more appropriate value for your own
862- data.
936+ .. warning::
937+
938+ the default threshold of 50 microvolts may not be appropriate
939+ for a particular dataset or data file, as the baseline voltage variance
940+ is affected by the impedance of the system that the data was recorded
941+ with. You may need to assess a more appropriate value for your own
942+ data. You can use the :func:`~pylossless.pipeline.find_bads_by_threshold`
943+ function to quickly assess a more appropriate threshold.
944+
945+ .. seealso::
946+
947+ :func:`~pylossless.pipeline.find_bads_by_threshold`
948+
949+ Examples
950+ --------
951+ >>> import mne
952+ >>> import pylossless as ll
953+ >>> config = ll.Config().load_default()
954+ >>> config["flag_channels_fixed_threshold"] = {"threshold": 5e-5}
955+ >>> pipeline = ll.LosslessPipeline(config=config)
956+ >>> sample_fpath = mne.datasets.sample.data_path()
957+ >>> fpath = sample_fpath / "MEG" / "sample" / "sample_audvis_raw.fif"
958+ >>> raw = mne.io.read_raw(fpath).pick("eeg")
959+ >>> pipeline.raw = raw
960+ >>> pipeline.flag_channels_fixed_threshold()
863961 """
864- if "flag_channels_fixed_threshold" not in self .config :
865- return
866- if "threshold" in self .config ["flag_channels_fixed_threshold" ]:
867- threshold = self .config ["flag_channels_fixed_threshold" ]["threshold" ]
868962 self ._flag_volt_std (flag_dim = "ch" , threshold = threshold , picks = picks )
869963
870964 def flag_epochs_fixed_threshold (self , threshold = 5e-5 , picks = "eeg" ):
@@ -1268,7 +1362,15 @@ def _run(self):
12681362
12691363 # OPTIONAL: Flag chs/epochs based off fixed std threshold of time axis
12701364 self .flag_epochs_fixed_threshold (picks = picks )
1271- self .flag_channels_fixed_threshold (picks = picks )
1365+ if "flag_channels_fixed_threshold" in self .config :
1366+ msg = "Flagging Channels by fixed threshold"
1367+ kwargs = dict (picks = picks , message = msg )
1368+ if "threshold" in self .config ["flag_channels_fixed_threshold" ]:
1369+ threshold = self .config ["flag_channels_fixed_threshold" ][
1370+ "threshold"
1371+ ]
1372+ kwargs ["threshold" ] = threshold
1373+ self .flag_channels_fixed_threshold (** kwargs )
12721374
12731375 # 3.flag channels based on large Stdev. across time
12741376 msg = "Flagging Noisy Channels"
0 commit comments