diff --git a/doc/Makefile b/doc/Makefile index 70d7429f4ad..3c251069045 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -76,6 +76,6 @@ doctest: "results in _build/doctest/output.txt." view: - @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/index.html')" + @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/sg_execution_times.html')" show: view diff --git a/doc/changes/devel.rst b/doc/changes/devel.rst index eaf1cb881ad..fadd872e621 100644 --- a/doc/changes/devel.rst +++ b/doc/changes/devel.rst @@ -36,6 +36,7 @@ Enhancements ~~~~~~~~~~~~ - Speed up export to .edf in :func:`mne.export.export_raw` by using ``edfio`` instead of ``EDFlib-Python`` (:gh:`12218` by :newcontrib:`Florian Hofer`) - We added type hints for the return values of :func:`mne.read_evokeds` and :func:`mne.io.read_raw`. Development environments like VS Code or PyCharm will now provide more help when using these functions in your code. (:gh:`12250` by `Richard Höchenberger`_ and `Eric Larson`_) +- Add ``method="polyphase"`` to :meth:`mne.io.Raw.resample` and related functions to allow resampling using :func:`scipy.signal.upfirdn` (:gh:`12268` by `Eric Larson`_) Bugs ~~~~ diff --git a/examples/datasets/spm_faces_dataset_sgskip.py b/examples/datasets/spm_faces_dataset.py similarity index 60% rename from examples/datasets/spm_faces_dataset_sgskip.py rename to examples/datasets/spm_faces_dataset.py index 1357fc513b6..32df7d1a9ed 100644 --- a/examples/datasets/spm_faces_dataset_sgskip.py +++ b/examples/datasets/spm_faces_dataset.py @@ -5,15 +5,8 @@ From raw data to dSPM on SPM Faces dataset ========================================== -Runs a full pipeline using MNE-Python: - - - artifact removal - - averaging Epochs - - forward model computation - - source reconstruction using dSPM on the contrast : "faces - scrambled" - -.. note:: This example does quite a bit of processing, so even on a - fast machine it can take several minutes to complete. +Runs a full pipeline using MNE-Python. This example does quite a bit of processing, so +even on a fast machine it can take several minutes to complete. """ # Authors: Alexandre Gramfort # Denis Engemann @@ -21,12 +14,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -# %% - -# sphinx_gallery_thumbnail_number = 10 - -import matplotlib.pyplot as plt - import mne from mne import combine_evoked, io from mne.datasets import spm_face @@ -40,109 +27,72 @@ spm_path = data_path / "MEG" / "spm" # %% -# Load and filter data, set up epochs +# Load data, filter it, and fit ICA. raw_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D.ds" - raw = io.read_raw_ctf(raw_fname, preload=True) # Take first run # Here to save memory and time we'll downsample heavily -- this is not # advised for real data as it can effectively jitter events! -raw.resample(120.0, npad="auto") - -picks = mne.pick_types(raw.info, meg=True, exclude="bads") -raw.filter(1, 30, method="fir", fir_design="firwin") +raw.resample(100) +raw.filter(1.0, None) # high-pass +reject = dict(mag=5e-12) +ica = ICA(n_components=0.95, max_iter="auto", random_state=0) +ica.fit(raw, reject=reject) +# compute correlation scores, get bad indices sorted by score +eog_epochs = create_eog_epochs(raw, ch_name="MRT31-2908", reject=reject) +eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name="MRT31-2908") +ica.plot_scores(eog_scores, eog_inds) # see scores the selection is based on +ica.plot_components(eog_inds) # view topographic sensitivity of components +ica.exclude += eog_inds[:1] # we saw the 2nd ECG component looked too dipolar +ica.plot_overlay(eog_epochs.average()) # inspect artifact removal +# %% +# Epoch data and apply ICA. events = mne.find_events(raw, stim_channel="UPPT001") - -# plot the events to get an idea of the paradigm -mne.viz.plot_events(events, raw.info["sfreq"]) - event_ids = {"faces": 1, "scrambled": 2} - tmin, tmax = -0.2, 0.6 -baseline = None # no baseline as high-pass is applied -reject = dict(mag=5e-12) - epochs = mne.Epochs( raw, events, event_ids, tmin, tmax, - picks=picks, - baseline=baseline, + picks="meg", + baseline=None, preload=True, reject=reject, ) - -# Fit ICA, find and remove major artifacts -ica = ICA(n_components=0.95, max_iter="auto", random_state=0) -ica.fit(raw, decim=1, reject=reject) - -# compute correlation scores, get bad indices sorted by score -eog_epochs = create_eog_epochs(raw, ch_name="MRT31-2908", reject=reject) -eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name="MRT31-2908") -ica.plot_scores(eog_scores, eog_inds) # see scores the selection is based on -ica.plot_components(eog_inds) # view topographic sensitivity of components -ica.exclude += eog_inds[:1] # we saw the 2nd ECG component looked too dipolar -ica.plot_overlay(eog_epochs.average()) # inspect artifact removal +del raw ica.apply(epochs) # clean data, default in place - evoked = [epochs[k].average() for k in event_ids] - contrast = combine_evoked(evoked, weights=[-1, 1]) # Faces - scrambled - evoked.append(contrast) - for e in evoked: e.plot(ylim=dict(mag=[-400, 400])) -plt.show() - -# estimate noise covarariance -noise_cov = mne.compute_covariance(epochs, tmax=0, method="shrunk", rank=None) - # %% -# Visualize fields on MEG helmet - -# The transformation here was aligned using the dig-montage. It's included in -# the spm_faces dataset and is named SPM_dig_montage.fif. -trans_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D_raw-trans.fif" - -maps = mne.make_field_map( - evoked[0], trans_fname, subject="spm", subjects_dir=subjects_dir, n_jobs=None -) - -evoked[0].plot_field(maps, time=0.170, time_viewer=False) - -# %% -# Look at the whitened evoked daat +# Estimate noise covariance and look at the whitened evoked data +noise_cov = mne.compute_covariance(epochs, tmax=0, method="shrunk", rank=None) evoked[0].plot_white(noise_cov) # %% # Compute forward model +trans_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D_raw-trans.fif" src = subjects_dir / "spm" / "bem" / "spm-oct-6-src.fif" bem = subjects_dir / "spm" / "bem" / "spm-5120-5120-5120-bem-sol.fif" forward = mne.make_forward_solution(contrast.info, trans_fname, src, bem) # %% -# Compute inverse solution +# Compute inverse solution and plot + +# sphinx_gallery_thumbnail_number = 8 snr = 3.0 lambda2 = 1.0 / snr**2 -method = "dSPM" - -inverse_operator = make_inverse_operator( - contrast.info, forward, noise_cov, loose=0.2, depth=0.8 -) - -# Compute inverse solution on contrast -stc = apply_inverse(contrast, inverse_operator, lambda2, method, pick_ori=None) -# stc.save('spm_%s_dSPM_inverse' % contrast.comment) - -# Plot contrast in 3D with mne.viz.Brain if available +inverse_operator = make_inverse_operator(contrast.info, forward, noise_cov) +stc = apply_inverse(contrast, inverse_operator, lambda2, method="dSPM", pick_ori=None) brain = stc.plot( hemi="both", subjects_dir=subjects_dir, @@ -150,4 +100,3 @@ views=["ven"], clim={"kind": "value", "lims": [3.0, 6.0, 9.0]}, ) -# brain.save_image('dSPM_map.png') diff --git a/examples/decoding/receptive_field_mtrf.py b/examples/decoding/receptive_field_mtrf.py index 24b459f192f..8dc04630753 100644 --- a/examples/decoding/receptive_field_mtrf.py +++ b/examples/decoding/receptive_field_mtrf.py @@ -17,7 +17,7 @@ .. _figure 1: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F1 .. _figure 2: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F2 .. _figure 5: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F5 -""" # noqa: E501 +""" # Authors: Chris Holdgraf # Eric Larson @@ -26,9 +26,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -# %% -# sphinx_gallery_thumbnail_number = 3 - from os.path import join import matplotlib.pyplot as plt @@ -58,8 +55,8 @@ speech = data["envelope"].T sfreq = float(data["Fs"].item()) sfreq /= decim -speech = mne.filter.resample(speech, down=decim, npad="auto") -raw = mne.filter.resample(raw, down=decim, npad="auto") +speech = mne.filter.resample(speech, down=decim, method="polyphase") +raw = mne.filter.resample(raw, down=decim, method="polyphase") # Read in channel positions and create our MNE objects from the raw data montage = mne.channels.make_standard_montage("biosemi128") @@ -131,6 +128,8 @@ # across the scalp. We will recreate `figure 1`_ and `figure 2`_ from # :footcite:`CrosseEtAl2016`. +# sphinx_gallery_thumbnail_number = 3 + # Print mean coefficients across all time delays / channels (see Fig 1) time_plot = 0.180 # For highlighting a specific time. fig, ax = plt.subplots(figsize=(4, 8), layout="constrained") diff --git a/examples/time_frequency/source_power_spectrum_opm.py b/examples/time_frequency/source_power_spectrum_opm.py index dd142138784..11168cc08a5 100644 --- a/examples/time_frequency/source_power_spectrum_opm.py +++ b/examples/time_frequency/source_power_spectrum_opm.py @@ -58,16 +58,16 @@ raw_erms = dict() new_sfreq = 60.0 # Nyquist frequency (30 Hz) < line noise freq (50 Hz) raws["vv"] = mne.io.read_raw_fif(vv_fname, verbose="error") # ignore naming -raws["vv"].load_data().resample(new_sfreq) +raws["vv"].load_data().resample(new_sfreq, method="polyphase") raws["vv"].info["bads"] = ["MEG2233", "MEG1842"] raw_erms["vv"] = mne.io.read_raw_fif(vv_erm_fname, verbose="error") -raw_erms["vv"].load_data().resample(new_sfreq) +raw_erms["vv"].load_data().resample(new_sfreq, method="polyphase") raw_erms["vv"].info["bads"] = ["MEG2233", "MEG1842"] raws["opm"] = mne.io.read_raw_fif(opm_fname) -raws["opm"].load_data().resample(new_sfreq) +raws["opm"].load_data().resample(new_sfreq, method="polyphase") raw_erms["opm"] = mne.io.read_raw_fif(opm_erm_fname) -raw_erms["opm"].load_data().resample(new_sfreq) +raw_erms["opm"].load_data().resample(new_sfreq, method="polyphase") # Make sure our assumptions later hold assert raws["opm"].info["sfreq"] == raws["vv"].info["sfreq"] diff --git a/mne/cuda.py b/mne/cuda.py index b4aa7c37bf3..7d7634a6e4e 100644 --- a/mne/cuda.py +++ b/mne/cuda.py @@ -330,7 +330,7 @@ def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, pad="reflect_li Number of samples to remove after resampling. cuda_dict : dict Dictionary constructed using setup_cuda_multiply_repeated(). - %(pad)s + %(pad_resample)s The default is ``'reflect_limited'``. .. versionadded:: 0.15 diff --git a/mne/filter.py b/mne/filter.py index 528128822b8..3d9b3ecc7da 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -5,6 +5,7 @@ from collections import Counter from copy import deepcopy from functools import partial +from math import gcd import numpy as np from scipy import fft, signal @@ -1898,12 +1899,13 @@ def resample( x, up=1.0, down=1.0, - npad=100, + *, axis=-1, - window="boxcar", + window="auto", n_jobs=None, - pad="reflect_limited", - *, + pad="auto", + npad=100, + method="fft", verbose=None, ): """Resample an array. @@ -1918,15 +1920,18 @@ def resample( Factor to upsample by. down : float Factor to downsample by. - %(npad)s axis : int Axis along which to resample (default is the last axis). %(window_resample)s %(n_jobs_cuda)s - %(pad)s - The default is ``'reflect_limited'``. + ``n_jobs='cuda'`` is only supported when ``method="fft"``. + %(pad_resample_auto)s .. versionadded:: 0.15 + %(npad_resample)s + %(method_resample)s + + .. versionadded:: 1.7 %(verbose)s Returns @@ -1936,26 +1941,16 @@ def resample( Notes ----- - This uses (hopefully) intelligent edge padding and frequency-domain - windowing improve scipy.signal.resample's resampling method, which + When using ``method="fft"`` (default), + this uses (hopefully) intelligent edge padding and frequency-domain + windowing improve :func:`scipy.signal.resample`'s resampling method, which we have adapted for our use here. Choices of npad and window have important consequences, and the default choices should work well for most natural signals. - - Resampling arguments are broken into "up" and "down" components for future - compatibility in case we decide to use an upfirdn implementation. The - current implementation is functionally equivalent to passing - up=up/down and down=1. """ - # check explicitly for backwards compatibility - if not isinstance(axis, int): - err = ( - "The axis parameter needs to be an integer (got %s). " - "The axis parameter was missing from this function for a " - "period of time, you might be intending to specify the " - "subsequent window parameter." % repr(axis) - ) - raise TypeError(err) + _validate_type(method, str, "method") + _validate_type(pad, str, "pad") + _check_option("method", method, ("fft", "polyphase")) # make sure our arithmetic will work x = _check_filterable(x, "resampled", "resample") @@ -1963,31 +1958,88 @@ def resample( del up, down if axis < 0: axis = x.ndim + axis - orig_last_axis = x.ndim - 1 - if axis != orig_last_axis: - x = x.swapaxes(axis, orig_last_axis) - orig_shape = x.shape - x_len = orig_shape[-1] - if x_len == 0: - warn("x has zero length along last axis, returning a copy of x") + if x.shape[axis] == 0: + warn(f"x has zero length along axis={axis}, returning a copy of x") return x.copy() - bad_msg = 'npad must be "auto" or an integer' + + # prep for resampling along the last axis (swap axis with last then reshape) + out_shape = list(x.shape) + out_shape.pop(axis) + out_shape.append(final_len) + x = np.atleast_2d(x.swapaxes(axis, -1).reshape((-1, x.shape[axis]))) + + # do the resampling using FFT or polyphase methods + kwargs = dict(pad=pad, window=window, n_jobs=n_jobs) + if method == "fft": + y = _resample_fft(x, npad=npad, ratio=ratio, final_len=final_len, **kwargs) + else: + up, down, kwargs["window"] = _prep_polyphase( + ratio, x.shape[-1], final_len, window + ) + half_len = len(window) // 2 + logger.info( + f"Polyphase resampling locality: ±{half_len} input sample{_pl(half_len)}" + ) + y = _resample_polyphase(x, up=up, down=down, **kwargs) + assert y.shape[-1] == final_len + + # restore dimensions (reshape then swap axis with last) + y = y.reshape(out_shape).swapaxes(axis, -1) + + return y + + +def _prep_polyphase(ratio, x_len, final_len, window): + if isinstance(window, str) and window == "auto": + window = ("kaiser", 5.0) # SciPy default + up = final_len + down = x_len + g_ = gcd(up, down) + up = up // g_ + down = down // g_ + # Figure out our signal locality and design window (adapted from SciPy) + if not isinstance(window, (list, np.ndarray)): + # Design a linear-phase low-pass FIR filter + max_rate = max(up, down) + f_c = 1.0 / max_rate # cutoff of FIR filter (rel. to Nyquist) + half_len = 10 * max_rate # reasonable cutoff for sinc-like function + window = signal.firwin(2 * half_len + 1, f_c, window=window) + return up, down, window + + +def _resample_polyphase(x, *, up, down, pad, window, n_jobs): + if pad == "auto": + pad = "reflect" + kwargs = dict(padtype=pad, window=window, up=up, down=down) + _validate_type( + n_jobs, (None, "int-like"), "n_jobs", extra="when method='polyphase'" + ) + parallel, p_fun, n_jobs = parallel_func(signal.resample_poly, n_jobs) + if n_jobs == 1: + y = signal.resample_poly(x, axis=-1, **kwargs) + else: + y = np.array(parallel(p_fun(x_, **kwargs) for x_ in x)) + return y + + +def _resample_fft(x_flat, *, ratio, final_len, pad, window, npad, n_jobs): + x_len = x_flat.shape[-1] + pad = "reflect_limited" if pad == "auto" else pad + if (isinstance(window, str) and window == "auto") or window is None: + window = "boxcar" if isinstance(npad, str): - if npad != "auto": - raise ValueError(bad_msg) + _check_option("npad", npad, ("auto",), extra="when a string") # Figure out reasonable pad that gets us to a power of 2 min_add = min(x_len // 8, 100) * 2 npad = 2 ** int(np.ceil(np.log2(x_len + min_add))) - x_len npad, extra = divmod(npad, 2) npads = np.array([npad, npad + extra], int) else: - if npad != int(npad): - raise ValueError(bad_msg) + npad = _ensure_int(npad, "npad", extra="or 'auto'") npads = np.array([npad, npad], int) del npad # prep for resampling now - x_flat = x.reshape((-1, x_len)) orig_len = x_len + npads.sum() # length after padding new_len = max(int(round(ratio * orig_len)), 1) # length after resampling to_removes = [int(round(ratio * npads[0]))] @@ -1997,15 +2049,12 @@ def resample( # assert np.abs(to_removes[1] - to_removes[0]) <= int(np.ceil(ratio)) # figure out windowing function - if window is not None: - if callable(window): - W = window(fft.fftfreq(orig_len)) - elif isinstance(window, np.ndarray) and window.shape == (orig_len,): - W = window - else: - W = fft.ifftshift(signal.get_window(window, orig_len)) + if callable(window): + W = window(fft.fftfreq(orig_len)) + elif isinstance(window, np.ndarray) and window.shape == (orig_len,): + W = window else: - W = np.ones(orig_len) + W = fft.ifftshift(signal.get_window(window, orig_len)) W *= float(new_len) / float(orig_len) # figure out if we should use CUDA @@ -2015,7 +2064,7 @@ def resample( # use of the 'flat' window is recommended for minimal ringing parallel, p_fun, n_jobs = parallel_func(_fft_resample, n_jobs) if n_jobs == 1: - y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x.dtype) + y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x_flat.dtype) for xi, x_ in enumerate(x_flat): y[xi] = _fft_resample(x_, new_len, npads, to_removes, cuda_dict, pad) else: @@ -2024,12 +2073,6 @@ def resample( ) y = np.array(y) - # Restore the original array shape (modified for resampling) - y.shape = orig_shape[:-1] + (y.shape[1],) - if axis != orig_last_axis: - y = y.swapaxes(axis, orig_last_axis) - assert y.shape[axis] == final_len - return y @@ -2635,11 +2678,12 @@ def filter( def resample( self, sfreq, + *, npad="auto", - window="boxcar", + window="auto", n_jobs=None, pad="edge", - *, + method="fft", verbose=None, ): """Resample data. @@ -2656,11 +2700,12 @@ def resample( %(npad)s %(window_resample)s %(n_jobs_cuda)s - %(pad)s - The default is ``'edge'``, which pads with the edge values of each - vector. + %(pad_resample)s .. versionadded:: 0.15 + %(method_resample)s + + .. versionadded:: 1.7 %(verbose)s Returns @@ -2691,7 +2736,14 @@ def resample( _check_preload(self, "inst.resample") self._data = resample( - self._data, sfreq, o_sfreq, npad, window=window, n_jobs=n_jobs, pad=pad + self._data, + sfreq, + o_sfreq, + npad=npad, + window=window, + n_jobs=n_jobs, + pad=pad, + method=method, ) lowpass = self.info.get("lowpass") lowpass = np.inf if lowpass is None else lowpass diff --git a/mne/io/base.py b/mne/io/base.py index 95ba7038865..652b747a8ac 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1260,12 +1260,14 @@ def notch_filter( def resample( self, sfreq, + *, npad="auto", - window="boxcar", + window="auto", stim_picks=None, n_jobs=None, events=None, - pad="reflect_limited", + pad="auto", + method="fft", verbose=None, ): """Resample all channels. @@ -1294,7 +1296,7 @@ def resample( ---------- sfreq : float New sample rate to use. - %(npad)s + %(npad_resample)s %(window_resample)s stim_picks : list of int | None Stim channels. These channels are simply subsampled or @@ -1307,10 +1309,12 @@ def resample( An optional event matrix. When specified, the onsets of the events are resampled jointly with the data. NB: The input events are not modified, but a new array is returned with the raw instead. - %(pad)s - The default is ``'reflect_limited'``. + %(pad_resample_auto)s .. versionadded:: 0.15 + %(method_resample)s + + .. versionadded:: 1.7 %(verbose)s Returns @@ -1364,7 +1368,13 @@ def resample( ) kwargs = dict( - up=sfreq, down=o_sfreq, npad=npad, window=window, n_jobs=n_jobs, pad=pad + up=sfreq, + down=o_sfreq, + npad=npad, + window=window, + n_jobs=n_jobs, + pad=pad, + method=method, ) ratio, n_news = zip( *( diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 5c760735800..2c302eac3ad 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -42,6 +42,7 @@ _record_warnings, assert_and_remove_boundary_annot, assert_object_equal, + catch_logging, requires_mne, run_subprocess, ) @@ -1290,23 +1291,28 @@ def test_resample_equiv(): @pytest.mark.slowtest @testing.requires_testing_data @pytest.mark.parametrize( - "preload, n, npad", + "preload, n, npad, method", [ - (True, 512, "auto"), - (False, 512, 0), + (True, 512, "auto", "fft"), + (True, 512, "auto", "polyphase"), + (False, 512, 0, "fft"), # only test one with non-preload because it's slow ], ) -def test_resample(tmp_path, preload, n, npad): +def test_resample(tmp_path, preload, n, npad, method): """Test resample (with I/O and multiple files).""" + kwargs = dict(npad=npad, method=method) raw = read_raw_fif(fif_fname) raw.crop(0, raw.times[n - 1]) + # Reduce to a few MEG channels and a few stim channels to speed up + n_meg = 5 + raw.pick(raw.ch_names[:n_meg] + raw.ch_names[312:320]) # 10 MEG + 3 STIM + 5 EEG assert len(raw.times) == n if preload: raw.load_data() raw_resamp = raw.copy() sfreq = raw.info["sfreq"] # test parallel on upsample - raw_resamp.resample(sfreq * 2, n_jobs=2, npad=npad) + raw_resamp.resample(sfreq * 2, n_jobs=2, **kwargs) assert raw_resamp.n_times == len(raw_resamp.times) raw_resamp.save(tmp_path / "raw_resamp-raw.fif") raw_resamp = read_raw_fif(tmp_path / "raw_resamp-raw.fif", preload=True) @@ -1315,7 +1321,13 @@ def test_resample(tmp_path, preload, n, npad): assert raw_resamp.get_data().shape[1] == raw_resamp.n_times assert raw.get_data().shape[0] == raw_resamp._data.shape[0] # test non-parallel on downsample - raw_resamp.resample(sfreq, n_jobs=None, npad=npad) + with catch_logging() as log: + raw_resamp.resample(sfreq, n_jobs=None, verbose=True, **kwargs) + log = log.getvalue() + if method == "fft": + assert "locality" not in log + else: + assert "locality" in log assert raw_resamp.info["sfreq"] == sfreq assert raw.get_data().shape == raw_resamp._data.shape assert raw.first_samp == raw_resamp.first_samp @@ -1324,18 +1336,12 @@ def test_resample(tmp_path, preload, n, npad): # works (hooray). Note that the stim channels had to be sub-sampled # without filtering to be accurately preserved # note we have to treat MEG and EEG+STIM channels differently (tols) - assert_allclose( - raw.get_data()[:306, 200:-200], - raw_resamp._data[:306, 200:-200], - rtol=1e-2, - atol=1e-12, - ) - assert_allclose( - raw.get_data()[306:, 200:-200], - raw_resamp._data[306:, 200:-200], - rtol=1e-2, - atol=1e-7, - ) + want_meg = raw.get_data()[:n_meg, 200:-200] + got_meg = raw_resamp._data[:n_meg, 200:-200] + want_non_meg = raw.get_data()[n_meg:, 200:-200] + got_non_meg = raw_resamp._data[n_meg:, 200:-200] + assert_allclose(got_meg, want_meg, rtol=1e-2, atol=1e-12) + assert_allclose(want_non_meg, got_non_meg, rtol=1e-2, atol=1e-7) # now check multiple file support w/resampling, as order of operations # (concat, resample) should not affect our data @@ -1344,9 +1350,9 @@ def test_resample(tmp_path, preload, n, npad): raw3 = raw.copy() raw4 = raw.copy() raw1 = concatenate_raws([raw1, raw2]) - raw1.resample(10.0, npad=npad) - raw3.resample(10.0, npad=npad) - raw4.resample(10.0, npad=npad) + raw1.resample(10.0, **kwargs) + raw3.resample(10.0, **kwargs) + raw4.resample(10.0, **kwargs) raw3 = concatenate_raws([raw3, raw4]) assert_array_equal(raw1._data, raw3._data) assert_array_equal(raw1._first_samps, raw3._first_samps) @@ -1364,12 +1370,12 @@ def test_resample(tmp_path, preload, n, npad): # basic decimation stim = [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) - assert_allclose(raw.resample(8.0, npad=npad)._data, [[1, 1, 0, 0, 1, 1, 0, 0]]) + assert_allclose(raw.resample(8.0, **kwargs)._data, [[1, 1, 0, 0, 1, 1, 0, 0]]) # decimation of multiple stim channels raw = RawArray(2 * [stim], create_info(2, len(stim), 2 * ["stim"])) assert_allclose( - raw.resample(8.0, npad=npad, verbose="error")._data, + raw.resample(8.0, **kwargs, verbose="error")._data, [[1, 1, 0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1, 0, 0]], ) @@ -1377,19 +1383,19 @@ def test_resample(tmp_path, preload, n, npad): # done naively stim = [0, 0, 0, 1, 1, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) - assert_allclose(raw.resample(4.0, npad=npad)._data, [[0, 1, 1, 0]]) + assert_allclose(raw.resample(4.0, **kwargs)._data, [[0, 1, 1, 0]]) # two events are merged in this case (warning) stim = [0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) with pytest.warns(RuntimeWarning, match="become unreliable"): - raw.resample(8.0, npad=npad) + raw.resample(8.0, **kwargs) # events are dropped in this case (warning) stim = [0, 1, 1, 0, 0, 1, 1, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) with pytest.warns(RuntimeWarning, match="become unreliable"): - raw.resample(4.0, npad=npad) + raw.resample(4.0, **kwargs) # test resampling events: this should no longer give a warning # we often have first_samp != 0, include it here too @@ -1400,7 +1406,7 @@ def test_resample(tmp_path, preload, n, npad): first_samp = len(stim) // 2 raw = RawArray([stim], create_info(1, o_sfreq, ["stim"]), first_samp=first_samp) events = find_events(raw) - raw, events = raw.resample(n_sfreq, events=events, npad=npad) + raw, events = raw.resample(n_sfreq, events=events, **kwargs) # Try index into raw.times with resampled events: raw.times[events[:, 0] - raw.first_samp] n_fsamp = int(first_samp * sfreq_ratio) # how it's calc'd in base.py @@ -1425,16 +1431,16 @@ def test_resample(tmp_path, preload, n, npad): # test copy flag stim = [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0] raw = RawArray([stim], create_info(1, len(stim), ["stim"])) - raw_resampled = raw.copy().resample(4.0, npad=npad) + raw_resampled = raw.copy().resample(4.0, **kwargs) assert raw_resampled is not raw - raw_resampled = raw.resample(4.0, npad=npad) + raw_resampled = raw.resample(4.0, **kwargs) assert raw_resampled is raw # resample should still work even when no stim channel is present raw = RawArray(np.random.randn(1, 100), create_info(1, 100, ["eeg"])) with raw.info._unlock(): raw.info["lowpass"] = 50.0 - raw.resample(10, npad=npad) + raw.resample(10, **kwargs) assert raw.info["lowpass"] == 5.0 assert len(raw) == 10 diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 213d00e5baa..50734817431 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -819,7 +819,17 @@ def crop(self, tmin=None, tmax=None, include_tmax=True): return self # return self for chaining methods @verbose - def resample(self, sfreq, npad="auto", window="boxcar", n_jobs=None, verbose=None): + def resample( + self, + sfreq, + *, + npad=100, + method="fft", + window="auto", + pad="auto", + n_jobs=None, + verbose=None, + ): """Resample data. If appropriate, an anti-aliasing filter is applied before resampling. @@ -833,8 +843,15 @@ def resample(self, sfreq, npad="auto", window="boxcar", n_jobs=None, verbose=Non Amount to pad the start and end of the data. Can also be "auto" to use a padding that will result in a power-of-two size (can be much faster). - window : str | tuple - Window to use in resampling. See :func:`scipy.signal.resample`. + %(method_resample)s + + .. versionadded:: 1.7 + %(window_resample)s + + .. versionadded:: 1.7 + %(pad_resample_auto)s + + .. versionadded:: 1.7 %(n_jobs)s %(verbose)s @@ -863,7 +880,9 @@ def resample(self, sfreq, npad="auto", window="boxcar", n_jobs=None, verbose=Non data = self.data if data.dtype == np.float32: data = data.astype(np.float64) - self.data = resample(data, sfreq, o_sfreq, npad, n_jobs=n_jobs) + self.data = resample( + data, sfreq, o_sfreq, npad=npad, window=window, n_jobs=n_jobs, method=method + ) # adjust indirectly affected variables self.tstep = 1.0 / sfreq diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py index 110a8f136c3..3ab60dba055 100644 --- a/mne/tests/test_filter.py +++ b/mne/tests/test_filter.py @@ -32,6 +32,8 @@ from mne.io import RawArray, read_raw_fif from mne.utils import catch_logging, requires_mne, run_subprocess, sum_squared +resample_method_parametrize = pytest.mark.parametrize("method", ("fft", "polyphase")) + def test_filter_array(): """Test filtering an array.""" @@ -372,20 +374,27 @@ def test_notch_filters(method, filter_length, line_freq, tol): assert_almost_equal(new_power, orig_power, tol) -def test_resample(): +@resample_method_parametrize +def test_resample(method): """Test resampling.""" rng = np.random.RandomState(0) x = rng.normal(0, 1, (10, 10, 10)) - x_rs = resample(x, 1, 2, 10) + with catch_logging() as log: + x_rs = resample(x, 1, 2, npad=10, method=method, verbose=True) + log = log.getvalue() + if method == "fft": + assert "locality" not in log + else: + assert "locality" in log assert x.shape == (10, 10, 10) assert x_rs.shape == (10, 10, 5) x_2 = x.swapaxes(0, 1) - x_2_rs = resample(x_2, 1, 2, 10) + x_2_rs = resample(x_2, 1, 2, npad=10, method=method) assert_array_equal(x_2_rs.swapaxes(0, 1), x_rs) x_3 = x.swapaxes(0, 2) - x_3_rs = resample(x_3, 1, 2, 10, 0) + x_3_rs = resample(x_3, 1, 2, npad=10, axis=0, method=method) assert_array_equal(x_3_rs.swapaxes(0, 2), x_rs) # make sure we cast to array if necessary @@ -401,12 +410,12 @@ def test_resample_scipy(): err_msg = "%s: %s" % (N, window) x_2_sp = sp_resample(x, 2 * N, window=window) for n_jobs in n_jobs_test: - x_2 = resample(x, 2, 1, 0, window=window, n_jobs=n_jobs) + x_2 = resample(x, 2, 1, npad=0, window=window, n_jobs=n_jobs) assert_allclose(x_2, x_2_sp, atol=1e-12, err_msg=err_msg) new_len = int(round(len(x) * (1.0 / 2.0))) x_p5_sp = sp_resample(x, new_len, window=window) for n_jobs in n_jobs_test: - x_p5 = resample(x, 1, 2, 0, window=window, n_jobs=n_jobs) + x_p5 = resample(x, 1, 2, npad=0, window=window, n_jobs=n_jobs) assert_allclose(x_p5, x_p5_sp, atol=1e-12, err_msg=err_msg) @@ -450,23 +459,25 @@ def test_resamp_stim_channel(): assert new_data.shape[1] == new_data_len -def test_resample_raw(): +@resample_method_parametrize +def test_resample_raw(method): """Test resampling using RawArray.""" x = np.zeros((1, 1001)) sfreq = 2048.0 raw = RawArray(x, create_info(1, sfreq, "eeg")) - raw.resample(128, npad=10) + raw.resample(128, npad=10, method=method) data = raw.get_data() assert data.shape == (1, 63) -def test_resample_below_1_sample(): +@resample_method_parametrize +def test_resample_below_1_sample(method): """Test resampling doesn't yield datapoints.""" # Raw x = np.zeros((1, 100)) sfreq = 1000.0 raw = RawArray(x, create_info(1, sfreq, "eeg")) - raw.resample(5) + raw.resample(5, method=method) assert len(raw.times) == 1 assert raw.get_data().shape[1] == 1 @@ -487,7 +498,13 @@ def test_resample_below_1_sample(): preload=True, verbose=False, ) - epochs.resample(1) + with catch_logging() as log: + epochs.resample(1, method=method, verbose=True) + log = log.getvalue() + if method == "fft": + assert "locality" not in log + else: + assert "locality" in log assert len(epochs.times) == 1 assert epochs.get_data(copy=False).shape[2] == 1 diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 8c9e7df9389..be31fd1501b 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -558,61 +558,73 @@ def test_stc_arithmetic(): @pytest.mark.slowtest @testing.requires_testing_data -def test_stc_methods(): +@pytest.mark.parametrize("kind", ("scalar", "vector")) +@pytest.mark.parametrize("method", ("fft", "polyphase")) +def test_stc_methods(kind, method): """Test stc methods lh_data, rh_data, bin(), resample().""" - stc_ = read_source_estimate(fname_stc) + stc = read_source_estimate(fname_stc) - # Make a vector version of the above source estimate - x = stc_.data[:, np.newaxis, :] - yz = np.zeros((x.shape[0], 2, x.shape[2])) - vec_stc_ = VectorSourceEstimate( - np.concatenate((x, yz), 1), stc_.vertices, stc_.tmin, stc_.tstep, stc_.subject - ) + if kind == "vector": + # Make a vector version of the above source estimate + x = stc.data[:, np.newaxis, :] + yz = np.zeros((x.shape[0], 2, x.shape[2])) + stc = VectorSourceEstimate( + np.concatenate((x, yz), 1), + stc.vertices, + stc.tmin, + stc.tstep, + stc.subject, + ) - for stc in [stc_, vec_stc_]: - # lh_data / rh_data - assert_array_equal(stc.lh_data, stc.data[: len(stc.lh_vertno)]) - assert_array_equal(stc.rh_data, stc.data[len(stc.lh_vertno) :]) - - # bin - binned = stc.bin(0.12) - a = np.mean(stc.data[..., : np.searchsorted(stc.times, 0.12)], axis=-1) - assert_array_equal(a, binned.data[..., 0]) - - stc = read_source_estimate(fname_stc) - stc.subject = "sample" - label_lh = read_labels_from_annot( - "sample", "aparc", "lh", subjects_dir=subjects_dir - )[0] - label_rh = read_labels_from_annot( - "sample", "aparc", "rh", subjects_dir=subjects_dir - )[0] - label_both = label_lh + label_rh - for label in (label_lh, label_rh, label_both): - assert isinstance(stc.shape, tuple) and len(stc.shape) == 2 - stc_label = stc.in_label(label) - if label.hemi != "both": - if label.hemi == "lh": - verts = stc_label.vertices[0] - else: # label.hemi == 'rh': - verts = stc_label.vertices[1] - n_vertices_used = len(label.get_vertices_used(verts)) - assert_equal(len(stc_label.data), n_vertices_used) - stc_lh = stc.in_label(label_lh) - pytest.raises(ValueError, stc_lh.in_label, label_rh) - label_lh.subject = "foo" - pytest.raises(RuntimeError, stc.in_label, label_lh) - - stc_new = deepcopy(stc) - o_sfreq = 1.0 / stc.tstep - # note that using no padding for this STC reduces edge ringing... - stc_new.resample(2 * o_sfreq, npad=0) - assert stc_new.data.shape[1] == 2 * stc.data.shape[1] - assert stc_new.tstep == stc.tstep / 2 - stc_new.resample(o_sfreq, npad=0) - assert stc_new.data.shape[1] == stc.data.shape[1] - assert stc_new.tstep == stc.tstep - assert_array_almost_equal(stc_new.data, stc.data, 5) + # lh_data / rh_data + assert_array_equal(stc.lh_data, stc.data[: len(stc.lh_vertno)]) + assert_array_equal(stc.rh_data, stc.data[len(stc.lh_vertno) :]) + + # bin + binned = stc.bin(0.12) + a = np.mean(stc.data[..., : np.searchsorted(stc.times, 0.12)], axis=-1) + assert_array_equal(a, binned.data[..., 0]) + + stc = read_source_estimate(fname_stc) + stc.subject = "sample" + label_lh = read_labels_from_annot( + "sample", "aparc", "lh", subjects_dir=subjects_dir + )[0] + label_rh = read_labels_from_annot( + "sample", "aparc", "rh", subjects_dir=subjects_dir + )[0] + label_both = label_lh + label_rh + for label in (label_lh, label_rh, label_both): + assert isinstance(stc.shape, tuple) and len(stc.shape) == 2 + stc_label = stc.in_label(label) + if label.hemi != "both": + if label.hemi == "lh": + verts = stc_label.vertices[0] + else: # label.hemi == 'rh': + verts = stc_label.vertices[1] + n_vertices_used = len(label.get_vertices_used(verts)) + assert_equal(len(stc_label.data), n_vertices_used) + stc_lh = stc.in_label(label_lh) + pytest.raises(ValueError, stc_lh.in_label, label_rh) + label_lh.subject = "foo" + pytest.raises(RuntimeError, stc.in_label, label_lh) + + stc_new = deepcopy(stc) + o_sfreq = 1.0 / stc.tstep + # note that using no padding for this STC reduces edge ringing... + stc_new.resample(2 * o_sfreq, npad=0, method=method) + assert stc_new.data.shape[1] == 2 * stc.data.shape[1] + assert stc_new.tstep == stc.tstep / 2 + stc_new.resample(o_sfreq, npad=0, method=method) + assert stc_new.data.shape[1] == stc.data.shape[1] + assert stc_new.tstep == stc.tstep + if method == "fft": + # no low-passing so survives round-trip + assert_allclose(stc_new.data, stc.data, atol=1e-5) + else: + # low-passing means we need something more flexible + corr = np.corrcoef(stc_new.data.ravel(), stc.data.ravel())[0, 1] + assert 0.99 < corr < 1 @testing.requires_testing_data diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 806d774f221..6d26d01dc40 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2245,6 +2245,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["method_psd"] = _method_psd.format("", "") docdict["method_psd_auto"] = _method_psd.format(" | ``'auto'``", "") +docdict["method_resample"] = """ +method : str + Resampling method to use. Can be ``"fft"`` (default) or ``"polyphase"`` + to use FFT-based on polyphase FIR resampling, respectively. These wrap to + :func:`scipy.signal.resample` and :func:`scipy.signal.resample_poly`, respectively. +""" + docdict["mode_eltc"] = """ mode : str Extraction mode, see Notes. @@ -2488,11 +2495,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["npad"] = """ npad : int | str - Amount to pad the start and end of the data. - Can also be ``"auto"`` to use a padding that will result in - a power-of-two size (can be much faster). + Amount to pad the start and end of the data. Can also be ``"auto"`` to use a padding + that will result in a power-of-two size (can be much faster). """ +docdict["npad_resample"] = ( + docdict["npad"] + + """ + Only used when ``method="fft"``. +""" +) docdict["nrows_ncols_ica_components"] = """ nrows, ncols : int | 'auto' The number of rows and columns of topographies to plot. If both ``nrows`` @@ -2698,22 +2710,38 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # P _pad_base = """ -pad : str - The type of padding to use. Supports all :func:`numpy.pad` ``mode`` - options. Can also be ``"reflect_limited"``, which pads with a - reflected version of each vector mirrored on the first and last values + all :func:`numpy.pad` ``mode`` options. Can also be ``"reflect_limited"``, which + pads with a reflected version of each vector mirrored on the first and last values of the vector, followed by zeros. """ -docdict["pad"] = _pad_base - docdict["pad_fir"] = ( - _pad_base - + """ + """ +pad : str + The type of padding to use. Supports """ + + _pad_base + + """\ Only used for ``method='fir'``. """ ) +docdict["pad_resample"] = ( # used when default is not "auto" + """ +pad : str + The type of padding to use. When ``method="fft"``, supports """ + + _pad_base + + """\ + When ``method="polyphase"``, supports all modes of :func:`scipy.signal.upfirdn`. +""" +) + +docdict["pad_resample_auto"] = ( # used when default is "auto" + docdict["pad_resample"] + + """\ + The default ("auto") means ``'reflect_limited'`` for ``method='fft'`` and + ``'reflect'`` for ``method='polyphase'``. +""" +) docdict["pca_vars_pctf"] = """ pca_vars : array, shape (n_comp,) | list of array The explained variances of the first n_comp SVD components across the @@ -4331,8 +4359,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["window_resample"] = """ window : str | tuple - Frequency-domain window to use in resampling. - See :func:`scipy.signal.resample`. + When ``method="fft"``, this is the *frequency-domain* window to use in resampling, + and should be the same length as the signal; see :func:`scipy.signal.resample` + for details. When ``method="polyphase"``, this is the *time-domain* linear-phase + window to use after upsampling the signal; see :func:`scipy.signal.resample_poly` + for details. The default ``"auto"`` will use ``"boxcar"`` for ``method="fft"`` and + ``("kaiser", 5.0)`` for ``method="polyphase"``. """ # %% diff --git a/tutorials/preprocessing/30_filtering_resampling.py b/tutorials/preprocessing/30_filtering_resampling.py index 530b92741f6..6c118c99180 100644 --- a/tutorials/preprocessing/30_filtering_resampling.py +++ b/tutorials/preprocessing/30_filtering_resampling.py @@ -206,16 +206,59 @@ def add_arrows(axes): # frequency`_ of the desired new sampling rate. This can be clearly seen in the # PSD plot, where a dashed vertical line indicates the filter cutoff; the # original data had an existing lowpass at around 172 Hz (see -# ``raw.info['lowpass']``), and the data resampled from 600 Hz to 200 Hz gets +# ``raw.info['lowpass']``), and the data resampled from ~600 Hz to 200 Hz gets # automatically lowpass filtered at 100 Hz (the `Nyquist frequency`_ for a # target rate of 200 Hz): raw_downsampled = raw.copy().resample(sfreq=200) +# choose n_fft for Welch PSD to make frequency axes similar resolution +n_ffts = [4096, int(round(4096 * 200 / raw.info["sfreq"]))] +fig, axes = plt.subplots(2, 1, sharey=True, layout="constrained", figsize=(10, 6)) +for ax, data, title, n_fft in zip( + axes, [raw, raw_downsampled], ["Original", "Downsampled"], n_ffts +): + fig = data.compute_psd(n_fft=n_fft).plot( + average=True, picks="data", exclude="bads", axes=ax + ) + ax.set(title=title, xlim=(0, 300)) -for data, title in zip([raw, raw_downsampled], ["Original", "Downsampled"]): - fig = data.compute_psd().plot(average=True, picks="data", exclude="bads") - fig.suptitle(title) - plt.setp(fig.axes, xlim=(0, 300)) +# %% +# By default, MNE-Python resamples using ``method="fft"``, which performs FFT-based +# resampling via :func:`scipy.signal.resample`. While efficient and good for most +# biological signals, it has two main potential drawbacks: +# +# 1. It assumes periodicity of the signal. We try to overcome this with appropriate +# signal padding, but some signal leakage may still occur. +# 2. It treats the entire signal as a single block. This means that in general effects +# are not guaranteed to be localized in time, though in practice they often are. +# +# Alternatively, resampling can be performed using ``method="polyphase"`` instead. +# This uses :func:`scipy.signal.resample_poly` under the hood, which in turn utilizes +# a three-step process to resample signals (see :func:`scipy.signal.upfirdn` for +# details). This process guarantees that each resampled output value is only affected by +# input values within a limited range. In other words, output values are guaranteed to +# be a result of a specific set of input values. +# +# In general, using ``method="polyphase"`` can also be faster than ``method="fft"`` in +# cases where the desired sampling rate is an integer factor different from the input +# sampling rate. For example: + +# sphinx_gallery_thumbnail_number = 11 + +n_ffts = [4096, 2048] # factor of 2 smaller n_fft +raw_downsampled_poly = raw.copy().resample( + sfreq=raw.info["sfreq"] / 2.0, + method="polyphase", + verbose=True, +) +fig, axes = plt.subplots(2, 1, sharey=True, layout="constrained", figsize=(10, 6)) +for ax, data, title, n_fft in zip( + axes, [raw, raw_downsampled_poly], ["Original", "Downsampled (polyphase)"], n_ffts +): + data.compute_psd(n_fft=n_fft).plot( + average=True, picks="data", exclude="bads", axes=ax + ) + ax.set(title=title, xlim=(0, 300)) # %% # Because resampling involves filtering, there are some pitfalls to resampling