Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pylossless/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def pipeline_fixture():
config["find_breaks"]["min_break_duration"] = 9
config["find_breaks"]["t_start_after_previous"] = 1
config["find_breaks"]["t_stop_before_next"] = 0
config["flag_channels_fixed_threshold"] = {"threshold": 10_000}
config["ica"]["ica_args"]["run1"]["max_iter"] = 5000

# Testing when passing the config object directly...
Expand Down
133 changes: 118 additions & 15 deletions pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,64 @@
return prop_outliers[prop_outliers > flag_crit].coords.to_index().values


def find_bads_by_threshold(epochs, threshold=5e-5):
"""Return channels with a standard deviation consistently above a fixed threshold.
Parameters
----------
epochs : mne.Epochs
an instance of mne.Epochs with a single channel type.
threshold : float
the threshold in volts. If the standard deviation of a channel's voltage
variance at a specific epoch is above the threshold, then that channel x epoch
will be flagged as an "outlier". If more than 20% of epochs are flagged as
outliers for a specific channel, then that channel will be flagged as bad.
Default threshold is 5e-5 (0.00005), i.e. 50 microvolts.
Returns
-------
list
a list of channel names that are considered outliers.
Notes
-----
If you are having trouble converting between exponential notation and
decimal notation, you can use the following code to convert between the two:
>>> import numpy as np
>>> threshold = 5e-5
>>> with np.printoptions(suppress=True):
... print(threshold)
0.00005
.. seealso::
:func:`~pylossless.LosslessPipeline.flag_channels_fixed_threshold` to use
this function within the lossless pipeline.
Examples
--------
>>> import mne
>>> import pylossless as ll
>>> fname = mne.datasets.sample.data_path() / "MEG/sample/sample_audvis_raw.fif"
>>> raw = mne.io.read_raw(fname, preload=True).pick("eeg")
>>> raw.apply_function(lambda x: x * 3, picks=["EEG 001"]) # Make a noisy channel
>>> epochs = mne.make_fixed_length_epochs(raw, preload=True)
>>> bad_chs = ll.pipeline.find_bads_by_threshold(epochs)
"""
# XXX: We should make this function handle multiple channel types.
# XXX: but I'd like to avoid making a copy of the epochs object
ch_types = np.unique(epochs.get_channel_types()).tolist()
if len(ch_types) > 1:
warn(
f"The epochs object contains multiple channel types: {ch_types}.\n"
" This will likely bias the results of the threshold detection."
" Use the `mne.Epochs.pick` to select a single channel type."
)
bads = _threshold_volt_std(epochs, flag_dim="ch", threshold=threshold)
return bads


def _threshold_volt_std(epochs, flag_dim, threshold=5e-5):
"""Detect epochs or channels whose voltage std is above threshold.
Expand All @@ -247,7 +305,7 @@
assert len(threshold) == 2
l_out, u_out = threshold
init_dir = "both"
elif isinstance(threshold, float):
elif isinstance(threshold, (float, int)):
l_out, u_out = (0, threshold)
init_dir = "pos"
else:
Expand Down Expand Up @@ -765,9 +823,21 @@
on. You may need to assess a more appropriate value for your own data.
"""
epochs = self.get_epochs(picks=picks)
above_threshold = _threshold_volt_std(
epochs, flag_dim=flag_dim, threshold=threshold
)
# So yes this add a few LOC, but IMO it's worth it for readability
if flag_dim == "ch":
above_threshold = find_bads_by_threshold(epochs, threshold=threshold)
if above_threshold.any():
logger.info(
f"🚩 Found {len(above_threshold)} channels with "
f"voltage variance above {threshold} volts: {above_threshold}"
)
else:
msg = f"No channels with standard deviation above {threshold} volts."
logger.info(msg)
else: # TODO: Implement an annotate_bads_by_threshold for epochs
above_threshold = _threshold_volt_std(

Check warning on line 838 in pylossless/pipeline.py

View check run for this annotation

Codecov / codecov/patch

pylossless/pipeline.py#L838

Added line #L838 was not covered by tests
epochs, flag_dim=flag_dim, threshold=threshold
)
self.flags[flag_dim].add_flag_cat("volt_std", above_threshold, epochs)

def find_outlier_chs(self, epochs=None, picks="eeg"):
Expand Down Expand Up @@ -834,6 +904,7 @@

return mean_ch_dist.ch[mean_ch_dist > mdn + 6 * deviation].values.tolist()

