diff --git a/elephant/statistics.py b/elephant/statistics.py index 45d9cd283..6a8fdcb10 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -74,7 +74,7 @@ import scipy.signal from numpy import ndarray from scipy.special import erf -from typing import Union +from typing import List, Optional, Union import elephant.conversion as conv import elephant.kernels as kernels @@ -1062,46 +1062,51 @@ def optimal_kernel(st): @deprecated_alias(binsize='bin_size') -def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, - output='counts', binary=False): +def time_histogram(spiketrains: Union[List[neo.SpikeTrain], neo.SpikeTrain], + bin_size: pq.Quantity, + t_start: Optional[pq.Quantity] = None, + t_stop: Optional[pq.Quantity] = None, + output: str = 'counts', + binary: bool = False) -> neo.AnalogSignal: """ - Time Histogram of a list of `neo.SpikeTrain` objects. + Time Histogram of a list of :class:`neo.core.SpikeTrain` objects. Visualization of this function is covered in Viziphant: :func:`viziphant.statistics.plot_time_histogram`. Parameters ---------- - spiketrains : list of neo.SpikeTrain - `neo.SpikeTrain`s with a common time axis (same `t_start` and `t_stop`) + spiketrains : list of :class:`neo.core.SpikeTrain` or :class:`neo.core.SpikeTrain` + `neo.SpikeTrain` objects with a common time axis (same `t_start` and `t_stop`) bin_size : pq.Quantity Width of the histogram's time bins. t_start : pq.Quantity, optional Start time of the histogram. Only events in `spiketrains` falling between `t_start` and `t_stop` (both included) are considered in the histogram. - If None, the maximum `t_start` of all `neo.SpikeTrain`s is used as + If None, the maximum `t_start` of all :class:`neo.core.SpikeTrain`s is used as `t_start`. Default: None t_stop : pq.Quantity, optional Stop time of the histogram. Only events in `spiketrains` falling between `t_start` and `t_stop` (both included) are considered in the histogram. - If None, the minimum `t_stop` of all `neo.SpikeTrain`s is used as + If None, the minimum `t_stop` of all :class:`neo.core.SpikeTrain` s is used as `t_stop`. Default: None output : {'counts', 'mean', 'rate'}, optional Normalization of the histogram. Can be one of: - * 'counts': spike counts at each bin (as integer numbers). - * 'mean': mean spike counts per spike train. - * 'rate': mean spike rate per spike train. Like 'mean', but the - counts are additionally normalized by the bin width. + + - 'counts': spike counts at each bin (as integer numbers). + - 'mean': mean spike counts per spike train. + - 'rate': mean spike rate per spike train. Like 'mean', but the counts are additionally normalized + by the bin width. Default: 'counts' binary : bool, optional - If True, indicates whether all `neo.SpikeTrain` objects should first + If True, indicates whether all :class:`neo.core.SpikeTrain` objects should first be binned to a binary representation (using the - `conversion.BinnedSpikeTrain` class) and the calculation of the + [:class:`elephant.conversion.BinnedSpikeTrain` class] and the calculation of the histogram is based on this representation. Note that the output is not binary, but a histogram of the converted, binary representation. @@ -1110,8 +1115,8 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, Returns ------- neo.AnalogSignal - A `neo.AnalogSignal` object containing the histogram values. - `neo.AnalogSignal[j]` is the histogram computed between + A :class:`neo.core.SpikeTrain` object containing the histogram values. + :class:`neo.core.SpikeTrain `[j]` is the histogram computed between `t_start + j * bin_size` and `t_start + (j + 1) * bin_size`. Raises @@ -1129,7 +1134,7 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, See also -------- - elephant.conversion.BinnedSpikeTrain + :func:`elephant.conversion.BinnedSpikeTrain` Examples -------- @@ -1178,17 +1183,17 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, def _counts() -> pq.Quantity: # 'counts': spike counts at each bin (as integer numbers). - return pq.Quantity(bin_hist, units=pq.dimensionless, copy=False) + return pq.Quantity(bin_hist, units=pq.dimensionless) def _mean() -> pq.Quantity: # 'mean': mean spike counts per spike train. - return pq.Quantity(bin_hist / len(spiketrains), - units=pq.dimensionless, copy=False) + return pq.Quantity(bin_hist / binned_spiketrain.shape[0], + units=pq.dimensionless) def _rate() -> pq.Quantity: # 'rate': mean spike rate per spike train. Like 'mean', but the # counts are additionally normalized by the bin width. - return bin_hist / (len(spiketrains) * bin_size) + return bin_hist / (binned_spiketrain.shape[0] * bin_size) output_mapping = {"counts": _counts, "mean": _mean, "rate": _rate} try: diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 426111810..1e6a8fec3 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -1089,6 +1089,20 @@ def test_annotations(self): self.assertIn('normalization', histogram.annotations) self.assertEqual(histogram.annotations['normalization'], output) + def test_time_histogram_regression_648_single_spiketrain(self): + # Create a single spike train + spiketrain = neo.SpikeTrain([0.1, 0.5, 1.0, 1.5, 2.0] * pq.s, t_stop=3.0 * pq.s) + + # Run time_histogram with spiketrain directly and observe the incorrect result + histogram_direct = statistics.time_histogram(spiketrain, output='rate', bin_size=0.5 * pq.s) + + # Wrap spiketrain in a list and run time_histogram + histogram_wrapped = statistics.time_histogram([spiketrain], output='rate', bin_size=0.5 * pq.s) + # Check if passing a single spiketrain directly vs in a list gives same result + np.testing.assert_array_equal(histogram_direct.magnitude, histogram_wrapped.magnitude) + # Check if the spike rate calculation is correct for a single spike train + np.testing.assert_array_equal(histogram_direct.magnitude.flatten(), [2., 2., 2., 2., 2., 0.]*pq.Hz) + class ComplexityTestCase(unittest.TestCase): def test_complexity_pdf_deprecated(self):