diff --git a/pylossless/conftest.py b/pylossless/conftest.py index 2042b943..50dfa27f 100644 --- a/pylossless/conftest.py +++ b/pylossless/conftest.py @@ -19,7 +19,7 @@ @pytest.fixture(scope="session") def pipeline_fixture(): - """Return a namedTuple containing MNE eyetracking raw data and events.""" + """Return a LosslessPipeline object.""" raw, config, bids_path = load_openneuro_bids() # raw.crop(tmin=0, tmax=60) # Too short for ICA to converge in some tests. annots = Annotations( diff --git a/pylossless/flagging.py b/pylossless/flagging.py index 4d8cf66f..46a23982 100644 --- a/pylossless/flagging.py +++ b/pylossless/flagging.py @@ -115,7 +115,11 @@ def add_flag_cat(self, kind, bad_ch_names, *args): logger.debug(f"NEW BAD CHANNELS {bad_ch_names}") if isinstance(bad_ch_names, xr.DataArray): bad_ch_names = bad_ch_names.values - self[kind] = bad_ch_names + if kind in self: + self[kind] = list(np.unique(np.concatenate((self[kind], bad_ch_names)))) + else: + self[kind] = bad_ch_names + def get_flagged(self): """Return a list of channels flagged by the lossless pipeline.""" @@ -138,9 +142,12 @@ def rereference(self, inst, **kwargs): """ # Concatenate and remove duplicates bad_chs = list( - set(self.ll.find_outlier_chs(inst) + self.get_flagged() + inst.info["bads"]) + set(self.ll.find_outlier_chs(inst, picks="eeg") + + self.get_flagged() + + inst.info["bads"]) ) - ref_chans = [ch for ch in inst.copy().pick("eeg").ch_names if ch not in bad_chs] + ref_chans = [ch for ch in inst.copy().pick("eeg").ch_names + if ch not in bad_chs] inst.set_eeg_reference(ref_channels=ref_chans, **kwargs) def save_tsv(self, fname): @@ -223,7 +230,10 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs): The :class:`mne.Epochs` object created from the Raw object that is being assessed by the LosslessPipeline. """ - self[kind] = bad_epoch_inds + if kind in self: + self[kind] = list(np.unique(np.concatenate((self[kind], bad_epoch_inds)))) + else: + self[kind] = bad_epoch_inds self.ll.add_pylossless_annotations(bad_epoch_inds, kind, epochs) def load_from_raw(self, raw, events, config): diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index e3a4aa78..1e261a72 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -650,7 +650,8 @@ def get_events(self): self.raw, duration=tmax - tmin, overlap=overlap ) - def get_epochs(self, detrend=None, preload=True, rereference=True, picks="eeg"): + def get_epochs(self, detrend=None, preload=True, + rereference=True, picks="eeg"): """Create mne.Epochs according to user arguments. Parameters @@ -666,6 +667,8 @@ def get_epochs(self, detrend=None, preload=True, rereference=True, picks="eeg"): preload : bool (default True) Load epochs from disk when creating the object or wait before accessing each epoch (more memory efficient but can be slower). + picks : str (default "eeg") + Type of channels to pick. Returns ------- @@ -687,7 +690,7 @@ def get_epochs(self, detrend=None, preload=True, rereference=True, picks="eeg"): epochs = epochs.pick(picks=picks, exclude="bads").pick( picks=None, exclude=list(self.flags["ch"].get_flagged()) ) - if rereference: + if rereference and picks=="eeg": self.flags["ch"].rereference(epochs) return epochs @@ -725,7 +728,7 @@ def find_breaks(self): breaks = annotate_break(self.raw, **self.config["find_breaks"]) self.raw.set_annotations(breaks + self.raw.annotations) - def _flag_volt_std(self, flag_dim, threshold=5e-5): + def _flag_volt_std(self, flag_dim, threshold=5e-5, picks="eeg"): """Determine if voltage standard deviation is above threshold. Parameters @@ -739,6 +742,8 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5): channel x epoch indices will considered an outlier. Defaults to 5e-5, or 50 microvolts. Note that here, 'time' refers to the samples in an epoch. + picks : str (default "eeg") + Type of channels to pick. Notes ----- @@ -758,20 +763,20 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5): is affected by the impedance of the system that the data was recording on. You may need to assess a more appropriate value for your own data. """ - epochs = self.get_epochs() + epochs = self.get_epochs(picks=picks) above_threshold = _threshold_volt_std( epochs, flag_dim=flag_dim, threshold=threshold ) self.flags[flag_dim].add_flag_cat("volt_std", above_threshold, epochs) - def find_outlier_chs(self, inst): + def find_outlier_chs(self, inst, picks="eeg"): """Detect outlier Channels to leave out of rereference.""" # TODO: Reuse _detect_outliers here. logger.info("🔍 Detecting channels to leave out of reference.") if isinstance(inst, mne.Epochs): epochs = inst elif isinstance(inst, mne.Raw): - epochs = self.get_epochs(rereference=False) + epochs = self.get_epochs(rereference=False, picks=picks) else: raise TypeError( "inst must be an MNE Raw or Epochs object," f" but got {type(inst)}." @@ -797,7 +802,7 @@ def find_outlier_chs(self, inst): return mean_ch_dist.ch[mean_ch_dist > mdn + 6 * deviation].values.tolist() - def flag_channels_fixed_threshold(self, threshold=5e-5): + def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"): """Flag channels based on the stdev value across the time dimension. Flags channels if the voltage-variance standard deviation is above @@ -813,6 +818,8 @@ def flag_channels_fixed_threshold(self, threshold=5e-5): the samples in an epoch. For each channel, if its std value is above the given threshold in more than 20% of the epochs, it is flagged. + picks : str (default "eeg") + Type of channels to pick. Notes ----- @@ -826,9 +833,9 @@ def flag_channels_fixed_threshold(self, threshold=5e-5): return if "threshold" in self.config["flag_channels_fixed_threshold"]: threshold = self.config["flag_channels_fixed_threshold"]["threshold"] - self._flag_volt_std(flag_dim="ch", threshold=threshold) + self._flag_volt_std(flag_dim="ch", threshold=threshold, picks=picks) - def flag_epochs_fixed_threshold(self, threshold=5e-5): + def flag_epochs_fixed_threshold(self, threshold=5e-5, picks="eeg"): """Flag epochs based on the stdev value across the time dimension. Flags an epoch if the voltage-variance standard deviation is above @@ -844,6 +851,8 @@ def flag_epochs_fixed_threshold(self, threshold=5e-5): the samples in an epoch. For each epoch, if the std value of more than 20% of channels (in that epoch) are above the given threshold, the epoch is flagged. + picks : str (default "eeg") + Type of channels to pick. Notes ----- @@ -857,10 +866,10 @@ def flag_epochs_fixed_threshold(self, threshold=5e-5): return if "threshold" in self.config["flag_epochs_fixed_threshold"]: threshold = self.config["flag_epochs_fixed_threshold"]["threshold"] - self._flag_volt_std(flag_dim="epoch", threshold=threshold) + self._flag_volt_std(flag_dim="epoch", threshold=threshold, picks="eeg") @lossless_logger - def flag_noisy_channels(self): + def flag_noisy_channels(self, picks="eeg"): """Flag channels with outlying standard deviation. Calculates the standard deviation of the voltage-variance for @@ -871,10 +880,15 @@ def flag_noisy_channels(self): current epoch) is flagged. If a channel is flagged as an outlier in more than n_percent of epochs (default: 20%), the channel is flagged for removal. + + Parameters + ---------- + picks : str (default "eeg") + Type of channels to pick. """ # TODO: flag "ch_sd" should be renamed "time_sd" # TODO: doc for step 3 and 4 need to be updated - epochs_xr = epochs_to_xr(self.get_epochs(), kind="ch") + epochs_xr = epochs_to_xr(self.get_epochs(picks=picks), kind="ch") data_sd = epochs_xr.std("time") # flag noisy channels @@ -886,10 +900,16 @@ def flag_noisy_channels(self): self.flags["ch"].add_flag_cat(kind="noisy", bad_ch_names=bad_ch_names) @lossless_logger - def flag_noisy_epochs(self): - """Flag epochs with outlying standard deviation.""" + def flag_noisy_epochs(self, picks="eeg"): + """Flag epochs with outlying standard deviation. + + Parameters + ---------- + picks : str (default "eeg") + Type of channels to pick. + """ outlier_methods = ("quantile", "trimmed", "fixed") - epochs = self.get_epochs() + epochs = self.get_epochs(picks=picks) epochs_xr = epochs_to_xr(epochs, kind="ch") data_sd = epochs_xr.std("time") @@ -906,18 +926,29 @@ def flag_noisy_epochs(self): ) self.flags["epoch"].add_flag_cat("noisy", bad_epoch_inds, epochs) - def get_n_nbr(self): - """Calculate nearest neighbour correlation for channels.""" + def get_n_nbr(self, picks="eeg"): + """Calculate nearest neighbour correlation for channels. + + Parameters + ---------- + picks : str (default "eeg") + Type of channels to pick. + """ # Calculate nearest neighbour correlation on # non-flagged channels and epochs... - epochs = self.get_epochs() + epochs = self.get_epochs(picks=picks) n_nbr_ch = self.config["nearest_neighbors"]["n_nbr_ch"] return chan_neighbour_r(epochs, n_nbr_ch, "max"), epochs @lossless_logger - def flag_uncorrelated_channels(self): + def flag_uncorrelated_channels(self, picks="eeg"): """Check neighboring channels for too high or low of a correlation. + Parameters + ---------- + picks : str (default "eeg") + Type of channels to pick. + Returns ------- data array : `numpy.array` @@ -925,7 +956,7 @@ def flag_uncorrelated_channels(self): """ # Calculate nearest neighbour correlation on # non-flagged channels and epochs... - data_r_ch = self.get_n_nbr()[0] + data_r_ch = self.get_n_nbr(picks=picks)[0] # Create the window criteria vector for flagging low_r chan_info... bad_ch_names = _detect_outliers( @@ -995,9 +1026,14 @@ def flag_rank_channel(self, data_r_ch): self.flags["ch"].add_flag_cat(kind="rank", bad_ch_names=bad_ch_names) @lossless_logger - def flag_uncorrelated_epochs(self): + def flag_uncorrelated_epochs(self, picks="eeg"): """Flag epochs where too many channels are uncorrelated. + Parameters + ---------- + picks : str (default "eeg") + Type of channels to pick. + Notes ----- Similarly to the neighbor r calculation done between channels this @@ -1006,7 +1042,7 @@ def flag_uncorrelated_epochs(self): """ # Calculate nearest neighbour correlation on # non-flagged channels and epochs... - data_r_ch, epochs = self.get_n_nbr() + data_r_ch, epochs = self.get_n_nbr(picks=picks) bad_epoch_inds = _detect_outliers( data_r_ch, @@ -1017,7 +1053,7 @@ def flag_uncorrelated_epochs(self): self.flags["epoch"].add_flag_cat("uncorrelated", bad_epoch_inds, epochs) @lossless_logger - def run_ica(self, run): + def run_ica(self, run, picks="eeg"): """Run ICA. Parameters @@ -1026,6 +1062,8 @@ def run_ica(self, run): Must be 'run1' or 'run2'. 'run1' is the initial ICA use to flag epochs, 'run2' is the final ICA used to classify components with `mne_icalabel`. + picks : str (default "eeg") + Type of channels to pick. """ ica_kwargs = self.config["ica"]["ica_args"][run] if "max_iter" not in ica_kwargs: @@ -1033,7 +1071,7 @@ def run_ica(self, run): if "random_state" not in ica_kwargs: ica_kwargs["random_state"] = 97 - epochs = self.get_epochs() + epochs = self.get_epochs(picks=picks) if run == "run1": self.ica1 = ICA(**ica_kwargs) self.ica1.fit(epochs) @@ -1041,18 +1079,24 @@ def run_ica(self, run): elif run == "run2": self.ica2 = ICA(**ica_kwargs) self.ica2.fit(epochs) - self.flags["ic"].label_components(epochs, self.ica2) + if picks == "eeg": + self.flags["ic"].label_components(epochs, self.ica2) else: raise ValueError("The `run` argument must be 'run1' or 'run2'") @lossless_logger - def flag_noisy_ics(self): + def flag_noisy_ics(self, picks="eeg"): """Calculate the IC standard Deviation by epoch window. Flags windows with too many ICs with outlying standard deviations. + + Parameters + ---------- + picks : str (default "eeg") + Type of channels to pick. """ # Calculate IC sd by window - epochs = self.get_epochs() + epochs = self.get_epochs(picks=picks) epochs_xr = epochs_to_xr(epochs, kind="ic", ica=self.ica1) data_sd = epochs_xr.std("time") @@ -1173,51 +1217,77 @@ def run_with_raw(self, raw): @lossless_time def _run(self): + # Make sure sampling frequency is an integer self._check_sfreq() - self.set_montage() - # 1. Execute the staging script if specified. - self.run_staging_script() + if "modality" not in self.config: + self.config["modality"] = ["eeg"] + if isinstance(self.config["modality"], str): + self.config["modality"] = [self.config["modality"]] + + for picks in self.config["modality"]: + # 1. Execute the staging script if specified. + self.run_staging_script() + + # find breaks + self.find_breaks(message="Looking for break periods between tasks") + + # OPTIONAL: Flag chs/epochs based off fixed std threshold of time axis + self.flag_epochs_fixed_threshold(picks=picks) + self.flag_channels_fixed_threshold(picks=picks) - # find breaks - self.find_breaks(message="Looking for break periods between tasks") + # 3.flag channels based on large Stdev. across time + msg = "Flagging Noisy Channels" + self.flag_noisy_channels(message=msg, picks=picks) - # OPTIONAL: Flag chs/epochs based off fixed std threshold of time axis - self.flag_epochs_fixed_threshold() - self.flag_channels_fixed_threshold() + # 4.flag epochs based on large Channel Stdev. across time + msg = "Flagging Noisy Time periods" + self.flag_noisy_epochs(message=msg, picks=picks) - # 3.flag channels based on large Stdev. across time - self.flag_noisy_channels(message="Flagging Noisy Channels") + # 5. Filtering + self.filter(message="Filtering") - # 4.flag epochs based on large Channel Stdev. across time - self.flag_noisy_epochs(message="Flagging Noisy Time periods") + if picks == "eeg": + # These steps are relevant only for EEG. For example, + # MEG channels don't get bridged or high impedance. + # Further, MEG doesn't have a montage so + # flag_uncorrelated_channels would crash. We could use + # mne.channels.read_layout() for MEG channel if we need + # at some point to implement such a functionality. But it + # seems irrelevant. - # 5. Filtering - self.filter(message="Filtering") + # 6. calculate nearest neighbort r values + msg = "Flagging uncorrelated channels" + data_r_ch = self.flag_uncorrelated_channels(message=msg, picks=picks) - # 6. calculate nearest neighbort r values - msg = "Flagging uncorrelated channels" - data_r_ch = self.flag_uncorrelated_channels(message=msg) + # 7. Identify bridged channels + msg = "Flagging Bridged channels" + self.flag_bridged_channels(data_r_ch, message=msg) - # 7. Identify bridged channels - self.flag_bridged_channels(data_r_ch, message="Flagging Bridged channels") + # 8. Flag rank channels + self.flag_rank_channel(data_r_ch, message="Flagging the rank channel") - # 8. Flag rank channels - self.flag_rank_channel(data_r_ch, message="Flagging the rank channel") + if picks == "eeg": + # 9. Calculate nearest neighbour R values for epochs + msg = "Flagging Uncorrelated epochs" + self.flag_uncorrelated_epochs(message=msg, picks=picks) - # 9. Calculate nearest neighbour R values for epochs - self.flag_uncorrelated_epochs(message="Flagging Uncorrelated epochs") + if self.config["ica"] is None: + # Skip ICA steps. + continue - # 10. Run ICA - self.run_ica("run1", message="Running Initial ICA") + # 10. Run ICA + self.run_ica("run1", message="Running Initial ICA", picks=picks) - # 11. Calculate IC SD - self.flag_noisy_ics(message="Flagging time periods with noisy IC's.") + # 11. Calculate IC SD + msg = "Flagging time periods with noisy IC's." + self.flag_noisy_ics(message=msg, picks=picks) - # 12. TODO: integrate labels from IClabels to self.flags["ic"] - self.run_ica("run2", message="Running Final ICA and ICLabel.") + # 12. TODO: integrate labels from IClabels to self.flags["ic"] + msg = "Running Final ICA and ICLabel." + self.run_ica("run2", message=msg, picks=picks) def run_dataset(self, paths): """Run a full dataset. diff --git a/pylossless/tests/test_pipeline.py b/pylossless/tests/test_pipeline.py index 2ee07c00..2686aa1e 100644 --- a/pylossless/tests/test_pipeline.py +++ b/pylossless/tests/test_pipeline.py @@ -80,6 +80,22 @@ def test_deprecation(): pipeline.config_fname = pipeline.config_fname +def test_multimodality(): + """Test running the pipeline on a multimodal (EEG, MEG) dataset.""" + fname = mne.datasets.sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif' + raw = mne.io.read_raw_fif(fname, preload=True) + raw.crop(tmin=0, tmax=60) + + config = ll.config.Config() + config.load_default() + config["modality"] = ["eeg", "meg"] + config["ica"] = None + pipeline = ll.LosslessPipeline(config=config) + pipeline.run_with_raw(raw) + + assert pipeline.flags["ch"]["noisy"] == ['EEG 007', 'MEG 1032'] + + @pytest.mark.filterwarnings("ignore:Converting data files to EDF format") def test_load_flags(pipeline_fixture, tmp_path): """Test running the pipeline."""