Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion elephant/spike_train_dissimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""

from __future__ import division, print_function, unicode_literals

import warnings
import numpy as np
import quantities as pq
from neo.core import SpikeTrain
Expand Down Expand Up @@ -363,6 +363,15 @@ def van_rossum_distance(spiketrains, time_constant=1.0 * pq.s, sort=True):
for i, j in np.ndindex(k_dist.shape):
vr_dist[i, j] = (
k_dist[i, i] + k_dist[j, j] - k_dist[i, j] - k_dist[j, i])

# Clip small negative values
if np.any(vr_dist < 0):
warnings.warn(
"van_rossum_distance: very small negative values encountered "
"(likely due to floating point error); "
"setting them to 0", RuntimeWarning)
vr_dist = np.maximum(vr_dist, 0.0)

return np.sqrt(vr_dist)


Expand Down
15 changes: 12 additions & 3 deletions elephant/test/test_spike_train_dissimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import elephant.kernels as kernels
from elephant.spike_train_generation import StationaryPoissonProcess
import elephant.spike_train_dissimilarity as stds

import warnings
from elephant.datasets import download_datasets


Expand Down Expand Up @@ -47,6 +47,7 @@ def setUp(self):
self.st23 = StationaryPoissonProcess(rate=30 * Hz, t_start=0 * ms,
t_stop=1000 * ms
).generate_spiketrain()
self.st24 = SpikeTrain([0.1782, 0.2286, 0.2804, 0.4972, 0.5504], units='s',t_stop=4.0)
self.rd_st_list = [self.st21, self.st22, self.st23]
self.st31 = SpikeTrain([12.0], units='ms', t_stop=1000.0)
self.st32 = SpikeTrain([12.0, 12.0], units='ms', t_stop=1000.0)
Expand All @@ -73,7 +74,7 @@ def setUp(self):
self.tau7 = 0.01 * s
self.q7 = 1.0 / self.tau7
self.t = np.linspace(0, 200, 20000001) * ms

self.tau8 = 0.1 * s
def test_wrong_input(self):
self.assertRaises(TypeError, stds.victor_purpura_distance,
[self.array1, self.array2], self.q3)
Expand Down Expand Up @@ -599,7 +600,15 @@ def test_van_rossum_distance(self):
self.assertEqual(stds.van_rossum_distance(
[self.st21], self.tau3)[0, 0], 0)
self.assertEqual(len(stds.van_rossum_distance([], self.tau3)), 0)


# Check small negative values edge case
with warnings.catch_warnings(record='True') as w:
warnings.simplefilter("always")
result = stds.van_rossum_distance([self.st24, self.st24], self.tau8)
self.assertTrue(any("very small negative values encountered" in str(warn.message)
for warn in w))
self.assertEqual(result[0,1], 0.0)
self.assertFalse(np.any(np.isnan(result)))

if __name__ == '__main__':
unittest.main()
Loading