Skip to content

Commit 3625b26

Browse files
authored
Make noisy channel exclusion during Reference compatible with MATLAB PREP (#93)
* Add "as_dict" option for get_bads * Don't permanently exclude initial bad-by-SNRs * Match PREP's noisy channel updating logic * Minor variable name cleanup * use as_dict throughout reference.py * Fix 'ignore' logic * Add back "bad_all" (whoops) * Update matlab_differences.rst * Add max_iterations args, link to dropout issue * Fix diffs mistake, update whats_new.rst * Add PrepPipeline API for max_iterations * Improve test coverage for Reference * Add whats_new entry for SNR changes * Fix quotes to make black happy * Improve Reference test coverage some more * remove unused import
1 parent 1abda4d commit 3625b26

File tree

6 files changed

+158
-111
lines changed

6 files changed

+158
-111
lines changed

docs/matlab_differences.rst

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,43 @@ roughly mean-centered. and will thus produce similar values to normal Pearson
148148
correlation. However, to avoid making any assumptions about the signal for any
149149
given channel / window, PyPREP defaults to normal Pearson correlation unless
150150
strict MATLAB equivalence is requested.
151+
152+
153+
Differences in Robust Referencing
154+
---------------------------------
155+
156+
During the robust referencing part of the pipeline, PREP tries to estimate a
157+
"clean" average reference signal for the dataset, excluding any channels
158+
flagged as noisy from contaminating the reference. The robust referencing
159+
process is performed using the following logic:
160+
161+
1) First, an initial pass of noisy channel detection is performed to identify
162+
channels bad by NaN values, flat signal, or low SNR: the data is then
163+
average-referenced excluding these channels. These channels are subsequently
164+
marked as "unusable" and are excluded from any future average referencing.
165+
166+
2) Noisy channel detection is performed on a copy of the re-referenced signal,
167+
and any newly detected bad channels are added to the full set of channels
168+
to be excluded from the reference signal.
169+
170+
3) After noisy channel detection, all bad channels detected so far are
171+
interpolated, and a new estimate of the robust average reference is
172+
calculated using the mean signal of all good channels and all interpolated
173+
bad channels (except those flagged as "unusable" during the first step).
174+
175+
4) A fresh copy of the re-referenced signal from Step 1 is re-referenced using
176+
the new reference signal calculated in Step 3.
177+
178+
5) Steps 2 through 4 are repeated until either two iterations have passed and
179+
no new noisy channels have been detected since the previous iteration, or
180+
the maximum number of reference iterations has been exceeded (default: 4).
181+
182+
183+
Exclusion of dropout channels
184+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
185+
186+
In MATLAB PREP, dropout channels (i.e., channels that have intermittent periods
187+
of flat signal) are detected on each iteration of the reference loop, but are
188+
currently not factored into the full set of "bad" channels to be interpolated.
189+
By contrast, PyPREP will detect and interpolate any bad-by-dropout channels
190+
detected during robust referencing.

