diff --git a/pylossless/conftest.py b/pylossless/conftest.py index 50dfa27f..ae44dae7 100644 --- a/pylossless/conftest.py +++ b/pylossless/conftest.py @@ -31,6 +31,7 @@ def pipeline_fixture(): config["find_breaks"]["min_break_duration"] = 9 config["find_breaks"]["t_start_after_previous"] = 1 config["find_breaks"]["t_stop_before_next"] = 0 + config["flag_channels_fixed_threshold"] = {"threshold": 10_000} config["ica"]["ica_args"]["run1"]["max_iter"] = 5000 # Testing when passing the config object directly... diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index 45931cf0..79dde603 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -227,6 +227,64 @@ def _detect_outliers( return prop_outliers[prop_outliers > flag_crit].coords.to_index().values +def find_bads_by_threshold(epochs, threshold=5e-5): + """Return channels with a standard deviation consistently above a fixed threshold. + + Parameters + ---------- + epochs : mne.Epochs + an instance of mne.Epochs with a single channel type. + threshold : float + the threshold in volts. If the standard deviation of a channel's voltage + variance at a specific epoch is above the threshold, then that channel x epoch + will be flagged as an "outlier". If more than 20% of epochs are flagged as + outliers for a specific channel, then that channel will be flagged as bad. + Default threshold is 5e-5 (0.00005), i.e. 50 microvolts. + + Returns + ------- + list + a list of channel names that are considered outliers. + + Notes + ----- + If you are having trouble converting between exponential notation and + decimal notation, you can use the following code to convert between the two: + + >>> import numpy as np + >>> threshold = 5e-5 + >>> with np.printoptions(suppress=True): + ... print(threshold) + 0.00005 + + .. seealso:: + + :func:`~pylossless.LosslessPipeline.flag_channels_fixed_threshold` to use + this function within the lossless pipeline. + + Examples + -------- + >>> import mne + >>> import pylossless as ll + >>> fname = mne.datasets.sample.data_path() / "MEG/sample/sample_audvis_raw.fif" + >>> raw = mne.io.read_raw(fname, preload=True).pick("eeg") + >>> raw.apply_function(lambda x: x * 3, picks=["EEG 001"]) # Make a noisy channel + >>> epochs = mne.make_fixed_length_epochs(raw, preload=True) + >>> bad_chs = ll.pipeline.find_bads_by_threshold(epochs) + """ + # TODO: We should make this function handle multiple channel types. + # TODO: but I'd like to avoid making a copy of the epochs object + ch_types = np.unique(epochs.get_channel_types()).tolist() + if len(ch_types) > 1: + warn( + f"The epochs object contains multiple channel types: {ch_types}.\n" + " This will likely bias the results of the threshold detection." + " Use the `mne.Epochs.pick` to select a single channel type." + ) + bads = _threshold_volt_std(epochs, flag_dim="ch", threshold=threshold) + return bads + + def _threshold_volt_std(epochs, flag_dim, threshold=5e-5): """Detect epochs or channels whose voltage std is above threshold. @@ -247,7 +305,7 @@ def _threshold_volt_std(epochs, flag_dim, threshold=5e-5): assert len(threshold) == 2 l_out, u_out = threshold init_dir = "both" - elif isinstance(threshold, float): + elif isinstance(threshold, (float, int)): l_out, u_out = (0, threshold) init_dir = "pos" else: @@ -765,9 +823,20 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5, picks="eeg"): on. You may need to assess a more appropriate value for your own data. """ epochs = self.get_epochs(picks=picks) - above_threshold = _threshold_volt_std( - epochs, flag_dim=flag_dim, threshold=threshold - ) + if flag_dim == "ch": + above_threshold = find_bads_by_threshold(epochs, threshold=threshold) + if above_threshold.any(): + logger.info( + f"🚩 Found {len(above_threshold)} channels with " + f"voltage variance above {threshold} volts: {above_threshold}" + ) + else: + msg = f"No channels with standard deviation above {threshold} volts." + logger.info(msg) + else: # TODO: Implement an annotate_bads_by_threshold for epochs + above_threshold = _threshold_volt_std( + epochs, flag_dim=flag_dim, threshold=threshold + ) self.flags[flag_dim].add_flag_cat("volt_std", above_threshold, epochs) def find_outlier_chs(self, epochs=None, picks="eeg"): @@ -834,6 +903,7 @@ def find_outlier_chs(self, epochs=None, picks="eeg"): return mean_ch_dist.ch[mean_ch_dist > mdn + 6 * deviation].values.tolist() + @lossless_logger def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"): """Flag channels based on the stdev value across the time dimension. @@ -845,7 +915,7 @@ def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"): threshold : float threshold, in volts. If the standard deviation across time in any channel x epoch indice is above this threshold, then the - channel x epoch indices will considered an outlier. Defaults + channel x epoch indices will be considered an outlier. Defaults to 5e-5, or 50 microvolts. Note that here, 'time' refers to the samples in an epoch. For each channel, if its std value is above the given threshold in more than 20% of the epochs, it @@ -853,18 +923,42 @@ def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"): picks : str (default "eeg") Type of channels to pick. + Returns + ------- + None + If any channels are flagged, those channel names will be logged + in the `flags` attribute of the `LosslessPipeline` object, + under the key ``'volt_std'``, e.g. + ``my_pipeline.flags["ch"]["volt_std"]``. + Notes ----- - WARNING: the default threshold of 50 microvolts may not be appropriate - for a particular dataset or data file, as the baseline voltage variance - is affected by the impedance of the system that the data was recorded - with. You may need to assess a more appropriate value for your own - data. + .. warning:: + + the default threshold of 50 microvolts may not be appropriate + for a particular dataset or data file, as the baseline voltage variance + is affected by the impedance of the system that the data was recorded + with. You may need to assess a more appropriate value for your own + data. You can use the :func:`~pylossless.pipeline.find_bads_by_threshold` + function to quickly assess a more appropriate threshold. + + .. seealso:: + + :func:`~pylossless.pipeline.find_bads_by_threshold` + + Examples + -------- + >>> import mne + >>> import pylossless as ll + >>> config = ll.Config().load_default() + >>> config["flag_channels_fixed_threshold"] = {"threshold": 5e-5} + >>> pipeline = ll.LosslessPipeline(config=config) + >>> sample_fpath = mne.datasets.sample.data_path() + >>> fpath = sample_fpath / "MEG" / "sample" / "sample_audvis_raw.fif" + >>> raw = mne.io.read_raw(fpath).pick("eeg") + >>> pipeline.raw = raw + >>> pipeline.flag_channels_fixed_threshold() """ - if "flag_channels_fixed_threshold" not in self.config: - return - if "threshold" in self.config["flag_channels_fixed_threshold"]: - threshold = self.config["flag_channels_fixed_threshold"]["threshold"] self._flag_volt_std(flag_dim="ch", threshold=threshold, picks=picks) def flag_epochs_fixed_threshold(self, threshold=5e-5, picks="eeg"): @@ -1268,7 +1362,15 @@ def _run(self): # OPTIONAL: Flag chs/epochs based off fixed std threshold of time axis self.flag_epochs_fixed_threshold(picks=picks) - self.flag_channels_fixed_threshold(picks=picks) + if "flag_channels_fixed_threshold" in self.config: + msg = "Flagging Channels by fixed threshold" + kwargs = dict(picks=picks, message=msg) + if "threshold" in self.config["flag_channels_fixed_threshold"]: + threshold = self.config["flag_channels_fixed_threshold"][ + "threshold" + ] + kwargs["threshold"] = threshold + self.flag_channels_fixed_threshold(**kwargs) # 3.flag channels based on large Stdev. across time msg = "Flagging Noisy Channels" diff --git a/pylossless/tests/test_pipeline.py b/pylossless/tests/test_pipeline.py index 2e334dea..b511c255 100644 --- a/pylossless/tests/test_pipeline.py +++ b/pylossless/tests/test_pipeline.py @@ -1,6 +1,7 @@ from pathlib import Path import mne import mne_bids +import numpy as np import pytest import pylossless as ll @@ -81,6 +82,32 @@ def test_find_outliers(): chs_to_leave_out = pipeline.find_outlier_chs() assert chs_to_leave_out == ['EEG 001'] +def test_find_bads_by_threshold(): + """Test the find bads by threshold function and method.""" + fname = mne.datasets.sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif' + raw = mne.io.read_raw_fif(fname, preload=True) + # Make a noisy channel + raw.apply_function(lambda x: x * 3, picks=["EEG 001"]) + epochs = mne.make_fixed_length_epochs(raw, preload=True) + + # First test the function + with pytest.warns( + RuntimeWarning, match="The epochs object contains multiple channel types" + ): + _ = ll.pipeline.find_bads_by_threshold(epochs) + epochs.pick("eeg") + bads = ll.pipeline.find_bads_by_threshold(epochs) + np.testing.assert_array_equal(bads, ['EEG 001']) + + # Now test the method + config = ll.config.Config().load_default() + pipeline = ll.LosslessPipeline(config=config) + pipeline.raw = raw + pipeline.flag_channels_fixed_threshold(threshold=10_000) # too high + np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], []) + pipeline.flag_channels_fixed_threshold() + np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], ['EEG 001']) + def test_deprecation(): """Test the config_name property added for deprecation."""