Skip to content

Commit 58f33c0

Browse files
authored
[BUG FIX] Fix CircleCI for Fetching and Spectral Connectivity Saving After averaging frequencies (#91)
1 parent f819d0f commit 58f33c0

File tree

5 files changed

+109
-37
lines changed

5 files changed

+109
-37
lines changed

.circleci/config.yml

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jobs:
2424
if ! git remote -v | grep upstream; then
2525
git remote add upstream https://github.com/mne-tools/mne-connectivity.git
2626
fi
27+
git remote set-url upstream https://github.com/mne-tools/mne-connectivity.git
2728
git fetch upstream
2829
2930
- save_cache:

doc/authors.inc

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
.. _Britta Westner: https://github.com/britta-wstnr
44
.. _Alexander Kroner: https://github.com/alexanderkroner
55
.. _Richard Höchenberger: https://github.com/hoechenberger
6-
.. _Alex Rockhill: https://github.com/alexrockhill
6+
.. _Alex Rockhill: https://github.com/alexrockhill
7+
.. _Szonja Weigl: https://github.com/weiglszonja

doc/whats_new.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Enhancements
2929
Bug
3030
~~~
3131

32-
-
32+
- Fix the output of :func:`mne_connectivity.spectral_connectivity_epochs` when ``faverage=True``, allowing one to save the Connectivity object, by `Adam Li`_ and `Szonja Weigl`_ :gh:`91`
3333

3434
API
3535
~~~

mne_connectivity/spectral/epochs.py

+48-28
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,47 @@
2323
from ..utils import fill_doc, check_indices
2424

2525

26+
def _compute_freqs(n_times, sfreq, cwt_freqs, mode):
27+
from scipy.fft import rfftfreq
28+
# get frequencies of interest for the different modes
29+
if mode in ('multitaper', 'fourier'):
30+
# fmin fmax etc is only supported for these modes
31+
# decide which frequencies to keep
32+
freqs_all = rfftfreq(n_times, 1. / sfreq)
33+
elif mode == 'cwt_morlet':
34+
# cwt_morlet mode
35+
if cwt_freqs is None:
36+
raise ValueError('define frequencies of interest using '
37+
'cwt_freqs')
38+
else:
39+
cwt_freqs = cwt_freqs.astype(np.float64)
40+
if any(cwt_freqs > (sfreq / 2.)):
41+
raise ValueError('entries in cwt_freqs cannot be '
42+
'larger than Nyquist (sfreq / 2)')
43+
freqs_all = cwt_freqs
44+
else:
45+
raise ValueError('mode has an invalid value')
46+
47+
return freqs_all
48+
49+
50+
def _compute_freq_mask(freqs_all, fmin, fmax, fskip):
51+
# create a frequency mask for all bands
52+
freq_mask = np.zeros(len(freqs_all), dtype=bool)
53+
for f_lower, f_upper in zip(fmin, fmax):
54+
freq_mask |= ((freqs_all >= f_lower) & (freqs_all <= f_upper))
55+
56+
# possibly skip frequency points
57+
for pos in range(fskip):
58+
freq_mask[pos + 1::fskip + 1] = False
59+
return freq_mask
60+
61+
2662
def _prepare_connectivity(epoch_block, times_in, tmin, tmax,
2763
fmin, fmax, sfreq, indices,
2864
mode, fskip, n_bands,
2965
cwt_freqs, faverage):
3066
"""Check and precompute dimensions of results data."""
31-
from scipy.fft import rfftfreq
3267
first_epoch = epoch_block[0]
3368

3469
# get the data size and time scale
@@ -68,25 +103,6 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax,
68103
logger.info(' using t=%0.3fs..%0.3fs for estimation (%d points)'
69104
% (tmin_true, tmax_true, n_times))
70105

71-
# get frequencies of interest for the different modes
72-
if mode in ('multitaper', 'fourier'):
73-
# fmin fmax etc is only supported for these modes
74-
# decide which frequencies to keep
75-
freqs_all = rfftfreq(n_times, 1. / sfreq)
76-
elif mode == 'cwt_morlet':
77-
# cwt_morlet mode
78-
if cwt_freqs is None:
79-
raise ValueError('define frequencies of interest using '
80-
'cwt_freqs')
81-
else:
82-
cwt_freqs = cwt_freqs.astype(np.float64)
83-
if any(cwt_freqs > (sfreq / 2.)):
84-
raise ValueError('entries in cwt_freqs cannot be '
85-
'larger than Nyquist (sfreq / 2)')
86-
freqs_all = cwt_freqs
87-
else:
88-
raise ValueError('mode has an invalid value')
89-
90106
# check that fmin corresponds to at least 5 cycles
91107
dur = float(n_times) / sfreq
92108
five_cycle_freq = 5. / dur
@@ -101,17 +117,15 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax,
101117
'unreliable.' % (np.min(fmin), dur * np.min(fmin), dur,
102118
5. / np.min(fmin), five_cycle_freq))
103119

104-
# create a frequency mask for all bands
105-
freq_mask = np.zeros(len(freqs_all), dtype=bool)
106-
for f_lower, f_upper in zip(fmin, fmax):
107-
freq_mask |= ((freqs_all >= f_lower) & (freqs_all <= f_upper))
120+
# compute frequencies to analyze based on number of samples,
121+
# sampling rate, specified wavelet frequencies and mode
122+
freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode)
108123

109-
# possibly skip frequency points
110-
for pos in range(fskip):
111-
freq_mask[pos + 1::fskip + 1] = False
124+
# compute the mask based on specified min/max and decimation factor
125+
freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip)
112126

