Skip to content

Commit 8bac14c

Browse files
[Fix] spike time tiling coefficient for unsorted spiketrains, added validation test (#564)
* refactor run_P * refactor run_T * add checks for t_start and t_stop in run_t * add input checks and unittests * add regression test for Issue #563 * add validation tests * add check if spike times are sorted, if not sort the spikes
1 parent 971fc6a commit 8bac14c

File tree

3 files changed

+204
-103
lines changed

3 files changed

+204
-103
lines changed

doc/bib/elephant.bib

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ @article{Cutts2014_14288
276276
number={43},
277277
pages={14288--14303},
278278
year={2014},
279-
publisher={Soc Neuroscience}
279+
publisher={Soc Neuroscience},
280+
doi={10.1523/JNEUROSCI.2767-14.2014}
280281
}
281282

282283
@article{Holt1996_1806,

elephant/spike_train_correlation.py

Lines changed: 90 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from scipy import integrate
2626

2727
from elephant.conversion import BinnedSpikeTrain
28-
from elephant.utils import deprecated_alias
28+
from elephant.utils import deprecated_alias, check_neo_consistency
2929

3030
__all__ = [
3131
"covariance",
@@ -824,15 +824,17 @@ def cross_correlation_histogram(
824824

825825

826826
@deprecated_alias(spiketrain_1='spiketrain_i', spiketrain_2='spiketrain_j')
827-
def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
827+
def spike_time_tiling_coefficient(spiketrain_i: neo.core.SpikeTrain,
828+
spiketrain_j: neo.core.SpikeTrain,
829+
dt: pq.Quantity = 0.005 * pq.s) -> float:
828830
"""
829831
Calculates the Spike Time Tiling Coefficient (STTC) as described in
830832
:cite:`correlation-Cutts2014_14288` following their implementation in C.
831833
The STTC is a pairwise measure of correlation between spike trains.
832834
It has been proposed as a replacement for the correlation index as it
833835
presents several advantages (e.g. it's not confounded by firing rate,
834836
appropriately distinguishes lack of correlation from anti-correlation,
835-
periods of silence don't add to the correlation and it's sensitive to
837+
periods of silence don't add to the correlation, and it's sensitive to
836838
firing patterns).
837839
838840
The STTC is calculated as follows:
@@ -845,7 +847,7 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
845847
in train 1, `PB` is the same proportion for the spikes in train 2;
846848
`TA` is the proportion of total recording time within `[-dt, +dt]` of any
847849
spike in train 1, TB is the same proportion for train 2.
848-
For :math:`TA = PB = 1`and for :math:`TB = PA = 1`
850+
For :math:`TA = PB = 1` and for :math:`TB = PA = 1`
849851
the resulting :math:`0/0` is replaced with :math:`1`,
850852
since every spike from the train with :math:`T = 1` is within
851853
`[-dt, +dt]` of a spike of the other train.
@@ -857,7 +859,7 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
857859
858860
Parameters
859861
----------
860-
spiketrain_i, spiketrain_j : neo.SpikeTrain
862+
spiketrain_i, spiketrain_j : :class:`neo.core.SpikeTrain`
861863
Spike trains to cross-correlate. They must have the same `t_start` and
862864
`t_stop`.
863865
dt : pq.Quantity.
@@ -869,9 +871,9 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
869871
870872
Returns
871873
-------
872-
index : float or np.nan
873-
The spike time tiling coefficient (STTC). Returns np.nan if any spike
874-
train is empty.
874+
index : :class:`float` or :obj:`numpy.nan`
875+
The spike time tiling coefficient (STTC). Returns :obj:`numpy.nan` if
876+
any spike train is empty.
875877
876878
Notes
877879
-----
@@ -891,109 +893,105 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
891893
0.4958601655933762
892894
893895
"""
896+
# input checks
897+
if dt <= 0 * pq.s:
898+
raise ValueError(f"dt must be > 0, found: {dt}")
894899

895-
def run_P(spiketrain_i, spiketrain_j):
900+
check_neo_consistency([spiketrain_j, spiketrain_i], neo.core.SpikeTrain)
901+
902+
if dt.units != spiketrain_i.units:
903+
dt = dt.rescale(spiketrain_i.units)
904+
905+
def run_p(spiketrain_j: neo.core.SpikeTrain,
906+
spiketrain_i: neo.core.SpikeTrain,
907+
dt: pq.Quantity = dt) -> float:
896908
"""
897-
Check every spike in train 1 to see if there's a spike in train 2
898-
within dt
909+
Returns number of spikes in spiketrain_j which lie within +- dt of
910+
any spike from spiketrain_i, divided by the total number of spikes in
911+
spiketrain_j
899912
"""
900-
N2 = len(spiketrain_j)
901-
902-
# Search spikes of spiketrain_i in spiketrain_j
903-
# ind will contain index of
904-
ind = np.searchsorted(spiketrain_j.times, spiketrain_i.times)
905-
906-
# To prevent IndexErrors
907-
# If a spike of spiketrain_i is after the last spike of spiketrain_j,
908-
# the index is N2, however spiketrain_j[N2] raises an IndexError.
909-
# By shifting this index, the spike of spiketrain_i will be compared
910-
# to the last 2 spikes of spiketrain_j (negligible overhead).
911-
# Note: Not necessary for index 0 that will be shifted to -1,
912-
# because spiketrain_j[-1] is valid (additional negligible comparison)
913-
ind[ind == N2] = N2 - 1
914-
915-
# Compare to nearest spike in spiketrain_j BEFORE spike in spiketrain_i
916-
close_left = np.abs(
917-
spiketrain_j.times[ind - 1] - spiketrain_i.times) <= dt
918-
# Compare to nearest spike in spiketrain_j AFTER (or simultaneous)
919-
# spike in spiketrain_j
920-
close_right = np.abs(
921-
spiketrain_j.times[ind] - spiketrain_i.times) <= dt
922-
923-
# spiketrain_j spikes that are in [-dt, dt] range of spiketrain_i
924-
# spikes are counted only ONCE (as per original implementation)
925-
close = close_left + close_right
926-
927-
# Count how many spikes in spiketrain_i have a "partner" in
928-
# spiketrain_j
929-
return np.count_nonzero(close)
930-
931-
def run_T(spiketrain):
913+
# Create a boolean array where each element represents whether a spike
914+
# in spiketrain_j lies within +- dt of any spike in spiketrain_i.
915+
tiled_spikes_j = np.isclose(
916+
spiketrain_j.times.magnitude[:, np.newaxis],
917+
spiketrain_i.times.magnitude,
918+
atol=dt.item())
919+
# Determine which spikes in spiketrain_j satisfy the time window
920+
# condition.
921+
tiled_spike_indices = np.any(tiled_spikes_j, axis=1)
922+
# Extract the spike times in spiketrain_j that satisfy the condition.
923+
tiled_spikes_j = spiketrain_j[tiled_spike_indices]
924+
# Calculate the ratio of matching spikes in j to the total spikes in j.
925+
return len(tiled_spikes_j)/len(spiketrain_j)
926+
927+
def run_t(spiketrain: neo.core.SpikeTrain, dt: pq.Quantity = dt) -> float:
932928
"""
933929
Calculate the proportion of the total recording time 'tiled' by spikes.
934930
"""
935-
N = len(spiketrain)
936-
time_A = 2 * N * dt # maximum possible time
937-
938-
if N == 1: # for only a single spike in the train
939-
940-
# Check difference between start of recording and single spike
941-
if spiketrain[0] - spiketrain.t_start < dt:
942-
time_A += - dt + spiketrain[0] - spiketrain.t_start
943-
944-
# Check difference between single spike and end of recording
945-
elif spiketrain[0] + dt > spiketrain.t_stop:
946-
time_A += - dt - spiketrain[0] + spiketrain.t_stop
947-
948-
else: # if more than a single spike in the train
949-
950-
# Calculate difference between consecutive spikes
951-
diff = np.diff(spiketrain)
952-
953-
# Find spikes whose tiles overlap
954-
idx = np.where(diff < 2 * dt)[0]
955-
# Subtract overlapping "2*dt" tiles and add differences instead
956-
time_A += - 2 * dt * len(idx) + diff[idx].sum()
957-
958-
# Check if spikes are within +/-dt of the start and/or end
959-
# if so, subtract overlap of first and/or last spike
960-
if (spiketrain[0] - spiketrain.t_start) < dt:
961-
time_A += spiketrain[0] - dt - spiketrain.t_start
962-
if (spiketrain.t_stop - spiketrain[N - 1]) < dt:
963-
time_A += - spiketrain[-1] - dt + spiketrain.t_stop
964-
965-
# Calculate the proportion of total recorded time to "tiled" time
966-
T = time_A / (spiketrain.t_stop - spiketrain.t_start)
967-
return T.simplified.item() # enforce simplification, strip units
931+
# Get the numerical value of 'dt'.
932+
dt = dt.item()
933+
# Get the start and stop times of the spike train.
934+
t_start = spiketrain.t_start.item()
935+
t_stop = spiketrain.t_stop.item()
936+
# Get the spike times as a NumPy array.
937+
sorted_spikes = spiketrain.times.magnitude
938+
# Check if spikes are sorted and sort them if not.
939+
if (np.diff(sorted_spikes) < 0).any():
940+
sorted_spikes = np.sort(sorted_spikes)
941+
942+
# Calculate the time differences between consecutive spikes.
943+
diff_spikes = np.diff(sorted_spikes)
944+
# Calculate durations of spike overlaps within a time window of 2 * dt.
945+
overlap_durations = diff_spikes[diff_spikes <= 2 * dt]
946+
covered_time_overlap = np.sum(overlap_durations)
947+
948+
# Calculate the durations of non-overlapping spikes.
949+
non_overlap_durations = diff_spikes[diff_spikes > 2 * dt]
950+
covered_time_non_overlap = len(non_overlap_durations) * 2 * dt
951+
952+
# Check if the first and last spikes are within +/-dt of the start
953+
# and end.
954+
# If so, adjust the overlapping and non-overlapping times accordingly.
955+
if sorted_spikes[0] - t_start < dt:
956+
covered_time_overlap += sorted_spikes[0] - t_start
957+
else:
958+
covered_time_non_overlap += dt
959+
if t_stop - sorted_spikes[- 1] < dt:
960+
covered_time_overlap += t_stop - sorted_spikes[-1]
961+
else:
962+
covered_time_non_overlap += dt
968963

969-
N1 = len(spiketrain_i)
970-
N2 = len(spiketrain_j)
964+
# Calculate the total time covered by spikes and the total recording
965+
# time.
966+
total_time_covered = covered_time_overlap + covered_time_non_overlap
967+
total_time = t_stop - t_start
968+
# Calculate and return the proportion of the total recording time
969+
# covered by spikes.
970+
return total_time_covered / total_time
971971

972-
if N1 == 0 or N2 == 0:
972+
if len(spiketrain_i) == 0 or len(spiketrain_j) == 0:
973973
index = np.nan
974974
else:
975-
TA = run_T(spiketrain_i)
976-
TB = run_T(spiketrain_j)
977-
PA = run_P(spiketrain_i, spiketrain_j)
978-
PA = PA / N1
979-
PB = run_P(spiketrain_j, spiketrain_i)
980-
PB = PB / N2
975+
TA = run_t(spiketrain_j, dt)
976+
TB = run_t(spiketrain_i, dt)
977+
PA = run_p(spiketrain_j, spiketrain_i, dt)
978+
PB = run_p(spiketrain_i, spiketrain_j, dt)
979+
981980
# check if the P and T values are 1 to avoid division by zero
982981
# This only happens for TA = PB = 1 and/or TB = PA = 1,
983982
# which leads to 0/0 in the calculation of the index.
984983
# In those cases, every spike in the train with P = 1
985984
# is within dt of a spike in the other train,
986985
# so we set the respective (partial) index to 1.
987-
if PA * TB == 1:
988-
if PB * TA == 1:
989-
index = 1.
990-
else:
991-
index = 0.5 + 0.5 * (PB - TA) / (1 - PB * TA)
986+
if PA * TB == 1 and PB * TA == 1:
987+
index = 1.
988+
elif PA * TB == 1:
989+
index = 0.5 + 0.5 * (PB - TA) / (1 - PB * TA)
992990
elif PB * TA == 1:
993991
index = 0.5 + 0.5 * (PA - TB) / (1 - PA * TB)
994992
else:
995-
index = 0.5 * (PA - TB) / (1 - PA * TB) + 0.5 * (PB - TA) / (
996-
1 - PB * TA)
993+
index = 0.5 * (PA - TB) / (1 - PA * TB) + \
994+
0.5 * (PB - TA) / (1 - PB * TA)
997995
return index
998996

999997

0 commit comments

Comments
 (0)