diff --git a/elephant/statistics.py b/elephant/statistics.py index 45d9cd283..cbf2ad371 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -74,7 +74,7 @@ import scipy.signal from numpy import ndarray from scipy.special import erf -from typing import Union +from typing import List, Union import elephant.conversion as conv import elephant.kernels as kernels @@ -270,10 +270,11 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): return rates -def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): +def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray], elephant.trials.Trials], + warn_tolerance: pq.Quantity = 0.1 * pq.ms) -> Union[float, List[float], List[List[float]]]: r""" Evaluates the empirical Fano factor F of the spike counts of - a list of `neo.SpikeTrain` objects. + a list of `neo.SpikeTrain` objects or `elephant.trials.Trial` object. Given the vector v containing the observed spike counts (one per spike train) in the time window [t0, t1], F is defined as: @@ -288,18 +289,20 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): Parameters ---------- - spiketrains : list + spiketrains : list or elephant.trials.Trial List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of - spike times for which to compute the Fano factor of spike counts. + spike times for which to compute the Fano factor of spike counts, or + an `elephant.trials.Trial` object. If a Trial object is used, spike trains are + pooled across trials before computing the Fano factor. warn_tolerance : pq.Quantity In case of a list of input neo.SpikeTrains, if their durations vary by - more than `warn_tolerence` in their absolute values, throw a warning + more than `warn_tolerance` in their absolute values, throw a warning (see Notes). Default: 0.1 ms Returns ------- - fano : float + fano : float, list of floats The Fano factor of the spike counts of the input spike trains. Returns np.NaN if an empty list is specified, or if all spike trains are empty. @@ -313,7 +316,7 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): Notes ----- The check for the equal duration of the input spike trains is performed - only if the input is of type`neo.SpikeTrain`: if you pass a numpy array, + only if the input is of type`neo.SpikeTrain`: if you pass e.g. a numpy array, please make sure that they all have the same duration manually. Examples @@ -328,29 +331,40 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): 0.07142857142857142 """ - # Build array of spike counts (one per spike train) - spike_counts = np.array([len(st) for st in spiketrains]) - - # Compute FF - if all(count == 0 for count in spike_counts): - # empty list of spiketrains reaches this branch, and NaN is returned - return np.nan - - if all(isinstance(st, neo.SpikeTrain) for st in spiketrains): - if not is_time_quantity(warn_tolerance): - raise TypeError("'warn_tolerance' must be a time quantity.") - durations = [(st.t_stop - st.t_start).simplified.item() - for st in spiketrains] - durations_min = min(durations) - durations_max = max(durations) - if durations_max - durations_min > warn_tolerance.simplified.item(): - warnings.warn("Fano factor calculated for spike trains of " - "different duration (minimum: {_min}s, maximum " - "{_max}s).".format(_min=durations_min, - _max=durations_max)) - - fano = spike_counts.var() / spike_counts.mean() - return fano + # Check if parameters are of the correct type + if not is_time_quantity(warn_tolerance): + raise TypeError(f"'warn_tolerance' must be a time quantity, but got {type(warn_tolerance)}") + + def _check_input_spiketrains_durations(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], + List[np.ndarray]]) -> None: + if spiketrains and all(isinstance(st, neo.SpikeTrain) for st in spiketrains): + durations = np.array(tuple(st.duration for st in spiketrains)) + if np.max(durations) - np.min(durations) > warn_tolerance: + warnings.warn(f"Fano factor calculated for spike trains of " + f"different duration (minimum: {np.min(durations)}s, maximum " + f"{np.max(durations)}s).") + + def _compute_fano(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray]]) -> float: + # Check spike train durations + _check_input_spiketrains_durations(spiketrains) + # Build array of spike counts (one per spike train) + spike_counts = np.array(tuple(len(st) for st in spiketrains)) + # Compute FF + if np.all(np.array(spike_counts) == 0): + # empty list of spiketrains reaches this branch, and NaN is returned + return np.nan + else: + return spike_counts.var()/spike_counts.mean() + + if isinstance(spiketrains, elephant.trials.Trials): + list_of_lists_of_spiketrains = [ + spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no) + for trial_no in range(spiketrains.n_trials)] + return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no] + for trial_no in range(len(list_of_lists_of_spiketrains))]) + for st_no in range(len(list_of_lists_of_spiketrains[0]))] + else: # Legacy behavior + return _compute_fano(spiketrains) def __variation_check(v, with_nan): diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 426111810..4dd10604c 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -21,7 +21,7 @@ from elephant import statistics from elephant.spike_train_generation import StationaryPoissonProcess from elephant.test.test_trials import _create_trials_block -from elephant.trials import TrialsFromBlock +from elephant.trials import TrialsFromBlock, TrialsFromLists class IsiTestCase(unittest.TestCase): @@ -269,32 +269,34 @@ def test_mean_firing_rate_with_plain_array_and_units_start_stop_typeerror( class FanoFactorTestCase(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): np.random.seed(100) num_st = 300 - self.test_spiketrains = [] - self.test_array = [] - self.test_quantity = [] - self.test_list = [] - self.sp_counts = np.zeros(num_st) + cls.test_spiketrains = [] + cls.test_array = [] + cls.test_quantity = [] + cls.test_list = [] + cls.sp_counts = np.zeros(num_st) for i in range(num_st): r = np.random.rand(np.random.randint(20) + 1) st = neo.core.SpikeTrain(r * pq.ms, t_start=0.0 * pq.ms, t_stop=20.0 * pq.ms) - self.test_spiketrains.append(st) - self.test_array.append(r) - self.test_quantity.append(r * pq.ms) - self.test_list.append(list(r)) + cls.test_spiketrains.append(st) + cls.test_array.append(r) + cls.test_quantity.append(r * pq.ms) + cls.test_list.append(list(r)) # for cross-validation - self.sp_counts[i] = len(st) + cls.sp_counts[i] = len(st) + + cls.test_trials = TrialsFromLists([cls.test_spiketrains, cls.test_spiketrains]) def test_fanofactor_spiketrains(self): # Test with list of spiketrains self.assertEqual( np.var(self.sp_counts) / np.mean(self.sp_counts), statistics.fanofactor(self.test_spiketrains)) - # One spiketrain in list st = self.test_spiketrains[0] self.assertEqual(statistics.fanofactor([st]), 0.0) @@ -352,6 +354,18 @@ def test_fanofactor_wrong_type(self): self.assertRaises(TypeError, statistics.fanofactor, [st1], warn_tolerance=1e-4) + def test_fanofactor_trials(self): + results = statistics.fanofactor(self.test_trials) + self.assertEqual(len(results), self.test_trials.n_spiketrains_trial_by_trial[0]) + + def test_fanofactor_trials_result(self): + results = statistics.fanofactor(self.test_trials) + for st_idx, result in enumerate(results): + spiketrains = [trial.spiketrains[st_idx] for trial in self.test_trials.get_trials_as_list()] + sp_count = sum([len(spiketrain) for spiketrain in spiketrains]) + self.assertEqual(np.var(sp_count) / np.mean(sp_count), + result) + class LVTestCase(unittest.TestCase): def setUp(self):