Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pylossless/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pylossless as ll


def load_openneuro_bids(subject="pd6"):
def load_openneuro_bids(subject="pd6", timeout=20):
"""Download and Load BIDS dataset ds002778 from OpenNeuro.

Parameters
Expand Down Expand Up @@ -72,8 +72,14 @@
root=bids_root,
)

while not bids_path.fpath.with_suffix(".bdf").exists():
print(list(bids_path.fpath.glob("*")))
for _ in range(timeout):
if bids_path.fpath.with_suffix(".bdf").exists():
break
print("Waiting for .bdf files to be created. Current files available:",

Check warning on line 78 in pylossless/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

pylossless/datasets/datasets.py#L78

Added line #L78 was not covered by tests
list(bids_path.fpath.glob("*")))
sleep(1)
else:
raise TimeoutError("OpenNeuro failed to create the .bdf files.")

Check warning on line 82 in pylossless/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

pylossless/datasets/datasets.py#L82

Added line #L82 was not covered by tests

raw = mne_bids.read_raw_bids(bids_path, verbose="ERROR")
return raw, config, bids_path
98 changes: 70 additions & 28 deletions pylossless/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,53 @@

from .utils._utils import _icalabel_to_data_frame

IC_LABELS = mne_icalabel.config.ICA_LABELS_TO_MNE
CH_LABELS: dict[str, str] = {
"Noisy": "ch_sd",
"Bridged": "bridge",
"Uncorrelated": "low_r",
"Rank": "rank"
}
EPOCH_LABELS: dict[str, str] = {
"Noisy": "noisy",
"Noisy ICs": "noisy_ICs",
"Uncorrelated": "uncorrelated",
}


class _Flagged(dict):

def __init__(self, key_map, kind_str, ll, *args, **kwargs):
"""Initialize class."""
super().__init__(*args, **kwargs)
self.ll = ll
self._key_map = key_map
self._kind_str = kind_str

@property
def valid_keys(self):
"""Return the valid keys."""
return tuple(self._key_map.values())

def __repr__(self):
"""Return a string representation."""
ret_str = f"Flagged {self._kind_str}s: |\n"
for key, val in self._key_map.items():
ret_str += f" {key}: {self.get(val, None)}\n"
return ret_str

def __eq__(self, other):
for key in self.valid_keys:
if not np.array_equal(self.get(key, np.array([])),
other.get(key, np.array([]))):
return False
return True

def __ne__(self, other):
return not self == other

class FlaggedChs(dict):

class FlaggedChs(_Flagged):
"""Object for handling flagged channels in an instance of mne.io.Raw.

Attributes
Expand Down Expand Up @@ -47,28 +92,17 @@ class FlaggedChs(dict):
and methods for python dictionaries.
"""

def __init__(self, ll, *args, **kwargs):
def __init__(self, *args, **kwargs):
"""Initialize class."""
super().__init__(*args, **kwargs)
self.ll = ll

def __repr__(self):
"""Return a string representation of the FlaggedChs object."""
return (
f"Flagged channels: |\n"
f" Noisy: {self.get('ch_sd', None)}\n"
f" Bridged: {self.get('bridge', None)}\n"
f" Uncorrelated: {self.get('low_r', None)}\n"
f" Rank: {self.get('rank', None)}\n"
)
super().__init__(CH_LABELS, "channel", *args, **kwargs)

def add_flag_cat(self, kind, bad_ch_names, *args):
"""Store channel names that have been flagged by pipeline.

Parameters
----------
kind : str
Should be one of ``'outlier'``, ``'ch_sd'``, ``'low_r'``,
Should be one of ``'ch_sd'``, ``'low_r'``,
``'bridge'``, ``'rank'``.
bad_ch_names : list | tuple
Channel names. Will be the values corresponding to the ``kind``
Expand Down Expand Up @@ -140,7 +174,7 @@ def load_tsv(self, fname):
self[label] = grp_df.ch_names.values


class FlaggedEpochs(dict):
class FlaggedEpochs(_Flagged):
"""Object for handling flagged Epochs in an instance of mne.Epochs.

Methods
Expand All @@ -159,7 +193,7 @@ class FlaggedEpochs(dict):
and methods for python dictionaries.
"""

def __init__(self, ll, *args, **kwargs):
def __init__(self, *args, **kwargs):
"""Initialize class.

Parameters
Expand All @@ -171,9 +205,7 @@ def __init__(self, ll, *args, **kwargs):
kwargs : dict
keyword arguments accepted by python's dictionary class.
"""
super().__init__(*args, **kwargs)

