Skip to content

Commit b731fd2

Browse files
committed
Replace _support function
For unknbown reasons, np.sum is slow on a very large boolean array.
1 parent 26ec7ab commit b731fd2

File tree

1 file changed

+11
-30
lines changed

1 file changed

+11
-30
lines changed

mlxtend/frequent_patterns/apriori.py

+11-30
Original file line numberDiff line numberDiff line change
@@ -121,32 +121,6 @@ def apriori(df, min_support=0.5, use_colnames=False, max_len=None, verbose=0,
121121
122122
"""
123123

124-
def _support(_x, _n_rows, _is_sparse):
125-
"""DRY private method to calculate support as the
126-
row-wise sum of values / number of rows
127-
128-
Parameters
129-
-----------
130-
131-
_x : matrix of bools or binary
132-
133-
_n_rows : numeric, number of rows in _x
134-
135-
_is_sparse : bool True if _x is sparse
136-
137-
Returns
138-
-----------
139-
np.array, shape = (n_rows, )
140-
141-
Examples
142-
-----------
143-
For usage examples, please see
144-
http://rasbt.github.io/mlxtend/user_guide/frequent_patterns/apriori/
145-
146-
"""
147-
out = (np.sum(_x, axis=0) / _n_rows)
148-
return np.array(out).reshape(-1)
149-
150124
if min_support <= 0.:
151125
raise ValueError('`min_support` must be a positive '
152126
'number within the interval `(0, 1]`. '
@@ -180,7 +154,17 @@ def _support(_x, _n_rows, _is_sparse):
180154
# dense DataFrame
181155
X = df.values
182156
is_sparse = False
183-
support = _support(X, X.shape[0], is_sparse)
157+
if is_sparse:
158+
# Count nonnull entries via direct access to X indices;
159+
# this requires X to be stored in CSC format, and to call
160+
# X.eliminate_zeros() to remove null entries from X.
161+
support = np.array([X.indptr[idx+1] - X.indptr[idx]
162+
for idx in range(X.shape[1])], dtype=int)
163+
else:
164+
# Faster than np.count_nonzero(X, axis=0) or np.sum(X, axis=0), why?
165+
support = np.array([np.count_nonzero(X[:, idx])
166+
for idx in range(X.shape[1])], dtype=int)
167+
support = support / X.shape[0]
184168
support_dict = {1: support[support >= min_support]}
185169
itemset_dict = {1: [(idx,) for idx in np.where(support >= min_support)[0]]}
186170
max_itemset = 1
@@ -199,9 +183,6 @@ def _support(_x, _n_rows, _is_sparse):
199183
processed += 1
200184
count[:] = 0
201185
for item in itemset:
202-
# Count nonnull entries via direct access to X indices;
203-
# this requires X to be stored in CSC format, and to call
204-
# X.eliminate_zeros() to remove null entries from X.
205186
count[X.indices[X.indptr[item]:X.indptr[item+1]]] += 1
206187
support = np.count_nonzero(count == len(itemset)) / X.shape[0]
207188
if support >= min_support:

0 commit comments

Comments
 (0)