Skip to content

Commit d3a35a2

Browse files
[Feature] Trial handling in Elephant (#579)
--------- Co-authored-by: Michael Denker <[email protected]>
1 parent 8bac14c commit d3a35a2

File tree

8 files changed

+1046
-111
lines changed

8 files changed

+1046
-111
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{{ fullname | escape | underline }}
2+
3+
.. currentmodule:: {{ module }}
4+
5+
.. autoclass:: {{ objname }}
6+
:special-members: __contains__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__
7+
:members:
8+
:exclude-members: __getitem__,__init__

doc/modules.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Spike trains
4646
:maxdepth: 2
4747

4848
reference/spike_train_generation
49-
49+
5050

5151
********************************
5252
LFPs and spike trains (combined)
@@ -77,14 +77,15 @@ Waveforms
7777

7878
reference/waveform_features
7979

80-
********************************
81-
Alternative data representations
82-
********************************
80+
********************
81+
Data Representations
82+
********************
8383

8484
.. toctree::
8585
:maxdepth: 1
8686

8787
reference/conversion
88+
reference/trials
8889

8990
*************
9091
Miscellaneous

doc/reference/trials.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
=====================
2+
Trial representations
3+
=====================
4+
5+
.. automodule:: elephant.trials

elephant/__init__.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,28 @@
66
:license: Modified BSD, see LICENSE.txt for details.
77
"""
88

9-
from . import (statistics,
10-
spike_train_generation,
11-
spike_train_synchrony,
12-
spike_train_correlation,
13-
unitary_event_analysis,
9+
from . import (cell_assembly_detection,
10+
change_point_detection,
11+
conversion,
1412
cubic,
15-
spectral,
13+
current_source_density,
14+
gpfa,
1615
kernels,
16+
neo_tools,
17+
phase_analysis,
18+
signal_processing,
19+
spade,
20+
spectral,
21+
spike_train_correlation,
1722
spike_train_dissimilarity,
23+
spike_train_generation,
1824
spike_train_surrogates,
19-
signal_processing,
20-
current_source_density,
21-
change_point_detection,
22-
phase_analysis,
25+
spike_train_synchrony,
2326
sta,
24-
conversion,
25-
neo_tools,
26-
cell_assembly_detection,
27-
spade,
27+
trials,
28+
unitary_event_analysis,
2829
waveform_features,
29-
gpfa)
30+
statistics)
3031

3132
# not included modules on purpose:
3233
# parallel: avoid warns when elephant is imported

elephant/statistics.py

Lines changed: 152 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,17 @@
6868
import warnings
6969

7070
import neo
71-
from neo.core.spiketrainlist import SpikeTrainList
7271
import numpy as np
7372
import quantities as pq
7473
import scipy.stats
7574
import scipy.signal
75+
from numpy import ndarray
7676
from scipy.special import erf
77+
from typing import Union
7778

7879
import elephant.conversion as conv
7980
import elephant.kernels as kernels
81+
import elephant.trials
8082
from elephant.conversion import BinnedSpikeTrain
8183
from elephant.utils import deprecated_alias, check_neo_consistency, \
8284
is_time_quantity, round_binning_errors
@@ -601,7 +603,8 @@ def lvr(time_intervals, R=5*pq.ms, with_nan=False):
601603
@deprecated_alias(spiketrain='spiketrains')
602604
def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
603605
cutoff=5.0, t_start=None, t_stop=None, trim=False,
604-
center_kernel=True, border_correction=False):
606+
center_kernel=True, border_correction=False,
607+
pool_trials=False, pool_spike_trains=False):
605608
r"""
606609
Estimates instantaneous firing rate by kernel convolution.
607610
@@ -610,9 +613,12 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
610613
611614
Parameters
612615
----------
613-
spiketrains : neo.SpikeTrain or list of neo.SpikeTrain
614-
Neo object(s) that contains spike times, the unit of the time stamps,
615-
and `t_start` and `t_stop` of the spike train.
616+
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials # noqa
617+
Input spike train(s) for which the instantaneous firing rate is
618+
calculated. If a list of spike trains is supplied, the parameter
619+
pool_spike_trains determines the behavior of the function. If a Trials
620+
object is supplied, the behavior is determined by the parameters
621+
pool_spike_trains (within a trial) and pool_trials (across trials).
616622
sampling_period : pq.Quantity
617623
Time stamp resolution of the spike times. The same resolution will
618624
be assumed for the kernel.
@@ -680,6 +686,21 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
680686
these spike train borders under the assumption that the rate does not
681687
change strongly.
682688
Only possible in the case of a Gaussian kernel.
689+
690+
Default: False
691+
pool_trials: bool, optional
692+
If true, calculate firing rates averaged over trials if spiketrains is
693+
of type elephant.trials.Trials
694+
Has no effect for single spike train or lists of spike trains.
695+
696+
Default: False
697+
pool_spike_trains: bool, optional
698+
If true, calculate firing rates averaged over spike trains. If the
699+
input is a Trials object, spike trains are pooled across spike trains
700+
within each trial, and pool_trials determines whether spike trains are
701+
additionally pooled across trials.
702+
Has no effect for a single spike train.
703+
683704
Default: False
684705
685706
Returns
@@ -788,6 +809,86 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
788809
[0.05842767]])
789810
790811
"""
812+
if isinstance(spiketrains, elephant.trials.Trials):
813+
kwargs = {
814+
'kernel': kernel,
815+
'cutoff': cutoff,
816+
't_start': t_start,
817+
't_stop': t_stop,
818+
'trim': trim,
819+
'center_kernel': center_kernel,
820+
'border_correction': border_correction,
821+
'pool_trials': False,
822+
'pool_spike_trains': False,
823+
}
824+
825+
if pool_trials:
826+
list_of_lists_of_spiketrains = [
827+
spiketrains.get_spiketrains_from_trial_as_list(
828+
trial_id=trial_no)
829+
for trial_no in range(spiketrains.n_trials)]
830+
831+
spiketrains_cross_trials = (
832+
[list_of_lists_of_spiketrains[trial_no][spiketrain_idx]
833+
for trial_no in range(spiketrains.n_trials)]
834+
for spiketrain_idx, spiketrain in
835+
enumerate(list_of_lists_of_spiketrains[0]))
836+
837+
rates_cross_trials = [instantaneous_rate(spiketrain,
838+
sampling_period,
839+
**kwargs)
840+
for spiketrain in spiketrains_cross_trials]
841+
842+
average_rate_cross_trials = (
843+
np.mean(rates, axis=1) for rates in rates_cross_trials)
844+
if pool_spike_trains:
845+
average_rate = np.mean(list(average_rate_cross_trials), axis=0)
846+
analog_signal = rates_cross_trials[0]
847+
848+
return (neo.AnalogSignal(
849+
signal=average_rate,
850+
sampling_period=analog_signal.sampling_period,
851+
units=analog_signal.units,
852+
t_start=analog_signal.t_start,
853+
t_stop=analog_signal.t_stop,
854+
kernel=analog_signal.annotations)
855+
)
856+
857+
list_of_average_rates_cross_trial = neo.AnalogSignal(
858+
signal=np.array(list(average_rate_cross_trials)).transpose(),
859+
sampling_period=rates_cross_trials[0].sampling_period,
860+
units=rates_cross_trials[0].units,
861+
t_start=rates_cross_trials[0].t_start,
862+
t_stop=rates_cross_trials[0].t_stop,
863+
kernel=rates_cross_trials[0].annotations)
864+
865+
return list_of_average_rates_cross_trial
866+
867+
if not pool_trials and not pool_spike_trains:
868+
return [instantaneous_rate(
869+
spiketrains.get_spiketrains_from_trial_as_list(
870+
trial_id=trial_no), sampling_period, **kwargs)
871+
for trial_no in range(spiketrains.n_trials)]
872+
873+
if not pool_trials and pool_spike_trains:
874+
rates = [instantaneous_rate(
875+
spiketrains.get_spiketrains_from_trial_as_list(
876+
trial_id=trial_no), sampling_period, **kwargs)
877+
for trial_no in range(spiketrains.n_trials)]
878+
879+
average_rates = (np.mean(rate, axis=1) for rate in rates)
880+
881+
list_of_average_rates_over_spiketrains = [
882+
neo.AnalogSignal(signal=average_rate,
883+
sampling_period=analog_signal.sampling_period,
884+
units=analog_signal.units,
885+
t_start=analog_signal.t_start,
886+
t_stop=analog_signal.t_stop,
887+
kernel=analog_signal.annotations)
888+
for average_rate, analog_signal in zip(average_rates, rates)]
889+
890+
return list_of_average_rates_over_spiketrains
891+
791892
def optimal_kernel(st):
792893
width_sigma = None
793894
if len(st) > 0:
@@ -930,6 +1031,10 @@ def optimal_kernel(st):
9301031
sigma=str(kernel.sigma),
9311032
invert=kernel.invert)
9321033

1034+
if isinstance(spiketrains, neo.core.spiketrainlist.SpikeTrainList) and (
1035+
pool_spike_trains):
1036+
rate = np.mean(rate, axis=1)
1037+
9331038
rate = neo.AnalogSignal(signal=rate,
9341039
sampling_period=sampling_period,
9351040
units=pq.Hz, t_start=t_start, t_stop=t_stop,
@@ -1035,7 +1140,8 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
10351140
... neo.SpikeTrain([0.3, 4.5, 6.7, 9.3], t_stop=10, units='s'),
10361141
... neo.SpikeTrain([0.7, 4.3, 8.2], t_stop=10, units='s')
10371142
... ]
1038-
>>> hist = statistics.time_histogram(spiketrains, bin_size=1 * pq.s)
1143+
>>> hist = statistics.time_histogram(spiketrains,
1144+
... bin_size=1 * pq.s)
10391145
>>> hist
10401146
<AnalogSignal(array([[2],
10411147
[0],
@@ -1053,32 +1159,49 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
10531159
10541160
"""
10551161
# Bin the spike trains and sum across columns
1056-
bs = BinnedSpikeTrain(spiketrains, t_start=t_start, t_stop=t_stop,
1057-
bin_size=bin_size)
1058-
10591162
if binary:
1060-
bs = bs.binarize(copy=False)
1061-
bin_hist = bs.get_num_of_spikes(axis=0)
1062-
# Flatten array
1063-
bin_hist = np.ravel(bin_hist)
1064-
# Renormalise the histogram
1065-
if output == 'counts':
1066-
# Raw
1067-
bin_hist = pq.Quantity(bin_hist, units=pq.dimensionless, copy=False)
1068-
elif output == 'mean':
1069-
# Divide by number of input spike trains
1070-
bin_hist = pq.Quantity(bin_hist / len(spiketrains),
1071-
units=pq.dimensionless, copy=False)
1072-
elif output == 'rate':
1073-
# Divide by number of input spike trains and bin width
1074-
bin_hist = bin_hist / (len(spiketrains) * bin_size)
1163+
binned_spiketrain = BinnedSpikeTrain(spiketrains,
1164+
t_start=t_start,
1165+
t_stop=t_stop, bin_size=bin_size
1166+
).binarize(copy=False)
10751167
else:
1168+
binned_spiketrain = BinnedSpikeTrain(spiketrains,
1169+
t_start=t_start,
1170+
t_stop=t_stop, bin_size=bin_size
1171+
)
1172+
1173+
bin_hist: Union[int, ndarray] = binned_spiketrain.get_num_of_spikes(axis=0)
1174+
# Flatten array
1175+
bin_hist.ravel()
1176+
1177+
# Re-normalise the histogram according to desired output
1178+
1179+
def _counts() -> pq.Quantity:
1180+
# 'counts': spike counts at each bin (as integer numbers).
1181+
return pq.Quantity(bin_hist, units=pq.dimensionless, copy=False)
1182+
1183+
def _mean() -> pq.Quantity:
1184+
# 'mean': mean spike counts per spike train.
1185+
return pq.Quantity(bin_hist / len(spiketrains),
1186+
units=pq.dimensionless, copy=False)
1187+
1188+
def _rate() -> pq.Quantity:
1189+
# 'rate': mean spike rate per spike train. Like 'mean', but the
1190+
# counts are additionally normalized by the bin width.
1191+
return bin_hist / (len(spiketrains) * bin_size)
1192+
1193+
output_mapping = {"counts": _counts, "mean": _mean, "rate": _rate}
1194+
try:
1195+
normalise_func = output_mapping.get(output)
1196+
normalised_bin_hist = normalise_func()
1197+
except TypeError:
10761198
raise ValueError(f'Parameter output ({output}) is not valid.')
10771199

1078-
return neo.AnalogSignal(signal=np.expand_dims(bin_hist, axis=1),
1079-
sampling_period=bin_size, units=bin_hist.units,
1080-
t_start=bs.t_start, normalization=output,
1081-
copy=False)
1200+
return neo.AnalogSignal(signal=np.expand_dims(normalised_bin_hist, axis=1),
1201+
sampling_period=bin_size,
1202+
units=normalised_bin_hist.units,
1203+
t_start=binned_spiketrain.t_start,
1204+
normalization=output, copy=False)
10821205

10831206

10841207
@deprecated_alias(binsize='bin_size')
@@ -1343,7 +1466,7 @@ def pdf(self):
13431466
`t_start + j * binsize` and `t_start + (j + 1) * binsize`.
13441467
"""
13451468
norm_hist = self.complexity_histogram / self.complexity_histogram.sum()
1346-
# Convert the Complexity pdf to an neo.AnalogSignal
1469+
# Convert the Complexity pdf to a neo.AnalogSignal
13471470
pdf = neo.AnalogSignal(
13481471
np.expand_dims(norm_hist, axis=1),
13491472
units=pq.dimensionless,

0 commit comments

Comments
 (0)