Skip to content

Commit 7b0ccf9

Browse files
add type check for pool parameters
1 parent 1711f43 commit 7b0ccf9

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

elephant/statistics.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,23 +364,31 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float:
364364
return fano
365365

366366
if isinstance(spiketrains, elephant.trials.Trials):
367+
# Check if parameters are of the correct type
368+
if not isinstance(pool_trials, bool):
369+
raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}")
370+
elif not isinstance(pool_spike_trains, bool):
371+
raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}")
367372
if not pool_trials and not pool_spike_trains:
368373
return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)]
369374
for idx in range(spiketrains.n_trials)]
370-
if not pool_trials and pool_spike_trains:
375+
elif not pool_trials and pool_spike_trains:
371376
return [_compute_fano(spiketrains.get_spiketrains_from_trial_as_list(idx))
372377
for idx in range(spiketrains.n_trials)]
373-
if pool_trials and not pool_spike_trains:
378+
elif pool_trials and not pool_spike_trains:
374379
list_of_lists_of_spiketrains = [
375380
spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)
376381
for trial_no in range(spiketrains.n_trials)]
377382
return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no]
378383
for trial_no in range(len(list_of_lists_of_spiketrains))])
379384
for st_no in range(len(list_of_lists_of_spiketrains[0]))]
380-
if pool_trials and pool_spike_trains:
385+
elif pool_trials and pool_spike_trains:
381386
return [_compute_fano(
382387
[spiketrain for trial_no in range(spiketrains.n_trials)
383388
for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)])]
389+
else:
390+
raise TypeError(f"pool_spiketrains and pool_trials must be of type: bool, but are "
391+
f"{type(pool_spike_trains)} and {type(pool_trials)}")
384392
else: # Legacy behavior
385393
return _compute_fano(spiketrains)
386394

elephant/test/test_statistics.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,12 @@ def test_fanofactor_trials_pool_trials_false_pool_spiketrains_false(self):
375375
for result in results:
376376
self.assertEqual(len(result), self.test_trials.n_spiketrains_trial_by_trial[0])
377377

378+
def test_fanofactor_trials_pool_spike_trains_wrong_type(self):
379+
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trains="Wrong Type")
380+
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type")
381+
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type",
382+
pool_spike_trains="Wrong Type")
383+
378384

379385
class LVTestCase(unittest.TestCase):
380386
def setUp(self):

0 commit comments

Comments
 (0)