Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions pylossless/config/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# License: MIT

import numpy as np
from importlib.metadata import version
import warnings

from .config import ConfigMixin

Expand Down Expand Up @@ -90,7 +92,7 @@ def __init__(
)

def __repr__(self):
"""Return a summary of the Calibration object."""
"""Return a summary of the RejectionPolicy object."""
return (
f"RejectionPolicy: |\n"
f" config_fname: {self['config_fname']}\n"
Expand All @@ -101,7 +103,7 @@ def __repr__(self):
f" remove_flagged_ics: {self['remove_flagged_ics']}\n"
)

def apply(self, pipeline, return_ica=False):
def apply(self, pipeline, return_ica=False, version_mismatch="raise"):
"""Return a cleaned new raw object based on the rejection policy.

Parameters
Expand All @@ -119,6 +121,21 @@ def apply(self, pipeline, return_ica=False):
An :class:`~mne.io.Raw` instance with the appropriate channels and ICs
added to mne bads, interpolated, or dropped.
"""
if pipeline.config["version"] != version("pylossless"):
error_message = (
"The output of the pipeline was saved with pylossless version "
f"{pipeline.config['version']} and you are currently using "
f"version {version('pylossless')}. The behavior is undefined."
)
if version_mismatch == "raise":
raise RuntimeError(error_message)
elif version_mismatch == "warning":
warnings.warn(error_message, RuntimeWarning)
elif version_mismatch != "ignore":
raise ValueError("version_mismatch can take values 'raise', "
"'warning', or 'ignore'. Received "
f"{version_mismatch}.")

# Get the raw object
raw = pipeline.raw.copy()

Expand Down
13 changes: 13 additions & 0 deletions pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from copy import deepcopy
from pathlib import Path
from functools import partial
from importlib.metadata import version

# Math and data structures
import numpy as np
Expand Down Expand Up @@ -461,6 +462,8 @@ def __init__(self, config_path=None, config=None):
"epoch": FlaggedEpochs(self),
"ic": FlaggedICs(),
}
self._config = None

if config:
self.config = config
if config_path is None:
Expand Down Expand Up @@ -526,6 +529,15 @@ def _repr_html_(self):

return html

@property
def config(self):
return self._config

@config.setter
def config(self, config):
self._config = config
self._config["version"] = version("pylossless")

@property
def config_fname(self):
warn('config_fname is deprecated and will be removed from future versions.',
Expand Down Expand Up @@ -1094,6 +1106,7 @@ def save(self, derivatives_path, overwrite=False, format="EDF", event_id=None):
config_bidspath = bpath.update(
extension=".yaml", suffix="ll_config", check=False
)

self.config.save(config_bidspath)

# Save flag["ch"]
Expand Down
14 changes: 13 additions & 1 deletion pylossless/tests/test_rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,19 @@ def test_rejection_policy(clean_ch_mode, pipeline_fixture):
want_flags = ["noisy", "uncorrelated", "bridged"]
assert rejection_config["ch_flags_to_reject"] == want_flags

raw, ica = rejection_config.apply(pipeline_fixture, return_ica=True)
pipeline_fixture.config["version"] = "-1"
with pytest.raises(RuntimeError, match="The output of the pipeline was"):
raw, ica = rejection_config.apply(pipeline_fixture,
version_mismatch="raise")
with pytest.raises(RuntimeWarning, match="The output of the pipeline was"):
raw, ica = rejection_config.apply(pipeline_fixture,
version_mismatch="warning")
with pytest.raises(ValueError, match="version_mismatch can take values"):
raw, ica = rejection_config.apply(pipeline_fixture,
version_mismatch="sdfdf")
raw, ica = rejection_config.apply(pipeline_fixture, return_ica=True,
version_mismatch="ignore")

flagged_chs = []
for key in rejection_config["ch_flags_to_reject"]:
flagged_chs.extend(pipeline_fixture.flags["ch"][key].tolist())
Expand Down
Loading