docs/whats_new.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,16 @@ Changelog
4141
- Changed RANSAC so that "bad by high-frequency noise" channels are retained when making channel predictions (provided they aren't flagged as bad by any other metric), matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`64`)
4242
- Added a new flag ``matlab_strict`` to :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, :class:`~pyprep.NoisyChannels`, and :func:`~pyprep.ransac.find_bad_by_ransac` for optionally matching MATLAB PREP's internal math as closely as possible, overriding areas where PyPREP attempts to improve on the original, by `Austin Hurst`_ (:gh:`70`)
4343
- Added a ``matlab_strict`` method for high-pass trend removal, exactly matching MATLAB PREP's values if ``matlab_strict`` is enabled, by `Austin Hurst`_ (:gh:`71`)
44-
- Added a window-wise implementaion of RANSAC and made it the default method, reducing the typical RAM demands of robust re-referencing considerably, by `Austin Hurst`_ (:gh:`66`)
44+
- Added a window-wise implementation of RANSAC and made it the default method, reducing the typical RAM demands of robust re-referencing considerably, by `Austin Hurst`_ (:gh:`66`)
4545
- Added `max_chunk_size` parameter for specifying the maximum chunk size to use for channel-wise RANSAC, allowing more control over PyPREP RAM usage, by `Austin Hurst`_ (:gh:`66`)
4646
- Changed :class:`~pyprep.Reference` to exclude "bad-by-SNR" channels from initial average referencing, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`78`)
4747
- Changed :class:`~pyprep.Reference` to only flag "unusable" channels (bad by flat, NaNs, or low SNR) from the first pass of noisy detection for permanent exclusion from the reference signal, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`78`)
4848
- Added a framework for automated testing of PyPREP's components against their MATLAB PREP counterparts (using ``.mat`` and ``.set`` files generated with the `matprep_artifacts`_ script), helping verify that the two PREP implementations are numerically equivalent when `matlab_strict` is ``True``, by `Austin Hurst`_ (:gh:`79`)
4949
- Changed :class:`~pyprep.NoisyChannels` to reuse the same random state for each run of RANSAC when ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`89`)
50+
- Added a new argument `as_dict` for :meth:`~pyprep.NoisyChannels.get_bads`, allowing easier retrieval of flagged noisy channels by category, by `Austin Hurst`_ (:gh:`93`)
51+
- Added a new argument `max_iterations` for :meth:`~pyprep.Reference.perform_reference` and :meth:`~pyprep.Reference.robust_reference`, allowing the maximum number of referencing iterations to be user-configurable, by `Austin Hurst`_ (:gh:`93`)
52+
- Changed :meth:`~pyprep.Reference.robust_reference` to ignore bad-by-dropout channels during referencing if ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`)
53+
- Changed :meth:`~pyprep.Reference.robust_reference` to allow initial bad-by-SNR channels to be used for rereferencing interpolation if no longer bad following initial average reference, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`)
5054

5155
.. _matprep_artifacts: https://github.com/a-hurst/matprep_artifacts
5256

pyprep/find_noisy_channels.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -125,46 +125,61 @@ def _get_filtered_data(self):
125125

126126
return EEG_filt
127127

128-
def get_bads(self, verbose=False):
129-
"""Get a list of all channels currently flagged as bad.
128+
def get_bads(self, verbose=False, as_dict=False):
129+
"""Get the names of all channels currently flagged as bad.
130130
131131
Note that this method does not perform any bad channel detection itself,
132132
and only reports channels already detected as bad by other methods.
133133
134134
Parameters
135135
----------
136-
verbose : bool
136+
verbose : bool, optional
137137
If ``True``, a summary of the channels currently flagged as by bad per
138138
category is printed. Defaults to ``False``.
139+
as_dict: bool, optional
140+
If ``True``, this method will return a dict of the channels currently
141+
flagged as bad by each individual bad channel type. If ``False``, this
142+
method will return a list of all unique bad channels detected so far.
143+
Defaults to ``False``.
139144
140145
Returns
141146
-------
142-
bads : list
143-
THe names of all bad channels detected by any method so far.
147+
bads : list or dict
148+
The names of all bad channels detected so far, either as a combined
149+
list or a dict indicating the channels flagged bad by each type.
144150
145151
"""
146152
bads = {
147-
"n/a": self.bad_by_nan,
148-
"flat": self.bad_by_flat,
149-
"deviation": self.bad_by_deviation,
150-
"hf noise": self.bad_by_hf_noise,
151-
"correl": self.bad_by_correlation,
152-
"SNR": self.bad_by_SNR,
153-
"dropout": self.bad_by_dropout,
154-
"RANSAC": self.bad_by_ransac,
153+
"bad_by_nan": self.bad_by_nan,
154+
"bad_by_flat": self.bad_by_flat,
155+
"bad_by_deviation": self.bad_by_deviation,
156+
"bad_by_hf_noise": self.bad_by_hf_noise,
157+
"bad_by_correlation": self.bad_by_correlation,
158+
"bad_by_SNR": self.bad_by_SNR,
159+
"bad_by_dropout": self.bad_by_dropout,
160+
"bad_by_ransac": self.bad_by_ransac,
155161
}
156162

157163
all_bads = set()
158164
for bad_chs in bads.values():
159165
all_bads.update(bad_chs)
160166

167+
name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"}
161168
if verbose:
162169
out = f"Found {len(all_bads)} uniquely bad channels:\n"
163170
for bad_type, bad_chs in bads.items():
171+
bad_type = bad_type.replace("bad_by_", "")
172+
if bad_type in name_map.keys():
173+
bad_type = name_map[bad_type]
164174
out += f"\n{len(bad_chs)} by {bad_type}: {bad_chs}\n"
165175
print(out)
166176

167-
return list(all_bads)
177+
if as_dict:
178+
bads["bad_all"] = list(all_bads)
179+
else:
180+
bads = list(all_bads)
181+
182+
return bads
168183

169184
def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None):
170185
"""Call all the functions to detect bad channels.

