Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pylossless/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def pipeline_fixture():
config["find_breaks"]["min_break_duration"] = 9
config["find_breaks"]["t_start_after_previous"] = 1
config["find_breaks"]["t_stop_before_next"] = 0
config["flag_channels_fixed_threshold"] = {"threshold": 10_000}
config["ica"]["ica_args"]["run1"]["max_iter"] = 5000

# Testing when passing the config object directly...
Expand Down
132 changes: 117 additions & 15 deletions pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,64 @@
return prop_outliers[prop_outliers > flag_crit].coords.to_index().values


def find_bads_by_threshold(epochs, threshold=5e-5):
"""Return channels with a standard deviation consistently above a fixed threshold.

Parameters
----------
epochs : mne.Epochs
an instance of mne.Epochs with a single channel type.
threshold : float
the threshold in volts. If the standard deviation of a channel's voltage
variance at a specific epoch is above the threshold, then that channel x epoch
will be flagged as an "outlier". If more than 20% of epochs are flagged as
outliers for a specific channel, then that channel will be flagged as bad.
Default threshold is 5e-5 (0.00005), i.e. 50 microvolts.

Returns
-------
list
a list of channel names that are considered outliers.

Notes
-----
If you are having trouble converting between exponential notation and
decimal notation, you can use the following code to convert between the two:

>>> import numpy as np
>>> threshold = 5e-5
>>> with np.printoptions(suppress=True):
... print(threshold)
0.00005

.. seealso::

:func:`~pylossless.LosslessPipeline.flag_channels_fixed_threshold` to use
this function within the lossless pipeline.

Examples
--------
>>> import mne
>>> import pylossless as ll
>>> fname = mne.datasets.sample.data_path() / "MEG/sample/sample_audvis_raw.fif"
>>> raw = mne.io.read_raw(fname, preload=True).pick("eeg")
>>> raw.apply_function(lambda x: x * 3, picks=["EEG 001"]) # Make a noisy channel
>>> epochs = mne.make_fixed_length_epochs(raw, preload=True)
>>> bad_chs = ll.pipeline.find_bads_by_threshold(epochs)
"""
# TODO: We should make this function handle multiple channel types.
# TODO: but I'd like to avoid making a copy of the epochs object
ch_types = np.unique(epochs.get_channel_types()).tolist()
if len(ch_types) > 1:
warn(
f"The epochs object contains multiple channel types: {ch_types}.\n"
" This will likely bias the results of the threshold detection."
" Use the `mne.Epochs.pick` to select a single channel type."
)
bads = _threshold_volt_std(epochs, flag_dim="ch", threshold=threshold)
return bads


def _threshold_volt_std(epochs, flag_dim, threshold=5e-5):
"""Detect epochs or channels whose voltage std is above threshold.

Expand All @@ -247,7 +305,7 @@
assert len(threshold) == 2
l_out, u_out = threshold
init_dir = "both"
elif isinstance(threshold, float):
elif isinstance(threshold, (float, int)):
l_out, u_out = (0, threshold)
init_dir = "pos"
else:
Expand Down Expand Up @@ -765,9 +823,20 @@
on. You may need to assess a more appropriate value for your own data.
"""
epochs = self.get_epochs(picks=picks)
above_threshold = _threshold_volt_std(
epochs, flag_dim=flag_dim, threshold=threshold
)
if flag_dim == "ch":
above_threshold = find_bads_by_threshold(epochs, threshold=threshold)
if above_threshold.any():
logger.info(
f"🚩 Found {len(above_threshold)} channels with "
f"voltage variance above {threshold} volts: {above_threshold}"
)
else:
msg = f"No channels with standard deviation above {threshold} volts."
logger.info(msg)
else: # TODO: Implement an annotate_bads_by_threshold for epochs
above_threshold = _threshold_volt_std(

Check warning on line 837 in pylossless/pipeline.py

View check run for this annotation

Codecov / codecov/patch

pylossless/pipeline.py#L837

Added line #L837 was not covered by tests
epochs, flag_dim=flag_dim, threshold=threshold
)
self.flags[flag_dim].add_flag_cat("volt_std", above_threshold, epochs)

def find_outlier_chs(self, epochs=None, picks="eeg"):
Expand Down Expand Up @@ -834,6 +903,7 @@

return mean_ch_dist.ch[mean_ch_dist > mdn + 6 * deviation].values.tolist()

