Skip to content

Commit 6bb2fc3

Browse files
add user warning and did refactoring of function
1 parent 3c1eb5d commit 6bb2fc3

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

elephant/statistics.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[
293293
spiketrains : list or elephant.trials.Trial
294294
List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of
295295
spike times for which to compute the Fano factor of spike counts, or
296-
an `elephant.trials.Trial` object, here the behavior can be controlled with the
296+
an `elephant.trials.Trial` object, here the behavior can be controlled with the
297297
pool_trials and pool_spike_trains parameters.
298298
warn_tolerance : pq.Quantity
299299
In case of a list of input neo.SpikeTrains, if their durations vary by
@@ -325,7 +325,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[
325325
Notes
326326
-----
327327
The check for the equal duration of the input spike trains is performed
328-
only if the input is of type`neo.SpikeTrain`: if you pass a numpy array,
328+
only if the input is of type`neo.SpikeTrain`: if you pass e.g. a numpy array,
329329
please make sure that they all have the same duration manually.
330330
331331
Examples
@@ -346,30 +346,32 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[
346346
elif not isinstance(pool_spike_trains, bool):
347347
raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}")
348348
elif not is_time_quantity(warn_tolerance):
349-
raise TypeError("'warn_tolerance' must be a time quantity.")
349+
raise TypeError(f"'warn_tolerance' must be a time quantity, but got {type(warn_tolerance)}")
350+
351+
def _check_input_spiketrains_durations(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity],
352+
List[np.ndarray]]) -> None:
353+
if spiketrains and all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
354+
durations = np.array(tuple(st.duration for st in spiketrains))
355+
if np.max(durations) - np.min(durations) > warn_tolerance:
356+
warnings.warn(f"Fano factor calculated for spike trains of "
357+
f"different duration (minimum: {np.min(durations)}s, maximum "
358+
f"{np.max(durations)}s).")
359+
else:
360+
warnings.warn(f"Spiketrains was of type {type(spiketrains)}, which does not support automatic duration"
361+
f"check. The parameter 'warn_tolerance' will have no effect. Please ensure manually that"
362+
f"all spike trains have the same duration.")
350363

351-
def _compute_fano(spiketrains: List[neo.SpikeTrain]) -> float:
364+
def _compute_fano(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray]]) -> float:
365+
# Check spike train durations
366+
_check_input_spiketrains_durations(spiketrains)
352367
# Build array of spike counts (one per spike train)
353-
spike_counts = np.array([len(st) for st in spiketrains])
354-
368+
spike_counts = np.array(tuple(len(st) for st in spiketrains))
355369
# Compute FF
356-
if all(count == 0 for count in spike_counts):
370+
if np.all(np.array(spike_counts) == 0):
357371
# empty list of spiketrains reaches this branch, and NaN is returned
358372
return np.nan
359-
360-
if all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
361-
durations = [(st.t_stop - st.t_start).simplified.item()
362-
for st in spiketrains]
363-
durations_min = min(durations)
364-
durations_max = max(durations)
365-
if durations_max - durations_min > warn_tolerance.simplified.item():
366-
warnings.warn("Fano factor calculated for spike trains of "
367-
"different duration (minimum: {_min}s, maximum "
368-
"{_max}s).".format(_min=durations_min,
369-
_max=durations_max))
370-
371-
fano = spike_counts.var() / spike_counts.mean()
372-
return fano
373+
else:
374+
return spike_counts.var()/spike_counts.mean()
373375

374376
if isinstance(spiketrains, elephant.trials.Trials):
375377
if not pool_trials and not pool_spike_trains:

elephant/test/test_statistics.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,11 @@ def test_fanofactor_trials_pool_spike_trains_wrong_type(self):
381381
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type",
382382
pool_spike_trains="Wrong Type")
383383

384+
def test_fanofactor_warn_durations_manual_check(self):
385+
st1 = [1, 2, 3] * pq.s
386+
st2 = [1, 2, 3] * pq.s
387+
self.assertWarns(UserWarning, statistics.fanofactor, (st1, st2))
388+
384389

385390
class LVTestCase(unittest.TestCase):
386391
def setUp(self):

0 commit comments

Comments
 (0)