pyprep/prep_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ class PrepPipeline:
3333
For example, for 60Hz you may specify
3434
``np.arange(60, sfreq / 2, 60)``. Specify an empty list to
3535
skip the line noise removal step.
36+
- max_iterations : int, optional
37+
- The maximum number of iterations of noisy channel removal to
38+
perform during robust referencing. Defaults to ``4``.
3639
montage : mne.channels.DigMontage
3740
Digital montage of EEG data.
3841
ransac : bool, optional
@@ -150,6 +153,8 @@ def __init__(
150153
self.prep_params["ref_chs"] = self.ch_names_eeg
151154
if self.prep_params["reref_chs"] == "eeg":
152155
self.prep_params["reref_chs"] = self.ch_names_eeg
156+
if "max_iterations" not in prep_params.keys():
157+
self.prep_params["max_iterations"] = 4
153158
self.sfreq = self.raw_eeg.info["sfreq"]
154159
self.ransac_settings = {
155160
"ransac": ransac,
@@ -215,7 +220,7 @@ def fit(self):
215220
matlab_strict=self.matlab_strict,
216221
**self.ransac_settings,
217222
)
218-
reference.perform_reference()
223+
reference.perform_reference(self.prep_params["max_iterations"])
219224
self.raw_eeg = reference.raw
220225
self.noisy_channels_original = reference.noisy_channels_original
221226
self.noisy_channels_before_interpolation = (

pyprep/reference.py

Lines changed: 49 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,15 @@ def __init__(
9393
self._extra_info = {}
9494
self.matlab_strict = matlab_strict
9595

96-
def perform_reference(self):
96+
def perform_reference(self, max_iterations=4):
9797
"""Estimate the true signal mean and interpolate bad channels.
9898
99+
Parameters
100+
----------
101+
max_iterations : int, optional
102+
The maximum number of iterations of noisy channel removal to perform
103+
during robust referencing. Defaults to ``4``.
104+
99105
This function implements the functionality of the `performReference` function
100106
as part of the PREP pipeline on mne raw object.
101107
@@ -107,7 +113,7 @@ def perform_reference(self):
107113
108114
"""
109115
# Phase 1: Estimate the true signal mean with robust referencing
110-
self.robust_reference()
116+
self.robust_reference(max_iterations)
111117
# If we interpolate the raw here we would be interpolating
112118
# more than what we later actually account for (in interpolated channels).
113119
dummy = self.raw.copy()
@@ -134,17 +140,7 @@ def perform_reference(self):
134140
# Record Noisy channels and EEG before interpolation
135141
self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
136142
self.EEG_before_interpolation = self.EEG.copy()
137-
self.noisy_channels_before_interpolation = {
138-
"bad_by_nan": noisy_detector.bad_by_nan,
139-
"bad_by_flat": noisy_detector.bad_by_flat,
140-
"bad_by_deviation": noisy_detector.bad_by_deviation,
141-
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
142-
"bad_by_correlation": noisy_detector.bad_by_correlation,
143-
"bad_by_SNR": noisy_detector.bad_by_SNR,
144-
"bad_by_dropout": noisy_detector.bad_by_dropout,
145-
"bad_by_ransac": noisy_detector.bad_by_ransac,
146-
"bad_all": noisy_detector.get_bads(),
147-
}
143+
self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True)
148144
self._extra_info["interpolated"] = noisy_detector._extra_info
149145

150146
bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
@@ -170,27 +166,23 @@ def perform_reference(self):
170166
noisy_detector.find_all_bads(**self.ransac_settings)
171167
self.still_noisy_channels = noisy_detector.get_bads()
172168
self.raw.info["bads"] = self.still_noisy_channels
173-
self.noisy_channels_after_interpolation = {
174-
"bad_by_nan": noisy_detector.bad_by_nan,
175-
"bad_by_flat": noisy_detector.bad_by_flat,
176-
"bad_by_deviation": noisy_detector.bad_by_deviation,
177-
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
178-
"bad_by_correlation": noisy_detector.bad_by_correlation,
179-
"bad_by_SNR": noisy_detector.bad_by_SNR,
180-
"bad_by_dropout": noisy_detector.bad_by_dropout,
181-
"bad_by_ransac": noisy_detector.bad_by_ransac,
182-
"bad_all": noisy_detector.get_bads(),
183-
}
169+
self.noisy_channels_after_interpolation = noisy_detector.get_bads(as_dict=True)
184170
self._extra_info["remaining_bad"] = noisy_detector._extra_info
185171

186172
return self
187173

188-
def robust_reference(self):
174+
def robust_reference(self, max_iterations=4):
189175
"""Detect bad channels and estimate the robust reference signal.
190176
191177
This function implements the functionality of the `robustReference` function
192178
as part of the PREP pipeline on mne raw object.
193179
180+
Parameters
181+
----------
182+
max_iterations : int, optional
183+
The maximum number of iterations of noisy channel removal to perform
184+
during robust referencing. Defaults to ``4``.
185+
194186
Returns
195187
-------
196188
noisy_channels: dict
@@ -213,17 +205,7 @@ def robust_reference(self):
213205
matlab_strict=self.matlab_strict,
214206
)
215207
noisy_detector.find_all_bads(**self.ransac_settings)
216-
self.noisy_channels_original = {
217-
"bad_by_nan": noisy_detector.bad_by_nan,
218-
"bad_by_flat": noisy_detector.bad_by_flat,
219-
"bad_by_deviation": noisy_detector.bad_by_deviation,
220-
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
221-
"bad_by_correlation": noisy_detector.bad_by_correlation,
222-
"bad_by_SNR": noisy_detector.bad_by_SNR,
223-
"bad_by_dropout": noisy_detector.bad_by_dropout,
224-
"bad_by_ransac": noisy_detector.bad_by_ransac,
225-
"bad_all": noisy_detector.get_bads(),
226-
}
208+
self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
227209
self._extra_info["initial_bad"] = noisy_detector._extra_info
228210
logger.info("Bad channels: {}".format(self.noisy_channels_original))
229211

@@ -235,16 +217,16 @@ def robust_reference(self):
235217
reference_channels = _set_diff(self.reference_channels, self.unusable_channels)
236218

237219
# Initialize channels to permanently flag as bad during referencing
238-
self.noisy_channels = {
220+
noisy = {
239221
"bad_by_nan": noisy_detector.bad_by_nan,
240222
"bad_by_flat": noisy_detector.bad_by_flat,
241223
"bad_by_deviation": [],
242224
"bad_by_hf_noise": [],
243225
"bad_by_correlation": [],
244-
"bad_by_SNR": noisy_detector.bad_by_SNR,
226+
"bad_by_SNR": [],
245227
"bad_by_dropout": [],
246228
"bad_by_ransac": [],
247-
"bad_all": self.unusable_channels,
229+
"bad_all": [],
248230
}
249231

250232
# Get initial estimate of the reference by the specified method
@@ -260,8 +242,7 @@ def robust_reference(self):
260242
# Remove reference from signal, iteratively interpolating bad channels
261243
raw_tmp = raw.copy()
262244
iterations = 0
263-
noisy_channels_old = []
264-
max_iteration_num = 4
245+
previous_bads = set()
265246

266247
while True:
267248
raw_tmp._data = signal_tmp * 1e-6
@@ -272,51 +253,46 @@ def robust_reference(self):
272253
matlab_strict=self.matlab_strict,
273254
)
274255
# Detrend applied at the beginning of the function.
256+
257+
# Detect all currently bad channels
275258
noisy_detector.find_all_bads(**self.ransac_settings)
276-
self.noisy_channels["bad_by_nan"] = _union(
277-
self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan
278-
)
279-
self.noisy_channels["bad_by_flat"] = _union(
280-
self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat
281-
)
282-
self.noisy_channels["bad_by_deviation"] = _union(
283-
self.noisy_channels["bad_by_deviation"], noisy_detector.bad_by_deviation
284-
)
285-
self.noisy_channels["bad_by_hf_noise"] = _union(
286-
self.noisy_channels["bad_by_hf_noise"], noisy_detector.bad_by_hf_noise
287-
)
288-
self.noisy_channels["bad_by_correlation"] = _union(
289-
self.noisy_channels["bad_by_correlation"],
290-
noisy_detector.bad_by_correlation,
291-
)
292-
self.noisy_channels["bad_by_ransac"] = _union(
293-
self.noisy_channels["bad_by_ransac"], noisy_detector.bad_by_ransac
294-
)
295-
self.noisy_channels["bad_all"] = _union(
296-
self.noisy_channels["bad_all"], noisy_detector.get_bads()
297-
)
298-
logger.info("Bad channels: {}".format(self.noisy_channels))
259+
noisy_new = noisy_detector.get_bads(as_dict=True)
260+
261+
# Specify bad channel types to ignore when updating noisy channels
262+
# NOTE: MATLAB PREP ignores dropout channels, possibly by mistake?
263+
# see: https://github.com/VisLab/EEG-Clean-Tools/issues/28
264+
ignore = ["bad_by_SNR", "bad_all"]
265+
if self.matlab_strict:
266+
ignore += ["bad_by_dropout"]
267+
268+
# Update set of all noisy channels detected so far with any new ones
269+
bad_chans = set()
270+
for bad_type in noisy_new.keys():
271+
noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
272+
if bad_type not in ignore:
273+
bad_chans.update(noisy[bad_type])
274+
noisy["bad_all"] = list(bad_chans)
275+
logger.info("Bad channels: {}".format(noisy))
299276

300277
if (
301278
iterations > 1
302-
and (
303-
not self.noisy_channels["bad_all"]
304-
or set(self.noisy_channels["bad_all"]) == set(noisy_channels_old)
305-
)
306-
or iterations > max_iteration_num
279+
and (len(bad_chans) == 0 or bad_chans == previous_bads)
280+
or iterations > max_iterations
307281
):
282+
logger.info("Robust reference done")
283+
self.noisy_channels = noisy
308284
break
309-
noisy_channels_old = self.noisy_channels["bad_all"].copy()
285+
previous_bads = bad_chans.copy()
310286

311-
if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2:
287+
if raw_tmp.info["nchan"] - len(bad_chans) < 2:
312288
raise ValueError(
313289
"RobustReference:TooManyBad "
314290
"Could not perform a robust reference -- not enough good channels"
315291
)
316292

317-
if self.noisy_channels["bad_all"]:
293+
if len(bad_chans) > 0:
318294
raw_tmp._data = signal * 1e-6
319-
raw_tmp.info["bads"] = self.noisy_channels["bad_all"]
295+
raw_tmp.info["bads"] = list(bad_chans)
320296
raw_tmp.interpolate_bads()
321297
signal_tmp = raw_tmp.get_data() * 1e6
322298
else:
@@ -331,7 +307,6 @@ def robust_reference(self):
331307
iterations = iterations + 1
332308
logger.info("Iterations: {}".format(iterations))
333309

334-
logger.info("Robust reference done")
335310
return self.noisy_channels, self.reference_signal
336311

337312
@staticmethod

0 commit comments

Comments
 (0)