Skip to content

Commit dbfdde1

Browse files
pietroquaglioalperyeg
authored andcommitted
Fixes to spade modules regarding the integration of an updated version of fim (#140)
* Fixes to spade modules which integrate an updated version of `fim` and updated corresponding tests
1 parent 87b9d3b commit dbfdde1

File tree

2 files changed

+36
-30
lines changed

2 files changed

+36
-30
lines changed

elephant/spade.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
SPADE analysis can be applied as demonstrated in this short toy example of 10
88
artificial spike trains of exhibiting fully synchronous events of order 10.
99
10+
This modules relies on the implementation of the fp-growth algorithm contained
11+
in the file fim.so which can be found here (http://www.borgelt.net/pyfim.html)
12+
and should be available in the spade_src folder (elephant/spade_src/).
13+
If the fim.so module is not present in the correct location or cannot be
14+
imported (only available for linux OS) SPADE will make use of a python
15+
implementation of the fast fca algorithm contained in
16+
elephant/spade_src/fast_fca.py, which is about 10 times slower.
17+
18+
1019
import elephant.spade
1120
import elephant.spike_train_generation
1221
import quantities as pq
@@ -56,6 +65,7 @@
5665
import time
5766
import quantities as pq
5867
import warnings
68+
warnings.simplefilter('once', UserWarning)
5969
try:
6070
from mpi4py import MPI # for parallelized routines
6171
HAVE_MPI = True
@@ -65,9 +75,11 @@
6575
try:
6676
from elephant.spade_src import fim
6777
HAVE_FIM = True
68-
# raise ImportError
6978
except ImportError: # pragma: no cover
7079
HAVE_FIM = False
80+
warnings.warn(
81+
'fim.so not found in elephant/spade_src folder,' +
82+
'you are using the python implementation of fast fca')
7183
from elephant.spade_src import fast_fca
7284

7385

@@ -444,7 +456,8 @@ def concepts_mining(data, binsize, winlen, min_spikes=2, min_occ=2,
444456
# By default, set the maximum pattern size to the maximum number of
445457
# spikes in a window
446458
if max_spikes is None:
447-
max_spikes = int(np.max(np.sum(rel_matrix, axis=1)))
459+
max_spikes = np.max((int(np.max(np.sum(rel_matrix, axis=1))),
460+
min_spikes + 1))
448461
# By default, set maximum number of occurrences to number of non-empty
449462
# windows
450463
if max_occ is None:
@@ -537,7 +550,7 @@ def _build_context(binary_matrix, winlen):
537550
rel_matrix[w, :] = times
538551
# appending to the transactions spike idx (fast_fca input) of the
539552
# current window (fpgrowth input)
540-
transactions.append(attributes[times])
553+
transactions.append(list(attributes[times]))
541554
# Return context and rel_matrix
542555
return context, transactions, rel_matrix
543556

@@ -624,26 +637,30 @@ def _fpgrowth(transactions, min_c=2, min_z=2, max_z=None,
624637
'''
625638
# By default, set the maximum pattern size to the number of spiketrains
626639
if max_z is None:
627-
max_z = np.max([len(tr) for tr in transactions]) + 1
640+
max_z = np.max((np.max([len(tr) for tr in transactions]), min_z + 1))
628641
# By default set maximum number of data to number of bins
629642
if max_c is None:
630-
max_c = len(transactions) + 1
631-
if report != '#' or min_neu > 1:
643+
max_c = len(transactions)
644+
if min_neu >= 1:
632645
if min_neu < 1:
633646
raise AttributeError('min_neu must be an integer >=1')
634647
# Inizializing outputs
635648
concepts = []
636649
spec_matrix = np.zeros((max_z, max_c))
637650
spectrum = []
638651
# Mining the data with fpgrowth algorithm
639-
fpgrowth_output = fim.fpgrowth(
640-
tracts=transactions,
641-
target=target,
642-
supp=-min_c,
643-
min=min_z,
644-
max=max_z,
645-
report='a',
646-
algo='s')
652+
if np.unique(transactions, return_counts=True)[1][0] == len(
653+
transactions):
654+
fpgrowth_output = [(tuple(transactions[0]), len(transactions))]
655+
else:
656+
fpgrowth_output = fim.fpgrowth(
657+
tracts=transactions,
658+
target=target,
659+
supp=-min_c,
660+
zmin=min_z,
661+
zmax=max_z,
662+
report='a',
663+
algo='s')
647664
# Applying min/max conditions and computing extent (window positions)
648665
fpgrowth_output = list(filter(
649666
lambda c: _fpgrowth_filter(
@@ -665,16 +682,6 @@ def _fpgrowth(transactions, min_c=2, min_z=2, max_z=None,
665682
return spectrum
666683
else:
667684
return concepts
668-
elif report == '#' and min_neu == 1:
669-
spectrum = fim.fpgrowth(
670-
tracts=transactions,
671-
target=target,
672-
supp=-min_c,
673-
min=min_z,
674-
max=max_z,
675-
report=report,
676-
algo='s')
677-
return spectrum
678685
else:
679686
raise AttributeError('min_neu must be an integer >=1')
680687

@@ -687,7 +694,7 @@ def _fpgrowth_filter(concept, winlen, max_c, min_neu):
687694
keep_concepts = len(
688695
np.unique(
689696
np.array(
690-
concept[0]) // winlen)) >= min_neu and concept[1][0] <= max_c
697+
concept[0]) // winlen)) >= min_neu and concept[1] <= max_c
691698
return keep_concepts
692699

693700

elephant/test/test_spade.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,22 @@
66
"""
77
from __future__ import division
88
import unittest
9-
import os
10-
import warnings
119

1210
import neo
1311
import numpy as np
14-
from numpy.testing.utils import assert_array_almost_equal, assert_array_equal
12+
from numpy.testing.utils import assert_array_equal
1513
import quantities as pq
16-
import elephant.spike_train_generation as stg
1714
import elephant.spade as spade
1815
import elephant.conversion as conv
1916
import elephant.spike_train_generation as stg
17+
2018
try:
21-
import fim
19+
from elephant.spade_src import fim
2220
HAVE_FIM = True
2321
except ImportError:
2422
HAVE_FIM = False
2523

24+
2625
class SpadeTestCase(unittest.TestCase):
2726
def setUp(self):
2827
# Spade parameters

0 commit comments

Comments
 (0)