@@ -293,7 +293,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[
293293 spiketrains : list or elephant.trials.Trial
294294 List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of
295295 spike times for which to compute the Fano factor of spike counts, or
296- an `elephant.trials.Trial` object, here the behavior can be controlled with the
296+ an `elephant.trials.Trial` object, here the behavior can be controlled with the
297297 pool_trials and pool_spike_trains parameters.
298298 warn_tolerance : pq.Quantity
299299 In case of a list of input neo.SpikeTrains, if their durations vary by
@@ -325,7 +325,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[
325325 Notes
326326 -----
327327 The check for the equal duration of the input spike trains is performed
328- only if the input is of type`neo.SpikeTrain`: if you pass a numpy array,
328+ only if the input is of type`neo.SpikeTrain`: if you pass e.g. a numpy array,
329329 please make sure that they all have the same duration manually.
330330
331331 Examples
@@ -346,30 +346,32 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[
346346 elif not isinstance (pool_spike_trains , bool ):
347347 raise TypeError (f"'pool_spike_trains' must be of type bool, but got { type (pool_spike_trains )} " )
348348 elif not is_time_quantity (warn_tolerance ):
349- raise TypeError ("'warn_tolerance' must be a time quantity." )
349+ raise TypeError (f"'warn_tolerance' must be a time quantity, but got { type (warn_tolerance )} " )
350+
351+ def _check_input_spiketrains_durations (spiketrains : Union [List [neo .SpikeTrain ], List [pq .Quantity ],
352+ List [np .ndarray ]]) -> None :
353+ if spiketrains and all (isinstance (st , neo .SpikeTrain ) for st in spiketrains ):
354+ durations = np .array (tuple (st .duration for st in spiketrains ))
355+ if np .max (durations ) - np .min (durations ) > warn_tolerance :
356+ warnings .warn (f"Fano factor calculated for spike trains of "
357+ f"different duration (minimum: { np .min (durations )} s, maximum "
358+ f"{ np .max (durations )} s)." )
359+ else :
360+ warnings .warn (f"Spiketrains was of type { type (spiketrains )} , which does not support automatic duration"
361+ f"check. The parameter 'warn_tolerance' will have no effect. Please ensure manually that"
362+ f"all spike trains have the same duration." )
350363
351- def _compute_fano (spiketrains : List [neo .SpikeTrain ]) -> float :
364+ def _compute_fano (spiketrains : Union [List [neo .SpikeTrain ], List [pq .Quantity ], List [np .ndarray ]]) -> float :
365+ # Check spike train durations
366+ _check_input_spiketrains_durations (spiketrains )
352367 # Build array of spike counts (one per spike train)
353- spike_counts = np .array ([len (st ) for st in spiketrains ])
354-
368+ spike_counts = np .array (tuple (len (st ) for st in spiketrains ))
355369 # Compute FF
356- if all (count == 0 for count in spike_counts ):
370+ if np . all (np . array ( spike_counts ) == 0 ):
357371 # empty list of spiketrains reaches this branch, and NaN is returned
358372 return np .nan
359-
360- if all (isinstance (st , neo .SpikeTrain ) for st in spiketrains ):
361- durations = [(st .t_stop - st .t_start ).simplified .item ()
362- for st in spiketrains ]
363- durations_min = min (durations )
364- durations_max = max (durations )
365- if durations_max - durations_min > warn_tolerance .simplified .item ():
366- warnings .warn ("Fano factor calculated for spike trains of "
367- "different duration (minimum: {_min}s, maximum "
368- "{_max}s)." .format (_min = durations_min ,
369- _max = durations_max ))
370-
371- fano = spike_counts .var () / spike_counts .mean ()
372- return fano
373+ else :
374+ return spike_counts .var ()/ spike_counts .mean ()
373375
374376 if isinstance (spiketrains , elephant .trials .Trials ):
375377 if not pool_trials and not pool_spike_trains :
0 commit comments