Skip to content

Commit 4217950

Browse files
authored
Merge pull request #35 from wwu-mmll/develop
Develop
2 parents 495f2b7 + 7d205e4 commit 4217950

File tree

4 files changed

+62
-16
lines changed

4 files changed

+62
-16
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
![PHOTON LOGO](http://www.photon-ai.com/static/img/photon/photon-logo-github.png "PHOTON Logo")
1+
[![PHOTON LOGO](https://www.photon-ai.com/static/img/photon/photon-logo-github.png)](https://www.photon-ai.com/)
22

33
[![GitHub Workflow Status](https://img.shields.io/github/workflow/status/wwu-mmll/photonai/PHOTONAI%20test%20and%20test%20deploy)](https://github.com/wwu-mmll/photonai/actions)
44
[![Coverage Status](https://coveralls.io/repos/github/wwu-mmll/photonai/badge.svg?branch=master)](https://coveralls.io/github/wwu-mmll/photonai?branch=master)

examples/advanced/feature_importance.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
my_pipe += PipelineElement('Ridge', alpha=1e-2)
1919
my_pipe.fit(X_train, y_train)
2020

21-
r = my_pipe.get_permutation_feature_importances(X_val, y_val, n_repeats=50, random_state=0)
21+
r = my_pipe.get_permutation_feature_importances(n_repeats=50, random_state=0)
2222

23-
for i in r.importances_mean.argsort()[::-1]:
24-
if r.importances_mean[i] - 2 * r.importances_std[i] > 0:
23+
for i in r["mean"].argsort()[::-1]:
24+
if r["mean"][i] - 2 * r["std"][i] > 0:
2525
print(f"{diabetes.feature_names[i]:<8}"
26-
f"{r.importances_mean[i]:.3f}"
27-
f" +/- {r.importances_std[i]:.3f}")
26+
f"{r['mean'][i]:.3f}"
27+
f" +/- {r['std'][i]:.3f}")

photonai/base/hyperpipe.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
PhotonNative
3030
from photonai.base.photon_pipeline import PhotonPipeline
3131
from photonai.base.json_transformer import JsonTransformer
32+
from photonai.helper.helper import PhotonDataHelper
3233
from photonai.optimization import FloatRange
3334
from photonai.photonlogger.logger import logger
3435
from photonai.processing import ResultsHandler
@@ -1166,12 +1167,11 @@ def score(self, data: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
11661167
scorer = Scorer.create(self.optimization.best_config_metric)
11671168
return scorer(y, predictions)
11681169

1169-
def get_permutation_feature_importances(self, X_val: np.ndarray, y_val: np.ndarray, **kwargs):
1170+
def get_permutation_feature_importances(self, **kwargs):
11701171
"""
1171-
Since PHOTONAI is built on top of the scikit-learn interface,
1172-
it is possible to use direct functions from their package.
1173-
Here the example of the [feature importance via permutations](
1174-
https://scikit-learn.org/stable/modules/generated/sklearn.inspection.permutation_importance.html).
1172+
Fits a model for the best config of each outer fold (using the training data of that fold).
1173+
Then calls sklearn.inspection.permutation_importance with the test data and the given kwargs (e.g. n_repeats).
1174+
Returns mean of "importances_mean" and of "importances_std" of all outer folds.
11751175
11761176
Parameters:
11771177
X_val:
@@ -1187,11 +1187,50 @@ def get_permutation_feature_importances(self, X_val: np.ndarray, y_val: np.ndarr
11871187
Keyword arguments, passed to sklearn.permutation_importance.
11881188
11891189
Returns:
1190-
Dictionary-like object, with the following attributes: importances_mean, importances_std, importances.
1190+
Dictionary with average of "mean" and "std" for all outer folds, respectively.
11911191
11921192
"""
11931193

1194-
return permutation_importance(self.optimum_pipe, X_val, y_val, **kwargs)
1194+
importance_list = {'mean': list(), 'std': list()}
1195+
pipe_copy = self.optimum_pipe.copy_me()
1196+
logger.photon_system_log("")
1197+
logger.photon_system_log("Computing permutation importances. This may take a while.")
1198+
logger.stars()
1199+
for outer_fold in self.results.outer_folds:
1200+
1201+
if outer_fold.best_config.best_config_score is None:
1202+
raise ValueError("Cannot compute permutation importances when use_test_set is false")
1203+
1204+
1205+
# prepare data
1206+
train_indices = outer_fold.best_config.best_config_score.training.indices
1207+
test_indices = outer_fold.best_config.best_config_score.validation.indices
1208+
1209+
train_X, train_y, train_kwargs = PhotonDataHelper.split_data(self.data.X,
1210+
self.data.y,
1211+
self.data.kwargs,
1212+
indices=train_indices)
1213+
1214+
test_X, test_y, test_kwargs = PhotonDataHelper.split_data(self.data.X,
1215+
self.data.y,
1216+
self.data.kwargs,
1217+
indices=test_indices)
1218+
# set pipe to config
1219+
pipe_copy.set_params(**outer_fold.best_config.config_dict)
1220+
logger.photon_system_log("Permutation Importances: Fitting model for outer fold " + str(outer_fold.fold_nr))
1221+
pipe_copy.fit(train_X, train_y, **train_kwargs)
1222+
1223+
logger.photon_system_log("Permutation Importances: Calculating performances for outer fold "
1224+
+ str(outer_fold.fold_nr))
1225+
outer_fold_perm_imps = permutation_importance(pipe_copy, test_X, test_y, **kwargs)
1226+
importance_list['mean'].append(outer_fold_perm_imps["importances_mean"])
1227+
importance_list['std'].append(outer_fold_perm_imps["importances_std"])
1228+
1229+
mean_importances = np.mean(np.array(importance_list["mean"]), axis=0)
1230+
std_importances = np.mean(np.array(importance_list["std"]), axis=0)
1231+
logger.stars()
1232+
1233+
return {'mean': mean_importances, 'std': std_importances}
11951234

11961235
def inverse_transform_pipeline(self, hyperparameters: dict,
11971236
data: np.ndarray,

test/base_tests/test_hyperpipe.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,16 @@ def test_permutation_feature_importances(self):
197197
score_element = svc.score(self.__X, self.__y)
198198
self.assertAlmostEqual(score_photon, score_element)
199199

200-
permutation_score = hp.get_permutation_feature_importances(self.__X, self.__y, n_repeats=50, random_state=0)
201-
score_2 = permutation_importance(svc, self.__X, self.__y, n_repeats=50, random_state=0)
202-
np.testing.assert_array_equal(permutation_score["importances"], score_2["importances"])
200+
permutation_score = hp.get_permutation_feature_importances(n_repeats=5, random_state=0)
201+
self.assertTrue("mean" in permutation_score)
202+
self.assertTrue("std" in permutation_score)
203+
self.assertEqual(permutation_score["mean"].shape, (self.__X.shape[1],))
204+
self.assertEqual(permutation_score["std"].shape, (self.__X.shape[1],))
205+
206+
hp.cross_validation.use_test_set = False
207+
hp.fit(self.__X, self.__y)
208+
with self.assertRaises(ValueError):
209+
hp.get_permutation_feature_importances(n_repeats=5)
203210

204211
def test_estimation_type(self):
205212
def callback(X, y=None, **kwargs):

0 commit comments

Comments
 (0)