From f489b1bbb78a1382e7c324374e5b97bb83e10ee5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Wed, 18 Jun 2025 17:42:23 +0200 Subject: [PATCH 01/40] Added methods to get all repetitions of a specific analog signal or spiketrain across trials --- elephant/trials.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/elephant/trials.py b/elephant/trials.py index cd006addd..ee7e2b8fc 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -41,6 +41,7 @@ from abc import ABCMeta, abstractmethod from typing import List +import numpy as np import neo.utils from neo.core import Segment, Block from neo.core.spiketrainlist import SpikeTrainList @@ -309,6 +310,12 @@ def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( segment.spiketrains.append(spiketrain) return segment + def get_spiketrains_trial_by_trial(self, spiketrain_id: int) -> ( + neo.core.spiketrainlist.SpikeTrainList): + # Return a list of all spike train repetitions across trials + return SpikeTrainList(items=[segment.spiketrains[spiketrain_id] for + segment in self.block.segments]) + def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( List[neo.core.AnalogSignal]): # Return a list of all analogsignals from a trial @@ -324,6 +331,12 @@ def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( segment.analogsignals.append(analogsignal) return segment + def get_analogsignals_trial_by_trial(self, signal_id: int) -> ( + List[neo.core.AnalogSignal]): + # Return a list of all analog signal repetitions across trials + return [segment.analogsignals[signal_id] + for segment in self.block.segments] + class TrialsFromLists(Trials): """ @@ -346,6 +359,15 @@ def __init__(self, list_of_trials: list, **kwargs): self.list_of_trials = list_of_trials super().__init__(**kwargs) + # Save indexes for quick search of spike trains or analog signals + # in a trial. The order of elements in the inner list must be + # consistent across all trials (using the first list, corresponding + # to the first trial, to fetch the indexes). + is_spiketrain = np.array([isinstance(data_element, neo.SpikeTrain) + for data_element in list_of_trials[0]]) + self._spiketrain_index = is_spiketrain.nonzero()[0] + self._analogsignal_index = (~is_spiketrain).nonzero()[0] + def __getitem__(self, trial_number: int) -> neo.core.Segment: # Get a specific trial by number segment = Segment() @@ -410,6 +432,13 @@ def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( segment.spiketrains.append(spiketrain) return segment + def get_spiketrains_trial_by_trial(self, spiketrain_id: int) -> ( + neo.core.spiketrainlist.SpikeTrainList): + # Return a list of all spike train repetitions across trials + list_idx = self._spiketrain_index[spiketrain_id] + return SpikeTrainList(items=[trial[list_idx] + for trial in self.list_of_trials]) + def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( List[neo.core.AnalogSignal]): # Return a list of all analogsignals from a trial @@ -425,3 +454,9 @@ def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( trial_id): segment.analogsignals.append(analogsignal) return segment + + def get_analogsignals_trial_by_trial(self, signal_id: int) -> ( + List[neo.core.AnalogSignal]): + # Return a list of all analog signal repetitions across trials + list_idx = self._analogsignal_index[signal_id] + return [trial[list_idx] for trial in self.list_of_trials] From d33d4dc873e663e1aee10306758bd1bdb1180b0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Wed, 18 Jun 2025 18:15:57 +0200 Subject: [PATCH 02/40] Added docstrings for the new methods --- elephant/trials.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/elephant/trials.py b/elephant/trials.py index ee7e2b8fc..8c2cb56c8 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -233,6 +233,58 @@ def get_analogsignals_from_trial_as_segment(self, trial_id: int the trial. """ + @abstractmethod + def get_spiketrains_trial_by_trial(self, spiketrain_id: int) -> ( + neo.core.spiketrainlist.SpikeTrainList): + """ + Retrieve a spike train across all its trial repetitions. + + This method returns a list containing :class:`neo.core.SpikeTrain` + objects corresponding to the same spike train (e.g., from a consistent + recording channel or neuronal source) across multiple trials. + + Parameters + ---------- + spiketrain_id : int + Index of the spike train to retrieve across trials. Indexing + starts at 0, so `spiketrain_id == 0` corresponds to the first + spike train in the trial data. + + Returns + ------- + list of :class:`neo.core.SpikeTrain` + A list-like container with the :class:`neo.core.SpikeTrain` + objects for the specified `spiketrain_id`, ordered from the first + trial (ID 0) to the last (ID `n_trials - 1`). + """ + pass + + @abstractmethod + def get_analogsignals_trial_by_trial(self, signal_id: int + ) -> List[neo.core.AnalogSignal]: + """ + Retrieve an analog signal across all its trial repetitions. + + This method returns a list containing :class:`neo.core.AnalogSignal` + objects corresponding to a continuous signal recorded from a consistent + recording channel or neuronal source across multiple trials. + + Parameters + ---------- + signal_id : int + Index of the analog signal to retrieve across trials. Indexing + starts at 0, so `signal_id == 0` corresponds to the first + analog signal in the trial data. + + Returns + ------- + list of :class:`neo.core.AnalogSignal` + A list with the :class:`neo.core.AnalogSignal` objects for the + specified `signal_id`, ordered from the first trial (ID 0) to the + last (ID `n_trials - 1`). + """ + pass + class TrialsFromBlock(Trials): """ From 1c3a1f040d6eeccbcb1cd0b71cb6a9e2791bf8e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Wed, 18 Jun 2025 18:17:14 +0200 Subject: [PATCH 03/40] Added IDs to the spiketrains and analog signals in the generated test data --- elephant/test/test_trials.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index b472e7a8e..9e8724aae 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -27,9 +27,15 @@ def _create_trials_block(n_trials: int = 0, t_stop=1000 * pq.ms ).generate_n_spiketrains( n_spiketrains=n_spiketrains) + for idx, st in enumerate(spiketrains): + st.name = f"Spiketrain {idx}" + st.description = f"Trial {trial}" + analogsignals = [AnalogSignal(signal=[.01, 3.3, 9.3], units='uV', - sampling_rate=1 * pq.Hz) - for _ in range(n_analogsignals)] + sampling_rate=1 * pq.Hz, + name=f"Signal {idx}", + description=f"Trial {trial}") + for idx in range(n_analogsignals)] for spiketrain in spiketrains: segment.spiketrains.append(spiketrain) for analogsignal in analogsignals: From a92413b2713659df33000aad966164a17f29b3e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Wed, 18 Jun 2025 18:17:52 +0200 Subject: [PATCH 04/40] Added unit tests for the trial-by-trial methods --- elephant/test/test_trials.py | 129 +++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 9e8724aae..5b9a48898 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -183,6 +183,70 @@ def test_trials_from_list_get_analogsignals_from_trial_as_segment(self) \ self.trial_object.get_analogsignals_from_trial_as_segment( 0).analogsignals[0], neo.core.AnalogSignal) + def test_trials_from_block_get_spiketrains_trial_by_trial(self) -> None: + """ + Test accessing all the SpikeTrain objects corresponding to the + repetitions of a spiketrain across the trials. + """ + for st_id in (0, 1): + spiketrains = self.trial_object.get_spiketrains_trial_by_trial(st_id) + + # Return is neo.SpikeTrainList + self.assertIsInstance(spiketrains, + neo.core.spiketrainlist.SpikeTrainList) + + # All elements are neo.SpikeTrain + self.assertTrue(all(map(lambda x: isinstance(x, neo.SpikeTrain), + spiketrains) + ) + ) + + # Data for all trials is returned + self.assertEqual(len(spiketrains), self.trial_object.n_trials) + + # Each trial-specific SpikeTrain object is from the same spiketrain + self.assertTrue(all([st.name == f"Spiketrain {st_id}" + for st in spiketrains] + ) + ) + + # Order of spiketrains is the order of the trials + expected_trials = [f"Trial {i}" + for i in range(self.trial_object.n_trials)] + self.assertListEqual([st.description for st in spiketrains], + expected_trials) + + def test_trials_from_block_get_analogsignals_trial_by_trial(self) -> None: + """ + Test accessing all the AnalogSignal objects corresponding to the + repetitions of an analog signal across the trials. + """ + for as_id in (0, 1): + signals = self.trial_object.get_analogsignals_trial_by_trial(as_id) + + # Return is list + self.assertIsInstance(signals, list) + + # All elements are neo.AnalogSignal + self.assertTrue(all(map(lambda x: isinstance(x, neo.AnalogSignal), + signals) + ) + ) + # Data for all trials returned + self.assertEqual(len(signals), self.trial_object.n_trials) + + # Each trial-specific AnalogSignal object is from the same signal + self.assertTrue(all([signal.name == f"Signal {as_id}" + for signal in signals] + ) + ) + + # Order in the list is the order of the trials + expected_trials = [f"Trial {i}" + for i in range(self.trial_object.n_trials)] + self.assertListEqual([signal.description for signal in signals], + expected_trials) + class TrialsFromListTestCase(unittest.TestCase): """Tests for elephant.trials.TrialsFromList class""" @@ -334,6 +398,71 @@ def test_trials_from_list_get_analogsignals_from_trial_as_segment(self self.trial_object.get_analogsignals_from_trial_as_segment( 0).analogsignals[0], neo.core.AnalogSignal) + def test_trials_from_list_get_spiketrains_trial_by_trial(self) -> None: + """ + Test accessing all the SpikeTrain objects corresponding to the + repetitions of a spiketrain across the trials. + """ + for st_id in (0, 1): + spiketrains = self.trial_object.get_spiketrains_trial_by_trial( + st_id) + + # Return is neo.SpikeTrainList + self.assertIsInstance(spiketrains, + neo.core.spiketrainlist.SpikeTrainList) + + # All elements are neo.SpikeTrain + self.assertTrue(all(map(lambda x: isinstance(x, neo.SpikeTrain), + spiketrains) + ) + ) + + # Data for all trials is returned + self.assertEqual(len(spiketrains), self.trial_object.n_trials) + + # Each trial-specific SpikeTrain object is from the same spiketrain + self.assertTrue(all([st.name == f"Spiketrain {st_id}" + for st in spiketrains] + ) + ) + + # Order of spiketrains is the order of the trials + expected_trials = [f"Trial {i}" + for i in range(self.trial_object.n_trials)] + self.assertListEqual([st.description for st in spiketrains], + expected_trials) + + def test_trials_from_list_get_analogsignals_trial_by_trial(self) -> None: + """ + Test accessing all the AnalogSignal objects corresponding to the + repetitions of an analog signal across the trials. + """ + for as_id in (0, 1): + signals = self.trial_object.get_analogsignals_trial_by_trial(as_id) + + # Return is list + self.assertIsInstance(signals, list) + + # All elements are neo.AnalogSignal + self.assertTrue(all(map(lambda x: isinstance(x, neo.AnalogSignal), + signals) + ) + ) + # Data for all trials returned + self.assertEqual(len(signals), self.trial_object.n_trials) + + # Each trial-specific AnalogSignal object is from the same signal + self.assertTrue(all([signal.name == f"Signal {as_id}" + for signal in signals] + ) + ) + + # Order in the list is the order of the trials + expected_trials = [f"Trial {i}" + for i in range(self.trial_object.n_trials)] + self.assertListEqual([signal.description for signal in signals], + expected_trials) + if __name__ == '__main__': unittest.main() From bbbfcc65a29ecc491365a7b1c19c7d210a285cc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Wed, 18 Jun 2025 18:18:44 +0200 Subject: [PATCH 05/40] Corrected test names for the TrialsFromBlock test case --- elephant/test/test_trials.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 5b9a48898..807786d77 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -148,8 +148,8 @@ def test_trials_from_block_get_spiketrains_from_trial_as_list(self self.trial_object.get_spiketrains_from_trial_as_list(0)[0], neo.core.SpikeTrain) - def test_trials_from_list_get_spiketrains_from_trial_as_segment(self - ) -> None: + def test_trials_from_block_get_spiketrains_from_trial_as_segment(self + ) -> None: """ Test get spiketrains from trial as segment """ From d813a7e998865ea56e39a1a9cebb3601c43108ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Wed, 18 Jun 2025 18:19:14 +0200 Subject: [PATCH 06/40] Corrected test names for the TrialsFromBlock test case --- elephant/test/test_trials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 807786d77..f2d7661b6 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -171,7 +171,7 @@ def test_trials_from_block_get_analogsignals_from_trial_as_list(self self.trial_object.get_analogsignals_from_trial_as_list(0)[0], neo.core.AnalogSignal) - def test_trials_from_list_get_analogsignals_from_trial_as_segment(self) \ + def test_trials_from_block_get_analogsignals_from_trial_as_segment(self) \ -> None: """ Test get spiketrains from trial as segment From 211173ba9f6bbcbccd98fe1cda4a8934f98b0a6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Mon, 22 Sep 2025 10:58:46 +0200 Subject: [PATCH 07/40] Fixed error if empty list is passed to TrialsFromLists --- elephant/trials.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index 8c2cb56c8..6022baaec 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -415,10 +415,14 @@ def __init__(self, list_of_trials: list, **kwargs): # in a trial. The order of elements in the inner list must be # consistent across all trials (using the first list, corresponding # to the first trial, to fetch the indexes). - is_spiketrain = np.array([isinstance(data_element, neo.SpikeTrain) - for data_element in list_of_trials[0]]) - self._spiketrain_index = is_spiketrain.nonzero()[0] - self._analogsignal_index = (~is_spiketrain).nonzero()[0] + if list_of_trials: + is_spiketrain = np.array([isinstance(data_element, neo.SpikeTrain) + for data_element in list_of_trials[0]]) + self._spiketrain_index = is_spiketrain.nonzero()[0] + self._analogsignal_index = (~is_spiketrain).nonzero()[0] + else: + self._spiketrain_index = [] + self._analogsignal_index = [] def __getitem__(self, trial_number: int) -> neo.core.Segment: # Get a specific trial by number From 6f8be635f0eb0084378e76264009955762f7d400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 25 Sep 2025 10:31:16 +0200 Subject: [PATCH 08/40] Moved decorator to extract trials as list of SpikeTrainList from utils.py to trials.py. GPFA methods and unit tests adjusted. --- elephant/gpfa/gpfa.py | 3 +- elephant/test/test_trials.py | 77 +++++++++++++++++++++++++++++++++- elephant/test/test_utils.py | 80 ------------------------------------ elephant/trials.py | 54 ++++++++++++++++++++++++ elephant/utils.py | 51 ----------------------- 5 files changed, 131 insertions(+), 134 deletions(-) diff --git a/elephant/gpfa/gpfa.py b/elephant/gpfa/gpfa.py index 79d490e0d..5f39ae2f6 100644 --- a/elephant/gpfa/gpfa.py +++ b/elephant/gpfa/gpfa.py @@ -76,8 +76,7 @@ import sklearn from elephant.gpfa import gpfa_core, gpfa_util -from elephant.trials import Trials -from elephant.utils import trials_to_list_of_spiketrainlist +from elephant.trials import Trials, trials_to_list_of_spiketrainlist __all__ = ["GPFA"] diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index f2d7661b6..221ba7be0 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -12,7 +12,8 @@ from neo.core import AnalogSignal from elephant.spike_train_generation import StationaryPoissonProcess -from elephant.trials import TrialsFromBlock, TrialsFromLists +from elephant.trials import (TrialsFromBlock, TrialsFromLists, + trials_to_list_of_spiketrainlist) def _create_trials_block(n_trials: int = 0, @@ -48,6 +49,80 @@ def _create_trials_block(n_trials: int = 0, # Tests # ######### +class DecoratorTest: + """ + This class is used as a mock for testing the decorator. + """ + @trials_to_list_of_spiketrainlist + def method_to_decorate(self, trials=None, trials_obj=None): + # This is just a mock implementation for testing purposes + if trials_obj: + return trials_obj + else: + return trials + + +class TestTrialsToListOfSpiketrainlist(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.n_channels = 10 + cls.n_trials = 5 + cls.list_of_list_of_spiketrains = [ + StationaryPoissonProcess(rate=5 * pq.Hz, t_stop=1000.0 * pq.ms + ).generate_n_spiketrains(cls.n_channels) + for _ in range(cls.n_trials)] + cls.trial_object = TrialsFromLists(cls.list_of_list_of_spiketrains) + + def test_decorator_applied(self): + # Test that the decorator is applied correctly + self.assertTrue(hasattr( + DecoratorTest.method_to_decorate, '__wrapped__' + )) + + def test_decorator_return_with_trials_input_as_arg(self): + # Test if decorator takes in trial-object and returns + # list of spiketrainlists + new_class = DecoratorTest() + list_of_spiketrainlists = new_class.method_to_decorate( + self.trial_object) + self.assertEqual(len(list_of_spiketrainlists), self.n_trials) + for spiketrainlist in list_of_spiketrainlists: + self.assertIsInstance(spiketrainlist, SpikeTrainList) + + def test_decorator_return_with_list_of_lists_input_as_arg(self): + # Test if decorator takes in list of lists of spiketrains + # and does not change input + new_class = DecoratorTest() + list_of_list_of_spiketrains = new_class.method_to_decorate( + self.list_of_list_of_spiketrains) + self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) + for list_of_spiketrains in list_of_list_of_spiketrains: + self.assertIsInstance(list_of_spiketrains, list) + for spiketrain in list_of_spiketrains: + self.assertIsInstance(spiketrain, SpikeTrain) + + def test_decorator_return_with_trials_input_as_kwarg(self): + # Test if decorator takes in trial-object and returns + # list of spiketrainlists + new_class = DecoratorTest() + list_of_spiketrainlists = new_class.method_to_decorate( + trials_obj=self.trial_object) + self.assertEqual(len(list_of_spiketrainlists), self.n_trials) + for spiketrainlist in list_of_spiketrainlists: + self.assertIsInstance(spiketrainlist, SpikeTrainList) + + def test_decorator_return_with_list_of_lists_input_as_kwarg(self): + # Test if decorator takes in list of lists of spiketrains + # and does not change input + new_class = DecoratorTest() + list_of_list_of_spiketrains = new_class.method_to_decorate( + trials_obj=self.list_of_list_of_spiketrains) + self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) + for list_of_spiketrains in list_of_list_of_spiketrains: + self.assertIsInstance(list_of_spiketrains, list) + for spiketrain in list_of_spiketrains: + self.assertIsInstance(spiketrain, SpikeTrain) + class TrialsFromBlockTestCase(unittest.TestCase): """Tests for elephant.trials.TrialsFromBlock class""" diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py index cc927d53e..ef935f335 100644 --- a/elephant/test/test_utils.py +++ b/elephant/test/test_utils.py @@ -12,11 +12,6 @@ from elephant import utils from numpy.testing import assert_array_equal -from elephant.spike_train_generation import StationaryPoissonProcess -from elephant.trials import TrialsFromLists -from neo.core.spiketrainlist import SpikeTrainList -from neo.core import SpikeTrain - class TestUtils(unittest.TestCase): @@ -61,80 +56,5 @@ def test_round_binning_errors(self): [0, 0, 0, 0]) -class DecoratorTest: - """ - This class is used as a mock for testing the decorator. - """ - @utils.trials_to_list_of_spiketrainlist - def method_to_decorate(self, trials=None, trials_obj=None): - # This is just a mock implementation for testing purposes - if trials_obj: - return trials_obj - else: - return trials - - -class TestTrialsToListOfSpiketrainlist(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.n_channels = 10 - cls.n_trials = 5 - cls.list_of_list_of_spiketrains = [ - StationaryPoissonProcess(rate=5 * pq.Hz, t_stop=1000.0 * pq.ms - ).generate_n_spiketrains(cls.n_channels) - for _ in range(cls.n_trials)] - cls.trial_object = TrialsFromLists(cls.list_of_list_of_spiketrains) - - def test_decorator_applied(self): - # Test that the decorator is applied correctly - self.assertTrue(hasattr( - DecoratorTest.method_to_decorate, '__wrapped__' - )) - - def test_decorator_return_with_trials_input_as_arg(self): - # Test if decorator takes in trial-object and returns - # list of spiketrainlists - new_class = DecoratorTest() - list_of_spiketrainlists = new_class.method_to_decorate( - self.trial_object) - self.assertEqual(len(list_of_spiketrainlists), self.n_trials) - for spiketrainlist in list_of_spiketrainlists: - self.assertIsInstance(spiketrainlist, SpikeTrainList) - - def test_decorator_return_with_list_of_lists_input_as_arg(self): - # Test if decorator takes in list of lists of spiketrains - # and does not change input - new_class = DecoratorTest() - list_of_list_of_spiketrains = new_class.method_to_decorate( - self.list_of_list_of_spiketrains) - self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) - for list_of_spiketrains in list_of_list_of_spiketrains: - self.assertIsInstance(list_of_spiketrains, list) - for spiketrain in list_of_spiketrains: - self.assertIsInstance(spiketrain, SpikeTrain) - - def test_decorator_return_with_trials_input_as_kwarg(self): - # Test if decorator takes in trial-object and returns - # list of spiketrainlists - new_class = DecoratorTest() - list_of_spiketrainlists = new_class.method_to_decorate( - trials_obj=self.trial_object) - self.assertEqual(len(list_of_spiketrainlists), self.n_trials) - for spiketrainlist in list_of_spiketrainlists: - self.assertIsInstance(spiketrainlist, SpikeTrainList) - - def test_decorator_return_with_list_of_lists_input_as_kwarg(self): - # Test if decorator takes in list of lists of spiketrains - # and does not change input - new_class = DecoratorTest() - list_of_list_of_spiketrains = new_class.method_to_decorate( - trials_obj=self.list_of_list_of_spiketrains) - self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) - for list_of_spiketrains in list_of_list_of_spiketrains: - self.assertIsInstance(list_of_spiketrains, list) - for spiketrain in list_of_spiketrains: - self.assertIsInstance(spiketrain, SpikeTrain) - - if __name__ == '__main__': unittest.main() diff --git a/elephant/trials.py b/elephant/trials.py index 6022baaec..a0512ff0f 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -41,10 +41,64 @@ from abc import ABCMeta, abstractmethod from typing import List +from functools import wraps import numpy as np import neo.utils from neo.core import Segment, Block from neo.core.spiketrainlist import SpikeTrainList +from elephant.utils import deprecated_alias + + +def trials_to_list_of_spiketrainlist(method): + """ + Decorator to convert `Trials` object to a list of `SpikeTrainList` before + calling the wrapped method. + + Parameters + ---------- + method: callable + The method to be decorated. + + Returns + ------- + callable + The decorated method. + + Examples + -------- + The decorator can be used as follows: + + >>> @trials_to_list_of_spiketrainlist + ... def process_data(self, spiketrains): + ... return None + """ + + @wraps(method) + def wrapper(*args, **kwargs): + new_args = tuple( + [ + arg.get_spiketrains_from_trial_as_list(idx) + for idx in range(arg.n_trials) + ] + if isinstance(arg, Trials) + else arg + for arg in args + ) + new_kwargs = { + key: ( + [ + value.get_spiketrains_from_trial_as_list(idx) + for idx in range(value.n_trials) + ] + if isinstance(value, Trials) + else value + ) + for key, value in kwargs.items() + } + + return method(*new_args, **new_kwargs) + + return wrapper class Trials: diff --git a/elephant/utils.py b/elephant/utils.py index b4ddfee22..e30b43617 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -20,7 +20,6 @@ import numpy as np import quantities as pq -from elephant.trials import Trials __all__ = [ @@ -396,53 +395,3 @@ def get_opencl_capability(): return False -def trials_to_list_of_spiketrainlist(method): - """ - Decorator to convert `Trials` object to a list of `SpikeTrainList` before - calling the wrapped method. - - Parameters - ---------- - method: callable - The method to be decorated. - - Returns - ------- - callable: - The decorated method. - - Examples - -------- - The decorator can be used as follows: - - >>> @trials_to_list_of_spiketrainlist - ... def process_data(self, spiketrains): - ... return None - """ - - @wraps(method) - def wrapper(*args, **kwargs): - new_args = tuple( - [ - arg.get_spiketrains_from_trial_as_list(idx) - for idx in range(arg.n_trials) - ] - if isinstance(arg, Trials) - else arg - for arg in args - ) - new_kwargs = { - key: ( - [ - value.get_spiketrains_from_trial_as_list(idx) - for idx in range(value.n_trials) - ] - if isinstance(value, Trials) - else value - ) - for key, value in kwargs.items() - } - - return method(*new_args, **new_kwargs) - - return wrapper From 040e72ea6650eab488a933bcd8da7be05b32ec96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 25 Sep 2025 14:50:28 +0200 Subject: [PATCH 09/40] Changed argument names to make explicit that trials and other elements are accessed by indexes (instead of ambiguous terms ID or number). Added deprecations with warning decorator and corresponding unit tests. --- elephant/test/test_trials.py | 35 +++++++ elephant/trials.py | 193 ++++++++++++++++++++--------------- 2 files changed, 143 insertions(+), 85 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 221ba7be0..9fde472b6 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -143,6 +143,24 @@ def setUp(self) -> None: Run before every test: """ + def test_deprecations(self): + trial_object = self.trial_object + with self.assertWarns(DeprecationWarning): + trial_object.get_trial_as_segment(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_trials_as_block(trial_ids=[0, 1]) + with self.assertWarns(DeprecationWarning): + trial_object.get_trials_as_list(trial_ids=[0, 1]) + with self.assertWarns(DeprecationWarning): + trial_object.get_spiketrains_from_trial_as_list(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_spiketrains_from_trial_as_segment(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_analogsignals_from_trial_as_list(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_analogsignals_from_trial_as_segment(trial_id=0) + + def test_trials_from_block_description(self) -> None: """ Test description of the trials object. @@ -352,6 +370,23 @@ def setUp(self) -> None: Run before every test: """ + def test_deprecations(self): + trial_object = self.trial_object + with self.assertWarns(DeprecationWarning): + trial_object.get_trial_as_segment(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_trials_as_block(trial_ids=[0, 1]) + with self.assertWarns(DeprecationWarning): + trial_object.get_trials_as_list(trial_ids=[0, 1]) + with self.assertWarns(DeprecationWarning): + trial_object.get_spiketrains_from_trial_as_list(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_spiketrains_from_trial_as_segment(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_analogsignals_from_trial_as_list(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_analogsignals_from_trial_as_segment(trial_id=0) + def test_trials_from_list_description(self) -> None: """ Test description of the trials object. diff --git a/elephant/trials.py b/elephant/trials.py index a0512ff0f..da08d070a 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -124,8 +124,7 @@ def __init__(self, description: str = "Trials"): self.description = description @abstractmethod - def __getitem__(self, trial_number: int) -> neo.core.Segment: - """Get a specific trial by number.""" + def __getitem__(self, trial_index: int) -> neo.core.Segment: pass @abstractmethod @@ -163,11 +162,13 @@ def n_analogsignals_trial_by_trial(self) -> List[int]: @abstractmethod def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: """Get trial as segment. + @deprecated_alias(trial_id="trial_index") + def get_trial_as_segment(self, trial_index: int) -> neo.core.Segment: Parameters ---------- - trial_id : int - Trial number to get (starting at trial ID 0). + trial_index : int + Index of the trial to retrieve (zero-based). Returns ------- @@ -177,14 +178,15 @@ def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: pass @abstractmethod - def get_trials_as_block(self, trial_ids: List[int] = None + @deprecated_alias(trial_ids="trial_indexes") + def get_trials_as_block(self, trial_indexes: List[int] = None ) -> neo.core.Block: """Get trials as block. Parameters ---------- - trial_ids : list of int - Trial IDs to include in the Block (starting at trial ID 0). + trial_indexes : list of int + Indexes of the trials to include in the Block (zero-based). If None is specified, all trials are returned. Default: None @@ -197,14 +199,15 @@ def get_trials_as_block(self, trial_ids: List[int] = None pass @abstractmethod - def get_trials_as_list(self, trial_ids: List[int] = None + @deprecated_alias(trial_ids="trial_indexes") + def get_trials_as_list(self, trial_indexes: List[int] = None ) -> neo.core.spiketrainlist.SpikeTrainList: """Get trials as list of segments. Parameters ---------- - trial_ids : list of int - Trial IDs to include in the list (starting at trial ID 0). + trial_indexes : list of int + Indexes of the trials to include in the list (zero-based). If None is specified, all trials are returned. Default: None @@ -217,15 +220,16 @@ def get_trials_as_list(self, trial_ids: List[int] = None pass @abstractmethod - def get_spiketrains_from_trial_as_list(self, trial_id: int) -> ( + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_list(self, trial_index: int) -> ( neo.core.spiketrainlist.SpikeTrainList): """ Get all spike trains from a specific trial and return a list. Parameters ---------- - trial_id : int - Trial ID to get the spike trains from (starting at trial ID 0). + trial_index : int + Index of the trial to get the spike trains from (zero-based). Returns ------- @@ -235,15 +239,16 @@ def get_spiketrains_from_trial_as_list(self, trial_id: int) -> ( pass @abstractmethod - def get_spiketrains_from_trial_as_segment(self, trial_id: int + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_segment(self, trial_index: int ) -> neo.core.Segment: """ Get all spike trains from a specific trial and return a Segment. Parameters ---------- - trial_id : int - Trial ID to get the spike trains from (starting at trial ID 0). + trial_index : int + Index of the trial to get the spike trains from (zero-based). Returns ------- @@ -252,15 +257,16 @@ def get_spiketrains_from_trial_as_segment(self, trial_id: int pass @abstractmethod - def get_analogsignals_from_trial_as_list(self, trial_id: int + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_list(self, trial_index: int ) -> List[neo.core.AnalogSignal]: """ Get all analogsignals from a specific trial and return a list. Parameters ---------- - trial_id : int - Trial ID to get the analogsignals from (starting at trial ID 0). + trial_index : int + Index of the trial to get the analog signals from (zero-based). Returns ------- @@ -270,7 +276,8 @@ def get_analogsignals_from_trial_as_list(self, trial_id: int pass @abstractmethod - def get_analogsignals_from_trial_as_segment(self, trial_id: int + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_segment(self, trial_index: int ) -> neo.core.Segment: """ Get all analogsignal objects from a specific trial and return a @@ -278,8 +285,8 @@ def get_analogsignals_from_trial_as_segment(self, trial_id: int Parameters ---------- - trial_id : int - Trial ID to get the analogsignals from (starting at trial ID 0). + trial_index : int + Index of the trial to get the analog signals from (zero-based). Returns ------- @@ -288,7 +295,7 @@ def get_analogsignals_from_trial_as_segment(self, trial_id: int """ @abstractmethod - def get_spiketrains_trial_by_trial(self, spiketrain_id: int) -> ( + def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( neo.core.spiketrainlist.SpikeTrainList): """ Retrieve a spike train across all its trial repetitions. @@ -299,9 +306,9 @@ def get_spiketrains_trial_by_trial(self, spiketrain_id: int) -> ( Parameters ---------- - spiketrain_id : int + spiketrain_index : int Index of the spike train to retrieve across trials. Indexing - starts at 0, so `spiketrain_id == 0` corresponds to the first + starts at 0, so `spiketrain_index == 0` corresponds to the first spike train in the trial data. Returns @@ -314,7 +321,7 @@ def get_spiketrains_trial_by_trial(self, spiketrain_id: int) -> ( pass @abstractmethod - def get_analogsignals_trial_by_trial(self, signal_id: int + def get_analogsignals_trial_by_trial(self, signal_index: int ) -> List[neo.core.AnalogSignal]: """ Retrieve an analog signal across all its trial repetitions. @@ -325,9 +332,9 @@ def get_analogsignals_trial_by_trial(self, signal_id: int Parameters ---------- - signal_id : int + signal_index : int Index of the analog signal to retrieve across trials. Indexing - starts at 0, so `signal_id == 0` corresponds to the first + starts at 0, so `signal_index == 0` corresponds to the first analog signal in the trial data. Returns @@ -363,28 +370,33 @@ def __init__(self, block: neo.core.block, **kwargs): def __getitem__(self, trial_number: int) -> neo.core.segment: return self.block.segments[trial_number] + def __getitem__(self, trial_index: int) -> neo.core.segment: + return self.block.segments[trial_index] - def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: - # Get a specific trial by number as a segment - return self.__getitem__(trial_id) + @deprecated_alias(trial_id="trial_index") + def get_trial_as_segment(self, trial_index: int) -> neo.core.Segment: + # Get a specific trial by its index as a Segment + return self.__getitem__(trial_index) - def get_trials_as_block(self, trial_ids: List[int] = None + @deprecated_alias(trial_ids="trial_indexes") + def get_trials_as_block(self, trial_indexes: List[int] = None ) -> neo.core.Block: - # Get a block of trials by trial numbers + # Get a set of trials by their indexes as a Block block = Block() - if not trial_ids: - trial_ids = list(range(self.n_trials)) - for trial_number in trial_ids: - block.segments.append(self.get_trial_as_segment(trial_number)) + if not trial_indexes: + trial_indexes = list(range(self.n_trials)) + for trial_index in trial_indexes: + block.segments.append(self.get_trial_as_segment(trial_index)) return block - def get_trials_as_list(self, trial_ids: List[int] = None + @deprecated_alias(trial_ids="trial_indexes") + def get_trials_as_list(self, trial_indexes: List[int] = None ) -> List[neo.core.Segment]: - if not trial_ids: - trial_ids = list(range(self.n_trials)) - # Get a list of segments by trial numbers - return [self.get_trial_as_segment(trial_number) - for trial_number in trial_ids] + if not trial_indexes: + trial_indexes = list(range(self.n_trials)) + # Get a set of trials by their indexes as a list of Segment + return [self.get_trial_as_segment(trial_index) + for trial_index in trial_indexes] @property def n_trials(self) -> int: @@ -401,46 +413,50 @@ def n_analogsignals_trial_by_trial(self) -> List[int]: # Get the number of AnalogSignals instances in each trial. return [len(trial.analogsignals) for trial in self.block.segments] - def get_spiketrains_from_trial_as_list(self, trial_id: int = 0) -> ( + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_list(self, trial_index: int = 0) -> ( neo.core.spiketrainlist.SpikeTrainList): # Return a list of all spike trains from a trial return SpikeTrainList(items=[spiketrain for spiketrain in - self.block.segments[trial_id].spiketrains]) + self.block.segments[trial_index].spiketrains]) - def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_segment(self, trial_index: int) -> ( neo.core.Segment): # Return a segment with all spiketrains from a trial segment = neo.core.Segment() - for spiketrain in self.get_spiketrains_from_trial_as_list(trial_id + for spiketrain in self.get_spiketrains_from_trial_as_list(trial_index ): segment.spiketrains.append(spiketrain) return segment - def get_spiketrains_trial_by_trial(self, spiketrain_id: int) -> ( + def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( neo.core.spiketrainlist.SpikeTrainList): # Return a list of all spike train repetitions across trials - return SpikeTrainList(items=[segment.spiketrains[spiketrain_id] for + return SpikeTrainList(items=[segment.spiketrains[spiketrain_index] for segment in self.block.segments]) - def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_list(self, trial_index: int) -> ( List[neo.core.AnalogSignal]): # Return a list of all analogsignals from a trial return [analogsignal for analogsignal in - self.block.segments[trial_id].analogsignals] + self.block.segments[trial_index].analogsignals] - def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_segment(self, trial_index: int) -> ( neo.core.Segment): # Return a segment with all analogsignals from a trial segment = neo.core.Segment() for analogsignal in self.get_analogsignals_from_trial_as_list( - trial_id): + trial_index): segment.analogsignals.append(analogsignal) return segment - def get_analogsignals_trial_by_trial(self, signal_id: int) -> ( + def get_analogsignals_trial_by_trial(self, signal_index: int) -> ( List[neo.core.AnalogSignal]): # Return a list of all analog signal repetitions across trials - return [segment.analogsignals[signal_id] + return [segment.analogsignals[signal_index] for segment in self.block.segments] @@ -478,37 +494,40 @@ def __init__(self, list_of_trials: list, **kwargs): self._spiketrain_index = [] self._analogsignal_index = [] - def __getitem__(self, trial_number: int) -> neo.core.Segment: - # Get a specific trial by number + def __getitem__(self, trial_index: int) -> neo.core.Segment: + # Get a specific trial by its index as a Segment segment = Segment() - for element in self.list_of_trials[trial_number]: + for element in self.list_of_trials[trial_index]: if isinstance(element, neo.core.SpikeTrain): segment.spiketrains.append(element) if isinstance(element, neo.core.AnalogSignal): segment.analogsignals.append(element) return segment - def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: - # Get a specific trial by number as a segment - return self.__getitem__(trial_id) + @deprecated_alias(trial_id="trial_index") + def get_trial_as_segment(self, trial_index: int) -> neo.core.Segment: + # Get a specific trial by its index as a Segment + return self.__getitem__(trial_index) - def get_trials_as_block(self, trial_ids: List[int] = None + @deprecated_alias(trial_ids="trial_indexes") + def get_trials_as_block(self, trial_indexes: List[int] = None ) -> neo.core.Block: - if not trial_ids: - trial_ids = list(range(self.n_trials)) - # Get a block of trials by trial numbers + if not trial_indexes: + trial_indexes = list(range(self.n_trials)) + # Get a block of trials by trial indexes block = Block() - for trial_number in trial_ids: - block.segments.append(self.get_trial_as_segment(trial_number)) + for trial_index in trial_indexes: + block.segments.append(self.get_trial_as_segment(trial_index)) return block - def get_trials_as_list(self, trial_ids: List[int] = None + @deprecated_alias(trial_ids="trial_indexes") + def get_trials_as_list(self, trial_indexes: List[int] = None ) -> List[neo.core.Segment]: - if not trial_ids: - trial_ids = list(range(self.n_trials)) - # Get a list of segments by trial numbers - return [self.get_trial_as_segment(trial_number) - for trial_number in trial_ids] + if not trial_indexes: + trial_indexes = list(range(self.n_trials)) + # Get a list of segments by trial indexes + return [self.get_trial_as_segment(trial_index) + for trial_index in trial_indexes] @property def n_trials(self) -> int: @@ -527,46 +546,50 @@ def n_analogsignals_trial_by_trial(self) -> List[int]: return [sum(map(lambda x: isinstance(x, neo.core.AnalogSignal), trial)) for trial in self.list_of_trials] - def get_spiketrains_from_trial_as_list(self, trial_id: int = 0) -> ( + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_list(self, trial_index: int = 0) -> ( neo.core.spiketrainlist.SpikeTrainList): # Return a list of all spiketrains from a trial return SpikeTrainList(items=[ - spiketrain for spiketrain in self.list_of_trials[trial_id] + spiketrain for spiketrain in self.list_of_trials[trial_index] if isinstance(spiketrain, neo.core.SpikeTrain)]) - def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_segment(self, trial_index: int) -> ( neo.core.Segment): # Return a segment with all spiketrains from a trial segment = neo.core.Segment() - for spiketrain in self.get_spiketrains_from_trial_as_list(trial_id): + for spiketrain in self.get_spiketrains_from_trial_as_list(trial_index): segment.spiketrains.append(spiketrain) return segment - def get_spiketrains_trial_by_trial(self, spiketrain_id: int) -> ( + def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( neo.core.spiketrainlist.SpikeTrainList): # Return a list of all spike train repetitions across trials - list_idx = self._spiketrain_index[spiketrain_id] + list_idx = self._spiketrain_index[spiketrain_index] return SpikeTrainList(items=[trial[list_idx] for trial in self.list_of_trials]) - def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_list(self, trial_index: int) -> ( List[neo.core.AnalogSignal]): # Return a list of all analogsignals from a trial return [analogsignal for analogsignal in - self.list_of_trials[trial_id] + self.list_of_trials[trial_index] if isinstance(analogsignal, neo.core.AnalogSignal)] - def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_segment(self, trial_index: int) -> ( neo.core.Segment): # Return a segment with all analogsignals from a trial segment = neo.core.Segment() for analogsignal in self.get_analogsignals_from_trial_as_list( - trial_id): + trial_index): segment.analogsignals.append(analogsignal) return segment - def get_analogsignals_trial_by_trial(self, signal_id: int) -> ( + def get_analogsignals_trial_by_trial(self, signal_index: int) -> ( List[neo.core.AnalogSignal]): # Return a list of all analog signal repetitions across trials - list_idx = self._analogsignal_index[signal_id] + list_idx = self._analogsignal_index[signal_index] return [trial[list_idx] for trial in self.list_of_trials] From 7e856b138bb93c55fe087520033c230af925f8f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 25 Sep 2025 16:47:13 +0200 Subject: [PATCH 10/40] Comments with short function descriptions rewritten for consistency and clarity --- elephant/trials.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index da08d070a..4cc40ff95 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -368,9 +368,8 @@ def __init__(self, block: neo.core.block, **kwargs): self.block = block super().__init__(**kwargs) - def __getitem__(self, trial_number: int) -> neo.core.segment: - return self.block.segments[trial_number] def __getitem__(self, trial_index: int) -> neo.core.segment: + # Get a specific trial by its index as a Segment return self.block.segments[trial_index] @deprecated_alias(trial_id="trial_index") @@ -392,25 +391,25 @@ def get_trials_as_block(self, trial_indexes: List[int] = None @deprecated_alias(trial_ids="trial_indexes") def get_trials_as_list(self, trial_indexes: List[int] = None ) -> List[neo.core.Segment]: + # Get a set of trials by their indexes as a list of Segment if not trial_indexes: trial_indexes = list(range(self.n_trials)) - # Get a set of trials by their indexes as a list of Segment return [self.get_trial_as_segment(trial_index) for trial_index in trial_indexes] @property def n_trials(self) -> int: - # Get the number of trials. + # Get the number of trials return len(self.block.segments) @property def n_spiketrains_trial_by_trial(self) -> List[int]: - # Get the number of SpikeTrain instances in each trial. + # Get the number of SpikeTrain objects in each trial return [len(trial.spiketrains) for trial in self.block.segments] @property def n_analogsignals_trial_by_trial(self) -> List[int]: - # Get the number of AnalogSignals instances in each trial. + # Get the number of AnalogSignal objects in each trial return [len(trial.analogsignals) for trial in self.block.segments] @deprecated_alias(trial_id="trial_index") @@ -423,7 +422,7 @@ def get_spiketrains_from_trial_as_list(self, trial_index: int = 0) -> ( @deprecated_alias(trial_id="trial_index") def get_spiketrains_from_trial_as_segment(self, trial_index: int) -> ( neo.core.Segment): - # Return a segment with all spiketrains from a trial + # Return a Segment with all spike trains from a trial segment = neo.core.Segment() for spiketrain in self.get_spiketrains_from_trial_as_list(trial_index ): @@ -439,14 +438,14 @@ def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( @deprecated_alias(trial_id="trial_index") def get_analogsignals_from_trial_as_list(self, trial_index: int) -> ( List[neo.core.AnalogSignal]): - # Return a list of all analogsignals from a trial + # Return a list of all analog signals from a trial return [analogsignal for analogsignal in self.block.segments[trial_index].analogsignals] @deprecated_alias(trial_id="trial_index") def get_analogsignals_from_trial_as_segment(self, trial_index: int) -> ( neo.core.Segment): - # Return a segment with all analogsignals from a trial + # Return a Segment with all analog signals from a trial segment = neo.core.Segment() for analogsignal in self.get_analogsignals_from_trial_as_list( trial_index): @@ -476,8 +475,6 @@ class attribute `description`. """ def __init__(self, list_of_trials: list, **kwargs): - # Constructor - # (actual documentation is in class documentation, see above!) self.list_of_trials = list_of_trials super().__init__(**kwargs) @@ -512,9 +509,9 @@ def get_trial_as_segment(self, trial_index: int) -> neo.core.Segment: @deprecated_alias(trial_ids="trial_indexes") def get_trials_as_block(self, trial_indexes: List[int] = None ) -> neo.core.Block: + # Get a set of trials by their indexes as a Block if not trial_indexes: trial_indexes = list(range(self.n_trials)) - # Get a block of trials by trial indexes block = Block() for trial_index in trial_indexes: block.segments.append(self.get_trial_as_segment(trial_index)) @@ -523,33 +520,33 @@ def get_trials_as_block(self, trial_indexes: List[int] = None @deprecated_alias(trial_ids="trial_indexes") def get_trials_as_list(self, trial_indexes: List[int] = None ) -> List[neo.core.Segment]: + # Get a set of trials by their indexes as a list of Segment if not trial_indexes: trial_indexes = list(range(self.n_trials)) - # Get a list of segments by trial indexes return [self.get_trial_as_segment(trial_index) for trial_index in trial_indexes] @property def n_trials(self) -> int: - # Get the number of trials. + # Get the number of trials return len(self.list_of_trials) @property def n_spiketrains_trial_by_trial(self) -> List[int]: - # Get the number of spiketrains in each trial. + # Get the number of SpikeTrain objects in each trial return [sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), trial)) for trial in self.list_of_trials] @property def n_analogsignals_trial_by_trial(self) -> List[int]: - # Get the number of analogsignals in each trial. + # Get the number of AnalogSignal objects in each trial return [sum(map(lambda x: isinstance(x, neo.core.AnalogSignal), trial)) for trial in self.list_of_trials] @deprecated_alias(trial_id="trial_index") def get_spiketrains_from_trial_as_list(self, trial_index: int = 0) -> ( neo.core.spiketrainlist.SpikeTrainList): - # Return a list of all spiketrains from a trial + # Return a list of all spike trains from a trial return SpikeTrainList(items=[ spiketrain for spiketrain in self.list_of_trials[trial_index] if isinstance(spiketrain, neo.core.SpikeTrain)]) @@ -557,7 +554,7 @@ def get_spiketrains_from_trial_as_list(self, trial_index: int = 0) -> ( @deprecated_alias(trial_id="trial_index") def get_spiketrains_from_trial_as_segment(self, trial_index: int) -> ( neo.core.Segment): - # Return a segment with all spiketrains from a trial + # Return a Segment with all spike trains from a trial segment = neo.core.Segment() for spiketrain in self.get_spiketrains_from_trial_as_list(trial_index): segment.spiketrains.append(spiketrain) @@ -573,7 +570,7 @@ def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( @deprecated_alias(trial_id="trial_index") def get_analogsignals_from_trial_as_list(self, trial_index: int) -> ( List[neo.core.AnalogSignal]): - # Return a list of all analogsignals from a trial + # Return a list of all analog signals from a trial return [analogsignal for analogsignal in self.list_of_trials[trial_index] if isinstance(analogsignal, neo.core.AnalogSignal)] @@ -581,7 +578,7 @@ def get_analogsignals_from_trial_as_list(self, trial_index: int) -> ( @deprecated_alias(trial_id="trial_index") def get_analogsignals_from_trial_as_segment(self, trial_index: int) -> ( neo.core.Segment): - # Return a segment with all analogsignals from a trial + # Return a Segment with all analog signals from a trial segment = neo.core.Segment() for analogsignal in self.get_analogsignals_from_trial_as_list( trial_index): From a79ebb6f76b738742a399bc5e52ed5a0ad191850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 25 Sep 2025 16:49:36 +0200 Subject: [PATCH 11/40] Module docstring rewritten for clarity. --- elephant/trials.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index 4cc40ff95..d5327baac 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -2,29 +2,30 @@ This module defines the basic classes that represent trials in Elephant. Many neuroscience methods rely on the concept of repeated trials to improve the -estimate of quantities measured from the data. In the simplest case, results -from multiple trials are averaged, in other scenarios more intricate steps must -be taken in order to pool information from each repetition of a trial. Typically, -trials are considered as fixed time intervals tied to a specific event in the -experiment, such as the onset of a stimulus. - -Neo does not impose a specific way in which trials are to be represented. A -natural way to represent trials is to have a :class:`neo.Block` containing multiple -:class:`neo.Segment` objects, each representing the data of one trial. Another popular -option is to store trials as lists of lists, where the outer refers to -individual lists, and inner lists contain Neo data objects (:class:`neo.SpikeTrain` -and :class:`neo.AnalogSignal` containing individual data of each trial. +estimate of quantities measured from the data. Typically, trials are +considered as fixed time intervals tied to a specific event in the experiment, +such as the onset of a stimulus. In the simplest case, results from multiple +trials are averaged. In other scenarios, more intricate steps must be taken +in order to pool information from each repetition of a trial. + +Neo does not impose a specific way to represent trial data. A natural way to +represent trials is to have a :class:`neo.Block` containing multiple +:class:`neo.Segment` objects, each representing the data of one trial. Another +popular option is to store trials as lists of lists, where the outer refers to +individual lists, and inner lists contain Neo data objects +(:class:`neo.SpikeTrain` and :class:`neo.AnalogSignal`) containing individual +data of each trial. The classes of this module abstract from these individual data representations -by introducing a set of :class:`Trials` classes with a common API. These classes -are initialized by a supported way of structuring trials, e.g., -:class:`TrialsFromBlock` for the first method described above. Internally, +by introducing a set of :class:`Trials` classes with a common API. These +classes are initialized by a supported way of structuring trials, e.g., +:class:`TrialsFromBlock` for the first method described above. Internally, the :class:`Trials` class will not convert this representation, but provide access to data in specific trials (e.g., all spike trains in trial 5) or general -information about the trial structure (e.g., how many trials are there?) via a -fixed API. Trials are consecutively numbered, starting at a trial ID of 0. +information about the trial structure (e.g., how many trials are there?) via a +fixed API. Trials are indexed consecutively starting from 0. -In the release, the classes :class:`TrialsFromBlock` and +In the current implementation, classes :class:`TrialsFromBlock` and :class:`TrialsFromLists` provide this unified way to access trial data. .. autosummary:: From ab8c9c779a37afbe3126932deb50f6ae02942ef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 25 Sep 2025 16:50:15 +0200 Subject: [PATCH 12/40] Added section on the accompanying tutorial. --- elephant/trials.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/elephant/trials.py b/elephant/trials.py index d5327baac..4923949dc 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -36,6 +36,17 @@ TrialsFromLists :copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. +Tutorial +-------- +For a detailed example on the classes usage and trial handling for analyses +using Elephant, check the :doc:`tutorial <../tutorials/trials>`. + +Run tutorial interactively: + +.. image:: https://mybinder.org/badge.svg + :target: https://mybinder.org/v2/gh/NeuralEnsemble/elephant/master + ?filepath=doc/tutorials/trials.ipynb + :license: Modified BSD, see LICENSE.txt for details. """ From c506414905cafb98497543a6dbe01a8a2b03359a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 25 Sep 2025 18:09:45 +0200 Subject: [PATCH 13/40] Updated copyright --- elephant/trials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/trials.py b/elephant/trials.py index 4923949dc..a2142da1e 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -35,7 +35,6 @@ TrialsFromBlock TrialsFromLists -:copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. Tutorial -------- For a detailed example on the classes usage and trial handling for analyses @@ -47,6 +46,7 @@ :target: https://mybinder.org/v2/gh/NeuralEnsemble/elephant/master ?filepath=doc/tutorials/trials.ipynb +:copyright: Copyright 2014-2025 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ From 4b421f7542ed3d7f8f8ba6f9c0342607b7ea9967 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Wed, 1 Oct 2025 12:02:47 +0200 Subject: [PATCH 14/40] Organized imports to avoid repeated namespace accesses --- elephant/test/test_trials.py | 94 +++++++++++++++++------------------- 1 file changed, 43 insertions(+), 51 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 9fde472b6..acf195a92 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -7,9 +7,9 @@ """ import unittest -import neo.utils import quantities as pq -from neo.core import AnalogSignal +from neo.core import Block, Segment, AnalogSignal, SpikeTrain +from neo.core.spiketrainlist import SpikeTrainList from elephant.spike_train_generation import StationaryPoissonProcess from elephant.trials import (TrialsFromBlock, TrialsFromLists, @@ -18,11 +18,11 @@ def _create_trials_block(n_trials: int = 0, n_spiketrains: int = 2, - n_analogsignals: int = 2) -> neo.core.Block: + n_analogsignals: int = 2) -> Block: """ Create block with n_trials, n_spiketrains and n_analog_signals """ - block = neo.Block(name='test_block') + block = Block(name='test_block') for trial in range(n_trials): - segment = neo.Segment(name=f'No. {trial}') + segment = Segment(name=f'No. {trial}') spiketrains = StationaryPoissonProcess(rate=50. * pq.Hz, t_start=0 * pq.ms, t_stop=1000 * pq.ms @@ -171,7 +171,7 @@ def test_trials_from_block_get_item(self) -> None: """ Test get a trial from the trials. """ - self.assertIsInstance(self.trial_object[0], neo.core.Segment) + self.assertIsInstance(self.trial_object[0], Segment) def test_trials_from_block_get_trial_as_segment(self) -> None: """ @@ -179,22 +179,21 @@ def test_trials_from_block_get_trial_as_segment(self) -> None: """ self.assertIsInstance( self.trial_object.get_trial_as_segment(0), - neo.core.Segment) + Segment) self.assertIsInstance( self.trial_object.get_trial_as_segment(0).spiketrains[0], - neo.core.SpikeTrain) + SpikeTrain) self.assertIsInstance( self.trial_object.get_trial_as_segment(0).analogsignals[0], - neo.core.AnalogSignal) + AnalogSignal) def test_trials_from_block_get_trials_as_block(self) -> None: """ Test get a block from list of trials. """ block = self.trial_object.get_trials_as_block([0, 3, 5]) - self.assertIsInstance(block, neo.core.Block) - self.assertIsInstance(self.trial_object.get_trials_as_block(), - neo.core.Block) + self.assertIsInstance(block, Block) + self.assertIsInstance(self.trial_object.get_trials_as_block(), Block) self.assertEqual(len(block.segments), 3) def test_trials_from_block_get_trials_as_list(self) -> None: @@ -204,7 +203,7 @@ def test_trials_from_block_get_trials_as_list(self) -> None: list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) self.assertIsInstance(list_of_trials, list) self.assertIsInstance(self.trial_object.get_trials_as_list(), list) - self.assertIsInstance(list_of_trials[0], neo.core.Segment) + self.assertIsInstance(list_of_trials[0], Segment) self.assertEqual(len(list_of_trials), 3) def test_trials_from_block_n_trials(self) -> None: @@ -236,10 +235,10 @@ def test_trials_from_block_get_spiketrains_from_trial_as_list(self """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0), - neo.core.spiketrainlist.SpikeTrainList) + SpikeTrainList) self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0)[0], - neo.core.SpikeTrain) + SpikeTrain) def test_trials_from_block_get_spiketrains_from_trial_as_segment(self ) -> None: @@ -248,10 +247,10 @@ def test_trials_from_block_get_spiketrains_from_trial_as_segment(self """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_segment(0), - neo.core.Segment) + Segment) self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_segment( - 0).spiketrains[0], neo.core.SpikeTrain) + 0).spiketrains[0], SpikeTrain) def test_trials_from_block_get_analogsignals_from_trial_as_list(self ) -> None: @@ -262,7 +261,7 @@ def test_trials_from_block_get_analogsignals_from_trial_as_list(self self.trial_object.get_analogsignals_from_trial_as_list(0), list) self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_list(0)[0], - neo.core.AnalogSignal) + AnalogSignal) def test_trials_from_block_get_analogsignals_from_trial_as_segment(self) \ -> None: @@ -271,10 +270,10 @@ def test_trials_from_block_get_analogsignals_from_trial_as_segment(self) \ """ self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_segment(0), - neo.core.Segment) + Segment) self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_segment( - 0).analogsignals[0], neo.core.AnalogSignal) + 0).analogsignals[0], AnalogSignal) def test_trials_from_block_get_spiketrains_trial_by_trial(self) -> None: """ @@ -285,11 +284,10 @@ def test_trials_from_block_get_spiketrains_trial_by_trial(self) -> None: spiketrains = self.trial_object.get_spiketrains_trial_by_trial(st_id) # Return is neo.SpikeTrainList - self.assertIsInstance(spiketrains, - neo.core.spiketrainlist.SpikeTrainList) + self.assertIsInstance(spiketrains, SpikeTrainList) # All elements are neo.SpikeTrain - self.assertTrue(all(map(lambda x: isinstance(x, neo.SpikeTrain), + self.assertTrue(all(map(lambda x: isinstance(x, SpikeTrain), spiketrains) ) ) @@ -321,7 +319,7 @@ def test_trials_from_block_get_analogsignals_trial_by_trial(self) -> None: self.assertIsInstance(signals, list) # All elements are neo.AnalogSignal - self.assertTrue(all(map(lambda x: isinstance(x, neo.AnalogSignal), + self.assertTrue(all(map(lambda x: isinstance(x, AnalogSignal), signals) ) ) @@ -397,35 +395,31 @@ def test_trials_from_list_get_item(self) -> None: """ Test get a trial from the trials. """ - self.assertIsInstance(self.trial_object[0], - neo.core.Segment) - self.assertIsInstance(self.trial_object[0].spiketrains[0], - neo.core.SpikeTrain) + self.assertIsInstance(self.trial_object[0], Segment) + self.assertIsInstance(self.trial_object[0].spiketrains[0], SpikeTrain) self.assertIsInstance(self.trial_object[0].analogsignals[0], - neo.core.AnalogSignal) + AnalogSignal) def test_trials_from_list_get_trial_as_segment(self) -> None: """ Test get a trial from the trials. """ self.assertIsInstance( - self.trial_object.get_trial_as_segment(0), - neo.core.Segment) + self.trial_object.get_trial_as_segment(0), Segment) self.assertIsInstance( self.trial_object.get_trial_as_segment(0).spiketrains[0], - neo.core.SpikeTrain) + SpikeTrain) self.assertIsInstance( self.trial_object.get_trial_as_segment(0).analogsignals[0], - neo.core.AnalogSignal) + AnalogSignal) def test_trials_from_list_get_trials_as_block(self) -> None: """ Test get a block from list of trials. """ block = self.trial_object.get_trials_as_block([0, 3, 5]) - self.assertIsInstance(block, neo.core.Block) - self.assertIsInstance(self.trial_object.get_trials_as_block(), - neo.core.Block) + self.assertIsInstance(block, Block) + self.assertIsInstance(self.trial_object.get_trials_as_block(), Block) self.assertEqual(len(block.segments), 3) def test_trials_from_list_get_trials_as_list(self) -> None: @@ -435,7 +429,7 @@ def test_trials_from_list_get_trials_as_list(self) -> None: list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) self.assertIsInstance(list_of_trials, list) self.assertIsInstance(self.trial_object.get_trials_as_list(), list) - self.assertIsInstance(list_of_trials[0], neo.core.Segment) + self.assertIsInstance(list_of_trials[0], Segment) self.assertEqual(len(list_of_trials), 3) def test_trials_from_list_n_trials(self) -> None: @@ -449,7 +443,7 @@ def test_trials_from_list_n_spiketrains_trial_by_trial(self) -> None: Test get number of spiketrains per trial. """ self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, - [sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), + [sum(map(lambda x: isinstance(x, SpikeTrain), trial)) for trial in self.trial_list]) def test_trials_from_list_n_analogsignals_trial_by_trial(self) -> None: @@ -457,8 +451,7 @@ def test_trials_from_list_n_analogsignals_trial_by_trial(self) -> None: Test get number of analogsignals per trial. """ self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, - [sum(map(lambda x: isinstance(x, - neo.core.AnalogSignal), + [sum(map(lambda x: isinstance(x, AnalogSignal), trial)) for trial in self.trial_list]) def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: @@ -467,10 +460,10 @@ def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0), - neo.core.spiketrainlist.SpikeTrainList) + SpikeTrainList) self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0)[0], - neo.core.SpikeTrain) + SpikeTrain) def test_trials_from_list_get_spiketrains_from_trial_as_segment(self ) -> None: @@ -479,10 +472,10 @@ def test_trials_from_list_get_spiketrains_from_trial_as_segment(self """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_segment(0), - neo.core.Segment) + Segment) self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_segment( - 0).spiketrains[0], neo.core.SpikeTrain) + 0).spiketrains[0], SpikeTrain) def test_trials_from_list_get_analogsignals_from_trial_as_list(self ) -> None: @@ -493,7 +486,7 @@ def test_trials_from_list_get_analogsignals_from_trial_as_list(self self.trial_object.get_analogsignals_from_trial_as_list(0), list) self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_list(0)[0], - neo.core.AnalogSignal) + AnalogSignal) def test_trials_from_list_get_analogsignals_from_trial_as_segment(self ) \ @@ -503,10 +496,10 @@ def test_trials_from_list_get_analogsignals_from_trial_as_segment(self """ self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_segment(0), - neo.core.Segment) + Segment) self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_segment( - 0).analogsignals[0], neo.core.AnalogSignal) + 0).analogsignals[0], AnalogSignal) def test_trials_from_list_get_spiketrains_trial_by_trial(self) -> None: """ @@ -518,11 +511,10 @@ def test_trials_from_list_get_spiketrains_trial_by_trial(self) -> None: st_id) # Return is neo.SpikeTrainList - self.assertIsInstance(spiketrains, - neo.core.spiketrainlist.SpikeTrainList) + self.assertIsInstance(spiketrains, SpikeTrainList) # All elements are neo.SpikeTrain - self.assertTrue(all(map(lambda x: isinstance(x, neo.SpikeTrain), + self.assertTrue(all(map(lambda x: isinstance(x, SpikeTrain), spiketrains) ) ) @@ -554,7 +546,7 @@ def test_trials_from_list_get_analogsignals_trial_by_trial(self) -> None: self.assertIsInstance(signals, list) # All elements are neo.AnalogSignal - self.assertTrue(all(map(lambda x: isinstance(x, neo.AnalogSignal), + self.assertTrue(all(map(lambda x: isinstance(x, AnalogSignal), signals) ) ) From d89980020514cfcd2ce214151ff58a924930651a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:38:04 +0200 Subject: [PATCH 15/40] Updated module docstring --- elephant/test/test_trials.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index acf195a92..47a8946e2 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -Unit tests for the trials objects. +nit tests for the objects of the API handling trial data in Elephant. -:copyright: Copyright 2014-2024 by the Elephant team, see AUTHORS.txt. +:copyright: Copyright 2014-2025 by the Elephant team, see AUTHORS.txt. :license: Modified BSD, see LICENSE.txt for details. """ From 34e1d8e08be3a0818f57d59a4ca4277c08dcbd7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:38:38 +0200 Subject: [PATCH 16/40] Simplified method logic --- elephant/test/test_trials.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 47a8946e2..27f40dd7a 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -58,8 +58,7 @@ def method_to_decorate(self, trials=None, trials_obj=None): # This is just a mock implementation for testing purposes if trials_obj: return trials_obj - else: - return trials + return trials class TestTrialsToListOfSpiketrainlist(unittest.TestCase): From d9fade24bad493e55d2b6c44a3b2b8d1de3c6430 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:43:30 +0200 Subject: [PATCH 17/40] Added return type hints to all tests --- elephant/test/test_trials.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 27f40dd7a..de6ad4fa8 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -74,6 +74,7 @@ def setUpClass(cls): def test_decorator_applied(self): # Test that the decorator is applied correctly + def test_decorator_applied(self) -> None: self.assertTrue(hasattr( DecoratorTest.method_to_decorate, '__wrapped__' )) @@ -81,6 +82,7 @@ def test_decorator_applied(self): def test_decorator_return_with_trials_input_as_arg(self): # Test if decorator takes in trial-object and returns # list of spiketrainlists + def test_decorator_return_with_trials_input_as_arg(self) -> None: new_class = DecoratorTest() list_of_spiketrainlists = new_class.method_to_decorate( self.trial_object) @@ -91,6 +93,7 @@ def test_decorator_return_with_trials_input_as_arg(self): def test_decorator_return_with_list_of_lists_input_as_arg(self): # Test if decorator takes in list of lists of spiketrains # and does not change input + def test_decorator_return_with_list_of_lists_input_as_arg(self) -> None: new_class = DecoratorTest() list_of_list_of_spiketrains = new_class.method_to_decorate( self.list_of_list_of_spiketrains) @@ -103,6 +106,7 @@ def test_decorator_return_with_list_of_lists_input_as_arg(self): def test_decorator_return_with_trials_input_as_kwarg(self): # Test if decorator takes in trial-object and returns # list of spiketrainlists + def test_decorator_return_with_trials_input_as_kwarg(self) -> None: new_class = DecoratorTest() list_of_spiketrainlists = new_class.method_to_decorate( trials_obj=self.trial_object) @@ -113,6 +117,7 @@ def test_decorator_return_with_trials_input_as_kwarg(self): def test_decorator_return_with_list_of_lists_input_as_kwarg(self): # Test if decorator takes in list of lists of spiketrains # and does not change input + def test_decorator_return_with_list_of_lists_input_as_kwarg(self) -> None: new_class = DecoratorTest() list_of_list_of_spiketrains = new_class.method_to_decorate( trials_obj=self.list_of_list_of_spiketrains) @@ -137,9 +142,7 @@ def setUpClass(cls) -> None: cls.trial_object = TrialsFromBlock(block, description='trials are segments') - def setUp(self) -> None: - """ - Run before every test: + def test_deprecations(self) -> None: """ def test_deprecations(self): @@ -363,6 +366,7 @@ def setUpClass(cls) -> None: description='trial is a list') def setUp(self) -> None: + def test_deprecations(self) -> None: """ Run before every test: """ From 443f485e20a3218ca78cf6561c9f1411579ff1d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:47:01 +0200 Subject: [PATCH 18/40] Refactored TestCase class setUp methods for clarity --- elephant/test/test_trials.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index de6ad4fa8..9f0f3058d 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -133,10 +133,6 @@ class TrialsFromBlockTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - """ - Run once before tests: - """ - block = _create_trials_block(n_trials=36) cls.block = block cls.trial_object = TrialsFromBlock(block, @@ -346,22 +342,23 @@ class TrialsFromListTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - """ - Run once before tests: - Download the dataset from elephant_data - """ - block = _create_trials_block(n_trials=36) - - # Create Trialobject as list of lists - # add spiketrains + cls.n_spiketrains = 2 + cls.n_analogsignals = 3 + block = _create_trials_block(n_trials=36, + n_spiketrains=cls.n_spiketrains, + n_analogsignals=cls.n_analogsignals) + + # Create trial data as list of lists + # 1. Add spiketrains trial_list = [[spiketrain for spiketrain in trial.spiketrains] for trial in block.segments] - # add analogsignals + # 2. Add analog signals for idx, trial in enumerate(block.segments): for analogsignal in trial.analogsignals: trial_list[idx].append(analogsignal) cls.trial_list = trial_list + # Create TrialsFromLists object cls.trial_object = TrialsFromLists(trial_list, description='trial is a list') From 25f2ddde1650fb3c3d139346e2dcd166cdb704c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:48:23 +0200 Subject: [PATCH 19/40] Improved documentation of the decorator test case --- elephant/test/test_trials.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 9f0f3058d..b79e3261c 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -72,17 +72,19 @@ def setUpClass(cls): for _ in range(cls.n_trials)] cls.trial_object = TrialsFromLists(cls.list_of_list_of_spiketrains) - def test_decorator_applied(self): - # Test that the decorator is applied correctly def test_decorator_applied(self) -> None: + """ + Test that the decorator is applied correctly. + """ self.assertTrue(hasattr( DecoratorTest.method_to_decorate, '__wrapped__' )) - def test_decorator_return_with_trials_input_as_arg(self): - # Test if decorator takes in trial-object and returns - # list of spiketrainlists def test_decorator_return_with_trials_input_as_arg(self) -> None: + """ + Test if the decorator takes in a `Trials` object and returns a list of + `SpikeTrainList`. + """ new_class = DecoratorTest() list_of_spiketrainlists = new_class.method_to_decorate( self.trial_object) @@ -90,10 +92,11 @@ def test_decorator_return_with_trials_input_as_arg(self) -> None: for spiketrainlist in list_of_spiketrainlists: self.assertIsInstance(spiketrainlist, SpikeTrainList) - def test_decorator_return_with_list_of_lists_input_as_arg(self): - # Test if decorator takes in list of lists of spiketrains - # and does not change input def test_decorator_return_with_list_of_lists_input_as_arg(self) -> None: + """ + Test if the decorator takes in a list of lists of `SpikeTrain` and + does not change the input. + """ new_class = DecoratorTest() list_of_list_of_spiketrains = new_class.method_to_decorate( self.list_of_list_of_spiketrains) @@ -103,10 +106,11 @@ def test_decorator_return_with_list_of_lists_input_as_arg(self) -> None: for spiketrain in list_of_spiketrains: self.assertIsInstance(spiketrain, SpikeTrain) - def test_decorator_return_with_trials_input_as_kwarg(self): - # Test if decorator takes in trial-object and returns - # list of spiketrainlists def test_decorator_return_with_trials_input_as_kwarg(self) -> None: + """ + Test if the decorator takes in a `Trials` object and returns a list of + `SpikeTrainList` when passed as kwarg. + """ new_class = DecoratorTest() list_of_spiketrainlists = new_class.method_to_decorate( trials_obj=self.trial_object) @@ -114,10 +118,11 @@ def test_decorator_return_with_trials_input_as_kwarg(self) -> None: for spiketrainlist in list_of_spiketrainlists: self.assertIsInstance(spiketrainlist, SpikeTrainList) - def test_decorator_return_with_list_of_lists_input_as_kwarg(self): - # Test if decorator takes in list of lists of spiketrains - # and does not change input def test_decorator_return_with_list_of_lists_input_as_kwarg(self) -> None: + """ + Test if the decorator takes in a list of lists of `SpikeTrain`and does + not change the input if passed as a kwarg. + """ new_class = DecoratorTest() list_of_list_of_spiketrains = new_class.method_to_decorate( trials_obj=self.list_of_list_of_spiketrains) From 60337b94de09fcfc8820b798f6382e488e6230f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:48:55 +0200 Subject: [PATCH 20/40] Improved documentation of the data generation function --- elephant/test/test_trials.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index b79e3261c..0e269a23c 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -19,7 +19,9 @@ def _create_trials_block(n_trials: int = 0, n_spiketrains: int = 2, n_analogsignals: int = 2) -> Block: - """ Create block with n_trials, n_spiketrains and n_analog_signals """ + """ + Create Neo `Block` with `n_trials`, `n_spiketrains` and `n_analogsignals`. + """ block = Block(name='test_block') for trial in range(n_trials): segment = Segment(name=f'No. {trial}') From 0cfad26135b98eddaa7fdbd788d20111892bf92b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:50:54 +0200 Subject: [PATCH 21/40] Implemented base TestCase class with custom assertions for checking Neo objects --- elephant/test/test_trials.py | 73 ++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 3 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 0e269a23c..46387a9b0 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -47,9 +47,72 @@ def _create_trials_block(n_trials: int = 0, return block -######### -# Tests # -######### +########################## +# Tests - helper classes # +########################## + +class TrialsBaseTestCase(unittest.TestCase): + """ + This is a base `unitest.TestCase` class to act as a helper when + constructing the specific test cases for each implementation of `Trials`. + + This helper class facilitates comparing Neo objects, as custom assertions + are implemented to perform a series of tests to ensure two Neo objects are + indeed equal (e.g., checking metadata or contents of collections such as + spiketrains in a `Segment`). + + As the `Trials` objects are based on references to the input data + structures, checks for `Segment`, `SpikeTrain`, and `AnalogSignal` objects + are enforcing instance equality (i.e., `(a is b) == True`). + """ + + def assertSegmentEqual(self, segment_1, segment_2) -> None: + self.assertIsInstance(segment_1, Segment) + self.assertIsInstance(segment_2, Segment) + self.assertIs(segment_1, segment_2) + self.assertEqual(segment_1.name, segment_2.name) + self.assertEqual(segment_2.description, segment_2.description) + self.assertDictEqual(segment_1.annotations, segment_2.annotations) + self.assertSpikeTrainListEqual(segment_1.spiketrains, + segment_2.spiketrains) + self.assertAnalogSignalListEqual(segment_1.analogsignals, + segment_2.analogsignals) + + def assertSpikeTrainEqual(self, spiketrain_1, spiketrain_2) -> None: + self.assertIsInstance(spiketrain_1, SpikeTrain) + self.assertIsInstance(spiketrain_2, SpikeTrain) + self.assertIs(spiketrain_1, spiketrain_2) + self.assertTrue(np.all(spiketrain_1 == spiketrain_2)) + self.assertEqual(spiketrain_1.name, spiketrain_2.name) + self.assertEqual(spiketrain_1.description, spiketrain_2.description) + self.assertDictEqual(spiketrain_1.annotations, + spiketrain_2.annotations) + + def assertSpikeTrainListEqual(self, spiketrains_1, spiketrains_2) -> None: + self.assertIsInstance(spiketrains_1, SpikeTrainList) + self.assertIsInstance(spiketrains_2, SpikeTrainList) + self.assertEqual(len(spiketrains_1), len(spiketrains_2)) + for st1, st2 in zip(spiketrains_1, spiketrains_2): + self.assertSpikeTrainEqual(st1, st2) + + def assertAnalogSignalEqual(self, signal_1, signal_2) -> None: + self.assertIsInstance(signal_1, AnalogSignal) + self.assertIsInstance(signal_2, AnalogSignal) + self.assertIs(signal_1, signal_2) + self.assertTrue(np.all(signal_1 == signal_2)) + self.assertEqual(signal_1.name, signal_2.name) + self.assertEqual(signal_1.description, signal_2.description) + self.assertDictEqual(signal_1.annotations, signal_2.annotations) + + def assertAnalogSignalListEqual(self, signals_1, signals_2) -> None: + # Not enforcing object type as `Segment.analogsignals` are + # `ObjectList`, and some of the functions return pure Python lists + # containing the `AnalogSignal` objects. Therefore, the type checking + # must be done in each test case accordingly. + self.assertEqual(len(signals_1), len(signals_2)) + for signal_1, signal_2 in zip(signals_1, signals_2): + self.assertAnalogSignalEqual(signal_1, signal_2) + class DecoratorTest: """ @@ -64,6 +127,10 @@ def method_to_decorate(self, trials=None, trials_obj=None): class TestTrialsToListOfSpiketrainlist(unittest.TestCase): +###################### +# Tests - test cases # +###################### + @classmethod def setUpClass(cls): cls.n_channels = 10 From cac295aac3251b1d0108d8cb7d9e41d7db0913e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:51:41 +0200 Subject: [PATCH 22/40] Implemented base TestCase class with custom assertions for checking Neo objects --- elephant/test/test_trials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 46387a9b0..48901f7ec 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -7,6 +7,7 @@ """ import unittest +import numpy as np import quantities as pq from neo.core import Block, Segment, AnalogSignal, SpikeTrain from neo.core.spiketrainlist import SpikeTrainList @@ -126,7 +127,6 @@ def method_to_decorate(self, trials=None, trials_obj=None): return trials -class TestTrialsToListOfSpiketrainlist(unittest.TestCase): ###################### # Tests - test cases # ###################### From 785f50d36e2768aae38480f8756b160efa2e3bf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:54:02 +0200 Subject: [PATCH 23/40] Updated test cases to use base class --- elephant/test/test_trials.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 48901f7ec..4619bbb61 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -131,6 +131,8 @@ def method_to_decorate(self, trials=None, trials_obj=None): # Tests - test cases # ###################### +class TestTrialsToListOfSpiketrainlist(TrialsBaseTestCase): + @classmethod def setUpClass(cls): cls.n_channels = 10 @@ -202,8 +204,8 @@ def test_decorator_return_with_list_of_lists_input_as_kwarg(self) -> None: self.assertIsInstance(spiketrain, SpikeTrain) -class TrialsFromBlockTestCase(unittest.TestCase): """Tests for elephant.trials.TrialsFromBlock class""" +class TrialsFromBlockTestCase(TrialsBaseTestCase): @classmethod def setUpClass(cls) -> None: @@ -411,8 +413,8 @@ def test_trials_from_block_get_analogsignals_trial_by_trial(self) -> None: expected_trials) -class TrialsFromListTestCase(unittest.TestCase): """Tests for elephant.trials.TrialsFromList class""" +class TrialsFromListTestCase(TrialsBaseTestCase): @classmethod def setUpClass(cls) -> None: @@ -436,7 +438,27 @@ def setUpClass(cls) -> None: cls.trial_object = TrialsFromLists(trial_list, description='trial is a list') - def setUp(self) -> None: + def assertSegmentEqualToList(self, segment, list_data, n_spiketrains, + n_analogsignals): + """ + This function compares trial data in a Segment to trial data in + a Python list. The order of objects is: `SpikeTrain`, `AnalogSignal`. + The number of spiketrains and analog signals must be informed to split + the data in the list. + """ + self.assertIsInstance(segment, Segment) + self.assertIsInstance(list_data, list) + + self.assertEqual(len(list_data), n_spiketrains + n_analogsignals) + self.assertEqual(len(segment.spiketrains), n_spiketrains) + self.assertEqual(len(segment.analogsignals), n_analogsignals) + + spiketrains = list_data[:n_spiketrains] + signals = list_data[n_spiketrains:] + self.assertSpikeTrainListEqual(segment.spiketrains, + SpikeTrainList(spiketrains)) + self.assertAnalogSignalListEqual(segment.analogsignals, signals) + def test_deprecations(self) -> None: """ Run before every test: From 7ecfec282ac343e4325c5942dc89999e9dd3b482 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:55:09 +0200 Subject: [PATCH 24/40] Refactored decorator test case to use custom assertions --- elephant/test/test_trials.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 4619bbb61..fa96305a2 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -160,8 +160,10 @@ def test_decorator_return_with_trials_input_as_arg(self) -> None: list_of_spiketrainlists = new_class.method_to_decorate( self.trial_object) self.assertEqual(len(list_of_spiketrainlists), self.n_trials) - for spiketrainlist in list_of_spiketrainlists: - self.assertIsInstance(spiketrainlist, SpikeTrainList) + for spiketrainlist, expected_list in zip( + list_of_spiketrainlists, self.list_of_list_of_spiketrains): + self.assertSpikeTrainListEqual(spiketrainlist, + SpikeTrainList(expected_list)) def test_decorator_return_with_list_of_lists_input_as_arg(self) -> None: """ @@ -172,10 +174,13 @@ def test_decorator_return_with_list_of_lists_input_as_arg(self) -> None: list_of_list_of_spiketrains = new_class.method_to_decorate( self.list_of_list_of_spiketrains) self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) - for list_of_spiketrains in list_of_list_of_spiketrains: + for list_of_spiketrains, expected_list in ( + zip(list_of_list_of_spiketrains, + self.list_of_list_of_spiketrains)): self.assertIsInstance(list_of_spiketrains, list) - for spiketrain in list_of_spiketrains: - self.assertIsInstance(spiketrain, SpikeTrain) + for spiketrain, expected_spiketrain in ( + zip(list_of_spiketrains, expected_list)): + self.assertSpikeTrainEqual(spiketrain, expected_spiketrain) def test_decorator_return_with_trials_input_as_kwarg(self) -> None: """ @@ -186,8 +191,10 @@ def test_decorator_return_with_trials_input_as_kwarg(self) -> None: list_of_spiketrainlists = new_class.method_to_decorate( trials_obj=self.trial_object) self.assertEqual(len(list_of_spiketrainlists), self.n_trials) - for spiketrainlist in list_of_spiketrainlists: - self.assertIsInstance(spiketrainlist, SpikeTrainList) + for spiketrainlist, expected_list in zip( + list_of_spiketrainlists, self.list_of_list_of_spiketrains): + self.assertSpikeTrainListEqual(spiketrainlist, + SpikeTrainList(expected_list)) def test_decorator_return_with_list_of_lists_input_as_kwarg(self) -> None: """ @@ -198,10 +205,13 @@ def test_decorator_return_with_list_of_lists_input_as_kwarg(self) -> None: list_of_list_of_spiketrains = new_class.method_to_decorate( trials_obj=self.list_of_list_of_spiketrains) self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) - for list_of_spiketrains in list_of_list_of_spiketrains: + for list_of_spiketrains, expected_list in ( + zip(list_of_list_of_spiketrains, + self.list_of_list_of_spiketrains)): self.assertIsInstance(list_of_spiketrains, list) - for spiketrain in list_of_spiketrains: - self.assertIsInstance(spiketrain, SpikeTrain) + for spiketrain, expected_spiketrain in ( + zip(list_of_spiketrains, expected_list)): + self.assertSpikeTrainEqual(spiketrain, expected_spiketrain) """Tests for elephant.trials.TrialsFromBlock class""" @@ -413,8 +423,10 @@ def test_trials_from_block_get_analogsignals_trial_by_trial(self) -> None: expected_trials) - """Tests for elephant.trials.TrialsFromList class""" class TrialsFromListTestCase(TrialsBaseTestCase): + """ + Tests for :class:`elephant.trials.TrialsFromList`. + """ @classmethod def setUpClass(cls) -> None: From a729821996196de80aa51d159a092e08f4cffaf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:55:57 +0200 Subject: [PATCH 25/40] Improved test case class documentation --- elephant/test/test_trials.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index fa96305a2..798a38b2e 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -214,8 +214,10 @@ def test_decorator_return_with_list_of_lists_input_as_kwarg(self) -> None: self.assertSpikeTrainEqual(spiketrain, expected_spiketrain) - """Tests for elephant.trials.TrialsFromBlock class""" class TrialsFromBlockTestCase(TrialsBaseTestCase): + """ + Tests for :class:`elephant.trials.TrialsFromBlock`. + """ @classmethod def setUpClass(cls) -> None: From 32d89697d70a3adf39e5fee730aab37ada6cee26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 2 Oct 2025 23:59:10 +0200 Subject: [PATCH 26/40] Improved documentation of each test --- elephant/test/test_trials.py | 75 ++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 798a38b2e..90ea28de3 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -228,8 +228,8 @@ def setUpClass(cls) -> None: def test_deprecations(self) -> None: """ - - def test_deprecations(self): + Test if all expected deprecation warnings are triggered. + """ trial_object = self.trial_object with self.assertWarns(DeprecationWarning): trial_object.get_trial_as_segment(trial_id=0) @@ -246,22 +246,22 @@ def test_deprecations(self): with self.assertWarns(DeprecationWarning): trial_object.get_analogsignals_from_trial_as_segment(trial_id=0) - def test_trials_from_block_description(self) -> None: """ - Test description of the trials object. + Test the description of the `Trials` object. """ self.assertEqual(self.trial_object.description, 'trials are segments') def test_trials_from_block_get_item(self) -> None: """ - Test get a trial from the trials. + Test to get a single trial from the `Trials` object using indexing + with brackets. Return is a `Segment`. """ self.assertIsInstance(self.trial_object[0], Segment) def test_trials_from_block_get_trial_as_segment(self) -> None: """ - Test get a trial from the trials. + Test to get a single trial from the `Trials` object as a `Segment`. """ self.assertIsInstance( self.trial_object.get_trial_as_segment(0), @@ -275,7 +275,8 @@ def test_trials_from_block_get_trial_as_segment(self) -> None: def test_trials_from_block_get_trials_as_block(self) -> None: """ - Test get a block from list of trials. + Test to get a set of specific trials grouped as a `Block`. Each trial + is a `Segment` containing all the data in the trial. """ block = self.trial_object.get_trials_as_block([0, 3, 5]) self.assertIsInstance(block, Block) @@ -284,7 +285,8 @@ def test_trials_from_block_get_trials_as_block(self) -> None: def test_trials_from_block_get_trials_as_list(self) -> None: """ - Test get a list of segments from list of trials. + Test to get a set of specific trials grouped as list of `Segment`. + Each trial is a single `Segment` containing all the data in the trial. """ list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) self.assertIsInstance(list_of_trials, list) @@ -294,13 +296,13 @@ def test_trials_from_block_get_trials_as_list(self) -> None: def test_trials_from_block_n_trials(self) -> None: """ - Test get number of trials. + Test to get the number of trials. """ self.assertEqual(self.trial_object.n_trials, len(self.block.segments)) def test_trials_from_block_n_spiketrains_trial_by_trial(self) -> None: """ - Test get number of spiketrains per trial. + Test to get the number of `SpikeTrain` objects per trial. """ self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, [len(trial.spiketrains) for trial in @@ -308,7 +310,7 @@ def test_trials_from_block_n_spiketrains_trial_by_trial(self) -> None: def test_trials_from_block_n_analogsignals_trial_by_trial(self) -> None: """ - Test get number of analogsignals per trial. + Test to get the number of `AnalogSignal` objects per trial. """ self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, [len(trial.analogsignals) for trial in @@ -317,7 +319,7 @@ def test_trials_from_block_n_analogsignals_trial_by_trial(self) -> None: def test_trials_from_block_get_spiketrains_from_trial_as_list(self ) -> None: """ - Test get spiketrains from trial as list + Test to get all spiketrains from a single trial as a `SpikeTrainList`. """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0), @@ -329,7 +331,8 @@ def test_trials_from_block_get_spiketrains_from_trial_as_list(self def test_trials_from_block_get_spiketrains_from_trial_as_segment(self ) -> None: """ - Test get spiketrains from trial as segment + Test to get the all spiketrains from a single trial as a `Segment`. + The `Segment.spiketrains` collection contains the spiketrains. """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_segment(0), @@ -341,7 +344,7 @@ def test_trials_from_block_get_spiketrains_from_trial_as_segment(self def test_trials_from_block_get_analogsignals_from_trial_as_list(self ) -> None: """ - Test get analogsignals from trial as list + Test to get all analog signals from a single trial as a list. """ self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_list(0), list) @@ -352,7 +355,8 @@ def test_trials_from_block_get_analogsignals_from_trial_as_list(self def test_trials_from_block_get_analogsignals_from_trial_as_segment(self) \ -> None: """ - Test get spiketrains from trial as segment + Test to get all analog signals from a single trial as a `Segment`. + The `Segment.analogsignals` collection contains the signals. """ self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_segment(0), @@ -363,7 +367,7 @@ def test_trials_from_block_get_analogsignals_from_trial_as_segment(self) \ def test_trials_from_block_get_spiketrains_trial_by_trial(self) -> None: """ - Test accessing all the SpikeTrain objects corresponding to the + Test to access all the `SpikeTrain` objects corresponding to the repetitions of a spiketrain across the trials. """ for st_id in (0, 1): @@ -395,7 +399,7 @@ def test_trials_from_block_get_spiketrains_trial_by_trial(self) -> None: def test_trials_from_block_get_analogsignals_trial_by_trial(self) -> None: """ - Test accessing all the AnalogSignal objects corresponding to the + Test to access all the `AnalogSignal` objects corresponding to the repetitions of an analog signal across the trials. """ for as_id in (0, 1): @@ -475,10 +479,8 @@ def assertSegmentEqualToList(self, segment, list_data, n_spiketrains, def test_deprecations(self) -> None: """ - Run before every test: + Test if all expected deprecation warnings are triggered. """ - - def test_deprecations(self): trial_object = self.trial_object with self.assertWarns(DeprecationWarning): trial_object.get_trial_as_segment(trial_id=0) @@ -497,13 +499,14 @@ def test_deprecations(self): def test_trials_from_list_description(self) -> None: """ - Test description of the trials object. + Test the description of the `Trials` object. """ self.assertEqual(self.trial_object.description, 'trial is a list') def test_trials_from_list_get_item(self) -> None: """ - Test get a trial from the trials. + Test to get a single trial from the `Trials` object using indexing + with brackets. Return is a `Segment`. """ self.assertIsInstance(self.trial_object[0], Segment) self.assertIsInstance(self.trial_object[0].spiketrains[0], SpikeTrain) @@ -512,7 +515,7 @@ def test_trials_from_list_get_item(self) -> None: def test_trials_from_list_get_trial_as_segment(self) -> None: """ - Test get a trial from the trials. + Test to get a single trial from the `Trials` object as a `Segment`. """ self.assertIsInstance( self.trial_object.get_trial_as_segment(0), Segment) @@ -525,7 +528,8 @@ def test_trials_from_list_get_trial_as_segment(self) -> None: def test_trials_from_list_get_trials_as_block(self) -> None: """ - Test get a block from list of trials. + Test to get a set of specific trials grouped as a `Block`. Each trial + is a `Segment` containing all the data in the trial. """ block = self.trial_object.get_trials_as_block([0, 3, 5]) self.assertIsInstance(block, Block) @@ -534,7 +538,8 @@ def test_trials_from_list_get_trials_as_block(self) -> None: def test_trials_from_list_get_trials_as_list(self) -> None: """ - Test get a list of segments from list of trials. + Test to get a set of specific trials grouped as list of `Segment`. + Each trial is a single `Segment` containing all the data in the trial. """ list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) self.assertIsInstance(list_of_trials, list) @@ -544,13 +549,13 @@ def test_trials_from_list_get_trials_as_list(self) -> None: def test_trials_from_list_n_trials(self) -> None: """ - Test get number of trials. + Test to get the number of trials. """ self.assertEqual(self.trial_object.n_trials, len(self.trial_list)) def test_trials_from_list_n_spiketrains_trial_by_trial(self) -> None: """ - Test get number of spiketrains per trial. + Test to get the number of `SpikeTrain` objects per trial. """ self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, [sum(map(lambda x: isinstance(x, SpikeTrain), @@ -558,7 +563,7 @@ def test_trials_from_list_n_spiketrains_trial_by_trial(self) -> None: def test_trials_from_list_n_analogsignals_trial_by_trial(self) -> None: """ - Test get number of analogsignals per trial. + Test to get the number of `AnalogSignal` objects per trial. """ self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, [sum(map(lambda x: isinstance(x, AnalogSignal), @@ -566,7 +571,7 @@ def test_trials_from_list_n_analogsignals_trial_by_trial(self) -> None: def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: """ - Test get spiketrains from trial as list + Test to get all spiketrains from a single trial as a `SpikeTrainList`. """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_list(0), @@ -578,7 +583,8 @@ def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: def test_trials_from_list_get_spiketrains_from_trial_as_segment(self ) -> None: """ - Test get spiketrains from trial as segment + Test to get the all spiketrains from a single trial as a `Segment`. + The `Segment.spiketrains` collection contains the spiketrains. """ self.assertIsInstance( self.trial_object.get_spiketrains_from_trial_as_segment(0), @@ -590,7 +596,7 @@ def test_trials_from_list_get_spiketrains_from_trial_as_segment(self def test_trials_from_list_get_analogsignals_from_trial_as_list(self ) -> None: """ - Test get analogsignals from trial as list + Test to get all analog signals from a single trial as a list. """ self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_list(0), list) @@ -602,7 +608,8 @@ def test_trials_from_list_get_analogsignals_from_trial_as_segment(self ) \ -> None: """ - Test get spiketrains from trial as segment + Test to get all analog signals from a single trial as a `Segment`. + The `Segment.analogsignals` collection contains the signals. """ self.assertIsInstance( self.trial_object.get_analogsignals_from_trial_as_segment(0), @@ -613,7 +620,7 @@ def test_trials_from_list_get_analogsignals_from_trial_as_segment(self def test_trials_from_list_get_spiketrains_trial_by_trial(self) -> None: """ - Test accessing all the SpikeTrain objects corresponding to the + Test to access all the `SpikeTrain` objects corresponding to the repetitions of a spiketrain across the trials. """ for st_id in (0, 1): @@ -646,7 +653,7 @@ def test_trials_from_list_get_spiketrains_trial_by_trial(self) -> None: def test_trials_from_list_get_analogsignals_trial_by_trial(self) -> None: """ - Test accessing all the AnalogSignal objects corresponding to the + Test to access all the `AnalogSignal` objects corresponding to the repetitions of an analog signal across the trials. """ for as_id in (0, 1): From ee1dbe200188c83b108a736dcae44b804b322c1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:00:23 +0200 Subject: [PATCH 27/40] Improved unit test for TrialsFromBlock description --- elephant/test/test_trials.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 90ea28de3..bdce29d54 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -224,7 +224,7 @@ def setUpClass(cls) -> None: block = _create_trials_block(n_trials=36) cls.block = block cls.trial_object = TrialsFromBlock(block, - description='trials are segments') + description='trial is Segment') def test_deprecations(self) -> None: """ @@ -250,7 +250,7 @@ def test_trials_from_block_description(self) -> None: """ Test the description of the `Trials` object. """ - self.assertEqual(self.trial_object.description, 'trials are segments') + self.assertEqual(self.trial_object.description, 'trial is Segment') def test_trials_from_block_get_item(self) -> None: """ From cb8bd6d785c71f6d97bce23cda372497f16073af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:05:18 +0200 Subject: [PATCH 28/40] Refactor and extended unit tests to use custom assertions to perform specific value checks --- elephant/test/test_trials.py | 307 +++++++++++++++++++---------------- 1 file changed, 171 insertions(+), 136 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index bdce29d54..70361d6bc 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -257,42 +257,49 @@ def test_trials_from_block_get_item(self) -> None: Test to get a single trial from the `Trials` object using indexing with brackets. Return is a `Segment`. """ - self.assertIsInstance(self.trial_object[0], Segment) + for trial_index in range(36): + trial_segment = self.trial_object[trial_index] + expected = self.block.segments[trial_index] + self.assertSegmentEqual(trial_segment, expected) def test_trials_from_block_get_trial_as_segment(self) -> None: """ Test to get a single trial from the `Trials` object as a `Segment`. """ - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0), - Segment) - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0).spiketrains[0], - SpikeTrain) - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0).analogsignals[0], - AnalogSignal) + for trial_index in range(36): + trial_segment = self.trial_object.get_trial_as_segment(trial_index) + expected = self.block.segments[trial_index] + self.assertSegmentEqual(trial_segment, expected) def test_trials_from_block_get_trials_as_block(self) -> None: """ Test to get a set of specific trials grouped as a `Block`. Each trial is a `Segment` containing all the data in the trial. """ - block = self.trial_object.get_trials_as_block([0, 3, 5]) - self.assertIsInstance(block, Block) - self.assertIsInstance(self.trial_object.get_trials_as_block(), Block) - self.assertEqual(len(block.segments), 3) + trial_block = self.trial_object.get_trials_as_block([0, 3, 5]) + self.assertIsInstance(trial_block, Block) + self.assertEqual(len(trial_block.segments), 3) + self.assertSegmentEqual(trial_block.segments[0], + self.block.segments[0]) + self.assertSegmentEqual(trial_block.segments[1], + self.block.segments[3]) + self.assertSegmentEqual(trial_block.segments[2], + self.block.segments[5]) def test_trials_from_block_get_trials_as_list(self) -> None: """ Test to get a set of specific trials grouped as list of `Segment`. Each trial is a single `Segment` containing all the data in the trial. """ - list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) + list_of_trials = self.trial_object.get_trials_as_list([0, 5, 7]) self.assertIsInstance(list_of_trials, list) - self.assertIsInstance(self.trial_object.get_trials_as_list(), list) - self.assertIsInstance(list_of_trials[0], Segment) self.assertEqual(len(list_of_trials), 3) + self.assertSegmentEqual(list_of_trials[0], + self.block.segments[0]) + self.assertSegmentEqual(list_of_trials[1], + self.block.segments[5]) + self.assertSegmentEqual(list_of_trials[2], + self.block.segments[7]) def test_trials_from_block_n_trials(self) -> None: """ @@ -321,12 +328,14 @@ def test_trials_from_block_get_spiketrains_from_trial_as_list(self """ Test to get all spiketrains from a single trial as a `SpikeTrainList`. """ - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_list(0), - SpikeTrainList) - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_list(0)[0], - SpikeTrain) + for trial_index in range(36): + trial_spiketrains = ( + self.trial_object.get_spiketrains_from_trial_as_list( + trial_index + ) + ) + expected = self.block.segments[trial_index].spiketrains + self.assertSpikeTrainListEqual(trial_spiketrains, expected) def test_trials_from_block_get_spiketrains_from_trial_as_segment(self ) -> None: @@ -334,23 +343,32 @@ def test_trials_from_block_get_spiketrains_from_trial_as_segment(self Test to get the all spiketrains from a single trial as a `Segment`. The `Segment.spiketrains` collection contains the spiketrains. """ - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment(0), - Segment) - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment( - 0).spiketrains[0], SpikeTrain) + for trial_index in range(36): + trial_spiketrains = ( + self.trial_object.get_spiketrains_from_trial_as_segment( + trial_index + ) + ) + expected = self.block.segments[trial_index].spiketrains + self.assertIsInstance(trial_spiketrains, Segment) + self.assertEqual(len(trial_spiketrains.analogsignals), 0) + self.assertSpikeTrainListEqual(trial_spiketrains.spiketrains, + expected) def test_trials_from_block_get_analogsignals_from_trial_as_list(self ) -> None: """ Test to get all analog signals from a single trial as a list. """ - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0), list) - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0)[0], - AnalogSignal) + for trial_index in range(36): + trial_signals = ( + self.trial_object.get_analogsignals_from_trial_as_list( + trial_index + ) + ) + expected = self.block.segments[trial_index].analogsignals + self.assertIsInstance(trial_signals, list) + self.assertAnalogSignalListEqual(trial_signals, expected) def test_trials_from_block_get_analogsignals_from_trial_as_segment(self) \ -> None: @@ -358,12 +376,17 @@ def test_trials_from_block_get_analogsignals_from_trial_as_segment(self) \ Test to get all analog signals from a single trial as a `Segment`. The `Segment.analogsignals` collection contains the signals. """ - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment(0), - Segment) - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment( - 0).analogsignals[0], AnalogSignal) + for trial_index in range(36): + trial_signals = ( + self.trial_object.get_analogsignals_from_trial_as_segment( + trial_index + ) + ) + expected = self.block.segments[trial_index].analogsignals + self.assertIsInstance(trial_signals, Segment) + self.assertEqual(len(trial_signals.spiketrains), 0) + self.assertAnalogSignalListEqual(trial_signals.analogsignals, + expected) def test_trials_from_block_get_spiketrains_trial_by_trial(self) -> None: """ @@ -371,29 +394,22 @@ def test_trials_from_block_get_spiketrains_trial_by_trial(self) -> None: repetitions of a spiketrain across the trials. """ for st_id in (0, 1): - spiketrains = self.trial_object.get_spiketrains_trial_by_trial(st_id) - - # Return is neo.SpikeTrainList - self.assertIsInstance(spiketrains, SpikeTrainList) - - # All elements are neo.SpikeTrain - self.assertTrue(all(map(lambda x: isinstance(x, SpikeTrain), - spiketrains) - ) - ) - - # Data for all trials is returned - self.assertEqual(len(spiketrains), self.trial_object.n_trials) + spiketrains = self.trial_object.get_spiketrains_trial_by_trial( + st_id) + expected_spiketrains = [trial.spiketrains[st_id] + for trial in self.block.segments] + self.assertEqual(len(spiketrains), 36) + self.assertSpikeTrainListEqual(spiketrains, + SpikeTrainList(expected_spiketrains)) - # Each trial-specific SpikeTrain object is from the same spiketrain + # Each trial-specific `SpikeTrain` object is from the same spiketrain self.assertTrue(all([st.name == f"Spiketrain {st_id}" for st in spiketrains] ) ) # Order of spiketrains is the order of the trials - expected_trials = [f"Trial {i}" - for i in range(self.trial_object.n_trials)] + expected_trials = [f"Trial {i}" for i in range(36)] self.assertListEqual([st.description for st in spiketrains], expected_trials) @@ -404,27 +420,20 @@ def test_trials_from_block_get_analogsignals_trial_by_trial(self) -> None: """ for as_id in (0, 1): signals = self.trial_object.get_analogsignals_trial_by_trial(as_id) - - # Return is list + expected_signals = [trial.analogsignals[as_id] + for trial in self.block.segments] + self.assertEqual(len(signals), 36) self.assertIsInstance(signals, list) + self.assertAnalogSignalListEqual(signals, expected_signals) - # All elements are neo.AnalogSignal - self.assertTrue(all(map(lambda x: isinstance(x, AnalogSignal), - signals) - ) - ) - # Data for all trials returned - self.assertEqual(len(signals), self.trial_object.n_trials) - - # Each trial-specific AnalogSignal object is from the same signal + # Each trial-specific `AnalogSignal` object is from the same signal self.assertTrue(all([signal.name == f"Signal {as_id}" for signal in signals] ) ) # Order in the list is the order of the trials - expected_trials = [f"Trial {i}" - for i in range(self.trial_object.n_trials)] + expected_trials = [f"Trial {i}" for i in range(36)] self.assertListEqual([signal.description for signal in signals], expected_trials) @@ -508,33 +517,46 @@ def test_trials_from_list_get_item(self) -> None: Test to get a single trial from the `Trials` object using indexing with brackets. Return is a `Segment`. """ - self.assertIsInstance(self.trial_object[0], Segment) - self.assertIsInstance(self.trial_object[0].spiketrains[0], SpikeTrain) - self.assertIsInstance(self.trial_object[0].analogsignals[0], - AnalogSignal) + for trial_index in range(36): + trial_segment = self.trial_object[trial_index] + expected = self.trial_list[trial_index] + self.assertSegmentEqualToList(trial_segment, + expected, + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) def test_trials_from_list_get_trial_as_segment(self) -> None: """ Test to get a single trial from the `Trials` object as a `Segment`. """ - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0), Segment) - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0).spiketrains[0], - SpikeTrain) - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0).analogsignals[0], - AnalogSignal) + for trial_index in range(36): + trial_segment = self.trial_object[trial_index] + expected = self.trial_list[trial_index] + self.assertSegmentEqualToList(trial_segment, + expected, + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) def test_trials_from_list_get_trials_as_block(self) -> None: """ Test to get a set of specific trials grouped as a `Block`. Each trial is a `Segment` containing all the data in the trial. """ - block = self.trial_object.get_trials_as_block([0, 3, 5]) - self.assertIsInstance(block, Block) - self.assertIsInstance(self.trial_object.get_trials_as_block(), Block) - self.assertEqual(len(block.segments), 3) + trial_block = self.trial_object.get_trials_as_block([1, 6, 18]) + self.assertIsInstance(trial_block, Block) + self.assertEqual(len(trial_block.segments), 3) + self.assertSegmentEqualToList(trial_block.segments[0], + self.trial_list[1], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + self.assertSegmentEqualToList(trial_block.segments[1], + self.trial_list[6], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + self.assertSegmentEqualToList(trial_block.segments[2], + self.trial_list[18], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) def test_trials_from_list_get_trials_as_list(self) -> None: """ @@ -543,9 +565,19 @@ def test_trials_from_list_get_trials_as_list(self) -> None: """ list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) self.assertIsInstance(list_of_trials, list) - self.assertIsInstance(self.trial_object.get_trials_as_list(), list) - self.assertIsInstance(list_of_trials[0], Segment) self.assertEqual(len(list_of_trials), 3) + self.assertSegmentEqualToList(list_of_trials[0], + self.trial_list[0], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + self.assertSegmentEqualToList(list_of_trials[1], + self.trial_list[3], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + self.assertSegmentEqualToList(list_of_trials[2], + self.trial_list[5], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) def test_trials_from_list_n_trials(self) -> None: """ @@ -573,12 +605,15 @@ def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: """ Test to get all spiketrains from a single trial as a `SpikeTrainList`. """ - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_list(0), - SpikeTrainList) - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_list(0)[0], - SpikeTrain) + for trial_index in range(36): + trial_spiketrains = ( + self.trial_object.get_spiketrains_from_trial_as_list( + trial_index + ) + ) + expected = self.trial_list[trial_index][:self.n_spiketrains] + self.assertSpikeTrainListEqual(trial_spiketrains, + SpikeTrainList(expected)) def test_trials_from_list_get_spiketrains_from_trial_as_segment(self ) -> None: @@ -586,23 +621,32 @@ def test_trials_from_list_get_spiketrains_from_trial_as_segment(self Test to get the all spiketrains from a single trial as a `Segment`. The `Segment.spiketrains` collection contains the spiketrains. """ - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment(0), - Segment) - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment( - 0).spiketrains[0], SpikeTrain) + for trial_index in range(36): + trial_spiketrains = ( + self.trial_object.get_spiketrains_from_trial_as_segment( + trial_index + ) + ) + expected = self.trial_list[trial_index][:self.n_spiketrains] + self.assertIsInstance(trial_spiketrains, Segment) + self.assertEqual(len(trial_spiketrains.analogsignals), 0) + self.assertSpikeTrainListEqual(trial_spiketrains.spiketrains, + SpikeTrainList(expected)) def test_trials_from_list_get_analogsignals_from_trial_as_list(self ) -> None: """ Test to get all analog signals from a single trial as a list. """ - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0), list) - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0)[0], - AnalogSignal) + for trial_index in range(36): + trial_signals = ( + self.trial_object.get_analogsignals_from_trial_as_list( + trial_index + ) + ) + expected = self.trial_list[trial_index][self.n_spiketrains:] + self.assertIsInstance(trial_signals, list) + self.assertAnalogSignalListEqual(trial_signals, expected) def test_trials_from_list_get_analogsignals_from_trial_as_segment(self ) \ @@ -611,12 +655,17 @@ def test_trials_from_list_get_analogsignals_from_trial_as_segment(self Test to get all analog signals from a single trial as a `Segment`. The `Segment.analogsignals` collection contains the signals. """ - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment(0), - Segment) - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment( - 0).analogsignals[0], AnalogSignal) + for trial_index in range(36): + trial_signals = ( + self.trial_object.get_analogsignals_from_trial_as_segment( + trial_index + ) + ) + expected = self.trial_list[trial_index][self.n_spiketrains:] + self.assertIsInstance(trial_signals, Segment) + self.assertEqual(len(trial_signals.spiketrains), 0) + self.assertAnalogSignalListEqual(trial_signals.analogsignals, + expected) def test_trials_from_list_get_spiketrains_trial_by_trial(self) -> None: """ @@ -626,28 +675,20 @@ def test_trials_from_list_get_spiketrains_trial_by_trial(self) -> None: for st_id in (0, 1): spiketrains = self.trial_object.get_spiketrains_trial_by_trial( st_id) + expected_spiketrains = [trial[:self.n_spiketrains][st_id] + for trial in self.trial_list] + self.assertEqual(len(spiketrains), 36) + self.assertSpikeTrainListEqual(spiketrains, + SpikeTrainList(expected_spiketrains)) - # Return is neo.SpikeTrainList - self.assertIsInstance(spiketrains, SpikeTrainList) - - # All elements are neo.SpikeTrain - self.assertTrue(all(map(lambda x: isinstance(x, SpikeTrain), - spiketrains) - ) - ) - - # Data for all trials is returned - self.assertEqual(len(spiketrains), self.trial_object.n_trials) - - # Each trial-specific SpikeTrain object is from the same spiketrain + # Each trial-specific `SpikeTrain` object is from the same spiketrain self.assertTrue(all([st.name == f"Spiketrain {st_id}" for st in spiketrains] ) ) # Order of spiketrains is the order of the trials - expected_trials = [f"Trial {i}" - for i in range(self.trial_object.n_trials)] + expected_trials = [f"Trial {i}" for i in range(36)] self.assertListEqual([st.description for st in spiketrains], expected_trials) @@ -658,19 +699,13 @@ def test_trials_from_list_get_analogsignals_trial_by_trial(self) -> None: """ for as_id in (0, 1): signals = self.trial_object.get_analogsignals_trial_by_trial(as_id) - - # Return is list + expected_signals = [trial[self.n_spiketrains:][as_id] + for trial in self.trial_list] + self.assertEqual(len(signals), 36) self.assertIsInstance(signals, list) + self.assertAnalogSignalListEqual(signals, expected_signals) - # All elements are neo.AnalogSignal - self.assertTrue(all(map(lambda x: isinstance(x, AnalogSignal), - signals) - ) - ) - # Data for all trials returned - self.assertEqual(len(signals), self.trial_object.n_trials) - - # Each trial-specific AnalogSignal object is from the same signal + # Each trial-specific `AnalogSignal` object is from the same signal self.assertTrue(all([signal.name == f"Signal {as_id}" for signal in signals] ) From faf6f233cd72869bcee61b69a8c8458723165596 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:06:18 +0200 Subject: [PATCH 29/40] Split different unit tests --- elephant/test/test_trials.py | 61 +++++++++++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 4 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 70361d6bc..121f88be8 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -271,7 +271,7 @@ def test_trials_from_block_get_trial_as_segment(self) -> None: expected = self.block.segments[trial_index] self.assertSegmentEqual(trial_segment, expected) - def test_trials_from_block_get_trials_as_block(self) -> None: + def test_trials_from_block_get_trials_as_block_indexes(self) -> None: """ Test to get a set of specific trials grouped as a `Block`. Each trial is a `Segment` containing all the data in the trial. @@ -286,7 +286,19 @@ def test_trials_from_block_get_trials_as_block(self) -> None: self.assertSegmentEqual(trial_block.segments[2], self.block.segments[5]) - def test_trials_from_block_get_trials_as_list(self) -> None: + def test_trials_from_block_get_trials_as_block(self) -> None: + """ + Test to get all trials grouped as a `Block`, where each trial is a + single `Segment`. + """ + trial_block = self.trial_object.get_trials_as_block() + self.assertIsInstance(trial_block, Block) + self.assertEqual(len(trial_block.segments), 36) + for trial, expected_trial in zip(trial_block.segments, + self.block.segments): + self.assertSegmentEqual(trial, expected_trial) + + def test_trials_from_block_get_trials_as_list_indexes(self) -> None: """ Test to get a set of specific trials grouped as list of `Segment`. Each trial is a single `Segment` containing all the data in the trial. @@ -301,6 +313,18 @@ def test_trials_from_block_get_trials_as_list(self) -> None: self.assertSegmentEqual(list_of_trials[2], self.block.segments[7]) + def test_trials_from_block_get_trials_as_list(self) -> None: + """ + Test to get all the trials grouped as a list of `Segment`. Each trial + is a single `Segment` containing all the data in the trial. + """ + list_of_trials = self.trial_object.get_trials_as_list() + self.assertIsInstance(list_of_trials, list) + self.assertEqual(len(list_of_trials), 36) + for trial, expected_trial in zip(list_of_trials, + self.block.segments): + self.assertSegmentEqual(trial, expected_trial) + def test_trials_from_block_n_trials(self) -> None: """ Test to get the number of trials. @@ -537,7 +561,7 @@ def test_trials_from_list_get_trial_as_segment(self) -> None: n_spiketrains=self.n_spiketrains, n_analogsignals=self.n_analogsignals) - def test_trials_from_list_get_trials_as_block(self) -> None: + def test_trials_from_list_get_trials_as_block_indexes(self) -> None: """ Test to get a set of specific trials grouped as a `Block`. Each trial is a `Segment` containing all the data in the trial. @@ -558,7 +582,22 @@ def test_trials_from_list_get_trials_as_block(self) -> None: n_spiketrains=self.n_spiketrains, n_analogsignals=self.n_analogsignals) - def test_trials_from_list_get_trials_as_list(self) -> None: + def test_trials_from_list_get_trials_as_block(self) -> None: + """ + Test to get all trials grouped as a `Block`, where each trial is a + single `Segment`. + """ + trial_block = self.trial_object.get_trials_as_block() + self.assertIsInstance(trial_block, Block) + self.assertEqual(len(trial_block.segments), 36) + for trial, expected_trial in zip(trial_block.segments, + self.trial_list): + self.assertSegmentEqualToList(trial, + expected_trial, + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + + def test_trials_from_list_get_trials_as_list_indexes(self) -> None: """ Test to get a set of specific trials grouped as list of `Segment`. Each trial is a single `Segment` containing all the data in the trial. @@ -579,6 +618,20 @@ def test_trials_from_list_get_trials_as_list(self) -> None: n_spiketrains=self.n_spiketrains, n_analogsignals=self.n_analogsignals) + def test_trials_from_list_get_trials_as_list(self) -> None: + """ + Test to get all the trials grouped as a list of `Segment`. Each trial + is a single `Segment` containing all the data in the trial. + """ + list_of_trials = self.trial_object.get_trials_as_list() + self.assertIsInstance(list_of_trials, list) + self.assertEqual(len(list_of_trials), 36) + for trial, expected_trial in zip(list_of_trials, self.trial_list): + self.assertSegmentEqualToList(trial, + expected_trial, + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + def test_trials_from_list_n_trials(self) -> None: """ Test to get the number of trials. From 5a04ecfd8aa137e6eea1901fca3024f286e382de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:08:23 +0200 Subject: [PATCH 30/40] Corrected typo --- elephant/test/test_trials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 121f88be8..726d3ee1f 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -nit tests for the objects of the API handling trial data in Elephant. +Unit tests for the objects of the API handling trial data in Elephant. :copyright: Copyright 2014-2025 by the Elephant team, see AUTHORS.txt. :license: Modified BSD, see LICENSE.txt for details. From 5e901ccb26e9a23789e904c8a3717c2f1f2b3d7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:14:20 +0200 Subject: [PATCH 31/40] Changed to use ABC (Python 3 style) --- elephant/trials.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index a2142da1e..ffe81b19c 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -50,7 +50,7 @@ :license: Modified BSD, see LICENSE.txt for details. """ -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from typing import List from functools import wraps @@ -113,7 +113,7 @@ def wrapper(*args, **kwargs): return wrapper -class Trials: +class Trials(ABC): """ Base class for handling trials. @@ -129,7 +129,6 @@ class attribute `description`. """ - __metaclass__ = ABCMeta def __init__(self, description: str = "Trials"): """Create an instance of the trials class.""" From 98ba9f68138050c983249a7661654e14649ce541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:14:58 +0200 Subject: [PATCH 32/40] Organized imports --- elephant/trials.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/elephant/trials.py b/elephant/trials.py index ffe81b19c..c4793010f 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -55,7 +55,8 @@ from functools import wraps import numpy as np -import neo.utils + +import neo from neo.core import Segment, Block from neo.core.spiketrainlist import SpikeTrainList from elephant.utils import deprecated_alias From 72eb5e2bfd68443b5d588048a9a3a5715c5e92ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:16:58 +0200 Subject: [PATCH 33/40] Added NotImplementedError exceptions to abstract methods --- elephant/trials.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index c4793010f..9d5abd68d 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -137,7 +137,8 @@ def __init__(self, description: str = "Trials"): @abstractmethod def __getitem__(self, trial_index: int) -> neo.core.Segment: - pass + # Get a specific trial by its index as a Segment + raise NotImplementedError @abstractmethod def n_trials(self) -> int: @@ -147,7 +148,7 @@ def n_trials(self) -> int: ------- int: Number of trials """ - pass + raise NotImplementedError @abstractmethod def n_spiketrains_trial_by_trial(self) -> List[int]: @@ -158,7 +159,7 @@ def n_spiketrains_trial_by_trial(self) -> List[int]: list of int: For each trial, contains the number of spike trains in the trial. """ - pass + raise NotImplementedError @abstractmethod def n_analogsignals_trial_by_trial(self) -> List[int]: @@ -169,7 +170,7 @@ def n_analogsignals_trial_by_trial(self) -> List[int]: list of int: For each trial, contains the number of analogsignal objects in the trial. """ - pass + raise NotImplementedError @abstractmethod def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: @@ -248,7 +249,7 @@ def get_spiketrains_from_trial_as_list(self, trial_index: int) -> ( list of :class:`neo.SpikeTrain` List of all spike trains of the trial. """ - pass + raise NotImplementedError @abstractmethod @deprecated_alias(trial_id="trial_index") @@ -285,7 +286,7 @@ def get_analogsignals_from_trial_as_list(self, trial_index: int list of :class`neo.AnalogSignal`: list of all analogsignal objects of the trial. """ - pass + raise NotImplementedError @abstractmethod @deprecated_alias(trial_id="trial_index") @@ -330,7 +331,7 @@ def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( objects for the specified `spiketrain_id`, ordered from the first trial (ID 0) to the last (ID `n_trials - 1`). """ - pass + raise NotImplementedError @abstractmethod def get_analogsignals_trial_by_trial(self, signal_index: int @@ -356,7 +357,7 @@ def get_analogsignals_trial_by_trial(self, signal_index: int specified `signal_id`, ordered from the first trial (ID 0) to the last (ID `n_trials - 1`). """ - pass + raise NotImplementedError class TrialsFromBlock(Trials): From 7d3f9f49a9afd89e78fec1c6184125be0de534d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:19:50 +0200 Subject: [PATCH 34/40] Moved all methods with redundant code to base class --- elephant/trials.py | 121 ++++++++------------------------------------- 1 file changed, 21 insertions(+), 100 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index 9d5abd68d..8838e9df9 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -140,6 +140,7 @@ def __getitem__(self, trial_index: int) -> neo.core.Segment: # Get a specific trial by its index as a Segment raise NotImplementedError + @property @abstractmethod def n_trials(self) -> int: """Get the number of trials. @@ -172,9 +173,6 @@ def n_analogsignals_trial_by_trial(self) -> List[int]: """ raise NotImplementedError - @abstractmethod - def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: - """Get trial as segment. @deprecated_alias(trial_id="trial_index") def get_trial_as_segment(self, trial_index: int) -> neo.core.Segment: @@ -188,9 +186,8 @@ def get_trial_as_segment(self, trial_index: int) -> neo.core.Segment: class:`neo.Segment`: a segment containing all spike trains and analogsignal objects of the trial. """ - pass + return self.__getitem__(trial_index) - @abstractmethod @deprecated_alias(trial_ids="trial_indexes") def get_trials_as_block(self, trial_indexes: List[int] = None ) -> neo.core.Block: @@ -209,9 +206,13 @@ def get_trials_as_block(self, trial_indexes: List[int] = None each of the selected trials, each containing all spike trains and analogsignal objects of the corresponding trial. """ - pass + block = Block() + if not trial_indexes: + trial_indexes = list(range(self.n_trials)) + for trial_index in trial_indexes: + block.segments.append(self.get_trial_as_segment(trial_index)) + return block - @abstractmethod @deprecated_alias(trial_ids="trial_indexes") def get_trials_as_list(self, trial_indexes: List[int] = None ) -> neo.core.spiketrainlist.SpikeTrainList: @@ -230,7 +231,10 @@ def get_trials_as_list(self, trial_indexes: List[int] = None objects for each of the selected trials, each containing all spike trains and analogsignal objects of the corresponding trial. """ - pass + if not trial_indexes: + trial_indexes = list(range(self.n_trials)) + return [self.get_trial_as_segment(trial_index) + for trial_index in trial_indexes] @abstractmethod @deprecated_alias(trial_id="trial_index") @@ -251,7 +255,6 @@ def get_spiketrains_from_trial_as_list(self, trial_index: int) -> ( """ raise NotImplementedError - @abstractmethod @deprecated_alias(trial_id="trial_index") def get_spiketrains_from_trial_as_segment(self, trial_index: int ) -> neo.core.Segment: @@ -267,7 +270,10 @@ def get_spiketrains_from_trial_as_segment(self, trial_index: int ------- :class:`neo.Segment`: Segment containing all spike trains of the trial. """ - pass + segment = neo.core.Segment() + for spiketrain in self.get_spiketrains_from_trial_as_list(trial_index): + segment.spiketrains.append(spiketrain) + return segment @abstractmethod @deprecated_alias(trial_id="trial_index") @@ -288,7 +294,6 @@ def get_analogsignals_from_trial_as_list(self, trial_index: int """ raise NotImplementedError - @abstractmethod @deprecated_alias(trial_id="trial_index") def get_analogsignals_from_trial_as_segment(self, trial_index: int ) -> neo.core.Segment: @@ -306,6 +311,11 @@ def get_analogsignals_from_trial_as_segment(self, trial_index: int class:`neo.Segment`: segment containing all analogsignal objects of the trial. """ + segment = neo.core.Segment() + for analogsignal in self.get_analogsignals_from_trial_as_list( + trial_index): + segment.analogsignals.append(analogsignal) + return segment @abstractmethod def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( @@ -385,31 +395,6 @@ def __getitem__(self, trial_index: int) -> neo.core.segment: # Get a specific trial by its index as a Segment return self.block.segments[trial_index] - @deprecated_alias(trial_id="trial_index") - def get_trial_as_segment(self, trial_index: int) -> neo.core.Segment: - # Get a specific trial by its index as a Segment - return self.__getitem__(trial_index) - - @deprecated_alias(trial_ids="trial_indexes") - def get_trials_as_block(self, trial_indexes: List[int] = None - ) -> neo.core.Block: - # Get a set of trials by their indexes as a Block - block = Block() - if not trial_indexes: - trial_indexes = list(range(self.n_trials)) - for trial_index in trial_indexes: - block.segments.append(self.get_trial_as_segment(trial_index)) - return block - - @deprecated_alias(trial_ids="trial_indexes") - def get_trials_as_list(self, trial_indexes: List[int] = None - ) -> List[neo.core.Segment]: - # Get a set of trials by their indexes as a list of Segment - if not trial_indexes: - trial_indexes = list(range(self.n_trials)) - return [self.get_trial_as_segment(trial_index) - for trial_index in trial_indexes] - @property def n_trials(self) -> int: # Get the number of trials @@ -432,16 +417,6 @@ def get_spiketrains_from_trial_as_list(self, trial_index: int = 0) -> ( return SpikeTrainList(items=[spiketrain for spiketrain in self.block.segments[trial_index].spiketrains]) - @deprecated_alias(trial_id="trial_index") - def get_spiketrains_from_trial_as_segment(self, trial_index: int) -> ( - neo.core.Segment): - # Return a Segment with all spike trains from a trial - segment = neo.core.Segment() - for spiketrain in self.get_spiketrains_from_trial_as_list(trial_index - ): - segment.spiketrains.append(spiketrain) - return segment - def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( neo.core.spiketrainlist.SpikeTrainList): # Return a list of all spike train repetitions across trials @@ -455,16 +430,6 @@ def get_analogsignals_from_trial_as_list(self, trial_index: int) -> ( return [analogsignal for analogsignal in self.block.segments[trial_index].analogsignals] - @deprecated_alias(trial_id="trial_index") - def get_analogsignals_from_trial_as_segment(self, trial_index: int) -> ( - neo.core.Segment): - # Return a Segment with all analog signals from a trial - segment = neo.core.Segment() - for analogsignal in self.get_analogsignals_from_trial_as_list( - trial_index): - segment.analogsignals.append(analogsignal) - return segment - def get_analogsignals_trial_by_trial(self, signal_index: int) -> ( List[neo.core.AnalogSignal]): # Return a list of all analog signal repetitions across trials @@ -514,31 +479,6 @@ def __getitem__(self, trial_index: int) -> neo.core.Segment: segment.analogsignals.append(element) return segment - @deprecated_alias(trial_id="trial_index") - def get_trial_as_segment(self, trial_index: int) -> neo.core.Segment: - # Get a specific trial by its index as a Segment - return self.__getitem__(trial_index) - - @deprecated_alias(trial_ids="trial_indexes") - def get_trials_as_block(self, trial_indexes: List[int] = None - ) -> neo.core.Block: - # Get a set of trials by their indexes as a Block - if not trial_indexes: - trial_indexes = list(range(self.n_trials)) - block = Block() - for trial_index in trial_indexes: - block.segments.append(self.get_trial_as_segment(trial_index)) - return block - - @deprecated_alias(trial_ids="trial_indexes") - def get_trials_as_list(self, trial_indexes: List[int] = None - ) -> List[neo.core.Segment]: - # Get a set of trials by their indexes as a list of Segment - if not trial_indexes: - trial_indexes = list(range(self.n_trials)) - return [self.get_trial_as_segment(trial_index) - for trial_index in trial_indexes] - @property def n_trials(self) -> int: # Get the number of trials @@ -564,15 +504,6 @@ def get_spiketrains_from_trial_as_list(self, trial_index: int = 0) -> ( spiketrain for spiketrain in self.list_of_trials[trial_index] if isinstance(spiketrain, neo.core.SpikeTrain)]) - @deprecated_alias(trial_id="trial_index") - def get_spiketrains_from_trial_as_segment(self, trial_index: int) -> ( - neo.core.Segment): - # Return a Segment with all spike trains from a trial - segment = neo.core.Segment() - for spiketrain in self.get_spiketrains_from_trial_as_list(trial_index): - segment.spiketrains.append(spiketrain) - return segment - def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( neo.core.spiketrainlist.SpikeTrainList): # Return a list of all spike train repetitions across trials @@ -588,16 +519,6 @@ def get_analogsignals_from_trial_as_list(self, trial_index: int) -> ( self.list_of_trials[trial_index] if isinstance(analogsignal, neo.core.AnalogSignal)] - @deprecated_alias(trial_id="trial_index") - def get_analogsignals_from_trial_as_segment(self, trial_index: int) -> ( - neo.core.Segment): - # Return a Segment with all analog signals from a trial - segment = neo.core.Segment() - for analogsignal in self.get_analogsignals_from_trial_as_list( - trial_index): - segment.analogsignals.append(analogsignal) - return segment - def get_analogsignals_trial_by_trial(self, signal_index: int) -> ( List[neo.core.AnalogSignal]): # Return a list of all analog signal repetitions across trials From f8b3fcac2d552988054726ee07b5dbc0c9be849f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:21:14 +0200 Subject: [PATCH 35/40] Improved documentation of the base Trials class. Default for description is None. --- elephant/trials.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index 8838e9df9..a3bbe53e8 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -116,23 +116,34 @@ def wrapper(*args, **kwargs): class Trials(ABC): """ - Base class for handling trials. + Abstract base class for handling trial-based data in Elephant. - This is the base class from which all trial objects inherit. - This class implements support for universally recommended arguments. + The `Trials` class defines a standardized interface for accessing and + manipulating trial data. It provides universally recommended methods and + attributes for trial handling, and serves as the base class for all + data-structure-specific implementations. + + Child classes such as :class:`TrialsFromBlock` and :class:`TrialsFromLists` + support specific input data structures. Usage details and examples are + provided in their respective documentation. Parameters ---------- - description : string, optional - A textual description of the set of trials. Can be accessed via the - class attribute `description`. - Default: None. + description : str, optional + A textual description of the set of trials. Accessible via the + `description` attribute. + Default: None + See Also + -------- + :class:`TrialsFromBlock` + :class:`TrialsFromLists` """ - - def __init__(self, description: str = "Trials"): - """Create an instance of the trials class.""" + def __init__(self, description: str = None): + """ + Create an instance of the `Trials` class. + """ self.description = description @abstractmethod From 8302c6841c0e83c0886ef1b7a634221f366c79e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Fri, 3 Oct 2025 00:30:52 +0200 Subject: [PATCH 36/40] Replaced dynamic expected values by explicit values --- elephant/test/test_trials.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 726d3ee1f..31b0bc7e7 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -329,23 +329,21 @@ def test_trials_from_block_n_trials(self) -> None: """ Test to get the number of trials. """ - self.assertEqual(self.trial_object.n_trials, len(self.block.segments)) + self.assertEqual(self.trial_object.n_trials, 36) def test_trials_from_block_n_spiketrains_trial_by_trial(self) -> None: """ Test to get the number of `SpikeTrain` objects per trial. """ self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, - [len(trial.spiketrains) for trial in - self.block.segments]) + [2] * 36) def test_trials_from_block_n_analogsignals_trial_by_trial(self) -> None: """ Test to get the number of `AnalogSignal` objects per trial. """ self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, - [len(trial.analogsignals) for trial in - self.block.segments]) + [2] * 36) def test_trials_from_block_get_spiketrains_from_trial_as_list(self ) -> None: @@ -636,23 +634,21 @@ def test_trials_from_list_n_trials(self) -> None: """ Test to get the number of trials. """ - self.assertEqual(self.trial_object.n_trials, len(self.trial_list)) + self.assertEqual(self.trial_object.n_trials, 36) def test_trials_from_list_n_spiketrains_trial_by_trial(self) -> None: """ Test to get the number of `SpikeTrain` objects per trial. """ self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, - [sum(map(lambda x: isinstance(x, SpikeTrain), - trial)) for trial in self.trial_list]) + [self.n_spiketrains] * 36) def test_trials_from_list_n_analogsignals_trial_by_trial(self) -> None: """ Test to get the number of `AnalogSignal` objects per trial. """ self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, - [sum(map(lambda x: isinstance(x, AnalogSignal), - trial)) for trial in self.trial_list]) + [self.n_analogsignals] * 36) def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: """ From 38ca75c238f17a97b8edfdbde1982c008069927b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 9 Oct 2025 00:00:42 +0200 Subject: [PATCH 37/40] Revised docstrings for content and clarity. Added examples. --- elephant/trials.py | 434 ++++++++++++++++++++++++++++++--------------- 1 file changed, 292 insertions(+), 142 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index a3bbe53e8..75e617927 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -11,12 +11,12 @@ Neo does not impose a specific way to represent trial data. A natural way to represent trials is to have a :class:`neo.Block` containing multiple :class:`neo.Segment` objects, each representing the data of one trial. Another -popular option is to store trials as lists of lists, where the outer refers to -individual lists, and inner lists contain Neo data objects -(:class:`neo.SpikeTrain` and :class:`neo.AnalogSignal`) containing individual -data of each trial. +popular option is to store trials as a list of lists, where the outer refers to +the collection of trials, and inner lists contain Neo data objects +(:class:`neo.SpikeTrain` and :class:`neo.AnalogSignal`) containing the +individual data of each trial. -The classes of this module abstract from these individual data representations +The classes of this module abstract from these specific data representations by introducing a set of :class:`Trials` classes with a common API. These classes are initialized by a supported way of structuring trials, e.g., :class:`TrialsFromBlock` for the first method described above. Internally, the @@ -57,25 +57,26 @@ import numpy as np import neo -from neo.core import Segment, Block +from neo.core import Segment, Block, SpikeTrain, AnalogSignal from neo.core.spiketrainlist import SpikeTrainList from elephant.utils import deprecated_alias -def trials_to_list_of_spiketrainlist(method): +def trials_to_list_of_spiketrainlist(function): """ - Decorator to convert `Trials` object to a list of `SpikeTrainList` before - calling the wrapped method. + Decorator that converts each argument passed as a :class:`Trials` object + into a list of :class:`neo.SpikeTrainList` before calling the wrapped + function. Parameters ---------- - method: callable - The method to be decorated. + function: callable + The function to be decorated. Returns ------- callable - The decorated method. + The decorated function. Examples -------- @@ -86,7 +87,7 @@ def trials_to_list_of_spiketrainlist(method): ... return None """ - @wraps(method) + @wraps(function) def wrapper(*args, **kwargs): new_args = tuple( [ @@ -109,7 +110,7 @@ def wrapper(*args, **kwargs): for key, value in kwargs.items() } - return method(*new_args, **new_kwargs) + return function(*new_args, **new_kwargs) return wrapper @@ -119,8 +120,8 @@ class Trials(ABC): Abstract base class for handling trial-based data in Elephant. The `Trials` class defines a standardized interface for accessing and - manipulating trial data. It provides universally recommended methods and - attributes for trial handling, and serves as the base class for all + manipulating trial data. It provides a unified set of attributes and + methods for trial handling, and serves as the base class for all data-structure-specific implementations. Child classes such as :class:`TrialsFromBlock` and :class:`TrialsFromLists` @@ -130,7 +131,7 @@ class Trials(ABC): Parameters ---------- description : str, optional - A textual description of the set of trials. Accessible via the + Textual description of the set of trials, accessible via the `description` attribute. Default: None @@ -141,81 +142,96 @@ class Trials(ABC): """ def __init__(self, description: str = None): - """ - Create an instance of the `Trials` class. - """ self.description = description @abstractmethod - def __getitem__(self, trial_index: int) -> neo.core.Segment: + def __getitem__(self, trial_index: int) -> Segment: # Get a specific trial by its index as a Segment raise NotImplementedError @property @abstractmethod def n_trials(self) -> int: - """Get the number of trials. + """ + Number of trials. Returns ------- - int: Number of trials + int + Total number of trials. """ raise NotImplementedError @abstractmethod def n_spiketrains_trial_by_trial(self) -> List[int]: - """Get the number of spike trains in each trial as a list. + """ + Number of spike trains per trial. Returns ------- - list of int: For each trial, contains the number of spike trains in the - trial. + List[int] + Number of :class:`neo.SpikeTrain` objects in each trial, ordered by + trial index in ascending order starting from zero. """ raise NotImplementedError @abstractmethod def n_analogsignals_trial_by_trial(self) -> List[int]: - """Get the number of analogsignal objects in each trial as a list. + """ + Number of analog signals per trial. Returns ------- - list of int: For each trial, contains the number of analogsignal objects - in the trial. + List[int] + Number of :class:`neo.AnalogSignal` objects in each trial, ordered + by trial index in ascending order starting from zero. """ raise NotImplementedError @deprecated_alias(trial_id="trial_index") - def get_trial_as_segment(self, trial_index: int) -> neo.core.Segment: + def get_trial_as_segment(self, trial_index: int) -> Segment: + """ + Return a single trial as a :class:`neo.Segment`. Parameters ---------- trial_index : int - Index of the trial to retrieve (zero-based). + Zero-based index of the trial to retrieve. Returns ------- - class:`neo.Segment`: a segment containing all spike trains and - analogsignal objects of the trial. + :class:`neo.Segment` + Segment containing all spike trains and analog signals associated + with the specified trial. Spike trains and analog signals are + accessed via the `spiketrains` and `analogsignals` attributes, + respectively. Their order corresponds to their index within these + collections (e.g., `spiketrains[0]` is the first spike train). """ return self.__getitem__(trial_index) @deprecated_alias(trial_ids="trial_indexes") - def get_trials_as_block(self, trial_indexes: List[int] = None - ) -> neo.core.Block: - """Get trials as block. + def get_trials_as_block(self, trial_indexes: List[int] = None) -> Block: + """ + Return multiple trials grouped into a :class:`neo.Block`. Parameters ---------- - trial_indexes : list of int - Indexes of the trials to include in the Block (zero-based). - If None is specified, all trials are returned. + trial_indexes : List[int], optional + Zero-based indices of the trials to include in the block. + If None, all trials are returned. Default: None Returns ------- - class:`neo.Block`: a Block containing :class:`neo.Segment` objects for - each of the selected trials, each containing all spike trains and - analogsignal objects of the corresponding trial. + :class:`neo.Block` + Block containing one :class:`neo.Segment` per trial. The trials are + accessed via the `segments` attribute. If all trials are included, + element indices correspond to trial indices. If a subset is + specified, the order matches that of `trial_indexes`. + + See Also + -------- + :method:`get_trial_as_segment()` """ block = Block() if not trial_indexes: @@ -226,21 +242,24 @@ def get_trials_as_block(self, trial_indexes: List[int] = None @deprecated_alias(trial_ids="trial_indexes") def get_trials_as_list(self, trial_indexes: List[int] = None - ) -> neo.core.spiketrainlist.SpikeTrainList: - """Get trials as list of segments. + ) -> List[Segment]: + """ + Return multiple trials as a list of :class:`neo.Segment` objects. Parameters ---------- - trial_indexes : list of int - Indexes of the trials to include in the list (zero-based). - If None is specified, all trials are returned. + trial_indexes : List[int], optional + Zero-based indices of the trials to include in the list. + If None, all trials are returned. Default: None Returns ------- - list of :class:`neo.Segment`: a list containing :class:`neo.Segment` - objects for each of the selected trials, each containing all spike - trains and analogsignal objects of the corresponding trial. + List[Segment] + List containing one :class:`neo.Segment` per selected trial. If all + trials are returned, list indices correspond to trial indices. + If a subset is specified, the order matches that of + `trial_indexes`. """ if not trial_indexes: trial_indexes = list(range(self.n_trials)) @@ -249,39 +268,42 @@ def get_trials_as_list(self, trial_indexes: List[int] = None @abstractmethod @deprecated_alias(trial_id="trial_index") - def get_spiketrains_from_trial_as_list(self, trial_index: int) -> ( - neo.core.spiketrainlist.SpikeTrainList): + def get_spiketrains_from_trial_as_list(self, trial_index: int + ) -> SpikeTrainList: """ - Get all spike trains from a specific trial and return a list. + Return all spike trains from a single trial as a list. Parameters ---------- trial_index : int - Index of the trial to get the spike trains from (zero-based). + Zero-based index of the trial. Returns ------- - list of :class:`neo.SpikeTrain` - List of all spike trains of the trial. + :class:`neo.SpikeTrainList` + List-like container with all :class:`neo.SpikeTrain` objects from + the specified trial. """ raise NotImplementedError @deprecated_alias(trial_id="trial_index") def get_spiketrains_from_trial_as_segment(self, trial_index: int - ) -> neo.core.Segment: + ) -> Segment: """ - Get all spike trains from a specific trial and return a Segment. + Return all spike trains from a single trial as a :class:`neo.Segment`. Parameters ---------- trial_index : int - Index of the trial to get the spike trains from (zero-based). + Zero-based index of the trial. Returns ------- - :class:`neo.Segment`: Segment containing all spike trains of the trial. + :class:`neo.Segment` + Segment containing all :class:`neo.SpikeTrain` objects from the + specified trial, accessible via the `spiketrains` attribute. """ - segment = neo.core.Segment() + segment = Segment() for spiketrain in self.get_spiketrains_from_trial_as_list(trial_index): segment.spiketrains.append(spiketrain) return segment @@ -289,40 +311,41 @@ def get_spiketrains_from_trial_as_segment(self, trial_index: int @abstractmethod @deprecated_alias(trial_id="trial_index") def get_analogsignals_from_trial_as_list(self, trial_index: int - ) -> List[neo.core.AnalogSignal]: + ) -> List[AnalogSignal]: """ - Get all analogsignals from a specific trial and return a list. + Return all analog signals from a single trial as a list. Parameters ---------- trial_index : int - Index of the trial to get the analog signals from (zero-based). + Zero-based index of the trial. Returns ------- - list of :class`neo.AnalogSignal`: list of all analogsignal objects of - the trial. + List[AnalogSignal] + List containing all :class:`neo.AnalogSignal` objects from the + specified trial. """ raise NotImplementedError @deprecated_alias(trial_id="trial_index") def get_analogsignals_from_trial_as_segment(self, trial_index: int - ) -> neo.core.Segment: + ) -> Segment: """ - Get all analogsignal objects from a specific trial and return a - :class:`neo.Segment`. + Return all analog signals from a single trial as a :class:`neo.Segment`. Parameters ---------- trial_index : int - Index of the trial to get the analog signals from (zero-based). + Zero-based index of the trial. Returns ------- - class:`neo.Segment`: segment containing all analogsignal objects of - the trial. + :class:`neo.Segment` + Segment containing all :class:`neo.AnalogSignal` objects from the + specified trial, accessible via the `analogsignals` attribute. """ - segment = neo.core.Segment() + segment = Segment() for analogsignal in self.get_analogsignals_from_trial_as_list( trial_index): segment.analogsignals.append(analogsignal) @@ -330,79 +353,139 @@ def get_analogsignals_from_trial_as_segment(self, trial_index: int @abstractmethod def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( - neo.core.spiketrainlist.SpikeTrainList): + SpikeTrainList): """ - Retrieve a spike train across all its trial repetitions. + Return spike train across all its trial repetitions. - This method returns a list containing :class:`neo.core.SpikeTrain` - objects corresponding to the same spike train (e.g., from a consistent + This method returns a list containing :class:`neo.SpikeTrain` objects + corresponding to the same spike train (e.g., from a consistent recording channel or neuronal source) across multiple trials. Parameters ---------- spiketrain_index : int - Index of the spike train to retrieve across trials. Indexing - starts at 0, so `spiketrain_index == 0` corresponds to the first - spike train in the trial data. + Zero-based index of the spike train to retrieve across trials. Returns ------- - list of :class:`neo.core.SpikeTrain` - A list-like container with the :class:`neo.core.SpikeTrain` - objects for the specified `spiketrain_id`, ordered from the first - trial (ID 0) to the last (ID `n_trials - 1`). + :class:`neo.SpikeTrainList` + List-like container storing the :class:`neo.SpikeTrain` objects + corresponding to the specified `spiketrain_index`, ordered by trial + index from zero to `n_trials - 1`. """ raise NotImplementedError @abstractmethod def get_analogsignals_trial_by_trial(self, signal_index: int - ) -> List[neo.core.AnalogSignal]: + ) -> List[AnalogSignal]: """ - Retrieve an analog signal across all its trial repetitions. + Return an analog signal across all its trial repetitions. - This method returns a list containing :class:`neo.core.AnalogSignal` + This method returns a list containing :class:`neo.AnalogSignal` objects corresponding to a continuous signal recorded from a consistent recording channel or neuronal source across multiple trials. Parameters ---------- signal_index : int - Index of the analog signal to retrieve across trials. Indexing - starts at 0, so `signal_index == 0` corresponds to the first - analog signal in the trial data. + Zero-based index of the analog signal to retrieve across trials. Returns ------- - list of :class:`neo.core.AnalogSignal` - A list with the :class:`neo.core.AnalogSignal` objects for the - specified `signal_id`, ordered from the first trial (ID 0) to the - last (ID `n_trials - 1`). + List[AnalogSignal] + List storing :class:`neo.AnalogSignal` objects corresponding to the + specified `signal_index`, ordered by trial index from zero to + `n_trials - 1`. """ raise NotImplementedError class TrialsFromBlock(Trials): """ - This class implements support for handling trials from neo.Block. + This class handles trial data organized within a :class:`neo.Block` object. + + In this representation, each trial is stored as a separate + :class:`neo.Segment` within the block. All trial segments are accessible + through the `segments` attribute. The data for a specific trial can be + accessed by its index, e.g., `segments[0]` corresponds to the first trial. + + Each :class:`neo.Segment` contains collections for spike trains and analog + signals, accessible via the `spiketrains` and `analogsignals` attributes, + respectively. When accessing data of individual :class:`neo.SpikeTrain` and + :class:`neo.AnalogSignal` objects, the indexes within these collections is + used. For instance, `spiketrains[0]` refers to the first spike train, and + `analogsignals[0]` to the first analog signal. Parameters ---------- - block : neo.Block - An instance of neo.Block containing the trials. - The structure is assumed to follow the neo representation: - A block contains multiple segments which are considered to contain the - single trials. - description : string, optional - A textual description of the set of trials. Can be accessed via the - class attribute `description`. - Default: None. + block : :class:`neo.Block` + An instance of :class:`neo.Block` containing the trial data. The block + contains multiple :class:`neo.Segment` objects, each containing the + data of a single trial. + description : str, optional + Textual description of the set of trials, accessible via the + :attr:`description` attribute. + Default: None + + Attributes + ---------- + description : str + The description of the set of trials. + n_trials : int + The total number of trials. + n_spiketrains_trial_by_trial : List[int] + The number of spike trains in each trial. + n_analogsignals_trial_by_trial : List[int] + The number of analog signals in each trial. + + Examples + -------- + 1. Generate `TrialFromBlock` object to handle data from two trials, each + containing three spike trains and one analog signal. + + >>> import numpy as np + >>> import quantities as pq + >>> import neo + >>> from elephant.spike_train_generation import StationaryPoissonProcess + >>> from elephant.trials import TrialsFromBlock + >>> + >>> st_generator = StationaryPoissonProcess(rate=10*pq.Hz, t_stop=1*pq.s) + >>> trial_block = neo.Block() + >>> for _ in range(2): + >>> trial = neo.Segment() + >>> trial.spiketrains = st_generator.generate_n_spiketrains(3) + >>> signal = np.sin(np.arange(0, 6*np.pi, 2*np.pi/1000)) + >>> signal += np.random.normal(size=signal.shape) + >>> trial.analogsignals.append( + ... neo.AnalogSignal(signal, units=pq.mV, + ... t_stop=1*pq.s, + ... sampling_rate=(1/3000)*pq.Hz) + ... ) + >>> trial_block.segments.append(trial) + >>> + >>> trials = TrialsFromBlock(trial_block) + + 2. Retrieve overall information. + + >>> print(trials.n_trials) + 2 + >>> print(trials.n_spiketrains_trial_by_trial) + [3, 3] + >>> print(trials.n_analogsignals_trial_by_trial) + [1, 1] + + 3. Access data in the first trial. + + >>> first_trial = trials[0] + >>> first_spike_train = first_trial.spiketrains[0] + >>> analog_signal = first_trial.analogsignals[0] """ - def __init__(self, block: neo.core.block, **kwargs): + def __init__(self, block: Block, **kwargs): self.block = block super().__init__(**kwargs) - def __getitem__(self, trial_index: int) -> neo.core.segment: + def __getitem__(self, trial_index: int) -> Segment: # Get a specific trial by its index as a Segment return self.block.segments[trial_index] @@ -422,27 +505,31 @@ def n_analogsignals_trial_by_trial(self) -> List[int]: return [len(trial.analogsignals) for trial in self.block.segments] @deprecated_alias(trial_id="trial_index") - def get_spiketrains_from_trial_as_list(self, trial_index: int = 0) -> ( - neo.core.spiketrainlist.SpikeTrainList): + def get_spiketrains_from_trial_as_list(self, trial_index: int = 0 + ) -> SpikeTrainList: # Return a list of all spike trains from a trial - return SpikeTrainList(items=[spiketrain for spiketrain in - self.block.segments[trial_index].spiketrains]) + return SpikeTrainList( + items=[spiketrain for spiketrain in + self.block.segments[trial_index].spiketrains] + ) - def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( - neo.core.spiketrainlist.SpikeTrainList): + def get_spiketrains_trial_by_trial(self, spiketrain_index: int + ) -> SpikeTrainList: # Return a list of all spike train repetitions across trials - return SpikeTrainList(items=[segment.spiketrains[spiketrain_index] for - segment in self.block.segments]) + return SpikeTrainList( + items=[segment.spiketrains[spiketrain_index] for + segment in self.block.segments] + ) @deprecated_alias(trial_id="trial_index") - def get_analogsignals_from_trial_as_list(self, trial_index: int) -> ( - List[neo.core.AnalogSignal]): + def get_analogsignals_from_trial_as_list(self, trial_index: int + ) -> List[AnalogSignal]: # Return a list of all analog signals from a trial return [analogsignal for analogsignal in self.block.segments[trial_index].analogsignals] - def get_analogsignals_trial_by_trial(self, signal_index: int) -> ( - List[neo.core.AnalogSignal]): + def get_analogsignals_trial_by_trial(self, signal_index: int + ) -> List[AnalogSignal]: # Return a list of all analog signal repetitions across trials return [segment.analogsignals[signal_index] for segment in self.block.segments] @@ -450,17 +537,80 @@ def get_analogsignals_trial_by_trial(self, signal_index: int) -> ( class TrialsFromLists(Trials): """ - This class implements support for handling trials from list of lists. + This class handles trial data structured as a list of lists. + + In this representation, each inner list represents a single trial and + includes one or more data elements, such as spike trains + (:class:`neo.SpikeTrain`) and/or analog signals + (:class:`neo.AnalogSignal`). The order of these elements must remain + consistent across all trial repetitions. + + The index identifying each spike train or analog signal is determined by + its position within the list. For example, if each trial contains two + elements such as a spike train followed by an analog signal, the spike + train at index 0 corresponds to the first element, and the analog signal + at index 0 corresponds to the second element. Parameters ---------- - list_of_trials : list of lists - A list of lists. Each list entry contains a list of neo.SpikeTrains - and/or neo.AnalogSignals. - description : string, optional - A textual description of the set of trials. Can be accessed via the - class attribute `description`. - Default: None. + list_of_trials : list + A list of lists containing trial data. The inner lists must contain + spike train (:class:`neo.SpikeTrain`) or analog signal + (:class:`neo.AnalogSignal`) objects. + description : str, optional + Textual description of the set of trials, accessible via the + :attr:`description` attribute. + Default: None + + Attributes + ---------- + description : str + The description of the set of trials. + n_trials : int + The total number of trials. + n_spiketrains_trial_by_trial : List[int] + The number of spike trains in each trial. + n_analogsignals_trial_by_trial : List[int] + The number of analog signals in each trial. + + Examples + -------- + 1. Generate `TrialFromLists` object to handle data from three trials, each + containing two spike trains and one analog signal. + + >>> import numpy as np + >>> import quantities as pq + >>> import neo + >>> from elephant.spike_train_generation import StationaryPoissonProcess + >>> from elephant.trials import TrialsFromLists + >>> + >>> st_generator = StationaryPoissonProcess(rate=10*pq.Hz, t_stop=1*pq.s) + >>> trial_list = [st_generator.generate_n_spiketrains(2) for _ in range(3)] + >>> for trial in trial_list: + >>> signal = np.sin(np.arange(0, 6*np.pi, 2*np.pi/1000)) + >>> signal += np.random.normal(size=signal.shape) + >>> trial.append( + ... neo.AnalogSignal(signal, units=pq.mV, + ... t_stop=1*pq.s, + ... sampling_rate=(1/3000)*pq.Hz) + ... ) + >>> + >>> trials = TrialsFromLists(trial_list) + + 2. Retrieve overall information. + + >>> print(trials.n_trials) + 3 + >>> print(trials.n_spiketrains_trial_by_trial) + [2, 2, 2] + >>> print(trials.n_analogsignals_trial_by_trial) + [1, 1, 1] + + 3. Access data in the first trial. + + >>> first_trial = trials[0] + >>> first_spike_train = first_trial.spiketrains[0] + >>> analog_signal = first_trial.analogsignals[0] """ def __init__(self, list_of_trials: list, **kwargs): @@ -472,7 +622,7 @@ def __init__(self, list_of_trials: list, **kwargs): # consistent across all trials (using the first list, corresponding # to the first trial, to fetch the indexes). if list_of_trials: - is_spiketrain = np.array([isinstance(data_element, neo.SpikeTrain) + is_spiketrain = np.array([isinstance(data_element, SpikeTrain) for data_element in list_of_trials[0]]) self._spiketrain_index = is_spiketrain.nonzero()[0] self._analogsignal_index = (~is_spiketrain).nonzero()[0] @@ -480,13 +630,13 @@ def __init__(self, list_of_trials: list, **kwargs): self._spiketrain_index = [] self._analogsignal_index = [] - def __getitem__(self, trial_index: int) -> neo.core.Segment: + def __getitem__(self, trial_index: int) -> Segment: # Get a specific trial by its index as a Segment segment = Segment() for element in self.list_of_trials[trial_index]: - if isinstance(element, neo.core.SpikeTrain): + if isinstance(element, SpikeTrain): segment.spiketrains.append(element) - if isinstance(element, neo.core.AnalogSignal): + if isinstance(element, AnalogSignal): segment.analogsignals.append(element) return segment @@ -498,40 +648,40 @@ def n_trials(self) -> int: @property def n_spiketrains_trial_by_trial(self) -> List[int]: # Get the number of SpikeTrain objects in each trial - return [sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), trial)) + return [sum(map(lambda x: isinstance(x, SpikeTrain), trial)) for trial in self.list_of_trials] @property def n_analogsignals_trial_by_trial(self) -> List[int]: # Get the number of AnalogSignal objects in each trial - return [sum(map(lambda x: isinstance(x, neo.core.AnalogSignal), trial)) + return [sum(map(lambda x: isinstance(x, AnalogSignal), trial)) for trial in self.list_of_trials] @deprecated_alias(trial_id="trial_index") - def get_spiketrains_from_trial_as_list(self, trial_index: int = 0) -> ( - neo.core.spiketrainlist.SpikeTrainList): + def get_spiketrains_from_trial_as_list(self, trial_index: int = 0 + ) -> SpikeTrainList: # Return a list of all spike trains from a trial return SpikeTrainList(items=[ spiketrain for spiketrain in self.list_of_trials[trial_index] - if isinstance(spiketrain, neo.core.SpikeTrain)]) + if isinstance(spiketrain, SpikeTrain)]) - def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( - neo.core.spiketrainlist.SpikeTrainList): + def get_spiketrains_trial_by_trial(self, spiketrain_index: int + ) -> SpikeTrainList: # Return a list of all spike train repetitions across trials list_idx = self._spiketrain_index[spiketrain_index] return SpikeTrainList(items=[trial[list_idx] for trial in self.list_of_trials]) @deprecated_alias(trial_id="trial_index") - def get_analogsignals_from_trial_as_list(self, trial_index: int) -> ( - List[neo.core.AnalogSignal]): + def get_analogsignals_from_trial_as_list(self, trial_index: int + ) -> List[AnalogSignal]: # Return a list of all analog signals from a trial return [analogsignal for analogsignal in self.list_of_trials[trial_index] - if isinstance(analogsignal, neo.core.AnalogSignal)] + if isinstance(analogsignal, AnalogSignal)] - def get_analogsignals_trial_by_trial(self, signal_index: int) -> ( - List[neo.core.AnalogSignal]): + def get_analogsignals_trial_by_trial(self, signal_index: int + ) -> AnalogSignal: # Return a list of all analog signal repetitions across trials list_idx = self._analogsignal_index[signal_index] return [trial[list_idx] for trial in self.list_of_trials] From cba383d40ce93b47a15126353de44302461694b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Thu, 6 Nov 2025 10:10:10 +0100 Subject: [PATCH 38/40] Fixed IndentationError in doctests --- elephant/trials.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index 75e617927..e2bc7bdce 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -452,16 +452,16 @@ class TrialsFromBlock(Trials): >>> st_generator = StationaryPoissonProcess(rate=10*pq.Hz, t_stop=1*pq.s) >>> trial_block = neo.Block() >>> for _ in range(2): - >>> trial = neo.Segment() - >>> trial.spiketrains = st_generator.generate_n_spiketrains(3) - >>> signal = np.sin(np.arange(0, 6*np.pi, 2*np.pi/1000)) - >>> signal += np.random.normal(size=signal.shape) - >>> trial.analogsignals.append( + ... trial = neo.Segment() + ... trial.spiketrains = st_generator.generate_n_spiketrains(3) + ... signal = np.sin(np.arange(0, 6*np.pi, 2*np.pi/1000)) + ... signal += np.random.normal(size=signal.shape) + ... trial.analogsignals.append( ... neo.AnalogSignal(signal, units=pq.mV, ... t_stop=1*pq.s, ... sampling_rate=(1/3000)*pq.Hz) ... ) - >>> trial_block.segments.append(trial) + ... trial_block.segments.append(trial) >>> >>> trials = TrialsFromBlock(trial_block) @@ -587,9 +587,9 @@ class TrialsFromLists(Trials): >>> st_generator = StationaryPoissonProcess(rate=10*pq.Hz, t_stop=1*pq.s) >>> trial_list = [st_generator.generate_n_spiketrains(2) for _ in range(3)] >>> for trial in trial_list: - >>> signal = np.sin(np.arange(0, 6*np.pi, 2*np.pi/1000)) - >>> signal += np.random.normal(size=signal.shape) - >>> trial.append( + ... signal = np.sin(np.arange(0, 6*np.pi, 2*np.pi/1000)) + ... signal += np.random.normal(size=signal.shape) + ... trial.append( ... neo.AnalogSignal(signal, units=pq.mV, ... t_stop=1*pq.s, ... sampling_rate=(1/3000)*pq.Hz) From 585162e535d2940a971ae6515cd190fa4fe93519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Tue, 2 Dec 2025 11:27:41 +0100 Subject: [PATCH 39/40] Moved tutorial link in docstring to the branch on the trial tutorial --- elephant/trials.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/elephant/trials.py b/elephant/trials.py index e2bc7bdce..70c94a56b 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -35,16 +35,6 @@ TrialsFromBlock TrialsFromLists -Tutorial --------- -For a detailed example on the classes usage and trial handling for analyses -using Elephant, check the :doc:`tutorial <../tutorials/trials>`. - -Run tutorial interactively: - -.. image:: https://mybinder.org/badge.svg - :target: https://mybinder.org/v2/gh/NeuralEnsemble/elephant/master - ?filepath=doc/tutorials/trials.ipynb :copyright: Copyright 2014-2025 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. From 77cfdb852ce164eb9d2d86e44f0f9476062bf378 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristiano=20K=C3=B6hler?= Date: Tue, 2 Dec 2025 11:29:42 +0100 Subject: [PATCH 40/40] Fixed error in assertion that checked the same segment --- elephant/test/test_trials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index 31b0bc7e7..6a0a295f8 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -72,7 +72,7 @@ def assertSegmentEqual(self, segment_1, segment_2) -> None: self.assertIsInstance(segment_2, Segment) self.assertIs(segment_1, segment_2) self.assertEqual(segment_1.name, segment_2.name) - self.assertEqual(segment_2.description, segment_2.description) + self.assertEqual(segment_1.description, segment_2.description) self.assertDictEqual(segment_1.annotations, segment_2.annotations) self.assertSpikeTrainListEqual(segment_1.spiketrains, segment_2.spiketrains)