Skip to content

Commit cc84b52

Browse files
Vahidrostamimdenker
authored andcommitted
Update test_unitary_event_analysis.py with validation test (#152)
* Update test_unitary_event_analysis.py test added to reproduce the results of Riehle et al. 1997 science using unitary event method (implemented in Elephant) and check the consistency with the original publication (see Rostami et al. 1997 [Re] science for a detailed explanation) * Update test_unitary_event_analysis.py used mock library to exclude importing matplotlib * Update test_unitary_event_analysis.py Added neo loading routine to the test file. * Update test_unitary_event_analysis.py changed indentation and `numpy` to `np` * Update test_unitary_event_analysis.py removed mock * Update test_unitary_event_analysis.py Removed argument `encoding` * Update test_unitary_event_analysis.py added `try except` block around `np.load` to support loading for python 2 and 3 * Update test_unitary_event_analysis.py except changed to `UnicodeError` * Update test_unitary_event_analysis.py added decorator to skip certain configuration in the unittest
1 parent f48fc20 commit cc84b52

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

elephant/test/test_unitary_event_analysis.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111
import types
1212
import elephant.unitary_event_analysis as ue
1313
import neo
14+
import sys
15+
import os
16+
17+
from distutils.version import StrictVersion
18+
19+
20+
def _check_for_incompatibilty():
21+
smaller_version = StrictVersion(np.__version__) < '1.10.0'
22+
return sys.version_info >= (3, 0) and smaller_version
23+
1424

1525
class UETestCase(unittest.TestCase):
1626

@@ -338,6 +348,149 @@ def test_jointJ_window_analysis(self):
338348
UE_dic['indices']['trial26'],expected_indecis_tril26))
339349
self.assertTrue(np.allclose(
340350
UE_dic['indices']['trial4'],expected_indecis_tril4))
351+
352+
@staticmethod
353+
def load_gdf2Neo(fname, trigger, t_pre, t_post):
354+
"""
355+
load and convert the gdf file to Neo format by
356+
cutting and aligning around a given trigger
357+
# codes for trigger events (extracted from a
358+
# documentation of an old file after
359+
# contacting Dr. Alexa Rihle)
360+
# 700 : ST (correct) 701, 702, 703, 704*
361+
# 500 : ST (error =5) 501, 502, 503, 504*
362+
# 1000: ST (if no selec) 1001,1002,1003,1004*
363+
# 11 : PS 111, 112, 113, 114
364+
# 12 : RS 121, 122, 123, 124
365+
# 13 : RT 131, 132, 133, 134
366+
# 14 : MT 141, 142, 143, 144
367+
# 15 : ES 151, 152, 153, 154
368+
# 16 : ES 161, 162, 163, 164
369+
# 17 : ES 171, 172, 173, 174
370+
# 19 : RW 191, 192, 193, 194
371+
# 20 : ET 201, 202, 203, 204
372+
"""
373+
data = np.loadtxt(fname)
374+
375+
if trigger == 'PS_4':
376+
trigger_code = 114
377+
if trigger == 'RS_4':
378+
trigger_code = 124
379+
if trigger == 'RS':
380+
trigger_code = 12
381+
if trigger == 'ES':
382+
trigger_code = 15
383+
# specify units
384+
units_id = np.unique(data[:, 0][data[:, 0] < 7])
385+
# indecies of the trigger
386+
sel_tr_idx = np.where(data[:, 0] == trigger_code)[0]
387+
# cutting the data by aligning on the trigger
388+
data_tr = []
389+
for id_tmp in units_id:
390+
data_sel_units = []
391+
for i_cnt, i in enumerate(sel_tr_idx):
392+
start_tmp = data[i][1] - t_pre.magnitude
393+
stop_tmp = data[i][1] + t_post.magnitude
394+
sel_data_tmp = np.array(
395+
data[np.where((data[:, 1] <= stop_tmp) &
396+
(data[:, 1] >= start_tmp))])
397+
sp_units_tmp = sel_data_tmp[:, 1][
398+
np.where(sel_data_tmp[:, 0] == id_tmp)[0]]
399+
if len(sp_units_tmp) > 0:
400+
aligned_time = sp_units_tmp - start_tmp
401+
data_sel_units.append(neo.SpikeTrain(
402+
aligned_time * pq.ms, t_start=0 * pq.ms,
403+
t_stop=t_pre + t_post))
404+
else:
405+
data_sel_units.append(neo.SpikeTrain(
406+
[] * pq.ms, t_start=0 * pq.ms,
407+
t_stop=t_pre + t_post))
408+
data_tr.append(data_sel_units)
409+
data_tr.reverse()
410+
spiketrain = np.vstack([i for i in data_tr]).T
411+
return spiketrain
412+
413+
# test if the result of newly implemented Unitary Events in
414+
# Elephant is consistent with the result of
415+
# Riehle et al 1997 Science
416+
# (see Rostami et al (2016) [Re] Science, 3(1):1-17)
417+
@unittest.skipIf(_check_for_incompatibilty(),
418+
'Incompatible package versions')
419+
def test_Riehle_et_al_97_UE(self):
420+
from neo.rawio.tests.tools import (download_test_file,
421+
create_local_temp_dir,
422+
make_all_directories)
423+
from neo.test.iotest.tools import (cleanup_test_file)
424+
url = [
425+
"https://raw.githubusercontent.com/ReScience-Archives/" +
426+
"Rostami-Ito-Denker-Gruen-2017/master/data",
427+
"https://raw.githubusercontent.com/ReScience-Archives/" +
428+
"Rostami-Ito-Denker-Gruen-2017/master/data"]
429+
shortname = "unitary_event_analysis_test_data"
430+
local_test_dir = create_local_temp_dir(
431+
shortname, os.environ.get("ELEPHANT_TEST_FILE_DIR"))
432+
files_to_download = ["extracted_data.npy", "winny131_23.gdf"]
433+
make_all_directories(files_to_download,
434+
local_test_dir)
435+
for f_cnt, f in enumerate(files_to_download):
436+
download_test_file(f, local_test_dir, url[f_cnt])
437+
438+
# load spike data of figure 2 of Riehle et al 1997
439+
sys.path.append(local_test_dir)
440+
file_name = '/winny131_23.gdf'
441+
trigger = 'RS_4'
442+
t_pre = 1799 * pq.ms
443+
t_post = 300 * pq.ms
444+
spiketrain = self.load_gdf2Neo(local_test_dir + file_name,
445+
trigger, t_pre, t_post)
446+
447+
# calculating UE ...
448+
winsize = 100 * pq.ms
449+
binsize = 5 * pq.ms
450+
winstep = 5 * pq.ms
451+
pattern_hash = [3]
452+
method = 'analytic_TrialAverage'
453+
t_start = spiketrain[0][0].t_start
454+
t_stop = spiketrain[0][0].t_stop
455+
t_winpos = ue._winpos(t_start, t_stop, winsize, winstep)
456+
significance_level = 0.05
457+
458+
UE = ue.jointJ_window_analysis(
459+
spiketrain, binsize, winsize, winstep,
460+
pattern_hash, method=method)
461+
# load extracted data from figure 2 of Riehle et al 1997
462+
try:
463+
extracted_data = np.load(
464+
local_test_dir + '/extracted_data.npy').item()
465+
except UnicodeError:
466+
extracted_data = np.load(
467+
local_test_dir + '/extracted_data.npy', encoding='latin1').item()
468+
Js_sig = ue.jointJ(significance_level)
469+
sig_idx_win = np.where(UE['Js'] >= Js_sig)[0]
470+
diff_UE_rep = []
471+
y_cnt = 0
472+
for tr in range(len(spiketrain)):
473+
x_idx = np.sort(
474+
np.unique(UE['indices']['trial' + str(tr)],
475+
return_index=True)[1])
476+
x = UE['indices']['trial' + str(tr)][x_idx]
477+
if len(x) > 0:
478+
# choose only the significant coincidences
479+
xx = []
480+
for j in sig_idx_win:
481+
xx = np.append(xx, x[np.where(
482+
(x * binsize >= t_winpos[j]) &
483+
(x * binsize < t_winpos[j] + winsize))])
484+
x_tmp = np.unique(xx) * binsize.magnitude
485+
if len(x_tmp) > 0:
486+
ue_trial = np.sort(extracted_data['ue'][y_cnt])
487+
diff_UE_rep = np.append(
488+
diff_UE_rep, x_tmp - ue_trial)
489+
y_cnt += +1
490+
np.testing.assert_array_less(np.abs(diff_UE_rep), 0.3)
491+
cleanup_test_file('dir', local_test_dir)
492+
493+
341494
def suite():
342495
suite = unittest.makeSuite(UETestCase, 'test')
343496
return suite

0 commit comments

Comments
 (0)