diff --git a/pylossless/datasets/datasets.py b/pylossless/datasets/datasets.py index abd63936..1e6b0f15 100644 --- a/pylossless/datasets/datasets.py +++ b/pylossless/datasets/datasets.py @@ -6,7 +6,7 @@ import pylossless as ll -def load_openneuro_bids(subject="pd6"): +def load_openneuro_bids(subject="pd6", timeout=20): """Download and Load BIDS dataset ds002778 from OpenNeuro. Parameters @@ -72,8 +72,14 @@ def load_openneuro_bids(subject="pd6"): root=bids_root, ) - while not bids_path.fpath.with_suffix(".bdf").exists(): - print(list(bids_path.fpath.glob("*"))) + for _ in range(timeout): + if bids_path.fpath.with_suffix(".bdf").exists(): + break + print("Waiting for .bdf files to be created. Current files available:", + list(bids_path.fpath.glob("*"))) sleep(1) + else: + raise TimeoutError("OpenNeuro failed to create the .bdf files.") + raw = mne_bids.read_raw_bids(bids_path, verbose="ERROR") return raw, config, bids_path diff --git a/pylossless/flagging.py b/pylossless/flagging.py index 02086dcd..4d8cf66f 100644 --- a/pylossless/flagging.py +++ b/pylossless/flagging.py @@ -17,8 +17,53 @@ from .utils._utils import _icalabel_to_data_frame +IC_LABELS = mne_icalabel.config.ICA_LABELS_TO_MNE +CH_LABELS: dict[str, str] = { + "Noisy": "ch_sd", + "Bridged": "bridge", + "Uncorrelated": "low_r", + "Rank": "rank" +} +EPOCH_LABELS: dict[str, str] = { + "Noisy": "noisy", + "Noisy ICs": "noisy_ICs", + "Uncorrelated": "uncorrelated", +} + + +class _Flagged(dict): + + def __init__(self, key_map, kind_str, ll, *args, **kwargs): + """Initialize class.""" + super().__init__(*args, **kwargs) + self.ll = ll + self._key_map = key_map + self._kind_str = kind_str + + @property + def valid_keys(self): + """Return the valid keys.""" + return tuple(self._key_map.values()) + + def __repr__(self): + """Return a string representation.""" + ret_str = f"Flagged {self._kind_str}s: |\n" + for key, val in self._key_map.items(): + ret_str += f" {key}: {self.get(val, None)}\n" + return ret_str + + def __eq__(self, other): + for key in self.valid_keys: + if not np.array_equal(self.get(key, np.array([])), + other.get(key, np.array([]))): + return False + return True + + def __ne__(self, other): + return not self == other -class FlaggedChs(dict): + +class FlaggedChs(_Flagged): """Object for handling flagged channels in an instance of mne.io.Raw. Attributes @@ -47,20 +92,9 @@ class FlaggedChs(dict): and methods for python dictionaries. """ - def __init__(self, ll, *args, **kwargs): + def __init__(self, *args, **kwargs): """Initialize class.""" - super().__init__(*args, **kwargs) - self.ll = ll - - def __repr__(self): - """Return a string representation of the FlaggedChs object.""" - return ( - f"Flagged channels: |\n" - f" Noisy: {self.get('ch_sd', None)}\n" - f" Bridged: {self.get('bridge', None)}\n" - f" Uncorrelated: {self.get('low_r', None)}\n" - f" Rank: {self.get('rank', None)}\n" - ) + super().__init__(CH_LABELS, "channel", *args, **kwargs) def add_flag_cat(self, kind, bad_ch_names, *args): """Store channel names that have been flagged by pipeline. @@ -68,7 +102,7 @@ def add_flag_cat(self, kind, bad_ch_names, *args): Parameters ---------- kind : str - Should be one of ``'outlier'``, ``'ch_sd'``, ``'low_r'``, + Should be one of ``'ch_sd'``, ``'low_r'``, ``'bridge'``, ``'rank'``. bad_ch_names : list | tuple Channel names. Will be the values corresponding to the ``kind`` @@ -140,7 +174,7 @@ def load_tsv(self, fname): self[label] = grp_df.ch_names.values -class FlaggedEpochs(dict): +class FlaggedEpochs(_Flagged): """Object for handling flagged Epochs in an instance of mne.Epochs. Methods @@ -159,7 +193,7 @@ class FlaggedEpochs(dict): and methods for python dictionaries. """ - def __init__(self, ll, *args, **kwargs): + def __init__(self, *args, **kwargs): """Initialize class. Parameters @@ -171,9 +205,7 @@ def __init__(self, ll, *args, **kwargs): kwargs : dict keyword arguments accepted by python's dictionary class. """ - super().__init__(*args, **kwargs) - - self.ll = ll + super().__init__(EPOCH_LABELS, "epoch", *args, **kwargs) def add_flag_cat(self, kind, bad_epoch_inds, epochs): """Add information on time periods flagged by pyLossless. @@ -194,17 +226,27 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs): self[kind] = bad_epoch_inds self.ll.add_pylossless_annotations(bad_epoch_inds, kind, epochs) - def load_from_raw(self, raw): + def load_from_raw(self, raw, events, config): """Load pylossless annotations from raw object.""" sfreq = raw.info["sfreq"] + tmax = config["epoching"]["epochs_args"]["tmax"] + tmin = config["epoching"]["epochs_args"]["tmin"] + starts = events[:, 0] / sfreq - tmin + stops = events[:, 0] / sfreq + tmax for annot in raw.annotations: - if annot["description"].upper().startswith("BAD_LL"): - ind_onset = int(np.round(annot["onset"] * sfreq)) - ind_dur = int(np.round(annot["duration"] * sfreq)) - inds = np.arange(ind_onset, ind_onset + ind_dur) - if annot["description"] not in self: - self[annot["description"]] = list() - self[annot["description"]].append(inds) + if annot["description"].upper().startswith("BAD_LL_"): + onset = annot["onset"] + offset = annot["onset"] + annot["duration"] + mask = ( + (starts >= onset) & (starts < offset) + | (stops > onset) & (stops <= offset) + | (onset <= starts) & (offset >= stops) + ) + inds = np.where(mask)[0] + desc = annot["description"].lower().replace("bad_ll_", "") + if desc not in self: + self[desc] = np.array([]) + self[desc] = np.concatenate((self[desc], inds)) class FlaggedICs(pd.DataFrame): diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index 7956e0a6..3cb6953e 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -1216,9 +1216,20 @@ def run_dataset(self, paths): for path in paths: self.run(path) - # TODO: Finish docstring def load_ll_derivative(self, derivatives_path): - """Load a completed pylossless derivative state.""" + """Load a completed pylossless derivative state. + + Parameters + ---------- + derivatives_path : str | mne_bids.BIDSPath + Path to a saved pylossless derivatives. + + Returns + ------- + :class:`~pylossless.pipeline.LosslessPipeline` + Returns an instance of :class:`~pylossless.pipeline.LosslessPipeline` + for the loaded pylossless derivative state. + """ if not isinstance(derivatives_path, BIDSPath): derivatives_path = get_bids_path_from_fname(derivatives_path) self.raw = mne_bids.read_raw_bids(derivatives_path) @@ -1247,7 +1258,7 @@ def load_ll_derivative(self, derivatives_path): self.flags["ch"].load_tsv(flagged_chs_fpath.fpath) # Load Flagged Epochs - self.flags["epoch"].load_from_raw(self.raw) + self.flags["epoch"].load_from_raw(self.raw, self.get_events(), self.config) return self diff --git a/pylossless/tests/test_pipeline.py b/pylossless/tests/test_pipeline.py index a9b3619b..cc76d213 100644 --- a/pylossless/tests/test_pipeline.py +++ b/pylossless/tests/test_pipeline.py @@ -1,5 +1,5 @@ from pathlib import Path - +import mne_bids import pytest import pylossless as ll @@ -54,3 +54,36 @@ def test_deprecation(): # with pytest.raises(DeprecationWarning, match=f"config_fname is deprecated"): # DeprecationWarning are currently ignored by pytest given our toml file pipeline.config_fname = pipeline.config_fname + + +@pytest.mark.filterwarnings("ignore:Converting data files to EDF format") +def test_load_flags(pipeline_fixture, tmp_path): + """Test running the pipeline.""" + bids_root = tmp_path / "derivatives" / "pylossless" + print(bids_root) + + subject = "pd6" + datatype = "eeg" + session = "off" + task = "rest" + suffix = "eeg" + bids_path = mne_bids.BIDSPath( + subject=subject, + session=session, + task=task, + suffix=suffix, + datatype=datatype, + root=bids_root + ) + + pipeline_fixture.save(bids_path, + overwrite=False, format="EDF", event_id=None) + pipeline = ll.LosslessPipeline().load_ll_derivative(bids_path) + + assert pipeline_fixture.flags['ch'] == pipeline.flags['ch'] + pipeline.flags['ch']["bridge"] = ["xx"] + assert pipeline_fixture.flags['ch'] != pipeline.flags['ch'] + + assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch'] + pipeline.flags['epoch']["bridge"] = ["noisy"] + assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']