Skip to content

Commit 1c9589c

Browse files
authored
Implement EEGLAB-style interpolation in Reference for matlab_strict (#96)
* Add initial implementation of EEGLAB interpolation * Update matlab_differences.rst * Add basic MATLAB comparison tests * Fix spelling mistake * Update whats_new.rst * Make isort happy * Improve wording around matlab_strict parameter * Fix Perrin citation RST * Add explicit refs to reimplemented EEGLAB code * Add references header in matlab_differences
1 parent 3625b26 commit 1c9589c

File tree

5 files changed

+318
-32
lines changed

5 files changed

+318
-32
lines changed

docs/matlab_differences.rst

+40-8
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ Although PyPREP aims to be a faithful reimplementation of the original MATLAB
99
version of PREP, there are a few places where PyPREP has deliberately chosen
1010
to use different defaults than the MATLAB PREP.
1111

12-
To override these differerences, you can set the ``matlab_strict`` argument to
13-
:class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, or
14-
:class:`~pyprep.NoisyChannels` as ``True`` to match the original PREP's
15-
internal math.
12+
To override these differerences, you can set the ``matlab_strict`` parameter
13+
for :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, or
14+
:class:`~pyprep.NoisyChannels` to ``True`` in order to match the original
15+
PREP's internal math.
1616

1717
.. contents:: Table of Contents
1818
:depth: 3
@@ -39,8 +39,8 @@ Because the practical differences are small and MNE's filtering is fast and
3939
well-tested, PyPREP defaults to using :func:`mne.filter.filter_data` for
4040
high-pass trend removal. However, for exact numerical compatibility, PyPREP
4141
has a basic re-implementation of EEGLAB's ``pop_eegfiltnew`` in Python that
42-
produces identical results to MATLAB PREP's ``removeTrend`` when
43-
``matlab_strict`` is set to ``True``.
42+
produces identical results to MATLAB PREP's ``removeTrend`` when the
43+
``matlab_strict`` parameter is set to ``True``.
4444

4545

4646
Differences in RANSAC
@@ -93,8 +93,8 @@ approach has the benefit of better randomness, but may also lead to more
9393
variability in PREP results between different seed values. More testing is
9494
required to determine which approach produces better results.
9595

96-
Note that to match MATLAB PREP exactly when ``matlab_strict`` is ``True``, the
97-
random seed ``435656`` must be used.
96+
Note that to match MATLAB PREP exactly when the ``matlab_strict`` parameter is
97+
set to ``True``, the random seed ``435656`` must be used.
9898

9999

100100
Calculation of median estimated signal
@@ -188,3 +188,35 @@ of flat signal) are detected on each iteration of the reference loop, but are
188188
currently not factored into the full set of "bad" channels to be interpolated.
189189
By contrast, PyPREP will detect and interpolate any bad-by-dropout channels
190190
detected during robust referencing.
191+
192+
193+
Bad channel interpolation
194+
^^^^^^^^^^^^^^^^^^^^^^^^^
195+
196+
MATLAB PREP uses EEGLAB's internal ``eeg_interp`` method of spherical spline
197+
interpolation for interpolating identified bad channels during robust reference
198+
estimation and (if enabled) immediately after the robust reference signal is
199+
applied in order to remove any remaining detected bad channels once referencing
200+
is complete.
201+
202+
However, ``eeg_interp``'s method of spherical interpolations differs quite a bit
203+
numerically from MNE's implementation as well as the interpolation method used
204+
by MATLAB PREP for RANSAC predictions, both of which are numerically identical
205+
and based directly on the formulas in Perrin et al. (1989) [1]_. ``eeg_interp``
206+
seems to use a modified variation of the Perrin et al. method, but diverges in
207+
a number of ways that are not clearly documented or cited in the code.
208+
209+
To keep with the more established method of spherical interpolation and stay
210+
consistent with the interpolation code used in RANSAC, PyPREP defaults to using
211+
MNE's :meth:`~mne.io.Raw.interpolate_bads` method for interpolation during and
212+
following robust referencing. However, for full numeric equivalence with
213+
MATLAB PREP, PyPREP will use a Python reimplementation of ``eeg_interp`` instead
214+
when the ``matlab_strict`` parameter is set to ``True``.
215+
216+
217+
References
218+
----------
219+
220+
.. [1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
221+
Spherical splines for scalp potential and current density mapping.
222+
Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.

docs/whats_new.rst

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Changelog
5151
- Added a new argument `max_iterations` for :meth:`~pyprep.Reference.perform_reference` and :meth:`~pyprep.Reference.robust_reference`, allowing the maximum number of referencing iterations to be user-configurable, by `Austin Hurst`_ (:gh:`93`)
5252
- Changed :meth:`~pyprep.Reference.robust_reference` to ignore bad-by-dropout channels during referencing if ``matlab_strict`` is ``True``, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`)
5353
- Changed :meth:`~pyprep.Reference.robust_reference` to allow initial bad-by-SNR channels to be used for rereferencing interpolation if no longer bad following initial average reference, matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`93`)
54+
- Added a ``matlab_strict`` method for bad channel interpolation, allowing for full numeric equivalence with MATLAB PREP's robust referencing, by `Austin Hurst`_ (:gh:`96`)
5455

5556
.. _matprep_artifacts: https://github.com/a-hurst/matprep_artifacts
5657

pyprep/reference.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pyprep.find_noisy_channels import NoisyChannels
88
from pyprep.removeTrend import removeTrend
9-
from pyprep.utils import _set_diff, _union
9+
from pyprep.utils import _eeglab_interpolate_bads, _set_diff, _union
1010

1111
logging.basicConfig(
1212
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -118,7 +118,10 @@ def perform_reference(self, max_iterations=4):
118118
# more than what we later actually account for (in interpolated channels).
119119
dummy = self.raw.copy()
120120
dummy.info["bads"] = self.noisy_channels["bad_all"]
121-
dummy.interpolate_bads()
121+
if self.matlab_strict:
122+
_eeglab_interpolate_bads(dummy)
123+
else:
124+
dummy.interpolate_bads()
122125
self.reference_signal = (
123126
np.nanmean(dummy.get_data(picks=self.reference_channels), axis=0) * 1e6
124127
)
@@ -145,7 +148,10 @@ def perform_reference(self, max_iterations=4):
145148

146149
bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
147150
self.raw.info["bads"] = bad_channels
148-
self.raw.interpolate_bads()
151+
if self.matlab_strict:
152+
_eeglab_interpolate_bads(self.raw)
153+
else:
154+
self.raw.interpolate_bads()
149155
reference_correct = (
150156
np.nanmean(self.raw.get_data(picks=self.reference_channels), axis=0) * 1e6
151157
)
@@ -293,7 +299,10 @@ def robust_reference(self, max_iterations=4):
293299
if len(bad_chans) > 0:
294300
raw_tmp._data = signal * 1e-6
295301
raw_tmp.info["bads"] = list(bad_chans)
296-
raw_tmp.interpolate_bads()
302+
if self.matlab_strict:
303+
_eeglab_interpolate_bads(raw_tmp)
304+
else:
305+
raw_tmp.interpolate_bads()
297306
signal_tmp = raw_tmp.get_data() * 1e6
298307
else:
299308
signal_tmp = signal

pyprep/utils.py

+141
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
import math
33
from cmath import sqrt
44

5+
import mne
56
import numpy as np
67
import scipy.interpolate
8+
from mne.surface import _normalize_vectors
9+
from numpy.polynomial.legendre import legval
710
from psutil import virtual_memory
11+
from scipy import linalg
812
from scipy.signal import firwin, lfilter, lfilter_zi
913

1014

@@ -235,6 +239,143 @@ def _eeglab_fir_filter(data, filt):
235239
return out
236240

237241

242+
def _eeglab_calc_g(pos_from, pos_to, stiffness=4, num_lterms=7):
243+
"""Calculate spherical spline g function between points on a sphere.
244+
245+
Parameters
246+
----------
247+
pos_from : np.ndarray of float, shape(n_good_sensors, 3)
248+
The electrode positions to interpolate from.
249+
pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
250+
The electrode positions to interpolate.
251+
stiffness : float
252+
Stiffness of the spline.
253+
num_lterms : int
254+
Number of Legendre terms to evaluate.
255+
256+
Returns
257+
-------
258+
G : np.ndarray of float, shape(n_channels, n_channels)
259+
The G matrix.
260+
261+
Notes
262+
-----
263+
Produces identical output to the private ``computeg`` function in EEGLAB's
264+
``eeg_interp.m``.
265+
266+
"""
267+
# https://github.com/sccn/eeglab/blob/167dfc8/functions/popfunc/eeg_interp.m#L347
268+
269+
n_to = pos_to.shape[0]
270+
n_from = pos_from.shape[0]
271+
272+
# Calculate the Euclidian distances between the 'to' and 'from' electrodes
273+
dxyz = []
274+
for i in range(0, 3):
275+
d1 = np.repeat(pos_to[:, i], n_from).reshape((n_to, n_from))
276+
d2 = np.repeat(pos_from[:, i], n_to).reshape((n_from, n_to)).T
277+
dxyz.append((d1 - d2) ** 2)
278+
elec_dists = np.sqrt(sum(dxyz))
279+
280+
# Subtract all the Euclidian electrode distances from 1 (why?)
281+
EI = np.ones([n_to, n_from]) - elec_dists
282+
283+
# Calculate Legendre coefficients for the given degree and stiffness
284+
factors = [0]
285+
for n in range(1, num_lterms + 1):
286+
f = (2 * n + 1) / (n ** stiffness * (n + 1) ** stiffness * 4 * np.pi)
287+
factors.append(f)
288+
289+
return legval(EI, factors)
290+
291+
292+
def _eeglab_interpolate(data, pos_from, pos_to):
293+
"""Interpolate bad channels using EEGLAB's custom method.
294+
295+
Parameters
296+
----------
297+
data : np.ndarray
298+
A 2-D array containing signals from currently-good EEG channels with
299+
which to interpolate signals for bad channels.
300+
pos_from : np.ndarray of float, shape(n_good_sensors, 3)
301+
The electrode positions to interpolate from.
302+
pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
303+
The electrode positions to interpolate.
304+
305+
Returns
306+
-------
307+
interpolated : np.ndarray
308+
The interpolated signals for all bad channels.
309+
310+
Notes
311+
-----
312+
Produces identical output to the private ``spheric_spline`` function in
313+
EEGLAB's ``eeg_interp.m`` (with minor rounding errors).
314+
315+
"""
316+
# https://github.com/sccn/eeglab/blob/167dfc8/functions/popfunc/eeg_interp.m#L314
317+
318+
# Calculate G for distances between good electrodes + between goods & bads
319+
G_from = _eeglab_calc_g(pos_from, pos_from)
320+
G_to_from = _eeglab_calc_g(pos_from, pos_to)
321+
322+
# Get average reference signal for all good channels and subtract from data
323+
avg_ref = np.mean(data, axis=0)
324+
data_tmp = data - avg_ref
325+
326+
# Calculate interpolation matrix from electrode locations
327+
pad_ones = np.ones((1, pos_from.shape[0]))
328+
C_inv = linalg.pinv(np.vstack([G_from, pad_ones]))
329+
interp_mat = np.matmul(G_to_from, C_inv[:, :-1])
330+
331+
# Interpolate bad channels and add average good reference to them
332+
interpolated = np.matmul(interp_mat, data_tmp) + avg_ref
333+
334+
return interpolated
335+
336+
337+
def _eeglab_interpolate_bads(raw):
338+
"""Interpolate bad channels using EEGLAB's custom method.
339+
340+
This method modifies the provided Raw object in place.
341+
342+
Parameters
343+
----------
344+
raw : mne.io.Raw
345+
An MNE Raw object for which channels marked as "bad" should be
346+
interpolated.
347+
348+
Notes
349+
-----
350+
Produces identical results as EEGLAB's ``eeg_interp`` function when using
351+
the default spheric spline method (with minor rounding errors). This method
352+
appears to be loosely based on the same general Perrin et al. (1989) method
353+
as MNE's interpolation, but there are several quirks with the implementation
354+
that cause it to produce fairly different numbers.
355+
356+
"""
357+
# Get the indices of good and bad EEG channels
358+
eeg_chans = mne.pick_types(raw.info, eeg=True, exclude=[])
359+
good_idx = mne.pick_types(raw.info, eeg=True, exclude="bads")
360+
bad_idx = sorted(_set_diff(eeg_chans, good_idx))
361+
362+
# Get the spatial coordinates of the good and bad electrodes
363+
elec_pos = raw._get_channel_positions(picks=eeg_chans)
364+
pos_good = elec_pos[good_idx, :].copy()
365+
pos_bad = elec_pos[bad_idx, :].copy()
366+
_normalize_vectors(pos_good)
367+
_normalize_vectors(pos_bad)
368+
369+
# Interpolate bad channels
370+
interp = _eeglab_interpolate(raw._data[good_idx, :], pos_good, pos_bad)
371+
raw._data[bad_idx, :] = interp
372+
373+
# Clear all bad EEG channels
374+
eeg_bad_names = [raw.info["ch_names"][i] for i in bad_idx]
375+
bads_non_eeg = _set_diff(raw.info["bads"], eeg_bad_names)
376+
raw.info["bads"] = bads_non_eeg
377+
378+
238379
def _get_random_subset(x, size, rand_state):
239380
"""Get a random subset of items from a list or array, without replacement.
240381

0 commit comments

Comments
 (0)