Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import elephant.trials
from elephant.conversion import BinnedSpikeTrain
from elephant.utils import deprecated_alias, check_neo_consistency, \
is_time_quantity, round_binning_errors
is_time_quantity, round_binning_errors, is_list_neo_spiketrains

# do not import unicode_literals
# (quantities rescale does not work with unicodes)
Expand Down Expand Up @@ -613,7 +613,7 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',

Parameters
----------
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials # noqa
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other PRs explicitly add links to the documentation of the referred classes using Sphinx's :class: statements, which facilitates navigating through the documentation.

Input spike train(s) for which the instantaneous firing rate is
calculated. If a list of spike trains is supplied, the parameter
pool_spike_trains determines the behavior of the function. If a Trials
Expand Down Expand Up @@ -1031,8 +1031,7 @@ def optimal_kernel(st):
sigma=str(kernel.sigma),
invert=kernel.invert)

if isinstance(spiketrains, neo.core.spiketrainlist.SpikeTrainList) and (
pool_spike_trains):
if is_list_neo_spiketrains(spiketrains) and (pool_spike_trains):
rate = np.mean(rate, axis=1)

rate = neo.AnalogSignal(signal=rate,
Expand Down
58 changes: 53 additions & 5 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,15 +482,15 @@ def test_cv2_raise_error(self):
self.assertRaises(ValueError, statistics.cv2, np.array([seq, seq]))


class InstantaneousRateTest(unittest.TestCase):
class InstantaneousRateTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
"""
Run once before tests:
"""

block = _create_trials_block(n_trials=36)
block = _create_trials_block(n_trials=36, n_spiketrains=5)
cls.block = block
cls.trial_object = TrialsFromBlock(block,
description='trials are segments')
Expand Down Expand Up @@ -988,8 +988,44 @@ def test_instantaneous_rate_trials_pool_trials(self):
pool_spike_trains=False,
pool_trials=True)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test checks for consistency between the dimensions of the computed rate and the data in the Trials objects. However, an additional test against the integer values would protect against errors in implementing the object attributes. These hard expected values are supposed to be known when designing the test data in line 493.


def test_instantaneous_rate_list_pool_spike_trains(self):
def test_instantaneous_rate_trials_pool_spiketrains(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A suggestion to improve understanding of these unit tests, where the dimensions of the outputs change depending on the pooling, is to include a comment block at the beginning to explicitly state the input dimensions --> expected output dimensions. It would be easier to assess the behavior of the object and what is being tested in each test case.

kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, where the test does not compare directly to the expected values, but retrieves them dynamically from the objects.

self.assertEqual(rate[0].shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_pool_trials(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=True)
self.assertIsInstance(rate, neo.AnalogSignal)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_false_pool_trials_false(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, where the test does not compare directly to the expected values, but retrieves them dynamically from the objects.

self.assertEqual(rate[0].shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])

def test_instantaneous_rate_spiketrainlist_pool_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(
Expand All @@ -999,7 +1035,19 @@ def test_instantaneous_rate_list_pool_spike_trains(self):
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 1)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_list_pool_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(
list(self.trial_object.get_spiketrains_from_trial_as_list(0)),
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_list_of_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)
Expand All @@ -1010,7 +1058,7 @@ def test_instantaneous_rate_list_of_spike_trains(self):
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 2)
self.assertEqual(rate.magnitude.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, where the test does not compare directly to the expected values, but retrieves them dynamically from the objects. The integer in the old test could be updated, and the new one just added as an additional check.



class TimeHistogramTestCase(unittest.TestCase):
Expand Down
37 changes: 37 additions & 0 deletions elephant/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,42 @@ def test_decorator_return_with_list_of_lists_input_as_kwarg(self):
self.assertIsInstance(spiketrain, SpikeTrain)


class TestIsListNeoSpiketrains(unittest.TestCase):
def setUp(self):
# Set up common test spiketrains.
self.spiketrain1 = neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=4 * pq.s)
self.spiketrain2 = neo.SpikeTrain([2, 3, 4] * pq.s, t_stop=5 * pq.s)

def test_valid_list_input(self):
valid_list = [self.spiketrain1, self.spiketrain2]
self.assertTrue(utils.is_list_neo_spiketrains(valid_list))

def test_valid_tuple_input(self):
valid_tuple = (self.spiketrain1, self.spiketrain2)
self.assertTrue(utils.is_list_neo_spiketrains(valid_tuple))

def test_valid_spiketrainlist_input(self):
valid_spiketrainlist = neo.core.spiketrainlist.SpikeTrainList(items=(self.spiketrain1, self.spiketrain2))
self.assertTrue(utils.is_list_neo_spiketrains(valid_spiketrainlist))

def test_non_iterable_input(self):
with self.assertRaises(TypeError):
utils.is_list_neo_spiketrains(42)

def test_non_spiketrain_objects(self):
invalid_list = [self.spiketrain1, "not a spiketrain"]
with self.assertRaises(TypeError):
utils.is_list_neo_spiketrains(invalid_list)

def test_mixed_types_input(self):
invalid_mixed = [self.spiketrain1, 42, self.spiketrain2]
with self.assertRaises(TypeError):
utils.is_list_neo_spiketrains(invalid_mixed)

def test_none_input(self):
with self.assertRaises(TypeError):
utils.is_list_neo_spiketrains(None)


if __name__ == '__main__':
unittest.main()
34 changes: 33 additions & 1 deletion elephant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
check_neo_consistency
check_same_units
round_binning_errors
is_list_neo_spiketrains
"""

from __future__ import division, print_function, unicode_literals
Expand All @@ -21,7 +22,8 @@
import quantities as pq

from elephant.trials import Trials

import collections.abc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from collections.abc import Iterable. Improved readability and reduced number of elements in the namespace.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also avoids always loading additional objects with attribute statements each time is_list_spiketrains is executed.

import neo

__all__ = [
"deprecated_alias",
Expand All @@ -31,6 +33,7 @@
"check_neo_consistency",
"check_same_units",
"round_binning_errors",
"is_list_neo_spiketrains",
]


Expand Down Expand Up @@ -446,3 +449,32 @@ def wrapper(*args, **kwargs):
return method(*new_args, **new_kwargs)

return wrapper


def is_list_neo_spiketrains(obj: object) -> bool:
"""
Check if input is an iterable containing only neo.SpikeTrain objects.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrections:

  • ... if the input ...
  • highlight neo.SpikeTrain using backticks or the class reference


Parameters
----------
obj : object
The object to check. Can be a neo.spiketrainlist, list, tuple or any other iterable.

Returns
-------
bool
True if obj is an iterable containing only neo.SpikeTrain objects.
Raises
------
TypeError
If obj is not an iterable, or if any element is not a neo.SpikeTrain.

"""

if not isinstance(obj, collections.abc.Iterable):
raise TypeError("Input must be an iterable (list, tuple, etc.)")

if not all(isinstance(st, neo.SpikeTrain) for st in obj):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be replaced by a direct return statement:

return isinstance(obj, Iterable) and all(isinstance(st, neo.SpikeTrain) for st in obj)

AND statements are evaluated lazyly from left to right. Whenever an non-Iterable is passed, the expression will be False and evaluation stops.

raise TypeError("All elements must be neo.SpikeTrain objects")

return True
Loading