Skip to content

Commit cf23b33

Browse files
committed
Improved feature selection
1 parent 470f06a commit cf23b33

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

nimfa/models/nmf.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def score_features(self, idx=None):
315315
investigate features that have strong component-specific membership values
316316
to the latent components.
317317
318-
Return the list of feature scores. Feature scores are real-valued from interval [0,1].
318+
Return array with feature scores. Feature scores are real-valued from interval [0,1].
319319
Higher value indicates greater feature specificity.
320320
321321
:param idx: Used in the multiple NMF model. In standard NMF model or nonsmooth NMF model
@@ -345,19 +345,20 @@ def select_features(self, idx=None):
345345
the corresponding row of the basis matrix (W)) is larger
346346
than the median of all contributions (i.e. of all elements of basis matrix (W)).
347347
348-
Return list of retained features' indices.
348+
Return a boolean array indicating whether features were selected.
349349
350-
:param idx: Used in the multiple NMF model. In factorizations following
351-
standard NMF model or nonsmooth NMF model ``idx`` is always None.
350+
:param idx: Used in the multiple NMF model. In standard NMF model or nonsmooth NMF
351+
model ``idx`` is always None.
352352
:type idx: None or `str` with values 'coef' or 'coef1' (`int` value of 0 or 1, respectively)
353353
"""
354354
scores = self.score_features(idx=idx)
355-
u = np.median(scores)
356-
s = np.median(abs(scores - u))
357-
res = [i for i in range(len(scores)) if scores[i] > u + 3. * s]
355+
th = np.median(scores) + 3 * np.median(abs(scores - np.median(scores)))
356+
sel = scores > th
358357
W = self.basis()
359358
m = np.median(W.toarray() if sp.isspmatrix(W) else W.tolist())
360-
return [i for i in res if np.max(W[i, :].toarray() if sp.isspmatrix(W) else W[i,:]) > m]
359+
sel = np.array([sel[i] and np.max(W[i, :].toarray() if sp.isspmatrix(W) else W[i,:]) > m
360+
for i in range(W.shape[0])])
361+
return sel
361362

362363
def purity(self, membership=None, idx=None):
363364
"""

0 commit comments

Comments
 (0)