From 07e3fc513e0df4c846ee2c81e5fd12f312bb7c9d Mon Sep 17 00:00:00 2001 From: Harris Antony Jos Date: Wed, 26 Nov 2025 17:03:10 +0100 Subject: [PATCH] Clipped small negative values in van_rossum_distance to prevent Nan from np.sqrt; added test case for same --- elephant/spike_train_dissimilarity.py | 11 ++++++++++- elephant/test/test_spike_train_dissimilarity.py | 15 ++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/elephant/spike_train_dissimilarity.py b/elephant/spike_train_dissimilarity.py index 3234f8916..2a541186e 100644 --- a/elephant/spike_train_dissimilarity.py +++ b/elephant/spike_train_dissimilarity.py @@ -21,7 +21,7 @@ """ from __future__ import division, print_function, unicode_literals - +import warnings import numpy as np import quantities as pq from neo.core import SpikeTrain @@ -363,6 +363,15 @@ def van_rossum_distance(spiketrains, time_constant=1.0 * pq.s, sort=True): for i, j in np.ndindex(k_dist.shape): vr_dist[i, j] = ( k_dist[i, i] + k_dist[j, j] - k_dist[i, j] - k_dist[j, i]) + + # Clip small negative values + if np.any(vr_dist < 0): + warnings.warn( + "van_rossum_distance: very small negative values encountered " + "(likely due to floating point error); " + "setting them to 0", RuntimeWarning) + vr_dist = np.maximum(vr_dist, 0.0) + return np.sqrt(vr_dist) diff --git a/elephant/test/test_spike_train_dissimilarity.py b/elephant/test/test_spike_train_dissimilarity.py index 4619d4bba..5e9b803c3 100644 --- a/elephant/test/test_spike_train_dissimilarity.py +++ b/elephant/test/test_spike_train_dissimilarity.py @@ -14,7 +14,7 @@ import elephant.kernels as kernels from elephant.spike_train_generation import StationaryPoissonProcess import elephant.spike_train_dissimilarity as stds - +import warnings from elephant.datasets import download_datasets @@ -47,6 +47,7 @@ def setUp(self): self.st23 = StationaryPoissonProcess(rate=30 * Hz, t_start=0 * ms, t_stop=1000 * ms ).generate_spiketrain() + self.st24 = SpikeTrain([0.1782, 0.2286, 0.2804, 0.4972, 0.5504], units='s',t_stop=4.0) self.rd_st_list = [self.st21, self.st22, self.st23] self.st31 = SpikeTrain([12.0], units='ms', t_stop=1000.0) self.st32 = SpikeTrain([12.0, 12.0], units='ms', t_stop=1000.0) @@ -73,7 +74,7 @@ def setUp(self): self.tau7 = 0.01 * s self.q7 = 1.0 / self.tau7 self.t = np.linspace(0, 200, 20000001) * ms - + self.tau8 = 0.1 * s def test_wrong_input(self): self.assertRaises(TypeError, stds.victor_purpura_distance, [self.array1, self.array2], self.q3) @@ -599,7 +600,15 @@ def test_van_rossum_distance(self): self.assertEqual(stds.van_rossum_distance( [self.st21], self.tau3)[0, 0], 0) self.assertEqual(len(stds.van_rossum_distance([], self.tau3)), 0) - + + # Check small negative values edge case + with warnings.catch_warnings(record='True') as w: + warnings.simplefilter("always") + result = stds.van_rossum_distance([self.st24, self.st24], self.tau8) + self.assertTrue(any("very small negative values encountered" in str(warn.message) + for warn in w)) + self.assertEqual(result[0,1], 0.0) + self.assertFalse(np.any(np.isnan(result))) if __name__ == '__main__': unittest.main()