From 9a251bbd73b3732745cca217749e46bbde897cc3 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 5 Dec 2023 12:06:03 -0500 Subject: [PATCH 1/4] ENH: Add polyphase resampling --- doc/Makefile | 2 +- doc/changes/devel.rst | 1 + ...dataset_sgskip.py => spm_faces_dataset.py} | 106 +++-------- examples/decoding/receptive_field_mtrf.py | 4 +- .../source_power_spectrum_opm.py | 8 +- mne/cuda.py | 2 +- mne/filter.py | 164 ++++++++++++------ mne/io/base.py | 22 ++- mne/io/fiff/tests/test_raw_fiff.py | 66 +++---- mne/source_estimate.py | 27 ++- mne/tests/test_filter.py | 39 +++-- mne/tests/test_source_estimate.py | 35 ++-- mne/utils/docs.py | 58 +++++-- .../preprocessing/30_filtering_resampling.py | 47 ++++- 14 files changed, 359 insertions(+), 222 deletions(-) rename examples/datasets/{spm_faces_dataset_sgskip.py => spm_faces_dataset.py} (62%) diff --git a/doc/Makefile b/doc/Makefile index 70d7429f4ad..3c251069045 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -76,6 +76,6 @@ doctest: "results in _build/doctest/output.txt." view: - @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/index.html')" + @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/sg_execution_times.html')" show: view diff --git a/doc/changes/devel.rst b/doc/changes/devel.rst index eaf1cb881ad..fadd872e621 100644 --- a/doc/changes/devel.rst +++ b/doc/changes/devel.rst @@ -36,6 +36,7 @@ Enhancements ~~~~~~~~~~~~ - Speed up export to .edf in :func:`mne.export.export_raw` by using ``edfio`` instead of ``EDFlib-Python`` (:gh:`12218` by :newcontrib:`Florian Hofer`) - We added type hints for the return values of :func:`mne.read_evokeds` and :func:`mne.io.read_raw`. Development environments like VS Code or PyCharm will now provide more help when using these functions in your code. (:gh:`12250` by `Richard Höchenberger`_ and `Eric Larson`_) +- Add ``method="polyphase"`` to :meth:`mne.io.Raw.resample` and related functions to allow resampling using :func:`scipy.signal.upfirdn` (:gh:`12268` by `Eric Larson`_) Bugs ~~~~ diff --git a/examples/datasets/spm_faces_dataset_sgskip.py b/examples/datasets/spm_faces_dataset.py similarity index 62% rename from examples/datasets/spm_faces_dataset_sgskip.py rename to examples/datasets/spm_faces_dataset.py index 1357fc513b6..cf332e5f7d8 100644 --- a/examples/datasets/spm_faces_dataset_sgskip.py +++ b/examples/datasets/spm_faces_dataset.py @@ -5,15 +5,8 @@ From raw data to dSPM on SPM Faces dataset ========================================== -Runs a full pipeline using MNE-Python: - - - artifact removal - - averaging Epochs - - forward model computation - - source reconstruction using dSPM on the contrast : "faces - scrambled" - -.. note:: This example does quite a bit of processing, so even on a - fast machine it can take several minutes to complete. +Runs a full pipeline using MNE-Python. This example does quite a bit of processing, so +even on a fast machine it can take several minutes to complete. """ # Authors: Alexandre Gramfort # Denis Engemann @@ -21,12 +14,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -# %% - -# sphinx_gallery_thumbnail_number = 10 - -import matplotlib.pyplot as plt - import mne from mne import combine_evoked, io from mne.datasets import spm_face @@ -40,109 +27,73 @@ spm_path = data_path / "MEG" / "spm" # %% -# Load and filter data, set up epochs +# Load data, filter it, and fit ICA. raw_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D.ds" - raw = io.read_raw_ctf(raw_fname, preload=True) # Take first run # Here to save memory and time we'll downsample heavily -- this is not # advised for real data as it can effectively jitter events! -raw.resample(120.0, npad="auto") - -picks = mne.pick_types(raw.info, meg=True, exclude="bads") -raw.filter(1, 30, method="fir", fir_design="firwin") +raw.resample(100) +raw.filter(1.0, None) # high-pass +reject = dict(mag=5e-12) +ica = ICA(n_components=0.95, max_iter="auto", random_state=0) +ica.fit(raw, reject=reject) +# compute correlation scores, get bad indices sorted by score +eog_epochs = create_eog_epochs(raw, ch_name="MRT31-2908", reject=reject) +eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name="MRT31-2908") +ica.plot_scores(eog_scores, eog_inds) # see scores the selection is based on +ica.plot_components(eog_inds) # view topographic sensitivity of components +ica.exclude += eog_inds[:1] # we saw the 2nd ECG component looked too dipolar +ica.plot_overlay(eog_epochs.average()) # inspect artifact removal +# %% +# Epoch data and apply ICA. events = mne.find_events(raw, stim_channel="UPPT001") - -# plot the events to get an idea of the paradigm -mne.viz.plot_events(events, raw.info["sfreq"]) - event_ids = {"faces": 1, "scrambled": 2} - tmin, tmax = -0.2, 0.6 baseline = None # no baseline as high-pass is applied -reject = dict(mag=5e-12) - epochs = mne.Epochs( raw, events, event_ids, tmin, tmax, - picks=picks, - baseline=baseline, + picks="meg", + baseline=None, preload=True, reject=reject, ) - -# Fit ICA, find and remove major artifacts -ica = ICA(n_components=0.95, max_iter="auto", random_state=0) -ica.fit(raw, decim=1, reject=reject) - -# compute correlation scores, get bad indices sorted by score -eog_epochs = create_eog_epochs(raw, ch_name="MRT31-2908", reject=reject) -eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name="MRT31-2908") -ica.plot_scores(eog_scores, eog_inds) # see scores the selection is based on -ica.plot_components(eog_inds) # view topographic sensitivity of components -ica.exclude += eog_inds[:1] # we saw the 2nd ECG component looked too dipolar -ica.plot_overlay(eog_epochs.average()) # inspect artifact removal +del raw ica.apply(epochs) # clean data, default in place - evoked = [epochs[k].average() for k in event_ids] - contrast = combine_evoked(evoked, weights=[-1, 1]) # Faces - scrambled - evoked.append(contrast) - for e in evoked: e.plot(ylim=dict(mag=[-400, 400])) -plt.show() - -# estimate noise covarariance -noise_cov = mne.compute_covariance(epochs, tmax=0, method="shrunk", rank=None) - # %% -# Visualize fields on MEG helmet - -# The transformation here was aligned using the dig-montage. It's included in -# the spm_faces dataset and is named SPM_dig_montage.fif. -trans_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D_raw-trans.fif" - -maps = mne.make_field_map( - evoked[0], trans_fname, subject="spm", subjects_dir=subjects_dir, n_jobs=None -) - -evoked[0].plot_field(maps, time=0.170, time_viewer=False) - -# %% -# Look at the whitened evoked daat +# Estimate noise covariance and look at the whitened evoked data +noise_cov = mne.compute_covariance(epochs, tmax=0, method="shrunk", rank=None) evoked[0].plot_white(noise_cov) # %% # Compute forward model +trans_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D_raw-trans.fif" src = subjects_dir / "spm" / "bem" / "spm-oct-6-src.fif" bem = subjects_dir / "spm" / "bem" / "spm-5120-5120-5120-bem-sol.fif" forward = mne.make_forward_solution(contrast.info, trans_fname, src, bem) # %% -# Compute inverse solution +# Compute inverse solution and plot + +# sphinx_gallery_thumbnail_number = 8 snr = 3.0 lambda2 = 1.0 / snr**2 -method = "dSPM" - -inverse_operator = make_inverse_operator( - contrast.info, forward, noise_cov, loose=0.2, depth=0.8 -) - -# Compute inverse solution on contrast -stc = apply_inverse(contrast, inverse_operator, lambda2, method, pick_ori=None) -# stc.save('spm_%s_dSPM_inverse' % contrast.comment) - -# Plot contrast in 3D with mne.viz.Brain if available +inverse_operator = make_inverse_operator(contrast.info, forward, noise_cov) +stc = apply_inverse(contrast, inverse_operator, lambda2, method="dSPM", pick_ori=None) brain = stc.plot( hemi="both", subjects_dir=subjects_dir, @@ -150,4 +101,3 @@ views=["ven"], clim={"kind": "value", "lims": [3.0, 6.0, 9.0]}, ) -# brain.save_image('dSPM_map.png') diff --git a/examples/decoding/receptive_field_mtrf.py b/examples/decoding/receptive_field_mtrf.py index 24b459f192f..1727d0f107c 100644 --- a/examples/decoding/receptive_field_mtrf.py +++ b/examples/decoding/receptive_field_mtrf.py @@ -58,8 +58,8 @@ speech = data["envelope"].T sfreq = float(data["Fs"].item()) sfreq /= decim -speech = mne.filter.resample(speech, down=decim, npad="auto") -raw = mne.filter.resample(raw, down=decim, npad="auto") +speech = mne.filter.resample(speech, down=decim, method="polyphase") +raw = mne.filter.resample(raw, down=decim, method="polyphase") # Read in channel positions and create our MNE objects from the raw data montage = mne.channels.make_standard_montage("biosemi128") diff --git a/examples/time_frequency/source_power_spectrum_opm.py b/examples/time_frequency/source_power_spectrum_opm.py index dd142138784..11168cc08a5 100644 --- a/examples/time_frequency/source_power_spectrum_opm.py +++ b/examples/time_frequency/source_power_spectrum_opm.py @@ -58,16 +58,16 @@ raw_erms = dict() new_sfreq = 60.0 # Nyquist frequency (30 Hz) < line noise freq (50 Hz) raws["vv"] = mne.io.read_raw_fif(vv_fname, verbose="error") # ignore naming -raws["vv"].load_data().resample(new_sfreq) +raws["vv"].load_data().resample(new_sfreq, method="polyphase") raws["vv"].info["bads"] = ["MEG2233", "MEG1842"] raw_erms["vv"] = mne.io.read_raw_fif(vv_erm_fname, verbose="error") -raw_erms["vv"].load_data().resample(new_sfreq) +raw_erms["vv"].load_data().resample(new_sfreq, method="polyphase") raw_erms["vv"].info["bads"] = ["MEG2233", "MEG1842"] raws["opm"] = mne.io.read_raw_fif(opm_fname) -raws["opm"].load_data().resample(new_sfreq) +raws["opm"].load_data().resample(new_sfreq, method="polyphase") raw_erms["opm"] = mne.io.read_raw_fif(opm_erm_fname) -raw_erms["opm"].load_data().resample(new_sfreq) +raw_erms["opm"].load_data().resample(new_sfreq, method="polyphase") # Make sure our assumptions later hold assert raws["opm"].info["sfreq"] == raws["vv"].info["sfreq"] diff --git a/mne/cuda.py b/mne/cuda.py index b4aa7c37bf3..7d7634a6e4e 100644 --- a/mne/cuda.py +++ b/mne/cuda.py @@ -330,7 +330,7 @@ def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, pad="reflect_li Number of samples to remove after resampling. cuda_dict : dict Dictionary constructed using setup_cuda_multiply_repeated(). - %(pad)s + %(pad_resample)s The default is ``'reflect_limited'``. .. versionadded:: 0.15 diff --git a/mne/filter.py b/mne/filter.py index 528128822b8..d4cd6011aa6 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -5,6 +5,7 @@ from collections import Counter from copy import deepcopy from functools import partial +from math import gcd import numpy as np from scipy import fft, signal @@ -1898,12 +1899,13 @@ def resample( x, up=1.0, down=1.0, - npad=100, + *, axis=-1, - window="boxcar", + window="auto", n_jobs=None, - pad="reflect_limited", - *, + pad="auto", + npad=100, + method="fft", verbose=None, ): """Resample an array. @@ -1918,15 +1920,18 @@ def resample( Factor to upsample by. down : float Factor to downsample by. - %(npad)s axis : int Axis along which to resample (default is the last axis). %(window_resample)s %(n_jobs_cuda)s - %(pad)s - The default is ``'reflect_limited'``. + ``n_jobs='cuda'`` is only supported when ``method="fft"``. + %(pad_resample_auto)s .. versionadded:: 0.15 + %(npad_resample)s + %(method_resample)s + + .. versionadded:: 1.7 %(verbose)s Returns @@ -1936,26 +1941,16 @@ def resample( Notes ----- + When using ``method="fft"`` (default), This uses (hopefully) intelligent edge padding and frequency-domain - windowing improve scipy.signal.resample's resampling method, which + windowing improve :func:`scipy.signal.resample`'s resampling method, which we have adapted for our use here. Choices of npad and window have important consequences, and the default choices should work well for most natural signals. - - Resampling arguments are broken into "up" and "down" components for future - compatibility in case we decide to use an upfirdn implementation. The - current implementation is functionally equivalent to passing - up=up/down and down=1. """ - # check explicitly for backwards compatibility - if not isinstance(axis, int): - err = ( - "The axis parameter needs to be an integer (got %s). " - "The axis parameter was missing from this function for a " - "period of time, you might be intending to specify the " - "subsequent window parameter." % repr(axis) - ) - raise TypeError(err) + _validate_type(method, str, "method") + _validate_type(pad, str, "pad") + _check_option("method", method, ("fft", "polyphase")) # make sure our arithmetic will work x = _check_filterable(x, "resampled", "resample") @@ -1963,31 +1958,88 @@ def resample( del up, down if axis < 0: axis = x.ndim + axis - orig_last_axis = x.ndim - 1 - if axis != orig_last_axis: - x = x.swapaxes(axis, orig_last_axis) - orig_shape = x.shape - x_len = orig_shape[-1] - if x_len == 0: - warn("x has zero length along last axis, returning a copy of x") + if x.shape[axis] == 0: + warn(f"x has zero length along axis={axis}, returning a copy of x") return x.copy() - bad_msg = 'npad must be "auto" or an integer' + + # prep for resampling along the last axis (swap axis with last then reshape) + out_shape = list(x.shape) + out_shape.pop(axis) + out_shape.append(final_len) + x = np.atleast_2d(x.swapaxes(axis, -1).reshape((-1, x.shape[axis]))) + + # do the resampling using FFT or polyphase methods + kwargs = dict(pad=pad, window=window, n_jobs=n_jobs) + if method == "fft": + y = _resample_fft(x, npad=npad, ratio=ratio, final_len=final_len, **kwargs) + else: + up, down, kwargs["window"] = _prep_polyphase( + ratio, x.shape[-1], final_len, window + ) + half_len = len(window) // 2 + logger.info( + f"Polyphase resampling locality: ±{half_len} input sample{_pl(half_len)}" + ) + y = _resample_polyphase(x, up=up, down=down, **kwargs) + assert y.shape[-1] == final_len + + # restore dimensions (reshape then swap axis with last) + y = y.reshape(out_shape).swapaxes(axis, -1) + + return y + + +def _prep_polyphase(ratio, x_len, final_len, window): + if isinstance(window, str) and window == "auto": + window = ("kaiser", 5.0) # SciPy default + up = final_len + down = x_len + g_ = gcd(up, down) + up = up // g_ + down = down // g_ + # Figure out our signal locality and design window (adapted from SciPy) + if not isinstance(window, (list, np.ndarray)): + # Design a linear-phase low-pass FIR filter + max_rate = max(up, down) + f_c = 1.0 / max_rate # cutoff of FIR filter (rel. to Nyquist) + half_len = 10 * max_rate # reasonable cutoff for sinc-like function + window = signal.firwin(2 * half_len + 1, f_c, window=window) + return up, down, window + + +def _resample_polyphase(x, *, up, down, pad, window, n_jobs): + if pad == "auto": + pad = "reflect" + kwargs = dict(padtype=pad, window=window, up=up, down=down) + _validate_type( + n_jobs, (None, "int-like"), "n_jobs", extra="when method='polyphase'" + ) + parallel, p_fun, n_jobs = parallel_func(signal.resample_poly, n_jobs) + if n_jobs == 1: + y = signal.resample_poly(x, axis=-1, **kwargs) + else: + y = np.array(parallel(p_fun(x_, **kwargs) for x_ in x)) + return y + + +def _resample_fft(x_flat, *, ratio, final_len, pad, window, npad, n_jobs): + x_len = x_flat.shape[-1] + pad = "reflect_limited" if pad == "auto" else pad + if (isinstance(window, str) and window == "auto") or window is None: + window = "boxcar" if isinstance(npad, str): - if npad != "auto": - raise ValueError(bad_msg) + _check_option("npad", npad, ("auto",), extra="when a string") # Figure out reasonable pad that gets us to a power of 2 min_add = min(x_len // 8, 100) * 2 npad = 2 ** int(np.ceil(np.log2(x_len + min_add))) - x_len npad, extra = divmod(npad, 2) npads = np.array([npad, npad + extra], int) else: - if npad != int(npad): - raise ValueError(bad_msg) + npad = _ensure_int(npad, "npad", extra="or 'auto'") npads = np.array([npad, npad], int) del npad # prep for resampling now - x_flat = x.reshape((-1, x_len)) orig_len = x_len + npads.sum() # length after padding new_len = max(int(round(ratio * orig_len)), 1) # length after resampling to_removes = [int(round(ratio * npads[0]))] @@ -1997,15 +2049,12 @@ def resample( # assert np.abs(to_removes[1] - to_removes[0]) <= int(np.ceil(ratio)) # figure out windowing function - if window is not None: - if callable(window): - W = window(fft.fftfreq(orig_len)) - elif isinstance(window, np.ndarray) and window.shape == (orig_len,): - W = window - else: - W = fft.ifftshift(signal.get_window(window, orig_len)) + if callable(window): + W = window(fft.fftfreq(orig_len)) + elif isinstance(window, np.ndarray) and window.shape == (orig_len,): + W = window else: - W = np.ones(orig_len) + W = fft.ifftshift(signal.get_window(window, orig_len)) W *= float(new_len) / float(orig_len) # figure out if we should use CUDA @@ -2015,7 +2064,7 @@ def resample( # use of the 'flat' window is recommended for minimal ringing parallel, p_fun, n_jobs = parallel_func(_fft_resample, n_jobs) if n_jobs == 1: - y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x.dtype) + y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x_flat.dtype) for xi, x_ in enumerate(x_flat): y[xi] = _fft_resample(x_, new_len, npads, to_removes, cuda_dict, pad) else: @@ -2024,12 +2073,6 @@ def resample( ) y = np.array(y) - # Restore the original array shape (modified for resampling) - y.shape = orig_shape[:-1] + (y.shape[1],) - if axis != orig_last_axis: - y = y.swapaxes(axis, orig_last_axis) - assert y.shape[axis] == final_len - return y @@ -2635,11 +2678,12 @@ def filter( def resample( self, sfreq, + *, npad="auto", - window="boxcar", + window="auto", n_jobs=None, pad="edge", - *, + method="fft", verbose=None, ): """Resample data. @@ -2656,11 +2700,12 @@ def resample( %(npad)s %(window_resample)s %(n_jobs_cuda)s - %(pad)s - The default is ``'edge'``, which pads with the edge values of each - vector. + %(pad_resample)s .. versionadded:: 0.15 + %(method_resample)s + + .. versionadded:: 1.7 %(verbose)s Returns @@ -2691,7 +2736,14 @@ def resample( _check_preload(self, "inst.resample") self._data = resample( - self._data, sfreq, o_sfreq, npad, window=window, n_jobs=n_jobs, pad=pad + self._data, + sfreq, + o_sfreq, + npad=npad, + window=window, + n_jobs=n_jobs, + pad=pad, + method=method, ) lowpass = self.info.get("lowpass") lowpass = np.inf if lowpass is None else lowpass diff --git a/mne/io/base.py b/mne/io/base.py index 95ba7038865..652b747a8ac 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1260,12 +1260,14 @@ def notch_filter( def resample( self, sfreq, + *, npad="auto", - window="boxcar", + window="auto", stim_picks=None, n_jobs=None, events=None, - pad="reflect_limited", + pad="auto", + method="fft", verbose=None, ): """Resample all channels. @@ -1294,7 +1296,7 @@ def resample( ---------- sfreq : float New sample rate to use. - %(npad)s + %(npad_resample)s %(window_resample)s stim_picks : list of int | None Stim channels. These channels are simply subsampled or @@ -1307,10 +1309,12 @@ def resample( An optional event matrix. When specified, the onsets of the events are resampled jointly with the data. NB: The input events are not modified, but a new array is returned with the raw instead. - %(pad)s - The default is ``'reflect_limited'``. + %(pad_resample_auto)s .. versionadded:: 0.15 + %(method_resample)s + + .. versionadded:: 1.7 %(verbose)s Returns @@ -1364,7 +1368,13 @@ def resample( ) kwargs = dict( - up=sfreq, down=o_sfreq, npad=npad, window=window, n_jobs=n_jobs, pad=pad + up=sfreq, + down=o_sfreq, + npad=npad, + window=window, + n_jobs=n_jobs, + pad=pad, + method=method, ) ratio, n_news = zip( *( diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 5c760735800..2c302eac3ad 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -42,6 +42,7 @@ _record_warnings, assert_and_remove_boundary_annot, assert_object_equal, + catch_logging, requires_mne, run_subprocess, ) @@ -1290,23 +1291,28 @@ def test_resample_equiv(): @pytest.mark.slowtest @testing.requires_testing_data @pytest.mark.parametrize( - "preload, n, npad", + "preload, n, npad, method", [ - (True, 512, "auto"), - (False, 512, 0), + (True, 512, "auto", "fft"), + (True, 512, "auto", "polyphase"), + (False, 512, 0, "fft"), # only test one with non-preload because it's slow ], ) -def test_resample(tmp_path, preload, n, npad): +def test_resample(tmp_path, preload, n, npad, method): """Test resample (with I/O and multiple files).""" + kwargs = dict(npad=npad, method=method) raw = read_raw_fif(fif_fname) raw.crop(0, raw.times[n - 1]) + # Reduce to a few MEG channels and a few stim channels to speed up + n_meg = 5 + raw.pick(raw.ch_names[:n_meg] + raw.ch_names[312:320]) # 10 MEG + 3 STIM + 5 EEG assert len(raw.times) == n if preload: raw.load_data() raw_resamp = raw.copy() sfreq = raw.info["sfreq"] # test parallel on upsample - raw_resamp.resample(sfreq * 2, n_jobs=2, npad=npad) + raw_resamp.resample(sfreq * 2, n_jobs=2, **kwargs) assert raw_resamp.n_times == len(raw_resamp.times) raw_resamp.save(tmp_path / "raw_resamp-raw.fif") raw_resamp = read_raw_fif(tmp_path / "raw_resamp-raw.fif", preload=True) @@ -1315,7 +1321,13 @@ def test_resample(tmp_path, preload, n, npad): assert raw_resamp.get_data().shape[1] == raw_resamp.n_times assert raw.get_data().shape[0] == raw_resamp._data.shape[0] # test non-parallel on downsample - raw_resamp.resample(sfreq, n_jobs=None, npad=npad) + with catch_logging() as log: + raw_resamp.resample(sfreq, n_jobs=None, verbose=True, **kwargs) + log = log.getvalue() + if method == "fft": + assert "locality" not in log + else: + assert "locality" in log assert raw_resamp.info["sfreq"] == sfreq assert raw.get_data().shape == raw_resamp._data.shape assert raw.first_samp == raw_resamp.first_samp @@ -1324,18 +1336,12 @@ def test_resample(tmp_path, preload, n, npad): # works (hooray). Note that the stim channels had to be sub-sampled # without filtering to be accurately preserved # note we have to treat MEG and EEG+STIM channels differently (tols) - assert_allclose( - raw.get_data()[:306, 200:-200], - raw_resamp._data[:306, 200:-200], - rtol=1e-2, - atol=1e-12, - ) - assert_allclose( - raw.get_data()[306:, 200:-200], - raw_resamp._data[306:, 200:-200], - rtol=1e-2, - atol=1e-7, - ) + want_meg = raw.get_data()[:n_meg, 200:-200] + got_meg = raw_resamp._data[:n_meg, 200:-200] + want_non_meg = raw.get_data()[n_meg:, 200:-200] + got_non_meg = raw_resamp._data[n_meg:, 200:-200] + assert_allclose(got_meg, want_meg, rtol=1e-2, atol=1e-12) + assert_allclose(want_non_meg, got_non_meg, rtol=1e-2, atol=1e-7) # now check multiple file support w/resampling, as order of operations # (concat, resample) should not affect our data @@ -1344,9 +1350,9 @@ def test_resample(tmp_path, preload, n, npad): raw3 = raw.copy() raw4 = raw.copy() raw1 = concatenate_raws([raw1, raw2]) - raw1.resample(10.0, npad=npad) - raw3.resample(10.0, npad=npad) - raw4.resample(10.0, npad=npad) + raw1.resample(10.0, **kwargs) + raw3.resample(10.0, **kwargs) + raw4.resample(10.0, **kwargs) raw3 = concatenate_raws([raw3, raw4]) assert_array_equal(raw1._data, raw3._data) assert_array_equal(raw1._first_samps, raw3._first_samps) @@ -1364,12 +1370,12 @@ def test_resample(tmp_path, preload, n, npad): # basic decimation stim = [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) - assert_allclose(raw.resample(8.0, npad=npad)._data, [[1, 1, 0, 0, 1, 1, 0, 0]]) + assert_allclose(raw.resample(8.0, **kwargs)._data, [[1, 1, 0, 0, 1, 1, 0, 0]]) # decimation of multiple stim channels raw = RawArray(2 * [stim], create_info(2, len(stim), 2 * ["stim"])) assert_allclose( - raw.resample(8.0, npad=npad, verbose="error")._data, + raw.resample(8.0, **kwargs, verbose="error")._data, [[1, 1, 0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1, 0, 0]], ) @@ -1377,19 +1383,19 @@ def test_resample(tmp_path, preload, n, npad): # done naively stim = [0, 0, 0, 1, 1, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) - assert_allclose(raw.resample(4.0, npad=npad)._data, [[0, 1, 1, 0]]) + assert_allclose(raw.resample(4.0, **kwargs)._data, [[0, 1, 1, 0]]) # two events are merged in this case (warning) stim = [0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) with pytest.warns(RuntimeWarning, match="become unreliable"): - raw.resample(8.0, npad=npad) + raw.resample(8.0, **kwargs) # events are dropped in this case (warning) stim = [0, 1, 1, 0, 0, 1, 1, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) with pytest.warns(RuntimeWarning, match="become unreliable"): - raw.resample(4.0, npad=npad) + raw.resample(4.0, **kwargs) # test resampling events: this should no longer give a warning # we often have first_samp != 0, include it here too @@ -1400,7 +1406,7 @@ def test_resample(tmp_path, preload, n, npad): first_samp = len(stim) // 2 raw = RawArray([stim], create_info(1, o_sfreq, ["stim"]), first_samp=first_samp) events = find_events(raw) - raw, events = raw.resample(n_sfreq, events=events, npad=npad) + raw, events = raw.resample(n_sfreq, events=events, **kwargs) # Try index into raw.times with resampled events: raw.times[events[:, 0] - raw.first_samp] n_fsamp = int(first_samp * sfreq_ratio) # how it's calc'd in base.py @@ -1425,16 +1431,16 @@ def test_resample(tmp_path, preload, n, npad): # test copy flag stim = [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) - raw_resampled = raw.copy().resample(4.0, npad=npad) + raw_resampled = raw.copy().resample(4.0, **kwargs) assert raw_resampled is not raw - raw_resampled = raw.resample(4.0, npad=npad) + raw_resampled = raw.resample(4.0, **kwargs) assert raw_resampled is raw # resample should still work even when no stim channel is present raw = RawArray(np.random.randn(1, 100), create_info(1, 100, ["eeg"])) with raw.info._unlock(): raw.info["lowpass"] = 50.0 - raw.resample(10, npad=npad) + raw.resample(10, **kwargs) assert raw.info["lowpass"] == 5.0 assert len(raw) == 10 diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 213d00e5baa..50734817431 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -819,7 +819,17 @@ def crop(self, tmin=None, tmax=None, include_tmax=True): return self # return self for chaining methods @verbose - def resample(self, sfreq, npad="auto", window="boxcar", n_jobs=None, verbose=None): + def resample( + self, + sfreq, + *, + npad=100, + method="fft", + window="auto", + pad="auto", + n_jobs=None, + verbose=None, + ): """Resample data. If appropriate, an anti-aliasing filter is applied before resampling. @@ -833,8 +843,15 @@ def resample(self, sfreq, npad="auto", window="boxcar", n_jobs=None, verbose=Non Amount to pad the start and end of the data. Can also be "auto" to use a padding that will result in a power-of-two size (can be much faster). - window : str | tuple - Window to use in resampling. See :func:`scipy.signal.resample`. + %(method_resample)s + + .. versionadded:: 1.7 + %(window_resample)s + + .. versionadded:: 1.7 + %(pad_resample_auto)s + + .. versionadded:: 1.7 %(n_jobs)s %(verbose)s @@ -863,7 +880,9 @@ def resample(self, sfreq, npad="auto", window="boxcar", n_jobs=None, verbose=Non data = self.data if data.dtype == np.float32: data = data.astype(np.float64) - self.data = resample(data, sfreq, o_sfreq, npad, n_jobs=n_jobs) + self.data = resample( + data, sfreq, o_sfreq, npad=npad, window=window, n_jobs=n_jobs, method=method + ) # adjust indirectly affected variables self.tstep = 1.0 / sfreq diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py index 110a8f136c3..3ab60dba055 100644 --- a/mne/tests/test_filter.py +++ b/mne/tests/test_filter.py @@ -32,6 +32,8 @@ from mne.io import RawArray, read_raw_fif from mne.utils import catch_logging, requires_mne, run_subprocess, sum_squared +resample_method_parametrize = pytest.mark.parametrize("method", ("fft", "polyphase")) + def test_filter_array(): """Test filtering an array.""" @@ -372,20 +374,27 @@ def test_notch_filters(method, filter_length, line_freq, tol): assert_almost_equal(new_power, orig_power, tol) -def test_resample(): +@resample_method_parametrize +def test_resample(method): """Test resampling.""" rng = np.random.RandomState(0) x = rng.normal(0, 1, (10, 10, 10)) - x_rs = resample(x, 1, 2, 10) + with catch_logging() as log: + x_rs = resample(x, 1, 2, npad=10, method=method, verbose=True) + log = log.getvalue() + if method == "fft": + assert "locality" not in log + else: + assert "locality" in log assert x.shape == (10, 10, 10) assert x_rs.shape == (10, 10, 5) x_2 = x.swapaxes(0, 1) - x_2_rs = resample(x_2, 1, 2, 10) + x_2_rs = resample(x_2, 1, 2, npad=10, method=method) assert_array_equal(x_2_rs.swapaxes(0, 1), x_rs) x_3 = x.swapaxes(0, 2) - x_3_rs = resample(x_3, 1, 2, 10, 0) + x_3_rs = resample(x_3, 1, 2, npad=10, axis=0, method=method) assert_array_equal(x_3_rs.swapaxes(0, 2), x_rs) # make sure we cast to array if necessary @@ -401,12 +410,12 @@ def test_resample_scipy(): err_msg = "%s: %s" % (N, window) x_2_sp = sp_resample(x, 2 * N, window=window) for n_jobs in n_jobs_test: - x_2 = resample(x, 2, 1, 0, window=window, n_jobs=n_jobs) + x_2 = resample(x, 2, 1, npad=0, window=window, n_jobs=n_jobs) assert_allclose(x_2, x_2_sp, atol=1e-12, err_msg=err_msg) new_len = int(round(len(x) * (1.0 / 2.0))) x_p5_sp = sp_resample(x, new_len, window=window) for n_jobs in n_jobs_test: - x_p5 = resample(x, 1, 2, 0, window=window, n_jobs=n_jobs) + x_p5 = resample(x, 1, 2, npad=0, window=window, n_jobs=n_jobs) assert_allclose(x_p5, x_p5_sp, atol=1e-12, err_msg=err_msg) @@ -450,23 +459,25 @@ def test_resamp_stim_channel(): assert new_data.shape[1] == new_data_len -def test_resample_raw(): +@resample_method_parametrize +def test_resample_raw(method): """Test resampling using RawArray.""" x = np.zeros((1, 1001)) sfreq = 2048.0 raw = RawArray(x, create_info(1, sfreq, "eeg")) - raw.resample(128, npad=10) + raw.resample(128, npad=10, method=method) data = raw.get_data() assert data.shape == (1, 63) -def test_resample_below_1_sample(): +@resample_method_parametrize +def test_resample_below_1_sample(method): """Test resampling doesn't yield datapoints.""" # Raw x = np.zeros((1, 100)) sfreq = 1000.0 raw = RawArray(x, create_info(1, sfreq, "eeg")) - raw.resample(5) + raw.resample(5, method=method) assert len(raw.times) == 1 assert raw.get_data().shape[1] == 1 @@ -487,7 +498,13 @@ def test_resample_below_1_sample(): preload=True, verbose=False, ) - epochs.resample(1) + with catch_logging() as log: + epochs.resample(1, method=method, verbose=True) + log = log.getvalue() + if method == "fft": + assert "locality" not in log + else: + assert "locality" in log assert len(epochs.times) == 1 assert epochs.get_data(copy=False).shape[2] == 1 diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 8c9e7df9389..ae893f3c108 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -558,18 +558,25 @@ def test_stc_arithmetic(): @pytest.mark.slowtest @testing.requires_testing_data -def test_stc_methods(): +@pytest.mark.parametrize("kind", ("scalar", "vector")) +@pytest.mark.parametrize("method", ("fft", "polyphase")) +def test_stc_methods(kind, method): """Test stc methods lh_data, rh_data, bin(), resample().""" stc_ = read_source_estimate(fname_stc) - # Make a vector version of the above source estimate - x = stc_.data[:, np.newaxis, :] - yz = np.zeros((x.shape[0], 2, x.shape[2])) - vec_stc_ = VectorSourceEstimate( - np.concatenate((x, yz), 1), stc_.vertices, stc_.tmin, stc_.tstep, stc_.subject - ) + if kind == "vector": + # Make a vector version of the above source estimate + x = stc_.data[:, np.newaxis, :] + yz = np.zeros((x.shape[0], 2, x.shape[2])) + stc_ = VectorSourceEstimate( + np.concatenate((x, yz), 1), + stc_.vertices, + stc_.tmin, + stc_.tstep, + stc_.subject, + ) - for stc in [stc_, vec_stc_]: + for stc in [stc_]: # noop to keep diff small # lh_data / rh_data assert_array_equal(stc.lh_data, stc.data[: len(stc.lh_vertno)]) assert_array_equal(stc.rh_data, stc.data[len(stc.lh_vertno) :]) @@ -606,13 +613,19 @@ def test_stc_methods(): stc_new = deepcopy(stc) o_sfreq = 1.0 / stc.tstep # note that using no padding for this STC reduces edge ringing... - stc_new.resample(2 * o_sfreq, npad=0) + stc_new.resample(2 * o_sfreq, npad=0, method=method) assert stc_new.data.shape[1] == 2 * stc.data.shape[1] assert stc_new.tstep == stc.tstep / 2 - stc_new.resample(o_sfreq, npad=0) + stc_new.resample(o_sfreq, npad=0, method=method) assert stc_new.data.shape[1] == stc.data.shape[1] assert stc_new.tstep == stc.tstep - assert_array_almost_equal(stc_new.data, stc.data, 5) + if method == "fft": + # no low-passing so survives round-trip + assert_allclose(stc_new.data, stc.data, atol=1e-5) + else: + # low-passing means we need something more flexible + corr = np.corrcoef(stc_new.data.ravel(), stc.data.ravel())[0, 1] + assert 0.99 < corr < 1 @testing.requires_testing_data diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 806d774f221..6d26d01dc40 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2245,6 +2245,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["method_psd"] = _method_psd.format("", "") docdict["method_psd_auto"] = _method_psd.format(" | ``'auto'``", "") +docdict["method_resample"] = """ +method : str + Resampling method to use. Can be ``"fft"`` (default) or ``"polyphase"`` + to use FFT-based on polyphase FIR resampling, respectively. These wrap to + :func:`scipy.signal.resample` and :func:`scipy.signal.resample_poly`, respectively. +""" + docdict["mode_eltc"] = """ mode : str Extraction mode, see Notes. @@ -2488,11 +2495,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["npad"] = """ npad : int | str - Amount to pad the start and end of the data. - Can also be ``"auto"`` to use a padding that will result in - a power-of-two size (can be much faster). + Amount to pad the start and end of the data. Can also be ``"auto"`` to use a padding + that will result in a power-of-two size (can be much faster). """ +docdict["npad_resample"] = ( + docdict["npad"] + + """ + Only used when ``method="fft"``. +""" +) docdict["nrows_ncols_ica_components"] = """ nrows, ncols : int | 'auto' The number of rows and columns of topographies to plot. If both ``nrows`` @@ -2698,22 +2710,38 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # P _pad_base = """ -pad : str - The type of padding to use. Supports all :func:`numpy.pad` ``mode`` - options. Can also be ``"reflect_limited"``, which pads with a - reflected version of each vector mirrored on the first and last values + all :func:`numpy.pad` ``mode`` options. Can also be ``"reflect_limited"``, which + pads with a reflected version of each vector mirrored on the first and last values of the vector, followed by zeros. """ -docdict["pad"] = _pad_base - docdict["pad_fir"] = ( - _pad_base - + """ + """ +pad : str + The type of padding to use. Supports """ + + _pad_base + + """\ Only used for ``method='fir'``. """ ) +docdict["pad_resample"] = ( # used when default is not "auto" + """ +pad : str + The type of padding to use. When ``method="fft"``, supports """ + + _pad_base + + """\ + When ``method="polyphase"``, supports all modes of :func:`scipy.signal.upfirdn`. +""" +) + +docdict["pad_resample_auto"] = ( # used when default is "auto" + docdict["pad_resample"] + + """\ + The default ("auto") means ``'reflect_limited'`` for ``method='fft'`` and + ``'reflect'`` for ``method='polyphase'``. +""" +) docdict["pca_vars_pctf"] = """ pca_vars : array, shape (n_comp,) | list of array The explained variances of the first n_comp SVD components across the @@ -4331,8 +4359,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["window_resample"] = """ window : str | tuple - Frequency-domain window to use in resampling. - See :func:`scipy.signal.resample`. + When ``method="fft"``, this is the *frequency-domain* window to use in resampling, + and should be the same length as the signal; see :func:`scipy.signal.resample` + for details. When ``method="polyphase"``, this is the *time-domain* linear-phase + window to use after upsampling the signal; see :func:`scipy.signal.resample_poly` + for details. The default ``"auto"`` will use ``"boxcar"`` for ``method="fft"`` and + ``("kaiser", 5.0)`` for ``method="polyphase"``. """ # %% diff --git a/tutorials/preprocessing/30_filtering_resampling.py b/tutorials/preprocessing/30_filtering_resampling.py index 530b92741f6..ed1df059d32 100644 --- a/tutorials/preprocessing/30_filtering_resampling.py +++ b/tutorials/preprocessing/30_filtering_resampling.py @@ -206,16 +206,53 @@ def add_arrows(axes): # frequency`_ of the desired new sampling rate. This can be clearly seen in the # PSD plot, where a dashed vertical line indicates the filter cutoff; the # original data had an existing lowpass at around 172 Hz (see -# ``raw.info['lowpass']``), and the data resampled from 600 Hz to 200 Hz gets +# ``raw.info['lowpass']``), and the data resampled from ~600 Hz to 200 Hz gets # automatically lowpass filtered at 100 Hz (the `Nyquist frequency`_ for a # target rate of 200 Hz): raw_downsampled = raw.copy().resample(sfreq=200) +# choose n_fft for Welch PSD to make frequency axes similar resolution +n_ffts = [4096, int(round(4096 * 200 / raw.info["sfreq"]))] +fig, axes = plt.subplots(2, 1, sharey=True, layout="constrained", figsize=(10, 6)) +for ax, data, title, n_fft in zip( + axes, [raw, raw_downsampled], ["Original", "Downsampled"], n_ffts +): + fig = data.compute_psd(n_fft=n_fft).plot( + average=True, picks="data", exclude="bads", axes=ax + ) + ax.set(title=title, xlim=(0, 300)) -for data, title in zip([raw, raw_downsampled], ["Original", "Downsampled"]): - fig = data.compute_psd().plot(average=True, picks="data", exclude="bads") - fig.suptitle(title) - plt.setp(fig.axes, xlim=(0, 300)) +# %% +# By default, MNE-Python resamples using ``method="fft"``, which performs FFT-based +# resampling via :func:`scipy.signal.resample`. While efficient and good for most +# biological signals, it has two main potential drawbacks: +# +# 1. It assumes periodicity of the signal. We try to overcome this with appropriate +# signal padding, but some signal leakage may still occur. +# 2. It treats the entire signal as a single block. This means that in general effects +# are not guaranteed to be localized in time, though in practice they often are. +# +# Alternatively, resampling can be performed using ``method="polyphase"`` instead. +# This uses :func:`scipy.signal.resample_poly` under the hood, which in turn utilizes +# a three-step process to resample signals (see :func:`scipy.signal.upfirdn` for +# details). This process guarantees that each resampled output value is only affected by +# input values within a limited range. In other words, output values are guaranteed to +# be a result of a specific set of input values. +# +# In general, using ``method="polyphase"`` can also be faster than ``method="fft"`` in +# cases where the desired sampling rate is an integer factor different from the input +# sampling rate. For example: + +n_ffts = [4096, 2048] # factor of 2 smaller n_fft +raw_downsampled_poly = raw.copy().resample(sfreq=raw.info["sfreq"] / 2.0, verbose=True) +fig, axes = plt.subplots(2, 1, sharey=True, layout="constrained", figsize=(10, 6)) +for ax, data, title, n_fft in zip( + axes, [raw, raw_downsampled_poly], ["Original", "Downsampled (polyphase)"], n_ffts +): + data.compute_psd(n_fft=n_fft).plot( + average=True, picks="data", exclude="bads", axes=ax + ) + ax.set(title=title, xlim=(0, 300)) # %% # Because resampling involves filtering, there are some pitfalls to resampling From 5628c72ccc4068e1b3ef3bda1ce9fcb8ad83497c Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 5 Dec 2023 20:18:53 -0500 Subject: [PATCH 2/4] FIX: Actually use polyphase --- tutorials/preprocessing/30_filtering_resampling.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tutorials/preprocessing/30_filtering_resampling.py b/tutorials/preprocessing/30_filtering_resampling.py index ed1df059d32..6c118c99180 100644 --- a/tutorials/preprocessing/30_filtering_resampling.py +++ b/tutorials/preprocessing/30_filtering_resampling.py @@ -243,8 +243,14 @@ def add_arrows(axes): # cases where the desired sampling rate is an integer factor different from the input # sampling rate. For example: +# sphinx_gallery_thumbnail_number = 11 + n_ffts = [4096, 2048] # factor of 2 smaller n_fft -raw_downsampled_poly = raw.copy().resample(sfreq=raw.info["sfreq"] / 2.0, verbose=True) +raw_downsampled_poly = raw.copy().resample( + sfreq=raw.info["sfreq"] / 2.0, + method="polyphase", + verbose=True, +) fig, axes = plt.subplots(2, 1, sharey=True, layout="constrained", figsize=(10, 6)) for ax, data, title, n_fft in zip( axes, [raw, raw_downsampled_poly], ["Original", "Downsampled (polyphase)"], n_ffts From 90dee7acfd54d14862eefc7f92b98668d561aa8d Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 6 Dec 2023 11:56:57 -0500 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Daniel McCloy --- examples/datasets/spm_faces_dataset.py | 1 - mne/filter.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/datasets/spm_faces_dataset.py b/examples/datasets/spm_faces_dataset.py index cf332e5f7d8..32df7d1a9ed 100644 --- a/examples/datasets/spm_faces_dataset.py +++ b/examples/datasets/spm_faces_dataset.py @@ -51,7 +51,6 @@ events = mne.find_events(raw, stim_channel="UPPT001") event_ids = {"faces": 1, "scrambled": 2} tmin, tmax = -0.2, 0.6 -baseline = None # no baseline as high-pass is applied epochs = mne.Epochs( raw, events, diff --git a/mne/filter.py b/mne/filter.py index d4cd6011aa6..3d9b3ecc7da 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -1942,7 +1942,7 @@ def resample( Notes ----- When using ``method="fft"`` (default), - This uses (hopefully) intelligent edge padding and frequency-domain + this uses (hopefully) intelligent edge padding and frequency-domain windowing improve :func:`scipy.signal.resample`'s resampling method, which we have adapted for our use here. Choices of npad and window have important consequences, and the default choices should work well From 335d6c86de6a092a613135fa635cb6af6c96aabb Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 6 Dec 2023 11:59:12 -0500 Subject: [PATCH 4/4] FIX: Move --- examples/decoding/receptive_field_mtrf.py | 7 +- mne/tests/test_source_estimate.py | 109 +++++++++++----------- 2 files changed, 57 insertions(+), 59 deletions(-) diff --git a/examples/decoding/receptive_field_mtrf.py b/examples/decoding/receptive_field_mtrf.py index 1727d0f107c..8dc04630753 100644 --- a/examples/decoding/receptive_field_mtrf.py +++ b/examples/decoding/receptive_field_mtrf.py @@ -17,7 +17,7 @@ .. _figure 1: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F1 .. _figure 2: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F2 .. _figure 5: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F5 -""" # noqa: E501 +""" # Authors: Chris Holdgraf # Eric Larson @@ -26,9 +26,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -# %% -# sphinx_gallery_thumbnail_number = 3 - from os.path import join import matplotlib.pyplot as plt @@ -131,6 +128,8 @@ # across the scalp. We will recreate `figure 1`_ and `figure 2`_ from # :footcite:`CrosseEtAl2016`. +# sphinx_gallery_thumbnail_number = 3 + # Print mean coefficients across all time delays / channels (see Fig 1) time_plot = 0.180 # For highlighting a specific time. fig, ax = plt.subplots(figsize=(4, 8), layout="constrained") diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index ae893f3c108..be31fd1501b 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -562,70 +562,69 @@ def test_stc_arithmetic(): @pytest.mark.parametrize("method", ("fft", "polyphase")) def test_stc_methods(kind, method): """Test stc methods lh_data, rh_data, bin(), resample().""" - stc_ = read_source_estimate(fname_stc) + stc = read_source_estimate(fname_stc) if kind == "vector": # Make a vector version of the above source estimate - x = stc_.data[:, np.newaxis, :] + x = stc.data[:, np.newaxis, :] yz = np.zeros((x.shape[0], 2, x.shape[2])) - stc_ = VectorSourceEstimate( + stc = VectorSourceEstimate( np.concatenate((x, yz), 1), - stc_.vertices, - stc_.tmin, - stc_.tstep, - stc_.subject, + stc.vertices, + stc.tmin, + stc.tstep, + stc.subject, ) - for stc in [stc_]: # noop to keep diff small - # lh_data / rh_data - assert_array_equal(stc.lh_data, stc.data[: len(stc.lh_vertno)]) - assert_array_equal(stc.rh_data, stc.data[len(stc.lh_vertno) :]) + # lh_data / rh_data + assert_array_equal(stc.lh_data, stc.data[: len(stc.lh_vertno)]) + assert_array_equal(stc.rh_data, stc.data[len(stc.lh_vertno) :]) - # bin - binned = stc.bin(0.12) - a = np.mean(stc.data[..., : np.searchsorted(stc.times, 0.12)], axis=-1) - assert_array_equal(a, binned.data[..., 0]) + # bin + binned = stc.bin(0.12) + a = np.mean(stc.data[..., : np.searchsorted(stc.times, 0.12)], axis=-1) + assert_array_equal(a, binned.data[..., 0]) - stc = read_source_estimate(fname_stc) - stc.subject = "sample" - label_lh = read_labels_from_annot( - "sample", "aparc", "lh", subjects_dir=subjects_dir - )[0] - label_rh = read_labels_from_annot( - "sample", "aparc", "rh", subjects_dir=subjects_dir - )[0] - label_both = label_lh + label_rh - for label in (label_lh, label_rh, label_both): - assert isinstance(stc.shape, tuple) and len(stc.shape) == 2 - stc_label = stc.in_label(label) - if label.hemi != "both": - if label.hemi == "lh": - verts = stc_label.vertices[0] - else: # label.hemi == 'rh': - verts = stc_label.vertices[1] - n_vertices_used = len(label.get_vertices_used(verts)) - assert_equal(len(stc_label.data), n_vertices_used) - stc_lh = stc.in_label(label_lh) - pytest.raises(ValueError, stc_lh.in_label, label_rh) - label_lh.subject = "foo" - pytest.raises(RuntimeError, stc.in_label, label_lh) - - stc_new = deepcopy(stc) - o_sfreq = 1.0 / stc.tstep - # note that using no padding for this STC reduces edge ringing... - stc_new.resample(2 * o_sfreq, npad=0, method=method) - assert stc_new.data.shape[1] == 2 * stc.data.shape[1] - assert stc_new.tstep == stc.tstep / 2 - stc_new.resample(o_sfreq, npad=0, method=method) - assert stc_new.data.shape[1] == stc.data.shape[1] - assert stc_new.tstep == stc.tstep - if method == "fft": - # no low-passing so survives round-trip - assert_allclose(stc_new.data, stc.data, atol=1e-5) - else: - # low-passing means we need something more flexible - corr = np.corrcoef(stc_new.data.ravel(), stc.data.ravel())[0, 1] - assert 0.99 < corr < 1 + stc = read_source_estimate(fname_stc) + stc.subject = "sample" + label_lh = read_labels_from_annot( + "sample", "aparc", "lh", subjects_dir=subjects_dir + )[0] + label_rh = read_labels_from_annot( + "sample", "aparc", "rh", subjects_dir=subjects_dir + )[0] + label_both = label_lh + label_rh + for label in (label_lh, label_rh, label_both): + assert isinstance(stc.shape, tuple) and len(stc.shape) == 2 + stc_label = stc.in_label(label) + if label.hemi != "both": + if label.hemi == "lh": + verts = stc_label.vertices[0] + else: # label.hemi == 'rh': + verts = stc_label.vertices[1] + n_vertices_used = len(label.get_vertices_used(verts)) + assert_equal(len(stc_label.data), n_vertices_used) + stc_lh = stc.in_label(label_lh) + pytest.raises(ValueError, stc_lh.in_label, label_rh) + label_lh.subject = "foo" + pytest.raises(RuntimeError, stc.in_label, label_lh) + + stc_new = deepcopy(stc) + o_sfreq = 1.0 / stc.tstep + # note that using no padding for this STC reduces edge ringing... + stc_new.resample(2 * o_sfreq, npad=0, method=method) + assert stc_new.data.shape[1] == 2 * stc.data.shape[1] + assert stc_new.tstep == stc.tstep / 2 + stc_new.resample(o_sfreq, npad=0, method=method) + assert stc_new.data.shape[1] == stc.data.shape[1] + assert stc_new.tstep == stc.tstep + if method == "fft": + # no low-passing so survives round-trip + assert_allclose(stc_new.data, stc.data, atol=1e-5) + else: + # low-passing means we need something more flexible + corr = np.corrcoef(stc_new.data.ravel(), stc.data.ravel())[0, 1] + assert 0.99 < corr < 1 @testing.requires_testing_data