15
15
SpectralConnectivity , spectral_connectivity_epochs ,
16
16
read_connectivity , spectral_connectivity_time )
17
17
from mne_connectivity .spectral .epochs import _CohEst , _get_n_epochs
18
+ from mne_connectivity .spectral .epochs import (
19
+ _compute_freq_mask , _compute_freqs )
18
20
19
21
20
22
def create_test_dataset (sfreq , n_signals , n_epochs , n_times , tmin , tmax ,
@@ -336,6 +338,7 @@ def test_spectral_connectivity(method, mode):
336
338
assert (n == n2 )
337
339
assert_array_almost_equal (times_data , times2 )
338
340
341
+ # Test with faverage
339
342
# compute same connections for two bands, fskip=1, and f. avg.
340
343
fmin = (5. , 15. )
341
344
fmax = (15. , 30. )
@@ -354,21 +357,51 @@ def test_spectral_connectivity(method, mode):
354
357
assert (isinstance (freqs3 , list ))
355
358
assert (len (freqs3 ) == len (fmin ))
356
359
for i in range (len (freqs3 )):
357
- assert np .all ((freqs3 [i ] >= fmin [i ]) &
358
- (freqs3 [i ] <= fmax [i ]))
360
+ _fmin = max (fmin [i ], min (cwt_freqs ))
361
+ _fmax = min (fmax [i ], max (cwt_freqs ))
362
+ assert_allclose (freqs3 [i ][0 ], _fmin , atol = 1 )
363
+ assert_allclose (freqs3 [i ][1 ], _fmax , atol = 1 )
359
364
360
365
# average con2 "manually" and we get the same result
366
+ fskip = 1
361
367
if not isinstance (method , list ):
362
368
for i in range (len (freqs3 )):
363
- freq_idx = np .searchsorted (freqs2 , freqs3 [i ])
364
- con2_avg = np .mean (con2 .get_data ()[:, freq_idx ], axis = 1 )
369
+ # now we want to get the frequency indices
370
+ # create a frequency mask for all bands
371
+ n_times = len (con2 .attrs .get ('times_used' ))
372
+
373
+ # compute frequencies to analyze based on number of samples,
374
+ # sampling rate, specified wavelet frequencies and mode
375
+ freqs = _compute_freqs (n_times , sfreq , cwt_freqs , mode )
376
+
377
+ # compute the mask based on specified min/max and decim factor
378
+ freq_mask = _compute_freq_mask (
379
+ freqs , [fmin [i ]], [fmax [i ]], fskip )
380
+ freqs = freqs [freq_mask ]
381
+ freqs_idx = np .searchsorted (freqs2 , freqs )
382
+ con2_avg = np .mean (con2 .get_data ()[:, freqs_idx ], axis = 1 )
365
383
assert_array_almost_equal (con2_avg , con3 .get_data ()[:, i ])
366
384
else :
367
385
for j in range (len (con2 )):
368
386
for i in range (len (freqs3 )):
369
- freq_idx = np .searchsorted (freqs2 , freqs3 [i ])
370
- con2_avg = np .mean (con2 [j ].get_data ()[:, freq_idx ],
371
- axis = 1 )
387
+ # now we want to get the frequency indices
388
+ # create a frequency mask for all bands
389
+ n_times = len (con2 [0 ].attrs .get ('times_used' ))
390
+
391
+ # compute frequencies to analyze based on number of
392
+ # samples, sampling rate, specified wavelet frequencies
393
+ # and mode
394
+ freqs = _compute_freqs (n_times , sfreq , cwt_freqs , mode )
395
+
396
+ # compute the mask based on specified min/max and
397
+ # decim factor
398
+ freq_mask = _compute_freq_mask (
399
+ freqs , [fmin [i ]], [fmax [i ]], fskip )
400
+ freqs = freqs [freq_mask ]
401
+ freqs_idx = np .searchsorted (freqs2 , freqs )
402
+
403
+ con2_avg = np .mean (con2 [j ].get_data ()[
404
+ :, freqs_idx ], axis = 1 )
372
405
assert_array_almost_equal (
373
406
con2_avg , con3 [j ].get_data ()[:, i ])
374
407
@@ -551,3 +584,20 @@ def test_time_resolved_spectral_conn_regression(method, mode):
551
584
conn_data = conn .get_data (output = 'dense' )[
552
585
:, row_triu_inds , col_triu_inds , ...]
553
586
assert_array_almost_equal (conn_data , test_conn )
587
+
588
+
589
+ def test_save (tmp_path ):
590
+ """Test saving results of spectral connectivity."""
591
+ rng = np .random .RandomState (0 )
592
+ n_epochs , n_chs , n_times , sfreq , f = 10 , 2 , 2000 , 1000. , 20.
593
+ data = rng .randn (n_epochs , n_chs , n_times )
594
+ sig = np .sin (2 * np .pi * f * np .arange (1000 ) / sfreq ) * np .hanning (1000 )
595
+ data [:, :, 500 :1500 ] += sig
596
+ info = create_info (n_chs , sfreq , 'eeg' )
597
+ tmin = - 1
598
+ epochs = EpochsArray (data , info , tmin = tmin )
599
+
600
+ conn = spectral_connectivity_epochs (
601
+ epochs , fmin = (4 , 8 , 13 , 30 ), fmax = (8 , 13 , 30 , 45 ),
602
+ faverage = True )
603
+ conn .save (tmp_path / 'foo.nc' )
0 commit comments