Skip to content

Commit 07e3fc5

Browse files
Clipped small negative values in van_rossum_distance to prevent Nan from np.sqrt; added test case for same
1 parent 5392583 commit 07e3fc5

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

elephant/spike_train_dissimilarity.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"""
2222

2323
from __future__ import division, print_function, unicode_literals
24-
24+
import warnings
2525
import numpy as np
2626
import quantities as pq
2727
from neo.core import SpikeTrain
@@ -363,6 +363,15 @@ def van_rossum_distance(spiketrains, time_constant=1.0 * pq.s, sort=True):
363363
for i, j in np.ndindex(k_dist.shape):
364364
vr_dist[i, j] = (
365365
k_dist[i, i] + k_dist[j, j] - k_dist[i, j] - k_dist[j, i])
366+
367+
# Clip small negative values
368+
if np.any(vr_dist < 0):
369+
warnings.warn(
370+
"van_rossum_distance: very small negative values encountered "
371+
"(likely due to floating point error); "
372+
"setting them to 0", RuntimeWarning)
373+
vr_dist = np.maximum(vr_dist, 0.0)
374+
366375
return np.sqrt(vr_dist)
367376

368377

elephant/test/test_spike_train_dissimilarity.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import elephant.kernels as kernels
1515
from elephant.spike_train_generation import StationaryPoissonProcess
1616
import elephant.spike_train_dissimilarity as stds
17-
17+
import warnings
1818
from elephant.datasets import download_datasets
1919

2020

@@ -47,6 +47,7 @@ def setUp(self):
4747
self.st23 = StationaryPoissonProcess(rate=30 * Hz, t_start=0 * ms,
4848
t_stop=1000 * ms
4949
).generate_spiketrain()
50+
self.st24 = SpikeTrain([0.1782, 0.2286, 0.2804, 0.4972, 0.5504], units='s',t_stop=4.0)
5051
self.rd_st_list = [self.st21, self.st22, self.st23]
5152
self.st31 = SpikeTrain([12.0], units='ms', t_stop=1000.0)
5253
self.st32 = SpikeTrain([12.0, 12.0], units='ms', t_stop=1000.0)
@@ -73,7 +74,7 @@ def setUp(self):
7374
self.tau7 = 0.01 * s
7475
self.q7 = 1.0 / self.tau7
7576
self.t = np.linspace(0, 200, 20000001) * ms
76-
77+
self.tau8 = 0.1 * s
7778
def test_wrong_input(self):
7879
self.assertRaises(TypeError, stds.victor_purpura_distance,
7980
[self.array1, self.array2], self.q3)
@@ -599,7 +600,15 @@ def test_van_rossum_distance(self):
599600
self.assertEqual(stds.van_rossum_distance(
600601
[self.st21], self.tau3)[0, 0], 0)
601602
self.assertEqual(len(stds.van_rossum_distance([], self.tau3)), 0)
602-
603+
604+
# Check small negative values edge case
605+
with warnings.catch_warnings(record='True') as w:
606+
warnings.simplefilter("always")
607+
result = stds.van_rossum_distance([self.st24, self.st24], self.tau8)
608+
self.assertTrue(any("very small negative values encountered" in str(warn.message)
609+
for warn in w))
610+
self.assertEqual(result[0,1], 0.0)
611+
self.assertFalse(np.any(np.isnan(result)))
603612

604613
if __name__ == '__main__':
605614
unittest.main()

0 commit comments

Comments
 (0)