self.ll = ll
super().__init__(EPOCH_LABELS, "epoch", *args, **kwargs)

def add_flag_cat(self, kind, bad_epoch_inds, epochs):
"""Add information on time periods flagged by pyLossless.
Expand All @@ -194,17 +226,27 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs):
self[kind] = bad_epoch_inds
self.ll.add_pylossless_annotations(bad_epoch_inds, kind, epochs)

def load_from_raw(self, raw):
def load_from_raw(self, raw, events, config):
"""Load pylossless annotations from raw object."""
sfreq = raw.info["sfreq"]
tmax = config["epoching"]["epochs_args"]["tmax"]
tmin = config["epoching"]["epochs_args"]["tmin"]
starts = events[:, 0] / sfreq - tmin
stops = events[:, 0] / sfreq + tmax
for annot in raw.annotations:
if annot["description"].upper().startswith("BAD_LL"):
ind_onset = int(np.round(annot["onset"] * sfreq))
ind_dur = int(np.round(annot["duration"] * sfreq))
inds = np.arange(ind_onset, ind_onset + ind_dur)
if annot["description"] not in self:
self[annot["description"]] = list()
self[annot["description"]].append(inds)
if annot["description"].upper().startswith("BAD_LL_"):
onset = annot["onset"]
offset = annot["onset"] + annot["duration"]
mask = (
(starts >= onset) & (starts < offset)
| (stops > onset) & (stops <= offset)
| (onset <= starts) & (offset >= stops)
)
inds = np.where(mask)[0]
desc = annot["description"].lower().replace("bad_ll_", "")
if desc not in self:
self[desc] = np.array([])
self[desc] = np.concatenate((self[desc], inds))


class FlaggedICs(pd.DataFrame):
Expand Down
17 changes: 14 additions & 3 deletions pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,9 +1216,20 @@ def run_dataset(self, paths):
for path in paths:
self.run(path)

# TODO: Finish docstring
def load_ll_derivative(self, derivatives_path):
"""Load a completed pylossless derivative state."""
"""Load a completed pylossless derivative state.

Parameters
----------
derivatives_path : str | mne_bids.BIDSPath
Path to a saved pylossless derivatives.

Returns
-------
:class:`~pylossless.pipeline.LosslessPipeline`
Returns an instance of :class:`~pylossless.pipeline.LosslessPipeline`
for the loaded pylossless derivative state.
"""
if not isinstance(derivatives_path, BIDSPath):
derivatives_path = get_bids_path_from_fname(derivatives_path)
self.raw = mne_bids.read_raw_bids(derivatives_path)
Expand Down Expand Up @@ -1247,7 +1258,7 @@ def load_ll_derivative(self, derivatives_path):
self.flags["ch"].load_tsv(flagged_chs_fpath.fpath)

# Load Flagged Epochs
self.flags["epoch"].load_from_raw(self.raw)
self.flags["epoch"].load_from_raw(self.raw, self.get_events(), self.config)

return self

Expand Down
35 changes: 34 additions & 1 deletion pylossless/tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path

import mne_bids
import pytest

import pylossless as ll
Expand Down Expand Up @@ -54,3 +54,36 @@ def test_deprecation():
# with pytest.raises(DeprecationWarning, match=f"config_fname is deprecated"):
# DeprecationWarning are currently ignored by pytest given our toml file
pipeline.config_fname = pipeline.config_fname


@pytest.mark.filterwarnings("ignore:Converting data files to EDF format")
def test_load_flags(pipeline_fixture, tmp_path):
"""Test running the pipeline."""
bids_root = tmp_path / "derivatives" / "pylossless"
print(bids_root)

subject = "pd6"
datatype = "eeg"
session = "off"
task = "rest"
suffix = "eeg"
bids_path = mne_bids.BIDSPath(
subject=subject,
session=session,
task=task,
suffix=suffix,
datatype=datatype,
root=bids_root
)

pipeline_fixture.save(bids_path,
overwrite=False, format="EDF", event_id=None)
pipeline = ll.LosslessPipeline().load_ll_derivative(bids_path)

assert pipeline_fixture.flags['ch'] == pipeline.flags['ch']
pipeline.flags['ch']["bridge"] = ["xx"]
assert pipeline_fixture.flags['ch'] != pipeline.flags['ch']

assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']
pipeline.flags['epoch']["bridge"] = ["noisy"]
assert pipeline_fixture.flags['epoch'] == pipeline.flags['epoch']
Loading