Skip to content

Commit 045f6d4

Browse files
committed
FIX NEOFIT
1 parent ccb0b73 commit 045f6d4

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

treeple/stats/neofit.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,7 @@ def test(self, feature_importance):
280280
null_stat = np.array(
281281
Parallel(n_jobs=self.n_jobs)(
282282
delayed(self.perm_stat)(ranks)
283-
for _ in tqdm(
284-
range(self.n_permutations),
285-
desc="Calculating null distribution",
286-
disable=not self.verbose,
287-
)
283+
for _ in range(self.n_permutations)
288284
)
289285
)
290286

@@ -334,22 +330,19 @@ def feat_imp_test(self, X, y):
334330
Corrected p-values for each feature.
335331
"""
336332
# Training on original data
337-
print(f"Training forest with {self.n_estimators} trees on original data...")
338333
results = Parallel(n_jobs=self.n_jobs)(
339-
delayed(self.train)(ii, X, y) for ii in tqdm(range(self.n_estimators))
334+
delayed(self.train)(ii, X, y) for ii in range(self.n_estimators)
340335
)
341336
feat_imp_all, _ = zip(*results)
342337

343338
# Training on shuffled data
344-
print(f"Training forest with {self.n_estimators} trees on shuffled data...")
345339
y_shuffled = shuffle(y, random_state=0)
346340
results = Parallel(n_jobs=self.n_jobs)(
347-
delayed(self.train)(ii, X, y_shuffled) for ii in tqdm(range(self.n_estimators))
341+
delayed(self.train)(ii, X, y_shuffled) for ii in range(self.n_estimators)
348342
)
349343
feat_imp_all_rand, _ = zip(*results)
350344

351345
# Computing p-values
352-
print(f"Computing p-values with {self.n_permutations} permutations...")
353346
p_corrected = self.get_p(np.array(feat_imp_all), np.array(feat_imp_all_rand))
354347

355348
return p_corrected
@@ -380,6 +373,15 @@ def get_significant_features(self, X, y):
380373
print(
381374
f"Found {np.sum(significant_features)} significant features out of {len(significant_features)}"
382375
)
383-
X_important = X[:, significant_features]
376+
# print the top 10 features: name, index
377+
if X.shape[1] > 1:
378+
print(f"Significant features: {np.where(significant_features)[0][:10]}")
379+
else:
380+
print(f"Significant feature: {np.where(significant_features)[0][:10]}")
381+
382+
if np.sum(significant_features) > 0:
383+
X_important = X[:, significant_features]
384+
else:
385+
X_important = X
384386

385387
return p_values, significant_features, X_important

0 commit comments

Comments
 (0)