diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index efca737b..45931cf0 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -770,18 +770,49 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5, picks="eeg"): ) self.flags[flag_dim].add_flag_cat("volt_std", above_threshold, epochs) - def find_outlier_chs(self, inst, picks="eeg"): - """Detect outlier Channels to leave out of rereference.""" + def find_outlier_chs(self, epochs=None, picks="eeg"): + """Detect outlier Channels to leave out of rereference. + + Parameters + ---------- + epochs : mne.Epochs | None + An instance of :class:`mne.Epochs`, or ``None``. If ``None``, then + :attr:`pylossless.LosslessPipeline.raw` should be set, and this + method will call :meth:`pylossless.LosslessPipeline.get_epochs` + to create epochs to use for outlier detection. + picks : str (default "eeg") + Channels to include in the outlier detection process. You can pass any + argument that is valid for the :meth:`~mne.Epochs.pick` method, but + you should avoid passing a mix of channel types with differing units of + measurement (e.g. EEG and MEG), as this would likely lead to incorrect + outlier detection (e.g. all EEG channels would be flagged as outliers). + + Returns + ------- + list + a list of channel names that are considered outliers. + + Notes + ----- + - This method is used to detect channels that are so noisy that they + should be left out of the robust average rereference process. + + Examples + -------- + >>> import mne + >>> import pylossless as ll + >>> config = ll.Config().load_default() + >>> pipeline = ll.LosslessPipeline(config=config) + >>> fname = mne.datasets.sample.data_path() / "MEG/sample/sample_audvis_raw.fif" + >>> raw = mne.io.read_raw(fname) + >>> epochs = mne.make_fixed_length_epochs(raw, preload=True) + >>> chs_to_leave_out = pipeline.find_outlier_chs(epochs=epochs) + """ # TODO: Reuse _detect_outliers here. logger.info("🔍 Detecting channels to leave out of reference.") - if isinstance(inst, mne.Epochs): - epochs = inst - elif isinstance(inst, mne.io.Raw): - epochs = self.get_epochs(rereference=False, picks=picks) - else: - raise TypeError( - "inst must be an MNE Raw or Epochs object," f" but got {type(inst)}." - ) + if epochs is None: + epochs = self.get_epochs(rereference=False) + epochs = epochs.copy().pick(picks=picks) epochs_xr = epochs_to_xr(epochs, kind="ch") # Determines comically bad channels, diff --git a/pylossless/tests/test_pipeline.py b/pylossless/tests/test_pipeline.py index a3538f03..2e334dea 100644 --- a/pylossless/tests/test_pipeline.py +++ b/pylossless/tests/test_pipeline.py @@ -70,6 +70,18 @@ def test_find_breaks(logging): Path(config_fname).unlink() # delete config file +def test_find_outliers(): + """Test the find_outliers method for the case that epochs is None.""" + fname = mne.datasets.sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif' + raw = mne.io.read_raw_fif(fname, preload=True) + raw.apply_function(lambda x: x * 10, picks="EEG 001") # create an outlier + config = ll.config.Config().load_default() + pipeline = ll.LosslessPipeline(config=config) + pipeline.raw = raw + chs_to_leave_out = pipeline.find_outlier_chs() + assert chs_to_leave_out == ['EEG 001'] + + def test_deprecation(): """Test the config_name property added for deprecation.""" config = ll.config.Config()