Skip to content

Commit 9052a63

Browse files
RFC, WIP: public function to find and return bad channels (#214)
* RFC, WIP: public function to find and return bad channels * TST, DOC: Add test, clean it up * TST: add flag_channels_fixed_threshold step to full pipeline run test * Christian Suggestions (More DRY) [skip actions] [ci skip] Co-authored-by: Christian O'Reilly <[email protected]> * FIX: line length... --------- Co-authored-by: Christian O'Reilly <[email protected]>
1 parent d835e40 commit 9052a63

File tree

3 files changed

+145
-15
lines changed

3 files changed

+145
-15
lines changed

pylossless/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def pipeline_fixture():
3131
config["find_breaks"]["min_break_duration"] = 9
3232
config["find_breaks"]["t_start_after_previous"] = 1
3333
config["find_breaks"]["t_stop_before_next"] = 0
34+
config["flag_channels_fixed_threshold"] = {"threshold": 10_000}
3435
config["ica"]["ica_args"]["run1"]["max_iter"] = 5000
3536

3637
# Testing when passing the config object directly...

pylossless/pipeline.py

Lines changed: 117 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,64 @@ def _detect_outliers(
227227
return prop_outliers[prop_outliers > flag_crit].coords.to_index().values
228228

229229

230+
def find_bads_by_threshold(epochs, threshold=5e-5):
231+
"""Return channels with a standard deviation consistently above a fixed threshold.
232+
233+
Parameters
234+
----------
235+
epochs : mne.Epochs
236+
an instance of mne.Epochs with a single channel type.
237+
threshold : float
238+
the threshold in volts. If the standard deviation of a channel's voltage
239+
variance at a specific epoch is above the threshold, then that channel x epoch
240+
will be flagged as an "outlier". If more than 20% of epochs are flagged as
241+
outliers for a specific channel, then that channel will be flagged as bad.
242+
Default threshold is 5e-5 (0.00005), i.e. 50 microvolts.
243+
244+
Returns
245+
-------
246+
list
247+
a list of channel names that are considered outliers.
248+
249+
Notes
250+
-----
251+
If you are having trouble converting between exponential notation and
252+
decimal notation, you can use the following code to convert between the two:
253+
254+
>>> import numpy as np
255+
>>> threshold = 5e-5
256+
>>> with np.printoptions(suppress=True):
257+
... print(threshold)
258+
0.00005
259+
260+
.. seealso::
261+
262+
:func:`~pylossless.LosslessPipeline.flag_channels_fixed_threshold` to use
263+
this function within the lossless pipeline.
264+
265+
Examples
266+
--------
267+
>>> import mne
268+
>>> import pylossless as ll
269+
>>> fname = mne.datasets.sample.data_path() / "MEG/sample/sample_audvis_raw.fif"
270+
>>> raw = mne.io.read_raw(fname, preload=True).pick("eeg")
271+
>>> raw.apply_function(lambda x: x * 3, picks=["EEG 001"]) # Make a noisy channel
272+
>>> epochs = mne.make_fixed_length_epochs(raw, preload=True)
273+
>>> bad_chs = ll.pipeline.find_bads_by_threshold(epochs)
274+
"""
275+
# TODO: We should make this function handle multiple channel types.
276+
# TODO: but I'd like to avoid making a copy of the epochs object
277+
ch_types = np.unique(epochs.get_channel_types()).tolist()
278+
if len(ch_types) > 1:
279+
warn(
280+
f"The epochs object contains multiple channel types: {ch_types}.\n"
281+
" This will likely bias the results of the threshold detection."
282+
" Use the `mne.Epochs.pick` to select a single channel type."
283+
)
284+
bads = _threshold_volt_std(epochs, flag_dim="ch", threshold=threshold)
285+
return bads
286+
287+
230288
def _threshold_volt_std(epochs, flag_dim, threshold=5e-5):
231289
"""Detect epochs or channels whose voltage std is above threshold.
232290
@@ -247,7 +305,7 @@ def _threshold_volt_std(epochs, flag_dim, threshold=5e-5):
247305
assert len(threshold) == 2
248306
l_out, u_out = threshold
249307
init_dir = "both"
250-
elif isinstance(threshold, float):
308+
elif isinstance(threshold, (float, int)):
251309
l_out, u_out = (0, threshold)
252310
init_dir = "pos"
253311
else:
@@ -765,9 +823,20 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5, picks="eeg"):
765823
on. You may need to assess a more appropriate value for your own data.
766824
"""
767825
epochs = self.get_epochs(picks=picks)
768-
above_threshold = _threshold_volt_std(
769-
epochs, flag_dim=flag_dim, threshold=threshold
770-
)
826+
if flag_dim == "ch":
827+
above_threshold = find_bads_by_threshold(epochs, threshold=threshold)
828+
if above_threshold.any():
829+
logger.info(
830+
f"🚩 Found {len(above_threshold)} channels with "
831+
f"voltage variance above {threshold} volts: {above_threshold}"
832+
)
833+
else:
834+
msg = f"No channels with standard deviation above {threshold} volts."
835+
logger.info(msg)
836+
else: # TODO: Implement an annotate_bads_by_threshold for epochs
837+
above_threshold = _threshold_volt_std(
838+
epochs, flag_dim=flag_dim, threshold=threshold
839+
)
771840
self.flags[flag_dim].add_flag_cat("volt_std", above_threshold, epochs)
772841

773842
def find_outlier_chs(self, epochs=None, picks="eeg"):
@@ -834,6 +903,7 @@ def find_outlier_chs(self, epochs=None, picks="eeg"):
834903

835904
return mean_ch_dist.ch[mean_ch_dist > mdn + 6 * deviation].values.tolist()
836905

906+
@lossless_logger
837907
def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"):
838908
"""Flag channels based on the stdev value across the time dimension.
839909
@@ -845,26 +915,50 @@ def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"):
845915
threshold : float
846916
threshold, in volts. If the standard deviation across time in
847917
any channel x epoch indice is above this threshold, then the
848-
channel x epoch indices will considered an outlier. Defaults
918+
channel x epoch indices will be considered an outlier. Defaults
849919
to 5e-5, or 50 microvolts. Note that here, 'time' refers to
850920
the samples in an epoch. For each channel, if its std value is
851921
above the given threshold in more than 20% of the epochs, it
852922
is flagged.
853923
picks : str (default "eeg")
854924
Type of channels to pick.
855925
926+
Returns
927+
-------
928+
None
929+
If any channels are flagged, those channel names will be logged
930+
in the `flags` attribute of the `LosslessPipeline` object,
931+
under the key ``'volt_std'``, e.g.
932+
``my_pipeline.flags["ch"]["volt_std"]``.
933+
856934
Notes
857935
-----
858-
WARNING: the default threshold of 50 microvolts may not be appropriate
859-
for a particular dataset or data file, as the baseline voltage variance
860-
is affected by the impedance of the system that the data was recorded
861-
with. You may need to assess a more appropriate value for your own
862-
data.
936+
.. warning::
937+
938+
the default threshold of 50 microvolts may not be appropriate
939+
for a particular dataset or data file, as the baseline voltage variance
940+
is affected by the impedance of the system that the data was recorded
941+
with. You may need to assess a more appropriate value for your own
942+
data. You can use the :func:`~pylossless.pipeline.find_bads_by_threshold`
943+
function to quickly assess a more appropriate threshold.
944+
945+
.. seealso::
946+
947+
:func:`~pylossless.pipeline.find_bads_by_threshold`
948+
949+
Examples
950+
--------
951+
>>> import mne
952+
>>> import pylossless as ll
953+
>>> config = ll.Config().load_default()
954+
>>> config["flag_channels_fixed_threshold"] = {"threshold": 5e-5}
955+
>>> pipeline = ll.LosslessPipeline(config=config)
956+
>>> sample_fpath = mne.datasets.sample.data_path()
957+
>>> fpath = sample_fpath / "MEG" / "sample" / "sample_audvis_raw.fif"
958+
>>> raw = mne.io.read_raw(fpath).pick("eeg")
959+
>>> pipeline.raw = raw
960+
>>> pipeline.flag_channels_fixed_threshold()
863961
"""
864-
if "flag_channels_fixed_threshold" not in self.config:
865-
return
866-
if "threshold" in self.config["flag_channels_fixed_threshold"]:
867-
threshold = self.config["flag_channels_fixed_threshold"]["threshold"]
868962
self._flag_volt_std(flag_dim="ch", threshold=threshold, picks=picks)
869963

870964
def flag_epochs_fixed_threshold(self, threshold=5e-5, picks="eeg"):
@@ -1268,7 +1362,15 @@ def _run(self):
12681362

12691363
# OPTIONAL: Flag chs/epochs based off fixed std threshold of time axis
12701364
self.flag_epochs_fixed_threshold(picks=picks)
1271-
self.flag_channels_fixed_threshold(picks=picks)
1365+
if "flag_channels_fixed_threshold" in self.config:
1366+
msg = "Flagging Channels by fixed threshold"
1367+
kwargs = dict(picks=picks, message=msg)
1368+
if "threshold" in self.config["flag_channels_fixed_threshold"]:
1369+
threshold = self.config["flag_channels_fixed_threshold"][
1370+
"threshold"
1371+
]
1372+
kwargs["threshold"] = threshold
1373+
self.flag_channels_fixed_threshold(**kwargs)
12721374

12731375
# 3.flag channels based on large Stdev. across time
12741376
msg = "Flagging Noisy Channels"

pylossless/tests/test_pipeline.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22
import mne
33
import mne_bids
4+
import numpy as np
45
import pytest
56

67
import pylossless as ll
@@ -81,6 +82,32 @@ def test_find_outliers():
8182
chs_to_leave_out = pipeline.find_outlier_chs()
8283
assert chs_to_leave_out == ['EEG 001']
8384

85+
def test_find_bads_by_threshold():
86+
"""Test the find bads by threshold function and method."""
87+
fname = mne.datasets.sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif'
88+
raw = mne.io.read_raw_fif(fname, preload=True)
89+
# Make a noisy channel
90+
raw.apply_function(lambda x: x * 3, picks=["EEG 001"])
91+
epochs = mne.make_fixed_length_epochs(raw, preload=True)
92+
93+
# First test the function
94+
with pytest.warns(
95+
RuntimeWarning, match="The epochs object contains multiple channel types"
96+
):
97+
_ = ll.pipeline.find_bads_by_threshold(epochs)
98+
epochs.pick("eeg")
99+
bads = ll.pipeline.find_bads_by_threshold(epochs)
100+
np.testing.assert_array_equal(bads, ['EEG 001'])
101+
102+
# Now test the method
103+
config = ll.config.Config().load_default()
104+
pipeline = ll.LosslessPipeline(config=config)
105+
pipeline.raw = raw
106+
pipeline.flag_channels_fixed_threshold(threshold=10_000) # too high
107+
np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], [])
108+
pipeline.flag_channels_fixed_threshold()
109+
np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], ['EEG 001'])
110+
84111

85112
def test_deprecation():
86113
"""Test the config_name property added for deprecation."""

0 commit comments

Comments
 (0)