diff --git a/elephant/gpfa/gpfa.py b/elephant/gpfa/gpfa.py index 79d490e0d..5f39ae2f6 100644 --- a/elephant/gpfa/gpfa.py +++ b/elephant/gpfa/gpfa.py @@ -76,8 +76,7 @@ import sklearn from elephant.gpfa import gpfa_core, gpfa_util -from elephant.trials import Trials -from elephant.utils import trials_to_list_of_spiketrainlist +from elephant.trials import Trials, trials_to_list_of_spiketrainlist __all__ = ["GPFA"] diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py index b472e7a8e..6a0a295f8 100644 --- a/elephant/test/test_trials.py +++ b/elephant/test/test_trials.py @@ -1,35 +1,45 @@ # -*- coding: utf-8 -*- """ -Unit tests for the trials objects. +Unit tests for the objects of the API handling trial data in Elephant. -:copyright: Copyright 2014-2024 by the Elephant team, see AUTHORS.txt. +:copyright: Copyright 2014-2025 by the Elephant team, see AUTHORS.txt. :license: Modified BSD, see LICENSE.txt for details. """ import unittest -import neo.utils +import numpy as np import quantities as pq -from neo.core import AnalogSignal +from neo.core import Block, Segment, AnalogSignal, SpikeTrain +from neo.core.spiketrainlist import SpikeTrainList from elephant.spike_train_generation import StationaryPoissonProcess -from elephant.trials import TrialsFromBlock, TrialsFromLists +from elephant.trials import (TrialsFromBlock, TrialsFromLists, + trials_to_list_of_spiketrainlist) def _create_trials_block(n_trials: int = 0, n_spiketrains: int = 2, - n_analogsignals: int = 2) -> neo.core.Block: - """ Create block with n_trials, n_spiketrains and n_analog_signals """ - block = neo.Block(name='test_block') + n_analogsignals: int = 2) -> Block: + """ + Create Neo `Block` with `n_trials`, `n_spiketrains` and `n_analogsignals`. + """ + block = Block(name='test_block') for trial in range(n_trials): - segment = neo.Segment(name=f'No. {trial}') + segment = Segment(name=f'No. {trial}') spiketrains = StationaryPoissonProcess(rate=50. * pq.Hz, t_start=0 * pq.ms, t_stop=1000 * pq.ms ).generate_n_spiketrains( n_spiketrains=n_spiketrains) + for idx, st in enumerate(spiketrains): + st.name = f"Spiketrain {idx}" + st.description = f"Trial {trial}" + analogsignals = [AnalogSignal(signal=[.01, 3.3, 9.3], units='uV', - sampling_rate=1 * pq.Hz) - for _ in range(n_analogsignals)] + sampling_rate=1 * pq.Hz, + name=f"Signal {idx}", + description=f"Trial {trial}") + for idx in range(n_analogsignals)] for spiketrain in spiketrains: segment.spiketrains.append(spiketrain) for analogsignal in analogsignals: @@ -38,295 +48,723 @@ def _create_trials_block(n_trials: int = 0, return block -######### -# Tests # -######### - +########################## +# Tests - helper classes # +########################## + +class TrialsBaseTestCase(unittest.TestCase): + """ + This is a base `unitest.TestCase` class to act as a helper when + constructing the specific test cases for each implementation of `Trials`. + + This helper class facilitates comparing Neo objects, as custom assertions + are implemented to perform a series of tests to ensure two Neo objects are + indeed equal (e.g., checking metadata or contents of collections such as + spiketrains in a `Segment`). + + As the `Trials` objects are based on references to the input data + structures, checks for `Segment`, `SpikeTrain`, and `AnalogSignal` objects + are enforcing instance equality (i.e., `(a is b) == True`). + """ + + def assertSegmentEqual(self, segment_1, segment_2) -> None: + self.assertIsInstance(segment_1, Segment) + self.assertIsInstance(segment_2, Segment) + self.assertIs(segment_1, segment_2) + self.assertEqual(segment_1.name, segment_2.name) + self.assertEqual(segment_1.description, segment_2.description) + self.assertDictEqual(segment_1.annotations, segment_2.annotations) + self.assertSpikeTrainListEqual(segment_1.spiketrains, + segment_2.spiketrains) + self.assertAnalogSignalListEqual(segment_1.analogsignals, + segment_2.analogsignals) + + def assertSpikeTrainEqual(self, spiketrain_1, spiketrain_2) -> None: + self.assertIsInstance(spiketrain_1, SpikeTrain) + self.assertIsInstance(spiketrain_2, SpikeTrain) + self.assertIs(spiketrain_1, spiketrain_2) + self.assertTrue(np.all(spiketrain_1 == spiketrain_2)) + self.assertEqual(spiketrain_1.name, spiketrain_2.name) + self.assertEqual(spiketrain_1.description, spiketrain_2.description) + self.assertDictEqual(spiketrain_1.annotations, + spiketrain_2.annotations) + + def assertSpikeTrainListEqual(self, spiketrains_1, spiketrains_2) -> None: + self.assertIsInstance(spiketrains_1, SpikeTrainList) + self.assertIsInstance(spiketrains_2, SpikeTrainList) + self.assertEqual(len(spiketrains_1), len(spiketrains_2)) + for st1, st2 in zip(spiketrains_1, spiketrains_2): + self.assertSpikeTrainEqual(st1, st2) + + def assertAnalogSignalEqual(self, signal_1, signal_2) -> None: + self.assertIsInstance(signal_1, AnalogSignal) + self.assertIsInstance(signal_2, AnalogSignal) + self.assertIs(signal_1, signal_2) + self.assertTrue(np.all(signal_1 == signal_2)) + self.assertEqual(signal_1.name, signal_2.name) + self.assertEqual(signal_1.description, signal_2.description) + self.assertDictEqual(signal_1.annotations, signal_2.annotations) + + def assertAnalogSignalListEqual(self, signals_1, signals_2) -> None: + # Not enforcing object type as `Segment.analogsignals` are + # `ObjectList`, and some of the functions return pure Python lists + # containing the `AnalogSignal` objects. Therefore, the type checking + # must be done in each test case accordingly. + self.assertEqual(len(signals_1), len(signals_2)) + for signal_1, signal_2 in zip(signals_1, signals_2): + self.assertAnalogSignalEqual(signal_1, signal_2) + + +class DecoratorTest: + """ + This class is used as a mock for testing the decorator. + """ + @trials_to_list_of_spiketrainlist + def method_to_decorate(self, trials=None, trials_obj=None): + # This is just a mock implementation for testing purposes + if trials_obj: + return trials_obj + return trials + + +###################### +# Tests - test cases # +###################### + +class TestTrialsToListOfSpiketrainlist(TrialsBaseTestCase): -class TrialsFromBlockTestCase(unittest.TestCase): - """Tests for elephant.trials.TrialsFromBlock class""" + @classmethod + def setUpClass(cls): + cls.n_channels = 10 + cls.n_trials = 5 + cls.list_of_list_of_spiketrains = [ + StationaryPoissonProcess(rate=5 * pq.Hz, t_stop=1000.0 * pq.ms + ).generate_n_spiketrains(cls.n_channels) + for _ in range(cls.n_trials)] + cls.trial_object = TrialsFromLists(cls.list_of_list_of_spiketrains) + + def test_decorator_applied(self) -> None: + """ + Test that the decorator is applied correctly. + """ + self.assertTrue(hasattr( + DecoratorTest.method_to_decorate, '__wrapped__' + )) + + def test_decorator_return_with_trials_input_as_arg(self) -> None: + """ + Test if the decorator takes in a `Trials` object and returns a list of + `SpikeTrainList`. + """ + new_class = DecoratorTest() + list_of_spiketrainlists = new_class.method_to_decorate( + self.trial_object) + self.assertEqual(len(list_of_spiketrainlists), self.n_trials) + for spiketrainlist, expected_list in zip( + list_of_spiketrainlists, self.list_of_list_of_spiketrains): + self.assertSpikeTrainListEqual(spiketrainlist, + SpikeTrainList(expected_list)) + + def test_decorator_return_with_list_of_lists_input_as_arg(self) -> None: + """ + Test if the decorator takes in a list of lists of `SpikeTrain` and + does not change the input. + """ + new_class = DecoratorTest() + list_of_list_of_spiketrains = new_class.method_to_decorate( + self.list_of_list_of_spiketrains) + self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) + for list_of_spiketrains, expected_list in ( + zip(list_of_list_of_spiketrains, + self.list_of_list_of_spiketrains)): + self.assertIsInstance(list_of_spiketrains, list) + for spiketrain, expected_spiketrain in ( + zip(list_of_spiketrains, expected_list)): + self.assertSpikeTrainEqual(spiketrain, expected_spiketrain) + + def test_decorator_return_with_trials_input_as_kwarg(self) -> None: + """ + Test if the decorator takes in a `Trials` object and returns a list of + `SpikeTrainList` when passed as kwarg. + """ + new_class = DecoratorTest() + list_of_spiketrainlists = new_class.method_to_decorate( + trials_obj=self.trial_object) + self.assertEqual(len(list_of_spiketrainlists), self.n_trials) + for spiketrainlist, expected_list in zip( + list_of_spiketrainlists, self.list_of_list_of_spiketrains): + self.assertSpikeTrainListEqual(spiketrainlist, + SpikeTrainList(expected_list)) + + def test_decorator_return_with_list_of_lists_input_as_kwarg(self) -> None: + """ + Test if the decorator takes in a list of lists of `SpikeTrain`and does + not change the input if passed as a kwarg. + """ + new_class = DecoratorTest() + list_of_list_of_spiketrains = new_class.method_to_decorate( + trials_obj=self.list_of_list_of_spiketrains) + self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) + for list_of_spiketrains, expected_list in ( + zip(list_of_list_of_spiketrains, + self.list_of_list_of_spiketrains)): + self.assertIsInstance(list_of_spiketrains, list) + for spiketrain, expected_spiketrain in ( + zip(list_of_spiketrains, expected_list)): + self.assertSpikeTrainEqual(spiketrain, expected_spiketrain) + + +class TrialsFromBlockTestCase(TrialsBaseTestCase): + """ + Tests for :class:`elephant.trials.TrialsFromBlock`. + """ @classmethod def setUpClass(cls) -> None: - """ - Run once before tests: - """ - block = _create_trials_block(n_trials=36) cls.block = block cls.trial_object = TrialsFromBlock(block, - description='trials are segments') - - def setUp(self) -> None: - """ - Run before every test: - """ + description='trial is Segment') + + def test_deprecations(self) -> None: + """ + Test if all expected deprecation warnings are triggered. + """ + trial_object = self.trial_object + with self.assertWarns(DeprecationWarning): + trial_object.get_trial_as_segment(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_trials_as_block(trial_ids=[0, 1]) + with self.assertWarns(DeprecationWarning): + trial_object.get_trials_as_list(trial_ids=[0, 1]) + with self.assertWarns(DeprecationWarning): + trial_object.get_spiketrains_from_trial_as_list(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_spiketrains_from_trial_as_segment(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_analogsignals_from_trial_as_list(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_analogsignals_from_trial_as_segment(trial_id=0) def test_trials_from_block_description(self) -> None: """ - Test description of the trials object. + Test the description of the `Trials` object. """ - self.assertEqual(self.trial_object.description, 'trials are segments') + self.assertEqual(self.trial_object.description, 'trial is Segment') def test_trials_from_block_get_item(self) -> None: """ - Test get a trial from the trials. + Test to get a single trial from the `Trials` object using indexing + with brackets. Return is a `Segment`. """ - self.assertIsInstance(self.trial_object[0], neo.core.Segment) + for trial_index in range(36): + trial_segment = self.trial_object[trial_index] + expected = self.block.segments[trial_index] + self.assertSegmentEqual(trial_segment, expected) def test_trials_from_block_get_trial_as_segment(self) -> None: """ - Test get a trial from the trials. + Test to get a single trial from the `Trials` object as a `Segment`. + """ + for trial_index in range(36): + trial_segment = self.trial_object.get_trial_as_segment(trial_index) + expected = self.block.segments[trial_index] + self.assertSegmentEqual(trial_segment, expected) + + def test_trials_from_block_get_trials_as_block_indexes(self) -> None: + """ + Test to get a set of specific trials grouped as a `Block`. Each trial + is a `Segment` containing all the data in the trial. """ - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0), - neo.core.Segment) - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0).spiketrains[0], - neo.core.SpikeTrain) - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0).analogsignals[0], - neo.core.AnalogSignal) + trial_block = self.trial_object.get_trials_as_block([0, 3, 5]) + self.assertIsInstance(trial_block, Block) + self.assertEqual(len(trial_block.segments), 3) + self.assertSegmentEqual(trial_block.segments[0], + self.block.segments[0]) + self.assertSegmentEqual(trial_block.segments[1], + self.block.segments[3]) + self.assertSegmentEqual(trial_block.segments[2], + self.block.segments[5]) def test_trials_from_block_get_trials_as_block(self) -> None: """ - Test get a block from list of trials. + Test to get all trials grouped as a `Block`, where each trial is a + single `Segment`. """ - block = self.trial_object.get_trials_as_block([0, 3, 5]) - self.assertIsInstance(block, neo.core.Block) - self.assertIsInstance(self.trial_object.get_trials_as_block(), - neo.core.Block) - self.assertEqual(len(block.segments), 3) + trial_block = self.trial_object.get_trials_as_block() + self.assertIsInstance(trial_block, Block) + self.assertEqual(len(trial_block.segments), 36) + for trial, expected_trial in zip(trial_block.segments, + self.block.segments): + self.assertSegmentEqual(trial, expected_trial) - def test_trials_from_block_get_trials_as_list(self) -> None: + def test_trials_from_block_get_trials_as_list_indexes(self) -> None: """ - Test get a list of segments from list of trials. + Test to get a set of specific trials grouped as list of `Segment`. + Each trial is a single `Segment` containing all the data in the trial. """ - list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) + list_of_trials = self.trial_object.get_trials_as_list([0, 5, 7]) self.assertIsInstance(list_of_trials, list) - self.assertIsInstance(self.trial_object.get_trials_as_list(), list) - self.assertIsInstance(list_of_trials[0], neo.core.Segment) self.assertEqual(len(list_of_trials), 3) + self.assertSegmentEqual(list_of_trials[0], + self.block.segments[0]) + self.assertSegmentEqual(list_of_trials[1], + self.block.segments[5]) + self.assertSegmentEqual(list_of_trials[2], + self.block.segments[7]) + + def test_trials_from_block_get_trials_as_list(self) -> None: + """ + Test to get all the trials grouped as a list of `Segment`. Each trial + is a single `Segment` containing all the data in the trial. + """ + list_of_trials = self.trial_object.get_trials_as_list() + self.assertIsInstance(list_of_trials, list) + self.assertEqual(len(list_of_trials), 36) + for trial, expected_trial in zip(list_of_trials, + self.block.segments): + self.assertSegmentEqual(trial, expected_trial) def test_trials_from_block_n_trials(self) -> None: """ - Test get number of trials. + Test to get the number of trials. """ - self.assertEqual(self.trial_object.n_trials, len(self.block.segments)) + self.assertEqual(self.trial_object.n_trials, 36) def test_trials_from_block_n_spiketrains_trial_by_trial(self) -> None: """ - Test get number of spiketrains per trial. + Test to get the number of `SpikeTrain` objects per trial. """ self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, - [len(trial.spiketrains) for trial in - self.block.segments]) + [2] * 36) def test_trials_from_block_n_analogsignals_trial_by_trial(self) -> None: """ - Test get number of analogsignals per trial. + Test to get the number of `AnalogSignal` objects per trial. """ self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, - [len(trial.analogsignals) for trial in - self.block.segments]) + [2] * 36) def test_trials_from_block_get_spiketrains_from_trial_as_list(self ) -> None: """ - Test get spiketrains from trial as list - """ - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_list(0), - neo.core.spiketrainlist.SpikeTrainList) - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_list(0)[0], - neo.core.SpikeTrain) - - def test_trials_from_list_get_spiketrains_from_trial_as_segment(self - ) -> None: - """ - Test get spiketrains from trial as segment - """ - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment(0), - neo.core.Segment) - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment( - 0).spiketrains[0], neo.core.SpikeTrain) + Test to get all spiketrains from a single trial as a `SpikeTrainList`. + """ + for trial_index in range(36): + trial_spiketrains = ( + self.trial_object.get_spiketrains_from_trial_as_list( + trial_index + ) + ) + expected = self.block.segments[trial_index].spiketrains + self.assertSpikeTrainListEqual(trial_spiketrains, expected) + + def test_trials_from_block_get_spiketrains_from_trial_as_segment(self + ) -> None: + """ + Test to get the all spiketrains from a single trial as a `Segment`. + The `Segment.spiketrains` collection contains the spiketrains. + """ + for trial_index in range(36): + trial_spiketrains = ( + self.trial_object.get_spiketrains_from_trial_as_segment( + trial_index + ) + ) + expected = self.block.segments[trial_index].spiketrains + self.assertIsInstance(trial_spiketrains, Segment) + self.assertEqual(len(trial_spiketrains.analogsignals), 0) + self.assertSpikeTrainListEqual(trial_spiketrains.spiketrains, + expected) def test_trials_from_block_get_analogsignals_from_trial_as_list(self ) -> None: """ - Test get analogsignals from trial as list + Test to get all analog signals from a single trial as a list. """ - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0), list) - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0)[0], - neo.core.AnalogSignal) + for trial_index in range(36): + trial_signals = ( + self.trial_object.get_analogsignals_from_trial_as_list( + trial_index + ) + ) + expected = self.block.segments[trial_index].analogsignals + self.assertIsInstance(trial_signals, list) + self.assertAnalogSignalListEqual(trial_signals, expected) - def test_trials_from_list_get_analogsignals_from_trial_as_segment(self) \ + def test_trials_from_block_get_analogsignals_from_trial_as_segment(self) \ -> None: """ - Test get spiketrains from trial as segment - """ - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment(0), - neo.core.Segment) - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment( - 0).analogsignals[0], neo.core.AnalogSignal) - - -class TrialsFromListTestCase(unittest.TestCase): - """Tests for elephant.trials.TrialsFromList class""" + Test to get all analog signals from a single trial as a `Segment`. + The `Segment.analogsignals` collection contains the signals. + """ + for trial_index in range(36): + trial_signals = ( + self.trial_object.get_analogsignals_from_trial_as_segment( + trial_index + ) + ) + expected = self.block.segments[trial_index].analogsignals + self.assertIsInstance(trial_signals, Segment) + self.assertEqual(len(trial_signals.spiketrains), 0) + self.assertAnalogSignalListEqual(trial_signals.analogsignals, + expected) + + def test_trials_from_block_get_spiketrains_trial_by_trial(self) -> None: + """ + Test to access all the `SpikeTrain` objects corresponding to the + repetitions of a spiketrain across the trials. + """ + for st_id in (0, 1): + spiketrains = self.trial_object.get_spiketrains_trial_by_trial( + st_id) + expected_spiketrains = [trial.spiketrains[st_id] + for trial in self.block.segments] + self.assertEqual(len(spiketrains), 36) + self.assertSpikeTrainListEqual(spiketrains, + SpikeTrainList(expected_spiketrains)) + + # Each trial-specific `SpikeTrain` object is from the same spiketrain + self.assertTrue(all([st.name == f"Spiketrain {st_id}" + for st in spiketrains] + ) + ) + + # Order of spiketrains is the order of the trials + expected_trials = [f"Trial {i}" for i in range(36)] + self.assertListEqual([st.description for st in spiketrains], + expected_trials) + + def test_trials_from_block_get_analogsignals_trial_by_trial(self) -> None: + """ + Test to access all the `AnalogSignal` objects corresponding to the + repetitions of an analog signal across the trials. + """ + for as_id in (0, 1): + signals = self.trial_object.get_analogsignals_trial_by_trial(as_id) + expected_signals = [trial.analogsignals[as_id] + for trial in self.block.segments] + self.assertEqual(len(signals), 36) + self.assertIsInstance(signals, list) + self.assertAnalogSignalListEqual(signals, expected_signals) + + # Each trial-specific `AnalogSignal` object is from the same signal + self.assertTrue(all([signal.name == f"Signal {as_id}" + for signal in signals] + ) + ) + + # Order in the list is the order of the trials + expected_trials = [f"Trial {i}" for i in range(36)] + self.assertListEqual([signal.description for signal in signals], + expected_trials) + + +class TrialsFromListTestCase(TrialsBaseTestCase): + """ + Tests for :class:`elephant.trials.TrialsFromList`. + """ @classmethod def setUpClass(cls) -> None: - """ - Run once before tests: - Download the dataset from elephant_data - """ - block = _create_trials_block(n_trials=36) - - # Create Trialobject as list of lists - # add spiketrains + cls.n_spiketrains = 2 + cls.n_analogsignals = 3 + block = _create_trials_block(n_trials=36, + n_spiketrains=cls.n_spiketrains, + n_analogsignals=cls.n_analogsignals) + + # Create trial data as list of lists + # 1. Add spiketrains trial_list = [[spiketrain for spiketrain in trial.spiketrains] for trial in block.segments] - # add analogsignals + # 2. Add analog signals for idx, trial in enumerate(block.segments): for analogsignal in trial.analogsignals: trial_list[idx].append(analogsignal) cls.trial_list = trial_list + # Create TrialsFromLists object cls.trial_object = TrialsFromLists(trial_list, description='trial is a list') - def setUp(self) -> None: - """ - Run before every test: - """ + def assertSegmentEqualToList(self, segment, list_data, n_spiketrains, + n_analogsignals): + """ + This function compares trial data in a Segment to trial data in + a Python list. The order of objects is: `SpikeTrain`, `AnalogSignal`. + The number of spiketrains and analog signals must be informed to split + the data in the list. + """ + self.assertIsInstance(segment, Segment) + self.assertIsInstance(list_data, list) + + self.assertEqual(len(list_data), n_spiketrains + n_analogsignals) + self.assertEqual(len(segment.spiketrains), n_spiketrains) + self.assertEqual(len(segment.analogsignals), n_analogsignals) + + spiketrains = list_data[:n_spiketrains] + signals = list_data[n_spiketrains:] + self.assertSpikeTrainListEqual(segment.spiketrains, + SpikeTrainList(spiketrains)) + self.assertAnalogSignalListEqual(segment.analogsignals, signals) + + def test_deprecations(self) -> None: + """ + Test if all expected deprecation warnings are triggered. + """ + trial_object = self.trial_object + with self.assertWarns(DeprecationWarning): + trial_object.get_trial_as_segment(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_trials_as_block(trial_ids=[0, 1]) + with self.assertWarns(DeprecationWarning): + trial_object.get_trials_as_list(trial_ids=[0, 1]) + with self.assertWarns(DeprecationWarning): + trial_object.get_spiketrains_from_trial_as_list(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_spiketrains_from_trial_as_segment(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_analogsignals_from_trial_as_list(trial_id=0) + with self.assertWarns(DeprecationWarning): + trial_object.get_analogsignals_from_trial_as_segment(trial_id=0) def test_trials_from_list_description(self) -> None: """ - Test description of the trials object. + Test the description of the `Trials` object. """ self.assertEqual(self.trial_object.description, 'trial is a list') def test_trials_from_list_get_item(self) -> None: """ - Test get a trial from the trials. + Test to get a single trial from the `Trials` object using indexing + with brackets. Return is a `Segment`. """ - self.assertIsInstance(self.trial_object[0], - neo.core.Segment) - self.assertIsInstance(self.trial_object[0].spiketrains[0], - neo.core.SpikeTrain) - self.assertIsInstance(self.trial_object[0].analogsignals[0], - neo.core.AnalogSignal) + for trial_index in range(36): + trial_segment = self.trial_object[trial_index] + expected = self.trial_list[trial_index] + self.assertSegmentEqualToList(trial_segment, + expected, + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) def test_trials_from_list_get_trial_as_segment(self) -> None: """ - Test get a trial from the trials. - """ - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0), - neo.core.Segment) - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0).spiketrains[0], - neo.core.SpikeTrain) - self.assertIsInstance( - self.trial_object.get_trial_as_segment(0).analogsignals[0], - neo.core.AnalogSignal) + Test to get a single trial from the `Trials` object as a `Segment`. + """ + for trial_index in range(36): + trial_segment = self.trial_object[trial_index] + expected = self.trial_list[trial_index] + self.assertSegmentEqualToList(trial_segment, + expected, + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + + def test_trials_from_list_get_trials_as_block_indexes(self) -> None: + """ + Test to get a set of specific trials grouped as a `Block`. Each trial + is a `Segment` containing all the data in the trial. + """ + trial_block = self.trial_object.get_trials_as_block([1, 6, 18]) + self.assertIsInstance(trial_block, Block) + self.assertEqual(len(trial_block.segments), 3) + self.assertSegmentEqualToList(trial_block.segments[0], + self.trial_list[1], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + self.assertSegmentEqualToList(trial_block.segments[1], + self.trial_list[6], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + self.assertSegmentEqualToList(trial_block.segments[2], + self.trial_list[18], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) def test_trials_from_list_get_trials_as_block(self) -> None: """ - Test get a block from list of trials. + Test to get all trials grouped as a `Block`, where each trial is a + single `Segment`. """ - block = self.trial_object.get_trials_as_block([0, 3, 5]) - self.assertIsInstance(block, neo.core.Block) - self.assertIsInstance(self.trial_object.get_trials_as_block(), - neo.core.Block) - self.assertEqual(len(block.segments), 3) + trial_block = self.trial_object.get_trials_as_block() + self.assertIsInstance(trial_block, Block) + self.assertEqual(len(trial_block.segments), 36) + for trial, expected_trial in zip(trial_block.segments, + self.trial_list): + self.assertSegmentEqualToList(trial, + expected_trial, + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) - def test_trials_from_list_get_trials_as_list(self) -> None: + def test_trials_from_list_get_trials_as_list_indexes(self) -> None: """ - Test get a list of segments from list of trials. + Test to get a set of specific trials grouped as list of `Segment`. + Each trial is a single `Segment` containing all the data in the trial. """ list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) self.assertIsInstance(list_of_trials, list) - self.assertIsInstance(self.trial_object.get_trials_as_list(), list) - self.assertIsInstance(list_of_trials[0], neo.core.Segment) self.assertEqual(len(list_of_trials), 3) + self.assertSegmentEqualToList(list_of_trials[0], + self.trial_list[0], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + self.assertSegmentEqualToList(list_of_trials[1], + self.trial_list[3], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + self.assertSegmentEqualToList(list_of_trials[2], + self.trial_list[5], + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) + + def test_trials_from_list_get_trials_as_list(self) -> None: + """ + Test to get all the trials grouped as a list of `Segment`. Each trial + is a single `Segment` containing all the data in the trial. + """ + list_of_trials = self.trial_object.get_trials_as_list() + self.assertIsInstance(list_of_trials, list) + self.assertEqual(len(list_of_trials), 36) + for trial, expected_trial in zip(list_of_trials, self.trial_list): + self.assertSegmentEqualToList(trial, + expected_trial, + n_spiketrains=self.n_spiketrains, + n_analogsignals=self.n_analogsignals) def test_trials_from_list_n_trials(self) -> None: """ - Test get number of trials. + Test to get the number of trials. """ - self.assertEqual(self.trial_object.n_trials, len(self.trial_list)) + self.assertEqual(self.trial_object.n_trials, 36) def test_trials_from_list_n_spiketrains_trial_by_trial(self) -> None: """ - Test get number of spiketrains per trial. + Test to get the number of `SpikeTrain` objects per trial. """ self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, - [sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), - trial)) for trial in self.trial_list]) + [self.n_spiketrains] * 36) def test_trials_from_list_n_analogsignals_trial_by_trial(self) -> None: """ - Test get number of analogsignals per trial. + Test to get the number of `AnalogSignal` objects per trial. """ self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, - [sum(map(lambda x: isinstance(x, - neo.core.AnalogSignal), - trial)) for trial in self.trial_list]) + [self.n_analogsignals] * 36) def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: """ - Test get spiketrains from trial as list + Test to get all spiketrains from a single trial as a `SpikeTrainList`. """ - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_list(0), - neo.core.spiketrainlist.SpikeTrainList) - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_list(0)[0], - neo.core.SpikeTrain) + for trial_index in range(36): + trial_spiketrains = ( + self.trial_object.get_spiketrains_from_trial_as_list( + trial_index + ) + ) + expected = self.trial_list[trial_index][:self.n_spiketrains] + self.assertSpikeTrainListEqual(trial_spiketrains, + SpikeTrainList(expected)) def test_trials_from_list_get_spiketrains_from_trial_as_segment(self ) -> None: """ - Test get spiketrains from trial as segment - """ - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment(0), - neo.core.Segment) - self.assertIsInstance( - self.trial_object.get_spiketrains_from_trial_as_segment( - 0).spiketrains[0], neo.core.SpikeTrain) + Test to get the all spiketrains from a single trial as a `Segment`. + The `Segment.spiketrains` collection contains the spiketrains. + """ + for trial_index in range(36): + trial_spiketrains = ( + self.trial_object.get_spiketrains_from_trial_as_segment( + trial_index + ) + ) + expected = self.trial_list[trial_index][:self.n_spiketrains] + self.assertIsInstance(trial_spiketrains, Segment) + self.assertEqual(len(trial_spiketrains.analogsignals), 0) + self.assertSpikeTrainListEqual(trial_spiketrains.spiketrains, + SpikeTrainList(expected)) def test_trials_from_list_get_analogsignals_from_trial_as_list(self ) -> None: """ - Test get analogsignals from trial as list + Test to get all analog signals from a single trial as a list. """ - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0), list) - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_list(0)[0], - neo.core.AnalogSignal) + for trial_index in range(36): + trial_signals = ( + self.trial_object.get_analogsignals_from_trial_as_list( + trial_index + ) + ) + expected = self.trial_list[trial_index][self.n_spiketrains:] + self.assertIsInstance(trial_signals, list) + self.assertAnalogSignalListEqual(trial_signals, expected) def test_trials_from_list_get_analogsignals_from_trial_as_segment(self ) \ -> None: """ - Test get spiketrains from trial as segment - """ - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment(0), - neo.core.Segment) - self.assertIsInstance( - self.trial_object.get_analogsignals_from_trial_as_segment( - 0).analogsignals[0], neo.core.AnalogSignal) + Test to get all analog signals from a single trial as a `Segment`. + The `Segment.analogsignals` collection contains the signals. + """ + for trial_index in range(36): + trial_signals = ( + self.trial_object.get_analogsignals_from_trial_as_segment( + trial_index + ) + ) + expected = self.trial_list[trial_index][self.n_spiketrains:] + self.assertIsInstance(trial_signals, Segment) + self.assertEqual(len(trial_signals.spiketrains), 0) + self.assertAnalogSignalListEqual(trial_signals.analogsignals, + expected) + + def test_trials_from_list_get_spiketrains_trial_by_trial(self) -> None: + """ + Test to access all the `SpikeTrain` objects corresponding to the + repetitions of a spiketrain across the trials. + """ + for st_id in (0, 1): + spiketrains = self.trial_object.get_spiketrains_trial_by_trial( + st_id) + expected_spiketrains = [trial[:self.n_spiketrains][st_id] + for trial in self.trial_list] + self.assertEqual(len(spiketrains), 36) + self.assertSpikeTrainListEqual(spiketrains, + SpikeTrainList(expected_spiketrains)) + + # Each trial-specific `SpikeTrain` object is from the same spiketrain + self.assertTrue(all([st.name == f"Spiketrain {st_id}" + for st in spiketrains] + ) + ) + + # Order of spiketrains is the order of the trials + expected_trials = [f"Trial {i}" for i in range(36)] + self.assertListEqual([st.description for st in spiketrains], + expected_trials) + + def test_trials_from_list_get_analogsignals_trial_by_trial(self) -> None: + """ + Test to access all the `AnalogSignal` objects corresponding to the + repetitions of an analog signal across the trials. + """ + for as_id in (0, 1): + signals = self.trial_object.get_analogsignals_trial_by_trial(as_id) + expected_signals = [trial[self.n_spiketrains:][as_id] + for trial in self.trial_list] + self.assertEqual(len(signals), 36) + self.assertIsInstance(signals, list) + self.assertAnalogSignalListEqual(signals, expected_signals) + + # Each trial-specific `AnalogSignal` object is from the same signal + self.assertTrue(all([signal.name == f"Signal {as_id}" + for signal in signals] + ) + ) + + # Order in the list is the order of the trials + expected_trials = [f"Trial {i}" + for i in range(self.trial_object.n_trials)] + self.assertListEqual([signal.description for signal in signals], + expected_trials) if __name__ == '__main__': diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py index cc927d53e..ef935f335 100644 --- a/elephant/test/test_utils.py +++ b/elephant/test/test_utils.py @@ -12,11 +12,6 @@ from elephant import utils from numpy.testing import assert_array_equal -from elephant.spike_train_generation import StationaryPoissonProcess -from elephant.trials import TrialsFromLists -from neo.core.spiketrainlist import SpikeTrainList -from neo.core import SpikeTrain - class TestUtils(unittest.TestCase): @@ -61,80 +56,5 @@ def test_round_binning_errors(self): [0, 0, 0, 0]) -class DecoratorTest: - """ - This class is used as a mock for testing the decorator. - """ - @utils.trials_to_list_of_spiketrainlist - def method_to_decorate(self, trials=None, trials_obj=None): - # This is just a mock implementation for testing purposes - if trials_obj: - return trials_obj - else: - return trials - - -class TestTrialsToListOfSpiketrainlist(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.n_channels = 10 - cls.n_trials = 5 - cls.list_of_list_of_spiketrains = [ - StationaryPoissonProcess(rate=5 * pq.Hz, t_stop=1000.0 * pq.ms - ).generate_n_spiketrains(cls.n_channels) - for _ in range(cls.n_trials)] - cls.trial_object = TrialsFromLists(cls.list_of_list_of_spiketrains) - - def test_decorator_applied(self): - # Test that the decorator is applied correctly - self.assertTrue(hasattr( - DecoratorTest.method_to_decorate, '__wrapped__' - )) - - def test_decorator_return_with_trials_input_as_arg(self): - # Test if decorator takes in trial-object and returns - # list of spiketrainlists - new_class = DecoratorTest() - list_of_spiketrainlists = new_class.method_to_decorate( - self.trial_object) - self.assertEqual(len(list_of_spiketrainlists), self.n_trials) - for spiketrainlist in list_of_spiketrainlists: - self.assertIsInstance(spiketrainlist, SpikeTrainList) - - def test_decorator_return_with_list_of_lists_input_as_arg(self): - # Test if decorator takes in list of lists of spiketrains - # and does not change input - new_class = DecoratorTest() - list_of_list_of_spiketrains = new_class.method_to_decorate( - self.list_of_list_of_spiketrains) - self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) - for list_of_spiketrains in list_of_list_of_spiketrains: - self.assertIsInstance(list_of_spiketrains, list) - for spiketrain in list_of_spiketrains: - self.assertIsInstance(spiketrain, SpikeTrain) - - def test_decorator_return_with_trials_input_as_kwarg(self): - # Test if decorator takes in trial-object and returns - # list of spiketrainlists - new_class = DecoratorTest() - list_of_spiketrainlists = new_class.method_to_decorate( - trials_obj=self.trial_object) - self.assertEqual(len(list_of_spiketrainlists), self.n_trials) - for spiketrainlist in list_of_spiketrainlists: - self.assertIsInstance(spiketrainlist, SpikeTrainList) - - def test_decorator_return_with_list_of_lists_input_as_kwarg(self): - # Test if decorator takes in list of lists of spiketrains - # and does not change input - new_class = DecoratorTest() - list_of_list_of_spiketrains = new_class.method_to_decorate( - trials_obj=self.list_of_list_of_spiketrains) - self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) - for list_of_spiketrains in list_of_list_of_spiketrains: - self.assertIsInstance(list_of_spiketrains, list) - for spiketrain in list_of_spiketrains: - self.assertIsInstance(spiketrain, SpikeTrain) - - if __name__ == '__main__': unittest.main() diff --git a/elephant/trials.py b/elephant/trials.py index cd006addd..70c94a56b 100644 --- a/elephant/trials.py +++ b/elephant/trials.py @@ -2,29 +2,30 @@ This module defines the basic classes that represent trials in Elephant. Many neuroscience methods rely on the concept of repeated trials to improve the -estimate of quantities measured from the data. In the simplest case, results -from multiple trials are averaged, in other scenarios more intricate steps must -be taken in order to pool information from each repetition of a trial. Typically, -trials are considered as fixed time intervals tied to a specific event in the -experiment, such as the onset of a stimulus. - -Neo does not impose a specific way in which trials are to be represented. A -natural way to represent trials is to have a :class:`neo.Block` containing multiple -:class:`neo.Segment` objects, each representing the data of one trial. Another popular -option is to store trials as lists of lists, where the outer refers to -individual lists, and inner lists contain Neo data objects (:class:`neo.SpikeTrain` -and :class:`neo.AnalogSignal` containing individual data of each trial. - -The classes of this module abstract from these individual data representations -by introducing a set of :class:`Trials` classes with a common API. These classes -are initialized by a supported way of structuring trials, e.g., -:class:`TrialsFromBlock` for the first method described above. Internally, +estimate of quantities measured from the data. Typically, trials are +considered as fixed time intervals tied to a specific event in the experiment, +such as the onset of a stimulus. In the simplest case, results from multiple +trials are averaged. In other scenarios, more intricate steps must be taken +in order to pool information from each repetition of a trial. + +Neo does not impose a specific way to represent trial data. A natural way to +represent trials is to have a :class:`neo.Block` containing multiple +:class:`neo.Segment` objects, each representing the data of one trial. Another +popular option is to store trials as a list of lists, where the outer refers to +the collection of trials, and inner lists contain Neo data objects +(:class:`neo.SpikeTrain` and :class:`neo.AnalogSignal`) containing the +individual data of each trial. + +The classes of this module abstract from these specific data representations +by introducing a set of :class:`Trials` classes with a common API. These +classes are initialized by a supported way of structuring trials, e.g., +:class:`TrialsFromBlock` for the first method described above. Internally, the :class:`Trials` class will not convert this representation, but provide access to data in specific trials (e.g., all spike trains in trial 5) or general -information about the trial structure (e.g., how many trials are there?) via a -fixed API. Trials are consecutively numbered, starting at a trial ID of 0. +information about the trial structure (e.g., how many trials are there?) via a +fixed API. Trials are indexed consecutively starting from 0. -In the release, the classes :class:`TrialsFromBlock` and +In the current implementation, classes :class:`TrialsFromBlock` and :class:`TrialsFromLists` provide this unified way to access trial data. .. autosummary:: @@ -34,394 +35,643 @@ TrialsFromBlock TrialsFromLists -:copyright: Copyright 2014-2024 by the Elephant team, see `doc/authors.rst`. + +:copyright: Copyright 2014-2025 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from typing import List -import neo.utils -from neo.core import Segment, Block +from functools import wraps +import numpy as np + +import neo +from neo.core import Segment, Block, SpikeTrain, AnalogSignal from neo.core.spiketrainlist import SpikeTrainList +from elephant.utils import deprecated_alias -class Trials: +def trials_to_list_of_spiketrainlist(function): """ - Base class for handling trials. - - This is the base class from which all trial objects inherit. - This class implements support for universally recommended arguments. + Decorator that converts each argument passed as a :class:`Trials` object + into a list of :class:`neo.SpikeTrainList` before calling the wrapped + function. Parameters ---------- - description : string, optional - A textual description of the set of trials. Can be accessed via the - class attribute `description`. - Default: None. + function: callable + The function to be decorated. + Returns + ------- + callable + The decorated function. + + Examples + -------- + The decorator can be used as follows: + + >>> @trials_to_list_of_spiketrainlist + ... def process_data(self, spiketrains): + ... return None """ - __metaclass__ = ABCMeta + @wraps(function) + def wrapper(*args, **kwargs): + new_args = tuple( + [ + arg.get_spiketrains_from_trial_as_list(idx) + for idx in range(arg.n_trials) + ] + if isinstance(arg, Trials) + else arg + for arg in args + ) + new_kwargs = { + key: ( + [ + value.get_spiketrains_from_trial_as_list(idx) + for idx in range(value.n_trials) + ] + if isinstance(value, Trials) + else value + ) + for key, value in kwargs.items() + } + + return function(*new_args, **new_kwargs) + + return wrapper + + +class Trials(ABC): + """ + Abstract base class for handling trial-based data in Elephant. - def __init__(self, description: str = "Trials"): - """Create an instance of the trials class.""" + The `Trials` class defines a standardized interface for accessing and + manipulating trial data. It provides a unified set of attributes and + methods for trial handling, and serves as the base class for all + data-structure-specific implementations. + + Child classes such as :class:`TrialsFromBlock` and :class:`TrialsFromLists` + support specific input data structures. Usage details and examples are + provided in their respective documentation. + + Parameters + ---------- + description : str, optional + Textual description of the set of trials, accessible via the + `description` attribute. + Default: None + + See Also + -------- + :class:`TrialsFromBlock` + :class:`TrialsFromLists` + """ + + def __init__(self, description: str = None): self.description = description @abstractmethod - def __getitem__(self, trial_number: int) -> neo.core.Segment: - """Get a specific trial by number.""" - pass + def __getitem__(self, trial_index: int) -> Segment: + # Get a specific trial by its index as a Segment + raise NotImplementedError + @property @abstractmethod def n_trials(self) -> int: - """Get the number of trials. + """ + Number of trials. Returns ------- - int: Number of trials + int + Total number of trials. """ - pass + raise NotImplementedError @abstractmethod def n_spiketrains_trial_by_trial(self) -> List[int]: - """Get the number of spike trains in each trial as a list. + """ + Number of spike trains per trial. Returns ------- - list of int: For each trial, contains the number of spike trains in the - trial. + List[int] + Number of :class:`neo.SpikeTrain` objects in each trial, ordered by + trial index in ascending order starting from zero. """ - pass + raise NotImplementedError @abstractmethod def n_analogsignals_trial_by_trial(self) -> List[int]: - """Get the number of analogsignal objects in each trial as a list. + """ + Number of analog signals per trial. Returns ------- - list of int: For each trial, contains the number of analogsignal objects - in the trial. + List[int] + Number of :class:`neo.AnalogSignal` objects in each trial, ordered + by trial index in ascending order starting from zero. """ - pass + raise NotImplementedError - @abstractmethod - def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: - """Get trial as segment. + @deprecated_alias(trial_id="trial_index") + def get_trial_as_segment(self, trial_index: int) -> Segment: + """ + Return a single trial as a :class:`neo.Segment`. Parameters ---------- - trial_id : int - Trial number to get (starting at trial ID 0). + trial_index : int + Zero-based index of the trial to retrieve. Returns ------- - class:`neo.Segment`: a segment containing all spike trains and - analogsignal objects of the trial. + :class:`neo.Segment` + Segment containing all spike trains and analog signals associated + with the specified trial. Spike trains and analog signals are + accessed via the `spiketrains` and `analogsignals` attributes, + respectively. Their order corresponds to their index within these + collections (e.g., `spiketrains[0]` is the first spike train). """ - pass + return self.__getitem__(trial_index) - @abstractmethod - def get_trials_as_block(self, trial_ids: List[int] = None - ) -> neo.core.Block: - """Get trials as block. + @deprecated_alias(trial_ids="trial_indexes") + def get_trials_as_block(self, trial_indexes: List[int] = None) -> Block: + """ + Return multiple trials grouped into a :class:`neo.Block`. Parameters ---------- - trial_ids : list of int - Trial IDs to include in the Block (starting at trial ID 0). - If None is specified, all trials are returned. + trial_indexes : List[int], optional + Zero-based indices of the trials to include in the block. + If None, all trials are returned. Default: None Returns ------- - class:`neo.Block`: a Block containing :class:`neo.Segment` objects for - each of the selected trials, each containing all spike trains and - analogsignal objects of the corresponding trial. + :class:`neo.Block` + Block containing one :class:`neo.Segment` per trial. The trials are + accessed via the `segments` attribute. If all trials are included, + element indices correspond to trial indices. If a subset is + specified, the order matches that of `trial_indexes`. + + See Also + -------- + :method:`get_trial_as_segment()` """ - pass + block = Block() + if not trial_indexes: + trial_indexes = list(range(self.n_trials)) + for trial_index in trial_indexes: + block.segments.append(self.get_trial_as_segment(trial_index)) + return block - @abstractmethod - def get_trials_as_list(self, trial_ids: List[int] = None - ) -> neo.core.spiketrainlist.SpikeTrainList: - """Get trials as list of segments. + @deprecated_alias(trial_ids="trial_indexes") + def get_trials_as_list(self, trial_indexes: List[int] = None + ) -> List[Segment]: + """ + Return multiple trials as a list of :class:`neo.Segment` objects. Parameters ---------- - trial_ids : list of int - Trial IDs to include in the list (starting at trial ID 0). - If None is specified, all trials are returned. + trial_indexes : List[int], optional + Zero-based indices of the trials to include in the list. + If None, all trials are returned. Default: None Returns ------- - list of :class:`neo.Segment`: a list containing :class:`neo.Segment` - objects for each of the selected trials, each containing all spike - trains and analogsignal objects of the corresponding trial. + List[Segment] + List containing one :class:`neo.Segment` per selected trial. If all + trials are returned, list indices correspond to trial indices. + If a subset is specified, the order matches that of + `trial_indexes`. """ - pass + if not trial_indexes: + trial_indexes = list(range(self.n_trials)) + return [self.get_trial_as_segment(trial_index) + for trial_index in trial_indexes] @abstractmethod - def get_spiketrains_from_trial_as_list(self, trial_id: int) -> ( - neo.core.spiketrainlist.SpikeTrainList): + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_list(self, trial_index: int + ) -> SpikeTrainList: + """ + Return all spike trains from a single trial as a list. + + Parameters + ---------- + trial_index : int + Zero-based index of the trial. + + Returns + ------- + :class:`neo.SpikeTrainList` + List-like container with all :class:`neo.SpikeTrain` objects from + the specified trial. """ - Get all spike trains from a specific trial and return a list. + raise NotImplementedError + + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_segment(self, trial_index: int + ) -> Segment: + """ + Return all spike trains from a single trial as a :class:`neo.Segment`. Parameters ---------- - trial_id : int - Trial ID to get the spike trains from (starting at trial ID 0). + trial_index : int + Zero-based index of the trial. Returns ------- - list of :class:`neo.SpikeTrain` - List of all spike trains of the trial. + :class:`neo.Segment` + Segment containing all :class:`neo.SpikeTrain` objects from the + specified trial, accessible via the `spiketrains` attribute. """ - pass + segment = Segment() + for spiketrain in self.get_spiketrains_from_trial_as_list(trial_index): + segment.spiketrains.append(spiketrain) + return segment @abstractmethod - def get_spiketrains_from_trial_as_segment(self, trial_id: int - ) -> neo.core.Segment: + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_list(self, trial_index: int + ) -> List[AnalogSignal]: """ - Get all spike trains from a specific trial and return a Segment. + Return all analog signals from a single trial as a list. Parameters ---------- - trial_id : int - Trial ID to get the spike trains from (starting at trial ID 0). + trial_index : int + Zero-based index of the trial. Returns ------- - :class:`neo.Segment`: Segment containing all spike trains of the trial. + List[AnalogSignal] + List containing all :class:`neo.AnalogSignal` objects from the + specified trial. """ - pass + raise NotImplementedError + + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_segment(self, trial_index: int + ) -> Segment: + """ + Return all analog signals from a single trial as a :class:`neo.Segment`. + + Parameters + ---------- + trial_index : int + Zero-based index of the trial. + + Returns + ------- + :class:`neo.Segment` + Segment containing all :class:`neo.AnalogSignal` objects from the + specified trial, accessible via the `analogsignals` attribute. + """ + segment = Segment() + for analogsignal in self.get_analogsignals_from_trial_as_list( + trial_index): + segment.analogsignals.append(analogsignal) + return segment @abstractmethod - def get_analogsignals_from_trial_as_list(self, trial_id: int - ) -> List[neo.core.AnalogSignal]: + def get_spiketrains_trial_by_trial(self, spiketrain_index: int) -> ( + SpikeTrainList): """ - Get all analogsignals from a specific trial and return a list. + Return spike train across all its trial repetitions. + + This method returns a list containing :class:`neo.SpikeTrain` objects + corresponding to the same spike train (e.g., from a consistent + recording channel or neuronal source) across multiple trials. Parameters ---------- - trial_id : int - Trial ID to get the analogsignals from (starting at trial ID 0). + spiketrain_index : int + Zero-based index of the spike train to retrieve across trials. Returns ------- - list of :class`neo.AnalogSignal`: list of all analogsignal objects of - the trial. + :class:`neo.SpikeTrainList` + List-like container storing the :class:`neo.SpikeTrain` objects + corresponding to the specified `spiketrain_index`, ordered by trial + index from zero to `n_trials - 1`. """ - pass + raise NotImplementedError @abstractmethod - def get_analogsignals_from_trial_as_segment(self, trial_id: int - ) -> neo.core.Segment: + def get_analogsignals_trial_by_trial(self, signal_index: int + ) -> List[AnalogSignal]: """ - Get all analogsignal objects from a specific trial and return a - :class:`neo.Segment`. + Return an analog signal across all its trial repetitions. + + This method returns a list containing :class:`neo.AnalogSignal` + objects corresponding to a continuous signal recorded from a consistent + recording channel or neuronal source across multiple trials. Parameters ---------- - trial_id : int - Trial ID to get the analogsignals from (starting at trial ID 0). + signal_index : int + Zero-based index of the analog signal to retrieve across trials. Returns ------- - class:`neo.Segment`: segment containing all analogsignal objects of - the trial. + List[AnalogSignal] + List storing :class:`neo.AnalogSignal` objects corresponding to the + specified `signal_index`, ordered by trial index from zero to + `n_trials - 1`. """ + raise NotImplementedError class TrialsFromBlock(Trials): """ - This class implements support for handling trials from neo.Block. + This class handles trial data organized within a :class:`neo.Block` object. + + In this representation, each trial is stored as a separate + :class:`neo.Segment` within the block. All trial segments are accessible + through the `segments` attribute. The data for a specific trial can be + accessed by its index, e.g., `segments[0]` corresponds to the first trial. + + Each :class:`neo.Segment` contains collections for spike trains and analog + signals, accessible via the `spiketrains` and `analogsignals` attributes, + respectively. When accessing data of individual :class:`neo.SpikeTrain` and + :class:`neo.AnalogSignal` objects, the indexes within these collections is + used. For instance, `spiketrains[0]` refers to the first spike train, and + `analogsignals[0]` to the first analog signal. Parameters ---------- - block : neo.Block - An instance of neo.Block containing the trials. - The structure is assumed to follow the neo representation: - A block contains multiple segments which are considered to contain the - single trials. - description : string, optional - A textual description of the set of trials. Can be accessed via the - class attribute `description`. - Default: None. + block : :class:`neo.Block` + An instance of :class:`neo.Block` containing the trial data. The block + contains multiple :class:`neo.Segment` objects, each containing the + data of a single trial. + description : str, optional + Textual description of the set of trials, accessible via the + :attr:`description` attribute. + Default: None + + Attributes + ---------- + description : str + The description of the set of trials. + n_trials : int + The total number of trials. + n_spiketrains_trial_by_trial : List[int] + The number of spike trains in each trial. + n_analogsignals_trial_by_trial : List[int] + The number of analog signals in each trial. + + Examples + -------- + 1. Generate `TrialFromBlock` object to handle data from two trials, each + containing three spike trains and one analog signal. + + >>> import numpy as np + >>> import quantities as pq + >>> import neo + >>> from elephant.spike_train_generation import StationaryPoissonProcess + >>> from elephant.trials import TrialsFromBlock + >>> + >>> st_generator = StationaryPoissonProcess(rate=10*pq.Hz, t_stop=1*pq.s) + >>> trial_block = neo.Block() + >>> for _ in range(2): + ... trial = neo.Segment() + ... trial.spiketrains = st_generator.generate_n_spiketrains(3) + ... signal = np.sin(np.arange(0, 6*np.pi, 2*np.pi/1000)) + ... signal += np.random.normal(size=signal.shape) + ... trial.analogsignals.append( + ... neo.AnalogSignal(signal, units=pq.mV, + ... t_stop=1*pq.s, + ... sampling_rate=(1/3000)*pq.Hz) + ... ) + ... trial_block.segments.append(trial) + >>> + >>> trials = TrialsFromBlock(trial_block) + + 2. Retrieve overall information. + + >>> print(trials.n_trials) + 2 + >>> print(trials.n_spiketrains_trial_by_trial) + [3, 3] + >>> print(trials.n_analogsignals_trial_by_trial) + [1, 1] + + 3. Access data in the first trial. + + >>> first_trial = trials[0] + >>> first_spike_train = first_trial.spiketrains[0] + >>> analog_signal = first_trial.analogsignals[0] """ - def __init__(self, block: neo.core.block, **kwargs): + def __init__(self, block: Block, **kwargs): self.block = block super().__init__(**kwargs) - def __getitem__(self, trial_number: int) -> neo.core.segment: - return self.block.segments[trial_number] - - def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: - # Get a specific trial by number as a segment - return self.__getitem__(trial_id) - - def get_trials_as_block(self, trial_ids: List[int] = None - ) -> neo.core.Block: - # Get a block of trials by trial numbers - block = Block() - if not trial_ids: - trial_ids = list(range(self.n_trials)) - for trial_number in trial_ids: - block.segments.append(self.get_trial_as_segment(trial_number)) - return block - - def get_trials_as_list(self, trial_ids: List[int] = None - ) -> List[neo.core.Segment]: - if not trial_ids: - trial_ids = list(range(self.n_trials)) - # Get a list of segments by trial numbers - return [self.get_trial_as_segment(trial_number) - for trial_number in trial_ids] + def __getitem__(self, trial_index: int) -> Segment: + # Get a specific trial by its index as a Segment + return self.block.segments[trial_index] @property def n_trials(self) -> int: - # Get the number of trials. + # Get the number of trials return len(self.block.segments) @property def n_spiketrains_trial_by_trial(self) -> List[int]: - # Get the number of SpikeTrain instances in each trial. + # Get the number of SpikeTrain objects in each trial return [len(trial.spiketrains) for trial in self.block.segments] @property def n_analogsignals_trial_by_trial(self) -> List[int]: - # Get the number of AnalogSignals instances in each trial. + # Get the number of AnalogSignal objects in each trial return [len(trial.analogsignals) for trial in self.block.segments] - def get_spiketrains_from_trial_as_list(self, trial_id: int = 0) -> ( - neo.core.spiketrainlist.SpikeTrainList): + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_list(self, trial_index: int = 0 + ) -> SpikeTrainList: # Return a list of all spike trains from a trial - return SpikeTrainList(items=[spiketrain for spiketrain in - self.block.segments[trial_id].spiketrains]) - - def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( - neo.core.Segment): - # Return a segment with all spiketrains from a trial - segment = neo.core.Segment() - for spiketrain in self.get_spiketrains_from_trial_as_list(trial_id - ): - segment.spiketrains.append(spiketrain) - return segment - - def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( - List[neo.core.AnalogSignal]): - # Return a list of all analogsignals from a trial + return SpikeTrainList( + items=[spiketrain for spiketrain in + self.block.segments[trial_index].spiketrains] + ) + + def get_spiketrains_trial_by_trial(self, spiketrain_index: int + ) -> SpikeTrainList: + # Return a list of all spike train repetitions across trials + return SpikeTrainList( + items=[segment.spiketrains[spiketrain_index] for + segment in self.block.segments] + ) + + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_list(self, trial_index: int + ) -> List[AnalogSignal]: + # Return a list of all analog signals from a trial return [analogsignal for analogsignal in - self.block.segments[trial_id].analogsignals] + self.block.segments[trial_index].analogsignals] - def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( - neo.core.Segment): - # Return a segment with all analogsignals from a trial - segment = neo.core.Segment() - for analogsignal in self.get_analogsignals_from_trial_as_list( - trial_id): - segment.analogsignals.append(analogsignal) - return segment + def get_analogsignals_trial_by_trial(self, signal_index: int + ) -> List[AnalogSignal]: + # Return a list of all analog signal repetitions across trials + return [segment.analogsignals[signal_index] + for segment in self.block.segments] class TrialsFromLists(Trials): """ - This class implements support for handling trials from list of lists. + This class handles trial data structured as a list of lists. + + In this representation, each inner list represents a single trial and + includes one or more data elements, such as spike trains + (:class:`neo.SpikeTrain`) and/or analog signals + (:class:`neo.AnalogSignal`). The order of these elements must remain + consistent across all trial repetitions. + + The index identifying each spike train or analog signal is determined by + its position within the list. For example, if each trial contains two + elements such as a spike train followed by an analog signal, the spike + train at index 0 corresponds to the first element, and the analog signal + at index 0 corresponds to the second element. Parameters ---------- - list_of_trials : list of lists - A list of lists. Each list entry contains a list of neo.SpikeTrains - and/or neo.AnalogSignals. - description : string, optional - A textual description of the set of trials. Can be accessed via the - class attribute `description`. - Default: None. + list_of_trials : list + A list of lists containing trial data. The inner lists must contain + spike train (:class:`neo.SpikeTrain`) or analog signal + (:class:`neo.AnalogSignal`) objects. + description : str, optional + Textual description of the set of trials, accessible via the + :attr:`description` attribute. + Default: None + + Attributes + ---------- + description : str + The description of the set of trials. + n_trials : int + The total number of trials. + n_spiketrains_trial_by_trial : List[int] + The number of spike trains in each trial. + n_analogsignals_trial_by_trial : List[int] + The number of analog signals in each trial. + + Examples + -------- + 1. Generate `TrialFromLists` object to handle data from three trials, each + containing two spike trains and one analog signal. + + >>> import numpy as np + >>> import quantities as pq + >>> import neo + >>> from elephant.spike_train_generation import StationaryPoissonProcess + >>> from elephant.trials import TrialsFromLists + >>> + >>> st_generator = StationaryPoissonProcess(rate=10*pq.Hz, t_stop=1*pq.s) + >>> trial_list = [st_generator.generate_n_spiketrains(2) for _ in range(3)] + >>> for trial in trial_list: + ... signal = np.sin(np.arange(0, 6*np.pi, 2*np.pi/1000)) + ... signal += np.random.normal(size=signal.shape) + ... trial.append( + ... neo.AnalogSignal(signal, units=pq.mV, + ... t_stop=1*pq.s, + ... sampling_rate=(1/3000)*pq.Hz) + ... ) + >>> + >>> trials = TrialsFromLists(trial_list) + + 2. Retrieve overall information. + + >>> print(trials.n_trials) + 3 + >>> print(trials.n_spiketrains_trial_by_trial) + [2, 2, 2] + >>> print(trials.n_analogsignals_trial_by_trial) + [1, 1, 1] + + 3. Access data in the first trial. + + >>> first_trial = trials[0] + >>> first_spike_train = first_trial.spiketrains[0] + >>> analog_signal = first_trial.analogsignals[0] """ def __init__(self, list_of_trials: list, **kwargs): - # Constructor - # (actual documentation is in class documentation, see above!) self.list_of_trials = list_of_trials super().__init__(**kwargs) - def __getitem__(self, trial_number: int) -> neo.core.Segment: - # Get a specific trial by number + # Save indexes for quick search of spike trains or analog signals + # in a trial. The order of elements in the inner list must be + # consistent across all trials (using the first list, corresponding + # to the first trial, to fetch the indexes). + if list_of_trials: + is_spiketrain = np.array([isinstance(data_element, SpikeTrain) + for data_element in list_of_trials[0]]) + self._spiketrain_index = is_spiketrain.nonzero()[0] + self._analogsignal_index = (~is_spiketrain).nonzero()[0] + else: + self._spiketrain_index = [] + self._analogsignal_index = [] + + def __getitem__(self, trial_index: int) -> Segment: + # Get a specific trial by its index as a Segment segment = Segment() - for element in self.list_of_trials[trial_number]: - if isinstance(element, neo.core.SpikeTrain): + for element in self.list_of_trials[trial_index]: + if isinstance(element, SpikeTrain): segment.spiketrains.append(element) - if isinstance(element, neo.core.AnalogSignal): + if isinstance(element, AnalogSignal): segment.analogsignals.append(element) return segment - def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: - # Get a specific trial by number as a segment - return self.__getitem__(trial_id) - - def get_trials_as_block(self, trial_ids: List[int] = None - ) -> neo.core.Block: - if not trial_ids: - trial_ids = list(range(self.n_trials)) - # Get a block of trials by trial numbers - block = Block() - for trial_number in trial_ids: - block.segments.append(self.get_trial_as_segment(trial_number)) - return block - - def get_trials_as_list(self, trial_ids: List[int] = None - ) -> List[neo.core.Segment]: - if not trial_ids: - trial_ids = list(range(self.n_trials)) - # Get a list of segments by trial numbers - return [self.get_trial_as_segment(trial_number) - for trial_number in trial_ids] - @property def n_trials(self) -> int: - # Get the number of trials. + # Get the number of trials return len(self.list_of_trials) @property def n_spiketrains_trial_by_trial(self) -> List[int]: - # Get the number of spiketrains in each trial. - return [sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), trial)) + # Get the number of SpikeTrain objects in each trial + return [sum(map(lambda x: isinstance(x, SpikeTrain), trial)) for trial in self.list_of_trials] @property def n_analogsignals_trial_by_trial(self) -> List[int]: - # Get the number of analogsignals in each trial. - return [sum(map(lambda x: isinstance(x, neo.core.AnalogSignal), trial)) + # Get the number of AnalogSignal objects in each trial + return [sum(map(lambda x: isinstance(x, AnalogSignal), trial)) for trial in self.list_of_trials] - def get_spiketrains_from_trial_as_list(self, trial_id: int = 0) -> ( - neo.core.spiketrainlist.SpikeTrainList): - # Return a list of all spiketrains from a trial + @deprecated_alias(trial_id="trial_index") + def get_spiketrains_from_trial_as_list(self, trial_index: int = 0 + ) -> SpikeTrainList: + # Return a list of all spike trains from a trial return SpikeTrainList(items=[ - spiketrain for spiketrain in self.list_of_trials[trial_id] - if isinstance(spiketrain, neo.core.SpikeTrain)]) - - def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( - neo.core.Segment): - # Return a segment with all spiketrains from a trial - segment = neo.core.Segment() - for spiketrain in self.get_spiketrains_from_trial_as_list(trial_id): - segment.spiketrains.append(spiketrain) - return segment - - def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( - List[neo.core.AnalogSignal]): - # Return a list of all analogsignals from a trial + spiketrain for spiketrain in self.list_of_trials[trial_index] + if isinstance(spiketrain, SpikeTrain)]) + + def get_spiketrains_trial_by_trial(self, spiketrain_index: int + ) -> SpikeTrainList: + # Return a list of all spike train repetitions across trials + list_idx = self._spiketrain_index[spiketrain_index] + return SpikeTrainList(items=[trial[list_idx] + for trial in self.list_of_trials]) + + @deprecated_alias(trial_id="trial_index") + def get_analogsignals_from_trial_as_list(self, trial_index: int + ) -> List[AnalogSignal]: + # Return a list of all analog signals from a trial return [analogsignal for analogsignal in - self.list_of_trials[trial_id] - if isinstance(analogsignal, neo.core.AnalogSignal)] - - def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( - neo.core.Segment): - # Return a segment with all analogsignals from a trial - segment = neo.core.Segment() - for analogsignal in self.get_analogsignals_from_trial_as_list( - trial_id): - segment.analogsignals.append(analogsignal) - return segment + self.list_of_trials[trial_index] + if isinstance(analogsignal, AnalogSignal)] + + def get_analogsignals_trial_by_trial(self, signal_index: int + ) -> AnalogSignal: + # Return a list of all analog signal repetitions across trials + list_idx = self._analogsignal_index[signal_index] + return [trial[list_idx] for trial in self.list_of_trials] diff --git a/elephant/utils.py b/elephant/utils.py index b4ddfee22..e30b43617 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -20,7 +20,6 @@ import numpy as np import quantities as pq -from elephant.trials import Trials __all__ = [ @@ -396,53 +395,3 @@ def get_opencl_capability(): return False -def trials_to_list_of_spiketrainlist(method): - """ - Decorator to convert `Trials` object to a list of `SpikeTrainList` before - calling the wrapped method. - - Parameters - ---------- - method: callable - The method to be decorated. - - Returns - ------- - callable: - The decorated method. - - Examples - -------- - The decorator can be used as follows: - - >>> @trials_to_list_of_spiketrainlist - ... def process_data(self, spiketrains): - ... return None - """ - - @wraps(method) - def wrapper(*args, **kwargs): - new_args = tuple( - [ - arg.get_spiketrains_from_trial_as_list(idx) - for idx in range(arg.n_trials) - ] - if isinstance(arg, Trials) - else arg - for arg in args - ) - new_kwargs = { - key: ( - [ - value.get_spiketrains_from_trial_as_list(idx) - for idx in range(value.n_trials) - ] - if isinstance(value, Trials) - else value - ) - for key, value in kwargs.items() - } - - return method(*new_args, **new_kwargs) - - return wrapper