2525from scipy import integrate
2626
2727from elephant .conversion import BinnedSpikeTrain
28- from elephant .utils import deprecated_alias
28+ from elephant .utils import deprecated_alias , check_neo_consistency
2929
3030__all__ = [
3131 "covariance" ,
@@ -824,15 +824,17 @@ def cross_correlation_histogram(
824824
825825
826826@deprecated_alias (spiketrain_1 = 'spiketrain_i' , spiketrain_2 = 'spiketrain_j' )
827- def spike_time_tiling_coefficient (spiketrain_i , spiketrain_j , dt = 0.005 * pq .s ):
827+ def spike_time_tiling_coefficient (spiketrain_i : neo .core .SpikeTrain ,
828+ spiketrain_j : neo .core .SpikeTrain ,
829+ dt : pq .Quantity = 0.005 * pq .s ) -> float :
828830 """
829831 Calculates the Spike Time Tiling Coefficient (STTC) as described in
830832 :cite:`correlation-Cutts2014_14288` following their implementation in C.
831833 The STTC is a pairwise measure of correlation between spike trains.
832834 It has been proposed as a replacement for the correlation index as it
833835 presents several advantages (e.g. it's not confounded by firing rate,
834836 appropriately distinguishes lack of correlation from anti-correlation,
835- periods of silence don't add to the correlation and it's sensitive to
837+ periods of silence don't add to the correlation, and it's sensitive to
836838 firing patterns).
837839
838840 The STTC is calculated as follows:
@@ -845,7 +847,7 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
845847 in train 1, `PB` is the same proportion for the spikes in train 2;
846848 `TA` is the proportion of total recording time within `[-dt, +dt]` of any
847849 spike in train 1, TB is the same proportion for train 2.
848- For :math:`TA = PB = 1`and for :math:`TB = PA = 1`
850+ For :math:`TA = PB = 1` and for :math:`TB = PA = 1`
849851 the resulting :math:`0/0` is replaced with :math:`1`,
850852 since every spike from the train with :math:`T = 1` is within
851853 `[-dt, +dt]` of a spike of the other train.
@@ -857,7 +859,7 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
857859
858860 Parameters
859861 ----------
860- spiketrain_i, spiketrain_j : neo.SpikeTrain
862+ spiketrain_i, spiketrain_j : :class:` neo.core. SpikeTrain`
861863 Spike trains to cross-correlate. They must have the same `t_start` and
862864 `t_stop`.
863865 dt : pq.Quantity.
@@ -869,9 +871,9 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
869871
870872 Returns
871873 -------
872- index : float or np .nan
873- The spike time tiling coefficient (STTC). Returns np .nan if any spike
874- train is empty.
874+ index : :class:` float` or :obj:`numpy .nan`
875+ The spike time tiling coefficient (STTC). Returns :obj:`numpy .nan` if
876+ any spike train is empty.
875877
876878 Notes
877879 -----
@@ -891,109 +893,105 @@ def spike_time_tiling_coefficient(spiketrain_i, spiketrain_j, dt=0.005 * pq.s):
891893 0.4958601655933762
892894
893895 """
896+ # input checks
897+ if dt <= 0 * pq .s :
898+ raise ValueError (f"dt must be > 0, found: { dt } " )
894899
895- def run_P (spiketrain_i , spiketrain_j ):
900+ check_neo_consistency ([spiketrain_j , spiketrain_i ], neo .core .SpikeTrain )
901+
902+ if dt .units != spiketrain_i .units :
903+ dt = dt .rescale (spiketrain_i .units )
904+
905+ def run_p (spiketrain_j : neo .core .SpikeTrain ,
906+ spiketrain_i : neo .core .SpikeTrain ,
907+ dt : pq .Quantity = dt ) -> float :
896908 """
897- Check every spike in train 1 to see if there's a spike in train 2
898- within dt
909+ Returns number of spikes in spiketrain_j which lie within +- dt of
910+ any spike from spiketrain_i, divided by the total number of spikes in
911+ spiketrain_j
899912 """
900- N2 = len (spiketrain_j )
901-
902- # Search spikes of spiketrain_i in spiketrain_j
903- # ind will contain index of
904- ind = np .searchsorted (spiketrain_j .times , spiketrain_i .times )
905-
906- # To prevent IndexErrors
907- # If a spike of spiketrain_i is after the last spike of spiketrain_j,
908- # the index is N2, however spiketrain_j[N2] raises an IndexError.
909- # By shifting this index, the spike of spiketrain_i will be compared
910- # to the last 2 spikes of spiketrain_j (negligible overhead).
911- # Note: Not necessary for index 0 that will be shifted to -1,
912- # because spiketrain_j[-1] is valid (additional negligible comparison)
913- ind [ind == N2 ] = N2 - 1
914-
915- # Compare to nearest spike in spiketrain_j BEFORE spike in spiketrain_i
916- close_left = np .abs (
917- spiketrain_j .times [ind - 1 ] - spiketrain_i .times ) <= dt
918- # Compare to nearest spike in spiketrain_j AFTER (or simultaneous)
919- # spike in spiketrain_j
920- close_right = np .abs (
921- spiketrain_j .times [ind ] - spiketrain_i .times ) <= dt
922-
923- # spiketrain_j spikes that are in [-dt, dt] range of spiketrain_i
924- # spikes are counted only ONCE (as per original implementation)
925- close = close_left + close_right
926-
927- # Count how many spikes in spiketrain_i have a "partner" in
928- # spiketrain_j
929- return np .count_nonzero (close )
930-
931- def run_T (spiketrain ):
913+ # Create a boolean array where each element represents whether a spike
914+ # in spiketrain_j lies within +- dt of any spike in spiketrain_i.
915+ tiled_spikes_j = np .isclose (
916+ spiketrain_j .times .magnitude [:, np .newaxis ],
917+ spiketrain_i .times .magnitude ,
918+ atol = dt .item ())
919+ # Determine which spikes in spiketrain_j satisfy the time window
920+ # condition.
921+ tiled_spike_indices = np .any (tiled_spikes_j , axis = 1 )
922+ # Extract the spike times in spiketrain_j that satisfy the condition.
923+ tiled_spikes_j = spiketrain_j [tiled_spike_indices ]
924+ # Calculate the ratio of matching spikes in j to the total spikes in j.
925+ return len (tiled_spikes_j )/ len (spiketrain_j )
926+
927+ def run_t (spiketrain : neo .core .SpikeTrain , dt : pq .Quantity = dt ) -> float :
932928 """
933929 Calculate the proportion of the total recording time 'tiled' by spikes.
934930 """
935- N = len (spiketrain )
936- time_A = 2 * N * dt # maximum possible time
937-
938- if N == 1 : # for only a single spike in the train
939-
940- # Check difference between start of recording and single spike
941- if spiketrain [0 ] - spiketrain .t_start < dt :
942- time_A += - dt + spiketrain [0 ] - spiketrain .t_start
943-
944- # Check difference between single spike and end of recording
945- elif spiketrain [0 ] + dt > spiketrain .t_stop :
946- time_A += - dt - spiketrain [0 ] + spiketrain .t_stop
947-
948- else : # if more than a single spike in the train
949-
950- # Calculate difference between consecutive spikes
951- diff = np .diff (spiketrain )
952-
953- # Find spikes whose tiles overlap
954- idx = np .where (diff < 2 * dt )[0 ]
955- # Subtract overlapping "2*dt" tiles and add differences instead
956- time_A += - 2 * dt * len (idx ) + diff [idx ].sum ()
957-
958- # Check if spikes are within +/-dt of the start and/or end
959- # if so, subtract overlap of first and/or last spike
960- if (spiketrain [0 ] - spiketrain .t_start ) < dt :
961- time_A += spiketrain [0 ] - dt - spiketrain .t_start
962- if (spiketrain .t_stop - spiketrain [N - 1 ]) < dt :
963- time_A += - spiketrain [- 1 ] - dt + spiketrain .t_stop
964-
965- # Calculate the proportion of total recorded time to "tiled" time
966- T = time_A / (spiketrain .t_stop - spiketrain .t_start )
967- return T .simplified .item () # enforce simplification, strip units
931+ # Get the numerical value of 'dt'.
932+ dt = dt .item ()
933+ # Get the start and stop times of the spike train.
934+ t_start = spiketrain .t_start .item ()
935+ t_stop = spiketrain .t_stop .item ()
936+ # Get the spike times as a NumPy array.
937+ sorted_spikes = spiketrain .times .magnitude
938+ # Check if spikes are sorted and sort them if not.
939+ if (np .diff (sorted_spikes ) < 0 ).any ():
940+ sorted_spikes = np .sort (sorted_spikes )
941+
942+ # Calculate the time differences between consecutive spikes.
943+ diff_spikes = np .diff (sorted_spikes )
944+ # Calculate durations of spike overlaps within a time window of 2 * dt.
945+ overlap_durations = diff_spikes [diff_spikes <= 2 * dt ]
946+ covered_time_overlap = np .sum (overlap_durations )
947+
948+ # Calculate the durations of non-overlapping spikes.
949+ non_overlap_durations = diff_spikes [diff_spikes > 2 * dt ]
950+ covered_time_non_overlap = len (non_overlap_durations ) * 2 * dt
951+
952+ # Check if the first and last spikes are within +/-dt of the start
953+ # and end.
954+ # If so, adjust the overlapping and non-overlapping times accordingly.
955+ if sorted_spikes [0 ] - t_start < dt :
956+ covered_time_overlap += sorted_spikes [0 ] - t_start
957+ else :
958+ covered_time_non_overlap += dt
959+ if t_stop - sorted_spikes [- 1 ] < dt :
960+ covered_time_overlap += t_stop - sorted_spikes [- 1 ]
961+ else :
962+ covered_time_non_overlap += dt
968963
969- N1 = len (spiketrain_i )
970- N2 = len (spiketrain_j )
964+ # Calculate the total time covered by spikes and the total recording
965+ # time.
966+ total_time_covered = covered_time_overlap + covered_time_non_overlap
967+ total_time = t_stop - t_start
968+ # Calculate and return the proportion of the total recording time
969+ # covered by spikes.
970+ return total_time_covered / total_time
971971
972- if N1 == 0 or N2 == 0 :
972+ if len ( spiketrain_i ) == 0 or len ( spiketrain_j ) == 0 :
973973 index = np .nan
974974 else :
975- TA = run_T (spiketrain_i )
976- TB = run_T (spiketrain_j )
977- PA = run_P (spiketrain_i , spiketrain_j )
978- PA = PA / N1
979- PB = run_P (spiketrain_j , spiketrain_i )
980- PB = PB / N2
975+ TA = run_t (spiketrain_j , dt )
976+ TB = run_t (spiketrain_i , dt )
977+ PA = run_p (spiketrain_j , spiketrain_i , dt )
978+ PB = run_p (spiketrain_i , spiketrain_j , dt )
979+
981980 # check if the P and T values are 1 to avoid division by zero
982981 # This only happens for TA = PB = 1 and/or TB = PA = 1,
983982 # which leads to 0/0 in the calculation of the index.
984983 # In those cases, every spike in the train with P = 1
985984 # is within dt of a spike in the other train,
986985 # so we set the respective (partial) index to 1.
987- if PA * TB == 1 :
988- if PB * TA == 1 :
989- index = 1.
990- else :
991- index = 0.5 + 0.5 * (PB - TA ) / (1 - PB * TA )
986+ if PA * TB == 1 and PB * TA == 1 :
987+ index = 1.
988+ elif PA * TB == 1 :
989+ index = 0.5 + 0.5 * (PB - TA ) / (1 - PB * TA )
992990 elif PB * TA == 1 :
993991 index = 0.5 + 0.5 * (PA - TB ) / (1 - PA * TB )
994992 else :
995- index = 0.5 * (PA - TB ) / (1 - PA * TB ) + 0.5 * ( PB - TA ) / (
996- 1 - PB * TA )
993+ index = 0.5 * (PA - TB ) / (1 - PA * TB ) + \
994+ 0.5 * ( PB - TA ) / ( 1 - PB * TA )
997995 return index
998996
999997
0 commit comments