113127
# the frequency points where we compute connectivity
114-
freqs = freqs_all[freq_mask]
128+
freqs = freqs[freq_mask]
115129
n_freqs = len(freqs)
116130

117131
# get the freq. indices and points for each band
@@ -1107,7 +1121,13 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
11071121
if faverage:
11081122
# for each band we return the frequencies that were averaged
11091123
freqs = [np.mean(x) for x in freqs_bands]
1124+
1125+
# make sure freq_bands is a list of equal-length lists
1126+
# XXX: we lose information on which frequency points went into the
1127+
# computation. If h5netcdf supports numpy objects in the future, then
1128+
# we can change the min/max to just make it a list of lists.
11101129
freqs_used = freqs_bands
1130+
freqs_used = [[np.min(band), np.max(band)] for band in freqs_used]
11111131

11121132
if indices is None:
11131133
# return all-to-all connectivity matrices

mne_connectivity/spectral/tests/test_spectral.py

+57-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
SpectralConnectivity, spectral_connectivity_epochs,
1616
read_connectivity, spectral_connectivity_time)
1717
from mne_connectivity.spectral.epochs import _CohEst, _get_n_epochs
18+
from mne_connectivity.spectral.epochs import (
19+
_compute_freq_mask, _compute_freqs)
1820

1921

2022
def create_test_dataset(sfreq, n_signals, n_epochs, n_times, tmin, tmax,
@@ -336,6 +338,7 @@ def test_spectral_connectivity(method, mode):
336338
assert (n == n2)
337339
assert_array_almost_equal(times_data, times2)
338340

341+
# Test with faverage
339342
# compute same connections for two bands, fskip=1, and f. avg.
340343
fmin = (5., 15.)
341344
fmax = (15., 30.)
@@ -354,21 +357,51 @@ def test_spectral_connectivity(method, mode):
354357
assert (isinstance(freqs3, list))
355358
assert (len(freqs3) == len(fmin))
356359
for i in range(len(freqs3)):
357-
assert np.all((freqs3[i] >= fmin[i]) &
358-
(freqs3[i] <= fmax[i]))
360+
_fmin = max(fmin[i], min(cwt_freqs))
361+
_fmax = min(fmax[i], max(cwt_freqs))
362+
assert_allclose(freqs3[i][0], _fmin, atol=1)
363+
assert_allclose(freqs3[i][1], _fmax, atol=1)
359364

360365
# average con2 "manually" and we get the same result
366+
fskip = 1
361367
if not isinstance(method, list):
362368
for i in range(len(freqs3)):
363-
freq_idx = np.searchsorted(freqs2, freqs3[i])
364-
con2_avg = np.mean(con2.get_data()[:, freq_idx], axis=1)
369+
# now we want to get the frequency indices
370+
# create a frequency mask for all bands
371+
n_times = len(con2.attrs.get('times_used'))
372+
373+
# compute frequencies to analyze based on number of samples,
374+
# sampling rate, specified wavelet frequencies and mode
375+
freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode)
376+
377+
# compute the mask based on specified min/max and decim factor
378+
freq_mask = _compute_freq_mask(
379+
freqs, [fmin[i]], [fmax[i]], fskip)
380+
freqs = freqs[freq_mask]
381+
freqs_idx = np.searchsorted(freqs2, freqs)
382+
con2_avg = np.mean(con2.get_data()[:, freqs_idx], axis=1)
365383
assert_array_almost_equal(con2_avg, con3.get_data()[:, i])
366384
else:
367385
for j in range(len(con2)):
368386
for i in range(len(freqs3)):
369-
freq_idx = np.searchsorted(freqs2, freqs3[i])
370-
con2_avg = np.mean(con2[j].get_data()[:, freq_idx],
371-
axis=1)
387+
# now we want to get the frequency indices
388+
# create a frequency mask for all bands
389+
n_times = len(con2[0].attrs.get('times_used'))
390+
391+
# compute frequencies to analyze based on number of
392+
# samples, sampling rate, specified wavelet frequencies
393+
# and mode
394+
freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode)
395+
396+
# compute the mask based on specified min/max and
397+
# decim factor
398+
freq_mask = _compute_freq_mask(
399+
freqs, [fmin[i]], [fmax[i]], fskip)
400+
freqs = freqs[freq_mask]
401+
freqs_idx = np.searchsorted(freqs2, freqs)
402+
403+
con2_avg = np.mean(con2[j].get_data()[
404+
:, freqs_idx], axis=1)
372405
assert_array_almost_equal(
373406
con2_avg, con3[j].get_data()[:, i])
374407

@@ -551,3 +584,20 @@ def test_time_resolved_spectral_conn_regression(method, mode):
551584
conn_data = conn.get_data(output='dense')[
552585
:, row_triu_inds, col_triu_inds, ...]
553586
assert_array_almost_equal(conn_data, test_conn)
587+
588+
589+
def test_save(tmp_path):
590+
"""Test saving results of spectral connectivity."""
591+
rng = np.random.RandomState(0)
592+
n_epochs, n_chs, n_times, sfreq, f = 10, 2, 2000, 1000., 20.
593+
data = rng.randn(n_epochs, n_chs, n_times)
594+
sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000)
595+
data[:, :, 500:1500] += sig
596+
info = create_info(n_chs, sfreq, 'eeg')
597+
tmin = -1
598+
epochs = EpochsArray(data, info, tmin=tmin)
599+
600+
conn = spectral_connectivity_epochs(
601+
epochs, fmin=(4, 8, 13, 30), fmax=(8, 13, 30, 45),
602+
faverage=True)
603+
conn.save(tmp_path / 'foo.nc')

0 commit comments

Comments
 (0)