Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',

Parameters
----------
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials # noqa
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other PRs explicitly add links to the documentation of the referred classes using Sphinx's :class: statements, which facilitates navigating through the documentation.

Input spike train(s) for which the instantaneous firing rate is
calculated. If a list of spike trains is supplied, the parameter
pool_spike_trains determines the behavior of the function. If a Trials
Expand Down Expand Up @@ -1031,7 +1031,7 @@ def optimal_kernel(st):
sigma=str(kernel.sigma),
invert=kernel.invert)

if isinstance(spiketrains, neo.core.spiketrainlist.SpikeTrainList) and (
if isinstance(spiketrains, (neo.core.spiketrainlist.SpikeTrainList, list, tuple)) and (
pool_spike_trains):
rate = np.mean(rate, axis=1)

Expand Down
58 changes: 53 additions & 5 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,15 +482,15 @@ def test_cv2_raise_error(self):
self.assertRaises(ValueError, statistics.cv2, np.array([seq, seq]))


class InstantaneousRateTest(unittest.TestCase):
class InstantaneousRateTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
"""
Run once before tests:
"""

block = _create_trials_block(n_trials=36)
block = _create_trials_block(n_trials=36, n_spiketrains=5)
cls.block = block
cls.trial_object = TrialsFromBlock(block,
description='trials are segments')
Expand Down Expand Up @@ -988,8 +988,44 @@ def test_instantaneous_rate_trials_pool_trials(self):
pool_spike_trains=False,
pool_trials=True)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test checks for consistency between the dimensions of the computed rate and the data in the Trials objects. However, an additional test against the integer values would protect against errors in implementing the object attributes. These hard expected values are supposed to be known when designing the test data in line 493.


def test_instantaneous_rate_list_pool_spike_trains(self):
def test_instantaneous_rate_trials_pool_spiketrains(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A suggestion to improve understanding of these unit tests, where the dimensions of the outputs change depending on the pooling, is to include a comment block at the beginning to explicitly state the input dimensions --> expected output dimensions. It would be easier to assess the behavior of the object and what is being tested in each test case.

kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, where the test does not compare directly to the expected values, but retrieves them dynamically from the objects.

self.assertEqual(rate[0].shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_pool_trials(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=True)
self.assertIsInstance(rate, neo.AnalogSignal)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_false_pool_trials_false(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, where the test does not compare directly to the expected values, but retrieves them dynamically from the objects.

self.assertEqual(rate[0].shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])

def test_instantaneous_rate_spiketrainlist_pool_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(
Expand All @@ -999,7 +1035,19 @@ def test_instantaneous_rate_list_pool_spike_trains(self):
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 1)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_list_pool_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(
list(self.trial_object.get_spiketrains_from_trial_as_list(0)),
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_list_of_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)
Expand All @@ -1010,7 +1058,7 @@ def test_instantaneous_rate_list_of_spike_trains(self):
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 2)
self.assertEqual(rate.magnitude.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, where the test does not compare directly to the expected values, but retrieves them dynamically from the objects. The integer in the old test could be updated, and the new one just added as an additional check.



class TimeHistogramTestCase(unittest.TestCase):
Expand Down
Loading