Skip to content

Commit fbbda0a

Browse files
committed
TST, DOC: Add test, clean it up
1 parent 35db8bd commit fbbda0a

File tree

2 files changed

+101
-21
lines changed

2 files changed

+101
-21
lines changed

pylossless/pipeline.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,14 @@ def find_bads_by_threshold(epochs, threshold=5e-5):
232232
233233
Parameters
234234
----------
235-
inst : mne.Epochs
236-
an instance of mne.Epochs
235+
epochs : mne.Epochs
236+
an instance of mne.Epochs with a single channel type.
237237
threshold : float
238238
the threshold in volts. If the standard deviation of a channel's voltage
239-
varianceat a specific epoch is above the threshold, then that channel x epoch
240-
will be flagged as an "outlier". Default is 5e-5 (0.00005), i.e.
241-
50 microvolts.
239+
variance at a specific epoch is above the threshold, then that channel x epoch
240+
will be flagged as an "outlier". If more than 20% of epochs are flagged as
241+
outliers for a specific channel, then that channel will be flagged as bad.
242+
Default threshold is 5e-5 (0.00005), i.e. 50 microvolts.
242243
243244
Returns
244245
-------
@@ -256,6 +257,11 @@ def find_bads_by_threshold(epochs, threshold=5e-5):
256257
... print(threshold)
257258
0.00005
258259
260+
.. seealso::
261+
262+
:func:`~pylossless.LosslessPipeline.flag_channels_fixed_threshold` to use
263+
this function within the lossless pipeline.
264+
259265
Examples
260266
--------
261267
>>> import mne
@@ -266,11 +272,17 @@ def find_bads_by_threshold(epochs, threshold=5e-5):
266272
>>> epochs = mne.make_fixed_length_epochs(raw, preload=True)
267273
>>> bad_chs = ll.pipeline.find_bads_by_threshold(epochs)
268274
"""
269-
bads = _threshold_volt_std(epochs, flag_dim="ch", threshold=threshold)
270-
logger.info(
271-
f"Found {len(bads)} channels with high voltage variance: {bads}"
275+
# XXX: We should make this function handle multiple channel types.
276+
# XXX: but I'd like to avoid making a copy of the epochs object
277+
ch_types = np.unique(epochs.get_channel_types()).tolist()
278+
if len(ch_types) > 1:
279+
warn(
280+
f"The epochs object contains multiple channel types: {ch_types}.\n"
281+
" This will likely bias the results of the threshold detection."
282+
" Use the `mne.Epochs.pick` to select a single channel type."
272283
)
273-
return _threshold_volt_std(epochs, flag_dim="ch", threshold=threshold)
284+
bads = _threshold_volt_std(epochs, flag_dim="ch", threshold=threshold)
285+
return bads
274286

275287

276288
def _threshold_volt_std(epochs, flag_dim, threshold=5e-5):
@@ -293,7 +305,7 @@ def _threshold_volt_std(epochs, flag_dim, threshold=5e-5):
293305
assert len(threshold) == 2
294306
l_out, u_out = threshold
295307
init_dir = "both"
296-
elif isinstance(threshold, float):
308+
elif isinstance(threshold, (float, int)):
297309
l_out, u_out = (0, threshold)
298310
init_dir = "pos"
299311
else:
@@ -814,6 +826,14 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5, picks="eeg"):
814826
# So yes this add a few LOC, but IMO it's worth it for readability
815827
if flag_dim == "ch":
816828
above_threshold = find_bads_by_threshold(epochs, threshold=threshold)
829+
if above_threshold.any():
830+
logger.info(
831+
f"🚩 Found {len(above_threshold)} channels with "
832+
f"voltage variance above {threshold} volts: {above_threshold}"
833+
)
834+
else:
835+
msg = f"No channels with standard deviation above {threshold} volts."
836+
logger.info(msg)
817837
else: # TODO: Implement an annotate_bads_by_threshold for epochs
818838
above_threshold = _threshold_volt_std(
819839
epochs, flag_dim=flag_dim, threshold=threshold
@@ -884,6 +904,7 @@ def find_outlier_chs(self, epochs=None, picks="eeg"):
884904

885905
return mean_ch_dist.ch[mean_ch_dist > mdn + 6 * deviation].values.tolist()
886906

907+
@lossless_logger
887908
def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"):
888909
"""Flag channels based on the stdev value across the time dimension.
889910
@@ -895,26 +916,50 @@ def flag_channels_fixed_threshold(self, threshold=5e-5, picks="eeg"):
895916
threshold : float
896917
threshold, in volts. If the standard deviation across time in
897918
any channel x epoch indice is above this threshold, then the
898-
channel x epoch indices will considered an outlier. Defaults
919+
channel x epoch indices will be considered an outlier. Defaults
899920
to 5e-5, or 50 microvolts. Note that here, 'time' refers to
900921
the samples in an epoch. For each channel, if its std value is
901922
above the given threshold in more than 20% of the epochs, it
902923
is flagged.
903924
picks : str (default "eeg")
904925
Type of channels to pick.
905926
927+
Returns
928+
-------
929+
None
930+
If any channels are flagged, those channel names will be logged
931+
in the `flags` attribute of the `LosslessPipeline` object,
932+
under the key ``'volt_std'``, e.g.
933+
``my_pipeline.flags["ch"]["volt_std"]``.
934+
906935
Notes
907936
-----
908-
WARNING: the default threshold of 50 microvolts may not be appropriate
909-
for a particular dataset or data file, as the baseline voltage variance
910-
is affected by the impedance of the system that the data was recorded
911-
with. You may need to assess a more appropriate value for your own
912-
data.
937+
.. warning::
938+
939+
the default threshold of 50 microvolts may not be appropriate
940+
for a particular dataset or data file, as the baseline voltage variance
941+
is affected by the impedance of the system that the data was recorded
942+
with. You may need to assess a more appropriate value for your own
943+
data. You can use the :func:`~pylossless.pipeline.find_bads_by_threshold`
944+
function to quickly assess a more appropriate threshold.
945+
946+
.. seealso::
947+
948+
:func:`~pylossless.pipeline.find_bads_by_threshold`
949+
950+
Examples
951+
--------
952+
>>> import mne
953+
>>> import pylossless as ll
954+
>>> config = ll.Config().load_default()
955+
>>> config["flag_channels_fixed_threshold"] = {"threshold": 5e-5}
956+
>>> pipeline = ll.LosslessPipeline(config=config)
957+
>>> sample_fpath = mne.datasets.sample.data_path()
958+
>>> fpath = sample_fpath / "MEG" / "sample" / "sample_audvis_raw.fif"
959+
>>> raw = mne.io.read_raw(fpath).pick("eeg")
960+
>>> pipeline.raw = raw
961+
>>> pipeline.flag_channels_fixed_threshold()
913962
"""
914-
if "flag_channels_fixed_threshold" not in self.config:
915-
return
916-
if "threshold" in self.config["flag_channels_fixed_threshold"]:
917-
threshold = self.config["flag_channels_fixed_threshold"]["threshold"]
918963
self._flag_volt_std(flag_dim="ch", threshold=threshold, picks=picks)
919964

920965
def flag_epochs_fixed_threshold(self, threshold=5e-5, picks="eeg"):
@@ -1318,7 +1363,15 @@ def _run(self):
13181363

13191364
# OPTIONAL: Flag chs/epochs based off fixed std threshold of time axis
13201365
self.flag_epochs_fixed_threshold(picks=picks)
1321-
self.flag_channels_fixed_threshold(picks=picks)
1366+
if "flag_channels_fixed_threshold" in self.config:
1367+
if "threshold" in self.config["flag_channels_fixed_threshold"]:
1368+
thresh = self.config["flag_channels_fixed_threshold"]["threshold"]
1369+
else:
1370+
thresh = 5e-5
1371+
msg = "Flagging Channels by fixed threshold"
1372+
self.flag_channels_fixed_threshold(
1373+
threshold=thresh, picks=picks, message=msg
1374+
)
13221375

13231376
# 3.flag channels based on large Stdev. across time
13241377
msg = "Flagging Noisy Channels"

pylossless/tests/test_pipeline.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22
import mne
33
import mne_bids
4+
import numpy as np
45
import pytest
56

67
import pylossless as ll
@@ -81,6 +82,32 @@ def test_find_outliers():
8182
chs_to_leave_out = pipeline.find_outlier_chs()
8283
assert chs_to_leave_out == ['EEG 001']
8384

85+
def test_find_bads_by_threshold():
86+
"""Test the find bads by threshold function and method."""
87+
fname = mne.datasets.sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif'
88+
raw = mne.io.read_raw_fif(fname, preload=True)
89+
# Make a noisy channel
90+
raw.apply_function(lambda x: x * 3, picks=["EEG 001"])
91+
epochs = mne.make_fixed_length_epochs(raw, preload=True)
92+
93+
# First test the function
94+
with pytest.warns(
95+
RuntimeWarning, match="The epochs object contains multiple channel types"
96+
):
97+
_ = ll.pipeline.find_bads_by_threshold(epochs)
98+
epochs.pick("eeg")
99+
bads = ll.pipeline.find_bads_by_threshold(epochs)
100+
np.testing.assert_array_equal(bads, ['EEG 001'])
101+
102+
# Now test the method
103+
config = ll.config.Config().load_default()
104+
pipeline = ll.LosslessPipeline(config=config)
105+
pipeline.raw = raw
106+
pipeline.flag_channels_fixed_threshold(threshold=10_000) # too high
107+
np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], [])
108+
pipeline.flag_channels_fixed_threshold()
109+
np.testing.assert_array_equal(pipeline.flags["ch"]["volt_std"], ['EEG 001'])
110+
84111

85112
def test_deprecation():
86113
"""Test the config_name property added for deprecation."""

0 commit comments

Comments
 (0)