Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pylossless/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@pytest.fixture(scope="session")
def pipeline_fixture():
"""Return a namedTuple containing MNE eyetracking raw data and events."""
"""Return a LosslessPipeline object."""
raw, config, bids_path = load_openneuro_bids()
# raw.crop(tmin=0, tmax=60) # Too short for ICA to converge in some tests.
annots = Annotations(
Expand Down
18 changes: 14 additions & 4 deletions pylossless/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ def add_flag_cat(self, kind, bad_ch_names, *args):
logger.debug(f"NEW BAD CHANNELS {bad_ch_names}")
if isinstance(bad_ch_names, xr.DataArray):
bad_ch_names = bad_ch_names.values
self[kind] = bad_ch_names
if kind in self:
self[kind] = list(np.unique(np.concatenate((self[kind], bad_ch_names))))
else:
self[kind] = bad_ch_names


def get_flagged(self):
"""Return a list of channels flagged by the lossless pipeline."""
Expand All @@ -138,9 +142,12 @@ def rereference(self, inst, **kwargs):
"""
# Concatenate and remove duplicates
bad_chs = list(
set(self.ll.find_outlier_chs(inst) + self.get_flagged() + inst.info["bads"])
set(self.ll.find_outlier_chs(inst, picks="eeg") +
self.get_flagged() +
inst.info["bads"])
)
ref_chans = [ch for ch in inst.copy().pick("eeg").ch_names if ch not in bad_chs]
ref_chans = [ch for ch in inst.copy().pick("eeg").ch_names
if ch not in bad_chs]
inst.set_eeg_reference(ref_channels=ref_chans, **kwargs)

def save_tsv(self, fname):
Expand Down Expand Up @@ -223,7 +230,10 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs):
The :class:`mne.Epochs` object created from the Raw object that is
being assessed by the LosslessPipeline.
"""
self[kind] = bad_epoch_inds
if kind in self:
self[kind] = list(np.unique(np.concatenate((self[kind], bad_epoch_inds))))
else:
self[kind] = bad_epoch_inds
self.ll.add_pylossless_annotations(bad_epoch_inds, kind, epochs)

def load_from_raw(self, raw, events, config):
Expand Down
Loading
Loading