@lossless_logger
def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"):
"""Flag channels based on the stdev value across the time dimension.

Expand All @@ -845,26 +915,50 @@
threshold : float
threshold, in volts. If the standard deviation across time in
any channel x epoch indice is above this threshold, then the
channel x epoch indices will considered an outlier. Defaults
channel x epoch indices will be considered an outlier. Defaults
to 5e-5, or 50 microvolts. Note that here, 'time' refers to
the samples in an epoch. For each channel, if its std value is
above the given threshold in more than 20% of the epochs, it
is flagged.
picks : str (default "eeg")
Type of channels to pick.

Returns
-------
None
If any channels are flagged, those channel names will be logged
in the `flags` attribute of the `LosslessPipeline` object,
under the key ``'volt_std'``, e.g.
``my_pipeline.flags["ch"]["volt_std"]``.

Notes
-----
WARNING: the default threshold of 50 microvolts may not be appropriate
for a particular dataset or data file, as the baseline voltage variance
is affected by the impedance of the system that the data was recorded
with. You may need to assess a more appropriate value for your own
data.
.. warning::

the default threshold of 50 microvolts may not be appropriate
for a particular dataset or data file, as the baseline voltage variance
is affected by the impedance of the system that the data was recorded
with. You may need to assess a more appropriate value for your own
data. You can use the :func:`~pylossless.pipeline.find_bads_by_threshold`
function to quickly assess a more appropriate threshold.

.. seealso::

:func:`~pylossless.pipeline.find_bads_by_threshold`

Examples
--------
>>> import mne
>>> import pylossless as ll
>>> config = ll.Config().load_default()
>>> config["flag_channels_fixed_threshold"] = {"threshold": 5e-5}
>>> pipeline = ll.LosslessPipeline(config=config)
>>> sample_fpath = mne.datasets.sample.data_path()
>>> fpath = sample_fpath / "MEG" / "sample" / "sample_audvis_raw.fif"
>>> raw = mne.io.read_raw(fpath).pick("eeg")
>>> pipeline.raw = raw
>>> pipeline.flag_channels_fixed_threshold()
"""
if "flag_channels_fixed_threshold" not in self.config:
return
if "threshold" in self.config["flag_channels_fixed_threshold"]:
threshold = self.config["flag_channels_fixed_threshold"]["threshold"]
self._flag_volt_std(flag_dim="ch", threshold=threshold, picks=picks)

def flag_epochs_fixed_threshold(self, threshold=5e-5, picks="eeg"):
Expand Down Expand Up @@ -1268,7 +1362,15 @@

# OPTIONAL: Flag chs/epochs based off fixed std threshold of time axis
self.flag_epochs_fixed_threshold(picks=picks)
self.flag_channels_fixed_threshold(picks=picks)
if "flag_channels_fixed_threshold" in self.config:
msg = "Flagging Channels by fixed threshold"
kwargs = dict(picks=picks, message=msg)
if "threshold" in self.config["flag_channels_fixed_threshold"]:
threshold = self.config["flag_channels_fixed_threshold"][
"threshold"
]
kwargs["threshold"] = threshold
self.flag_channels_fixed_threshold(**kwargs)

# 3.flag channels based on large Stdev. across time
msg = "Flagging Noisy Channels"
Expand Down
27 changes: 27 additions & 0 deletions pylossless/tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
import mne
import mne_bids
import numpy as np
import pytest

import pylossless as ll
Expand Down Expand Up @@ -81,6 +82,32 @@ def test_find_outliers():
chs_to_leave_out = pipeline.find_outlier_chs()
assert chs_to_leave_out == ['EEG 001']

def test_find_bads_by_threshold():
"""Test the find bads by threshold function and method."""
fname = mne.datasets.sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif'
raw = mne.io.read_raw_fif(fname, preload=True)
# Make a noisy channel
raw.apply_function(lambda x: x * 3, picks=["EEG 001"])
epochs = mne.make_fixed_length_epochs(raw, preload=True)

# First test the function
with pytest.warns(
RuntimeWarning, match="The epochs object contains multiple channel types"
):
_ = ll.pipeline.find_bads_by_threshold(epochs)
epochs.pick("eeg")
bads = ll.pipeline.find_bads_by_threshold(epochs)
np.testing.assert_array_equal(bads, ['EEG 001'])

# Now test the method
config = ll.config.Config().load_default()
pipeline = ll.LosslessPipeline(config=config)
pipeline.raw = raw
pipeline.flag_channels_fixed_threshold(threshold=10_000) # too high
np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], [])
pipeline.flag_channels_fixed_threshold()
np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], ['EEG 001'])


def test_deprecation():
"""Test the config_name property added for deprecation."""
Expand Down
Loading