diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index 3a5a3a2b..8b8df107 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -30,7 +30,7 @@ jobs: - uses: actions/cache@v2 with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('setup.cfg') }}-${{ hashFiles('requirements-dev.txt') }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.cfg') }}-${{ hashFiles('requirements-dev.txt') }}-version-1 - name: Install dependencies run: | diff --git a/docs/whats_new.rst b/docs/whats_new.rst index ed0d7e2d..903d3637 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -82,7 +82,7 @@ Changelog - Channel types are now available from a new ``ch_types_all`` attribute, and non-EEG channel names are now available from a new ``ch_names_non_eeg`` attribute from :class:`PrepPipeline `, by `Yorguin Mantilla`_ (:gh:`34`) - Renaming of ``ch_names`` attribute of :class:`PrepPipeline ` to ``ch_names_all``, by `Yorguin Mantilla`_ (:gh:`34`) - It's now possible to pass ``'eeg'`` to ``ref_chs`` and ``reref_chs`` keywords to the ``prep_params`` parameter of :class:`PrepPipeline ` to select only eeg channels for referencing, by `Yorguin Mantilla`_ (:gh:`34`) -- :class:`PrepPipeline ` will retain the non eeg channels through the ``raw`` attribute. The eeg-only and non-eeg parts will be in raw_eeg and raw_non_eeg respectively. See the ``raw`` attribute, by `Christian Oreilly`_ (:gh:`34`) +- :class:`PrepPipeline ` will retain the non eeg channels through the ``raw`` attribute. The eeg-only and non-eeg parts will be in raw_eeg and raw_non_eeg respectively. See the ``raw`` attribute, by `Christian O'Reilly`_ (:gh:`34`) - When a ransac call needs more memory than available, pyprep will now automatically switch to a slower but less memory-consuming version of ransac, by `Yorguin Mantilla`_ (:gh:`32`) - It's now possible to pass an empty list for the ``line_freqs`` param in :class:`PrepPipeline ` to skip the line noise removal, by `Yorguin Mantilla`_ (:gh:`29`) - The three main classes :class:`~pyprep.PrepPipeline`, :class:`~pyprep.NoisyChannels`, and :class:`pyprep.Reference` now have a ``random_state`` parameter to set a seed that gets passed on to all their internal methods and class calls, by `Stefan Appelhoff`_ (:gh:`31`) diff --git a/examples/ransac_comparison_autoreject.py b/examples/ransac_comparison_autoreject.py new file mode 100644 index 00000000..b2cb1925 --- /dev/null +++ b/examples/ransac_comparison_autoreject.py @@ -0,0 +1,212 @@ +""" +=============================================== +RANSAC comparison between pyprep and autoreject +=============================================== + +Next to the RANSAC implementation in ``pyprep``, +there is another implementation that makes use of MNE-Python. +That alternative RANSAC implementation can be found in the +`"autoreject" package `_. + +In this example, we make a basic comparison between the two implementations. + +#. by running them on the same simulated data +#. by running them on the same "real" data + + +.. currentmodule:: pyprep +""" + +# Authors: Yorguin Mantilla +# +# License: MIT + +# %% +# First we import what we need for this example. +import numpy as np +import mne +from scipy import signal as signal +from time import perf_counter +from autoreject import Ransac +import pyprep.ransac as ransac_pyprep + + +# %% +# Now let's make some arbitrary MNE raw object for demonstration purposes. +# We will think of good channels as sine waves and bad channels correlated with +# each other as sawtooths. The RANSAC will be biased towards sines in its +# prediction (they are the majority) so it will identify the sawtooths as bad. + +# Set a random seed to make this example reproducible +rng = np.random.RandomState(435656) + +# start defining some key aspects for our simulated data +sfreq = 1000.0 +montage = mne.channels.make_standard_montage("standard_1020") +ch_names = montage.ch_names +n_chans = len(ch_names) +info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * n_chans) +time = np.arange(0, 30, 1.0 / sfreq) # 30 seconds of recording + +# randomly pick some "bad" channels (sawtooths) +n_bad_chans = 3 +bad_channels = rng.choice(np.arange(n_chans), n_bad_chans, replace=False) +bad_channels = [int(i) for i in bad_channels] +bad_ch_names = [ch_names[i] for i in bad_channels] + +# The frequency components to use in the signal for good and bad channels +freq_good = 20 +freq_bad = 20 + +# Generate the data: sinewaves for "good", sawtooths for "bad" channels +X = [ + signal.sawtooth(2 * np.pi * freq_bad * time) + if i in bad_channels + else np.sin(2 * np.pi * freq_good * time) + for i in range(n_chans) +] + +# Scale the signal amplitude and add noise. +X = 2e-5 * np.array(X) + 1e-5 * np.random.random((n_chans, time.shape[0])) + +# Finally, put it all together as an mne "Raw" object. +raw = mne.io.RawArray(X, info) +raw.set_montage(montage, verbose=False) + +# Print, which channels are simulated as "bad" +print(bad_ch_names) + +# %% +# Configure RANSAC parameters +n_samples = 50 +fraction_good = 0.25 +corr_thresh = 0.75 +fraction_bad = 0.4 +corr_window_secs = 5.0 + +# %% +# autoreject's RANSAC +ransac_ar = Ransac( + picks=None, + n_resample=n_samples, + min_channels=fraction_good, + min_corr=corr_thresh, + unbroken_time=fraction_bad, + n_jobs=1, + random_state=rng, +) +epochs = mne.make_fixed_length_epochs( + raw, + duration=corr_window_secs, + preload=True, + reject_by_annotation=False, + verbose=None, +) + +start_time = perf_counter() +ransac_ar = ransac_ar.fit(epochs) +print("--- %s seconds ---" % (perf_counter() - start_time)) + +corr_ar = ransac_ar.corr_ +bad_by_ransac_ar = ransac_ar.bad_chs_ + +# Check channels that go bad together by RANSAC +print("autoreject bad chs:", bad_by_ransac_ar) +assert set(bad_ch_names) == set(bad_by_ransac_ar) + +# %% +# pyprep's RANSAC + +start_time = perf_counter() +bad_by_ransac_pyprep, corr_pyprep = ransac_pyprep.find_bad_by_ransac( + data=raw._data.copy(), + sample_rate=raw.info["sfreq"], + complete_chn_labs=np.asarray(raw.info["ch_names"]), + chn_pos=raw._get_channel_positions(), + exclude=[], + n_samples=n_samples, + sample_prop=fraction_good, + corr_thresh=corr_thresh, + frac_bad=fraction_bad, + corr_window_secs=corr_window_secs, + channel_wise=False, + random_state=rng, +) +print("--- %s seconds ---" % (perf_counter() - start_time)) + +# Check channels that go bad together by RANSAC +print("pyprep bad chs:", bad_by_ransac_pyprep) +assert set(bad_ch_names) == set(bad_by_ransac_pyprep) + +# %% +# Now we test the algorithms on real EEG data. +# Let's download some data for testing. +data_paths = mne.datasets.eegbci.load_data(subject=4, runs=1, update_path=True) +fname_test_file = data_paths[0] + +# %% +# Load data and prepare it + +raw = mne.io.read_raw_edf(fname_test_file, preload=True) + +# The eegbci data has non-standard channel names. We need to rename them: +mne.datasets.eegbci.standardize(raw) + +# Add a montage to the data +montage_kind = "standard_1005" +montage = mne.channels.make_standard_montage(montage_kind) +raw.set_montage(montage) + + +# %% +# autoreject's RANSAC +ransac_ar = Ransac( + picks=None, + n_resample=n_samples, + min_channels=fraction_good, + min_corr=corr_thresh, + unbroken_time=fraction_bad, + n_jobs=1, + random_state=rng, +) +epochs = mne.make_fixed_length_epochs( + raw, + duration=corr_window_secs, + preload=True, + reject_by_annotation=False, + verbose=None, +) + +start_time = perf_counter() +ransac_ar = ransac_ar.fit(epochs) +print("--- %s seconds ---" % (perf_counter() - start_time)) + +corr_ar = ransac_ar.corr_ +bad_by_ransac_ar = ransac_ar.bad_chs_ + +# Check channels that go bad together by RANSAC +print("autoreject bad chs:", bad_by_ransac_ar) + + +# %% +# pyprep's RANSAC + +start_time = perf_counter() +bad_by_ransac_pyprep, corr_pyprep = ransac_pyprep.find_bad_by_ransac( + data=raw._data.copy(), + sample_rate=raw.info["sfreq"], + complete_chn_labs=np.asarray(raw.info["ch_names"]), + chn_pos=raw._get_channel_positions(), + exclude=[], + n_samples=n_samples, + sample_prop=fraction_good, + corr_thresh=corr_thresh, + frac_bad=fraction_bad, + corr_window_secs=corr_window_secs, + channel_wise=False, + random_state=rng, +) +print("--- %s seconds ---" % (perf_counter() - start_time)) + +# Check channels that go bad together by RANSAC +print("pyprep bad chs:", bad_by_ransac_pyprep) diff --git a/examples/run_ransac.py b/examples/run_ransac.py index 93f2705d..4a9b53f5 100644 --- a/examples/run_ransac.py +++ b/examples/run_ransac.py @@ -90,7 +90,7 @@ # `raw` object. For more information, we can access attributes of the ``nd`` # instance: -# Check channels that go bad together by correlation (RANSAC) +# Check channels that go bad together by RANSAC print(nd.bad_by_ransac) assert set(bad_ch_names) == set(nd.bad_by_ransac) diff --git a/requirements-dev.txt b/requirements-dev.txt index 7b9feef0..4737af7d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +git+git://github.com/autoreject/autoreject.git@master#egg=autoreject black check-manifest flake8