6868import warnings
6969
7070import neo
71- from neo .core .spiketrainlist import SpikeTrainList
7271import numpy as np
7372import quantities as pq
7473import scipy .stats
7574import scipy .signal
75+ from numpy import ndarray
7676from scipy .special import erf
77+ from typing import Union
7778
7879import elephant .conversion as conv
7980import elephant .kernels as kernels
81+ import elephant .trials
8082from elephant .conversion import BinnedSpikeTrain
8183from elephant .utils import deprecated_alias , check_neo_consistency , \
8284 is_time_quantity , round_binning_errors
@@ -601,7 +603,8 @@ def lvr(time_intervals, R=5*pq.ms, with_nan=False):
601603@deprecated_alias (spiketrain = 'spiketrains' )
602604def instantaneous_rate (spiketrains , sampling_period , kernel = 'auto' ,
603605 cutoff = 5.0 , t_start = None , t_stop = None , trim = False ,
604- center_kernel = True , border_correction = False ):
606+ center_kernel = True , border_correction = False ,
607+ pool_trials = False , pool_spike_trains = False ):
605608 r"""
606609 Estimates instantaneous firing rate by kernel convolution.
607610
@@ -610,9 +613,12 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
610613
611614 Parameters
612615 ----------
613- spiketrains : neo.SpikeTrain or list of neo.SpikeTrain
614- Neo object(s) that contains spike times, the unit of the time stamps,
615- and `t_start` and `t_stop` of the spike train.
616+ spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials # noqa
617+ Input spike train(s) for which the instantaneous firing rate is
618+ calculated. If a list of spike trains is supplied, the parameter
619+ pool_spike_trains determines the behavior of the function. If a Trials
620+ object is supplied, the behavior is determined by the parameters
621+ pool_spike_trains (within a trial) and pool_trials (across trials).
616622 sampling_period : pq.Quantity
617623 Time stamp resolution of the spike times. The same resolution will
618624 be assumed for the kernel.
@@ -680,6 +686,21 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
680686 these spike train borders under the assumption that the rate does not
681687 change strongly.
682688 Only possible in the case of a Gaussian kernel.
689+
690+ Default: False
691+ pool_trials: bool, optional
692+ If true, calculate firing rates averaged over trials if spiketrains is
693+ of type elephant.trials.Trials
694+ Has no effect for single spike train or lists of spike trains.
695+
696+ Default: False
697+ pool_spike_trains: bool, optional
698+ If true, calculate firing rates averaged over spike trains. If the
699+ input is a Trials object, spike trains are pooled across spike trains
700+ within each trial, and pool_trials determines whether spike trains are
701+ additionally pooled across trials.
702+ Has no effect for a single spike train.
703+
683704 Default: False
684705
685706 Returns
@@ -788,6 +809,86 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
788809 [0.05842767]])
789810
790811 """
812+ if isinstance (spiketrains , elephant .trials .Trials ):
813+ kwargs = {
814+ 'kernel' : kernel ,
815+ 'cutoff' : cutoff ,
816+ 't_start' : t_start ,
817+ 't_stop' : t_stop ,
818+ 'trim' : trim ,
819+ 'center_kernel' : center_kernel ,
820+ 'border_correction' : border_correction ,
821+ 'pool_trials' : False ,
822+ 'pool_spike_trains' : False ,
823+ }
824+
825+ if pool_trials :
826+ list_of_lists_of_spiketrains = [
827+ spiketrains .get_spiketrains_from_trial_as_list (
828+ trial_id = trial_no )
829+ for trial_no in range (spiketrains .n_trials )]
830+
831+ spiketrains_cross_trials = (
832+ [list_of_lists_of_spiketrains [trial_no ][spiketrain_idx ]
833+ for trial_no in range (spiketrains .n_trials )]
834+ for spiketrain_idx , spiketrain in
835+ enumerate (list_of_lists_of_spiketrains [0 ]))
836+
837+ rates_cross_trials = [instantaneous_rate (spiketrain ,
838+ sampling_period ,
839+ ** kwargs )
840+ for spiketrain in spiketrains_cross_trials ]
841+
842+ average_rate_cross_trials = (
843+ np .mean (rates , axis = 1 ) for rates in rates_cross_trials )
844+ if pool_spike_trains :
845+ average_rate = np .mean (list (average_rate_cross_trials ), axis = 0 )
846+ analog_signal = rates_cross_trials [0 ]
847+
848+ return (neo .AnalogSignal (
849+ signal = average_rate ,
850+ sampling_period = analog_signal .sampling_period ,
851+ units = analog_signal .units ,
852+ t_start = analog_signal .t_start ,
853+ t_stop = analog_signal .t_stop ,
854+ kernel = analog_signal .annotations )
855+ )
856+
857+ list_of_average_rates_cross_trial = neo .AnalogSignal (
858+ signal = np .array (list (average_rate_cross_trials )).transpose (),
859+ sampling_period = rates_cross_trials [0 ].sampling_period ,
860+ units = rates_cross_trials [0 ].units ,
861+ t_start = rates_cross_trials [0 ].t_start ,
862+ t_stop = rates_cross_trials [0 ].t_stop ,
863+ kernel = rates_cross_trials [0 ].annotations )
864+
865+ return list_of_average_rates_cross_trial
866+
867+ if not pool_trials and not pool_spike_trains :
868+ return [instantaneous_rate (
869+ spiketrains .get_spiketrains_from_trial_as_list (
870+ trial_id = trial_no ), sampling_period , ** kwargs )
871+ for trial_no in range (spiketrains .n_trials )]
872+
873+ if not pool_trials and pool_spike_trains :
874+ rates = [instantaneous_rate (
875+ spiketrains .get_spiketrains_from_trial_as_list (
876+ trial_id = trial_no ), sampling_period , ** kwargs )
877+ for trial_no in range (spiketrains .n_trials )]
878+
879+ average_rates = (np .mean (rate , axis = 1 ) for rate in rates )
880+
881+ list_of_average_rates_over_spiketrains = [
882+ neo .AnalogSignal (signal = average_rate ,
883+ sampling_period = analog_signal .sampling_period ,
884+ units = analog_signal .units ,
885+ t_start = analog_signal .t_start ,
886+ t_stop = analog_signal .t_stop ,
887+ kernel = analog_signal .annotations )
888+ for average_rate , analog_signal in zip (average_rates , rates )]
889+
890+ return list_of_average_rates_over_spiketrains
891+
791892 def optimal_kernel (st ):
792893 width_sigma = None
793894 if len (st ) > 0 :
@@ -930,6 +1031,10 @@ def optimal_kernel(st):
9301031 sigma = str (kernel .sigma ),
9311032 invert = kernel .invert )
9321033
1034+ if isinstance (spiketrains , neo .core .spiketrainlist .SpikeTrainList ) and (
1035+ pool_spike_trains ):
1036+ rate = np .mean (rate , axis = 1 )
1037+
9331038 rate = neo .AnalogSignal (signal = rate ,
9341039 sampling_period = sampling_period ,
9351040 units = pq .Hz , t_start = t_start , t_stop = t_stop ,
@@ -1035,7 +1140,8 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
10351140 ... neo.SpikeTrain([0.3, 4.5, 6.7, 9.3], t_stop=10, units='s'),
10361141 ... neo.SpikeTrain([0.7, 4.3, 8.2], t_stop=10, units='s')
10371142 ... ]
1038- >>> hist = statistics.time_histogram(spiketrains, bin_size=1 * pq.s)
1143+ >>> hist = statistics.time_histogram(spiketrains,
1144+ ... bin_size=1 * pq.s)
10391145 >>> hist
10401146 <AnalogSignal(array([[2],
10411147 [0],
@@ -1053,32 +1159,49 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
10531159
10541160 """
10551161 # Bin the spike trains and sum across columns
1056- bs = BinnedSpikeTrain (spiketrains , t_start = t_start , t_stop = t_stop ,
1057- bin_size = bin_size )
1058-
10591162 if binary :
1060- bs = bs .binarize (copy = False )
1061- bin_hist = bs .get_num_of_spikes (axis = 0 )
1062- # Flatten array
1063- bin_hist = np .ravel (bin_hist )
1064- # Renormalise the histogram
1065- if output == 'counts' :
1066- # Raw
1067- bin_hist = pq .Quantity (bin_hist , units = pq .dimensionless , copy = False )
1068- elif output == 'mean' :
1069- # Divide by number of input spike trains
1070- bin_hist = pq .Quantity (bin_hist / len (spiketrains ),
1071- units = pq .dimensionless , copy = False )
1072- elif output == 'rate' :
1073- # Divide by number of input spike trains and bin width
1074- bin_hist = bin_hist / (len (spiketrains ) * bin_size )
1163+ binned_spiketrain = BinnedSpikeTrain (spiketrains ,
1164+ t_start = t_start ,
1165+ t_stop = t_stop , bin_size = bin_size
1166+ ).binarize (copy = False )
10751167 else :
1168+ binned_spiketrain = BinnedSpikeTrain (spiketrains ,
1169+ t_start = t_start ,
1170+ t_stop = t_stop , bin_size = bin_size
1171+ )
1172+
1173+ bin_hist : Union [int , ndarray ] = binned_spiketrain .get_num_of_spikes (axis = 0 )
1174+ # Flatten array
1175+ bin_hist .ravel ()
1176+
1177+ # Re-normalise the histogram according to desired output
1178+
1179+ def _counts () -> pq .Quantity :
1180+ # 'counts': spike counts at each bin (as integer numbers).
1181+ return pq .Quantity (bin_hist , units = pq .dimensionless , copy = False )
1182+
1183+ def _mean () -> pq .Quantity :
1184+ # 'mean': mean spike counts per spike train.
1185+ return pq .Quantity (bin_hist / len (spiketrains ),
1186+ units = pq .dimensionless , copy = False )
1187+
1188+ def _rate () -> pq .Quantity :
1189+ # 'rate': mean spike rate per spike train. Like 'mean', but the
1190+ # counts are additionally normalized by the bin width.
1191+ return bin_hist / (len (spiketrains ) * bin_size )
1192+
1193+ output_mapping = {"counts" : _counts , "mean" : _mean , "rate" : _rate }
1194+ try :
1195+ normalise_func = output_mapping .get (output )
1196+ normalised_bin_hist = normalise_func ()
1197+ except TypeError :
10761198 raise ValueError (f'Parameter output ({ output } ) is not valid.' )
10771199
1078- return neo .AnalogSignal (signal = np .expand_dims (bin_hist , axis = 1 ),
1079- sampling_period = bin_size , units = bin_hist .units ,
1080- t_start = bs .t_start , normalization = output ,
1081- copy = False )
1200+ return neo .AnalogSignal (signal = np .expand_dims (normalised_bin_hist , axis = 1 ),
1201+ sampling_period = bin_size ,
1202+ units = normalised_bin_hist .units ,
1203+ t_start = binned_spiketrain .t_start ,
1204+ normalization = output , copy = False )
10821205
10831206
10841207@deprecated_alias (binsize = 'bin_size' )
@@ -1343,7 +1466,7 @@ def pdf(self):
13431466 `t_start + j * binsize` and `t_start + (j + 1) * binsize`.
13441467 """
13451468 norm_hist = self .complexity_histogram / self .complexity_histogram .sum ()
1346- # Convert the Complexity pdf to an neo.AnalogSignal
1469+ # Convert the Complexity pdf to a neo.AnalogSignal
13471470 pdf = neo .AnalogSignal (
13481471 np .expand_dims (norm_hist , axis = 1 ),
13491472 units = pq .dimensionless ,
0 commit comments