1414import elephant .kernels as kernels
1515from elephant .spike_train_generation import StationaryPoissonProcess
1616import elephant .spike_train_dissimilarity as stds
17-
17+ import warnings
1818from elephant .datasets import download_datasets
1919
2020
@@ -47,6 +47,7 @@ def setUp(self):
4747 self .st23 = StationaryPoissonProcess (rate = 30 * Hz , t_start = 0 * ms ,
4848 t_stop = 1000 * ms
4949 ).generate_spiketrain ()
50+ self .st24 = SpikeTrain ([0.1782 , 0.2286 , 0.2804 , 0.4972 , 0.5504 ], units = 's' ,t_stop = 4.0 )
5051 self .rd_st_list = [self .st21 , self .st22 , self .st23 ]
5152 self .st31 = SpikeTrain ([12.0 ], units = 'ms' , t_stop = 1000.0 )
5253 self .st32 = SpikeTrain ([12.0 , 12.0 ], units = 'ms' , t_stop = 1000.0 )
@@ -73,7 +74,7 @@ def setUp(self):
7374 self .tau7 = 0.01 * s
7475 self .q7 = 1.0 / self .tau7
7576 self .t = np .linspace (0 , 200 , 20000001 ) * ms
76-
77+ self . tau8 = 0.1 * s
7778 def test_wrong_input (self ):
7879 self .assertRaises (TypeError , stds .victor_purpura_distance ,
7980 [self .array1 , self .array2 ], self .q3 )
@@ -599,7 +600,15 @@ def test_van_rossum_distance(self):
599600 self .assertEqual (stds .van_rossum_distance (
600601 [self .st21 ], self .tau3 )[0 , 0 ], 0 )
601602 self .assertEqual (len (stds .van_rossum_distance ([], self .tau3 )), 0 )
602-
603+
604+ # Check small negative values edge case
605+ with warnings .catch_warnings (record = 'True' ) as w :
606+ warnings .simplefilter ("always" )
607+ result = stds .van_rossum_distance ([self .st24 , self .st24 ], self .tau8 )
608+ self .assertTrue (any ("very small negative values encountered" in str (warn .message )
609+ for warn in w ))
610+ self .assertEqual (result [0 ,1 ], 0.0 )
611+ self .assertFalse (np .any (np .isnan (result )))
603612
604613if __name__ == '__main__' :
605614 unittest .main ()
0 commit comments