@lossless_logger
def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"):
"""Flag channels based on the stdev value across the time dimension.
Expand All @@ -845,26 +916,50 @@
threshold : float
threshold, in volts. If the standard deviation across time in
any channel x epoch indice is above this threshold, then the
channel x epoch indices will considered an outlier. Defaults
channel x epoch indices will be considered an outlier. Defaults
to 5e-5, or 50 microvolts. Note that here, 'time' refers to
the samples in an epoch. For each channel, if its std value is
above the given threshold in more than 20% of the epochs, it
is flagged.
picks : str (default "eeg")
Type of channels to pick.
Returns
-------
None
If any channels are flagged, those channel names will be logged
in the `flags` attribute of the `LosslessPipeline` object,
under the key ``'volt_std'``, e.g.
``my_pipeline.flags["ch"]["volt_std"]``.
Notes
-----
WARNING: the default threshold of 50 microvolts may not be appropriate
for a particular dataset or data file, as the baseline voltage variance
is affected by the impedance of the system that the data was recorded
with. You may need to assess a more appropriate value for your own
data.
.. warning::
the default threshold of 50 microvolts may not be appropriate
for a particular dataset or data file, as the baseline voltage variance
is affected by the impedance of the system that the data was recorded
with. You may need to assess a more appropriate value for your own
data. You can use the :func:`~pylossless.pipeline.find_bads_by_threshold`
function to quickly assess a more appropriate threshold.
.. seealso::
:func:`~pylossless.pipeline.find_bads_by_threshold`
Examples
--------
>>> import mne
>>> import pylossless as ll
>>> config = ll.Config().load_default()
>>> config["flag_channels_fixed_threshold"] = {"threshold": 5e-5}
>>> pipeline = ll.LosslessPipeline(config=config)
>>> sample_fpath = mne.datasets.sample.data_path()
>>> fpath = sample_fpath / "MEG" / "sample" / "sample_audvis_raw.fif"
>>> raw = mne.io.read_raw(fpath).pick("eeg")
>>> pipeline.raw = raw
>>> pipeline.flag_channels_fixed_threshold()
"""
if "flag_channels_fixed_threshold" not in self.config:
return
if "threshold" in self.config["flag_channels_fixed_threshold"]:
threshold = self.config["flag_channels_fixed_threshold"]["threshold"]
self._flag_volt_std(flag_dim="ch", threshold=threshold, picks=picks)

def flag_epochs_fixed_threshold(self, threshold=5e-5, picks="eeg"):
Expand Down Expand Up @@ -1268,7 +1363,15 @@

# OPTIONAL: Flag chs/epochs based off fixed std threshold of time axis
self.flag_epochs_fixed_threshold(picks=picks)
self.flag_channels_fixed_threshold(picks=picks)
if "flag_channels_fixed_threshold" in self.config:
if "threshold" in self.config["flag_channels_fixed_threshold"]:
thresh = self.config["flag_channels_fixed_threshold"]["threshold"]
else:
thresh = 5e-5

Check warning on line 1370 in pylossless/pipeline.py

View check run for this annotation

Codecov / codecov/patch

pylossless/pipeline.py#L1370

Added line #L1370 was not covered by tests
msg = "Flagging Channels by fixed threshold"
self.flag_channels_fixed_threshold(
threshold=thresh, picks=picks, message=msg
)

# 3.flag channels based on large Stdev. across time
msg = "Flagging Noisy Channels"
Expand Down
27 changes: 27 additions & 0 deletions pylossless/tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
import mne
import mne_bids
import numpy as np
import pytest

import pylossless as ll
Expand Down Expand Up @@ -81,6 +82,32 @@ def test_find_outliers():
chs_to_leave_out = pipeline.find_outlier_chs()
assert chs_to_leave_out == ['EEG 001']

def test_find_bads_by_threshold():
"""Test the find bads by threshold function and method."""
fname = mne.datasets.sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif'
raw = mne.io.read_raw_fif(fname, preload=True)
# Make a noisy channel
raw.apply_function(lambda x: x * 3, picks=["EEG 001"])
epochs = mne.make_fixed_length_epochs(raw, preload=True)

# First test the function
with pytest.warns(
RuntimeWarning, match="The epochs object contains multiple channel types"
):
_ = ll.pipeline.find_bads_by_threshold(epochs)
epochs.pick("eeg")
bads = ll.pipeline.find_bads_by_threshold(epochs)
np.testing.assert_array_equal(bads, ['EEG 001'])

# Now test the method
config = ll.config.Config().load_default()
pipeline = ll.LosslessPipeline(config=config)
pipeline.raw = raw
pipeline.flag_channels_fixed_threshold(threshold=10_000) # too high
np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], [])
pipeline.flag_channels_fixed_threshold()
np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], ['EEG 001'])


def test_deprecation():
"""Test the config_name property added for deprecation."""
Expand Down
Loading