|
5 | 5 | import copy |
6 | 6 | import tempfile |
7 | 7 | import shutil |
| 8 | +import numpy as np |
8 | 9 | import pandas as pd |
9 | 10 | from shapash.webapp.smart_app import SmartApp |
10 | 11 | from shapash.utils.io import save_pickle |
|
26 | 27 | from .smart_plotter import SmartPlotter |
27 | 28 | import shapash.explainer.smart_predictor |
28 | 29 | from shapash.utils.model import predict_proba, predict |
| 30 | +from shapash.utils.explanation_metrics import find_neighbors, shap_neighbors, get_min_nb_features, get_distance |
29 | 31 |
|
30 | 32 | logging.basicConfig(level=logging.INFO) |
31 | 33 |
|
@@ -91,6 +93,13 @@ class SmartExplainer: |
91 | 93 | Dictionary that references the numbers of feature values in the x_pred |
92 | 94 | features_imp: pandas.Series (regression) or list (classification) |
93 | 95 | Features importance values |
| 96 | + local_neighbors: dict |
| 97 | + Dictionary of values to be displayed on the local_neighbors plot. |
| 98 | + The key is "norm_shap (normalized contributions values of instance and neighbors) |
| 99 | + features_stability: dict |
| 100 | + Dictionary of arrays to be displayed on the stability plot. |
| 101 | + The keys are "amplitude" (average contributions values for selected instances) and |
| 102 | + "stability" (stability metric across neighborhood) |
94 | 103 | preprocessing : category_encoders, ColumnTransformer, list or dict |
95 | 104 | The processing apply to the original data. |
96 | 105 | postprocessing : dict |
@@ -278,6 +287,9 @@ def compile(self, x, model, explainer=None, contributions=None, y_pred=None, |
278 | 287 | self.features_groups = features_groups |
279 | 288 | if features_groups: |
280 | 289 | self._compile_features_groups(features_groups) |
| 290 | + self.local_neighbors = None |
| 291 | + self.features_stability = None |
| 292 | + self.features_compacity = None |
281 | 293 |
|
282 | 294 | def _compile_features_groups(self, features_groups): |
283 | 295 | """ |
@@ -998,6 +1010,71 @@ def compute_features_import(self, force=False): |
998 | 1010 | if self.features_imp is None or force: |
999 | 1011 | self.features_imp = self.state.compute_features_import(self.contributions) |
1000 | 1012 |
|
| 1013 | + def compute_features_stability(self, selection): |
| 1014 | + """ |
| 1015 | + For a selection of instances, compute features stability metrics used in |
| 1016 | + methods `local_neighbors_plot` and `local_stability_plot`. |
| 1017 | + - If selection is a single instance, the method returns the (normalized) contribution values |
| 1018 | + of instance and corresponding neighbors. |
| 1019 | + - If selection represents multiple instances, the method returns the average (normalized) contribution values |
| 1020 | + of instances and neighbors (=amplitude), as well as the variability of those values in the neighborhood (=variability) |
| 1021 | +
|
| 1022 | + Parameters |
| 1023 | + ---------- |
| 1024 | + selection: list |
| 1025 | + Indices of rows to be displayed on the stability plot |
| 1026 | +
|
| 1027 | + Returns |
| 1028 | + ------- |
| 1029 | + Dictionary |
| 1030 | + Values that will be displayed on the graph. Keys are "amplitude", "variability" and "norm_shap" |
| 1031 | + """ |
| 1032 | + if (self._case == "classification") and (len(self._classes) > 2): |
| 1033 | + raise AssertionError("Multi-class classification is not supported") |
| 1034 | + |
| 1035 | + all_neighbors = find_neighbors(selection, self.x_init, self.model, self._case) |
| 1036 | + |
| 1037 | + # Check if entry is a single instance or not |
| 1038 | + if len(selection) == 1: |
| 1039 | + # Compute explanations for instance and neighbors |
| 1040 | + norm_shap, _, _ = shap_neighbors(all_neighbors[0], self.x_init, self.contributions, self._case) |
| 1041 | + self.local_neighbors = {"norm_shap": norm_shap} |
| 1042 | + else: |
| 1043 | + numb_expl = len(selection) |
| 1044 | + amplitude = np.zeros((numb_expl, self.x_pred.shape[1])) |
| 1045 | + variability = np.zeros((numb_expl, self.x_pred.shape[1])) |
| 1046 | + # For each instance (+ neighbors), compute explanation |
| 1047 | + for i in range(numb_expl): |
| 1048 | + (_, variability[i, :], amplitude[i, :],) = shap_neighbors(all_neighbors[i], self.x_init, self.contributions, self._case) |
| 1049 | + self.features_stability = {"variability": variability, "amplitude": amplitude} |
| 1050 | + |
| 1051 | + def compute_features_compacity(self, selection, distance, nb_features): |
| 1052 | + """ |
| 1053 | + For a selection of instances, compute features compacity metrics used in method `compacity_plot`. |
| 1054 | +
|
| 1055 | + The method returns : |
| 1056 | + * the minimum number of features needed for a given approximation level |
| 1057 | + * conversely, the approximation reached with a given number of features |
| 1058 | +
|
| 1059 | + Parameters |
| 1060 | + ---------- |
| 1061 | + selection: list |
| 1062 | + Indices of rows to be displayed on the stability plot |
| 1063 | + distance : float |
| 1064 | + How close we want to be from model with all features |
| 1065 | + nb_features : int |
| 1066 | + Number of features used |
| 1067 | + """ |
| 1068 | + if (self._case == "classification") and (len(self._classes) > 2): |
| 1069 | + raise AssertionError("Multi-class classification is not supported") |
| 1070 | + |
| 1071 | + features_needed = get_min_nb_features(selection, self.contributions, self._case, distance) |
| 1072 | + distance_reached = get_distance(selection, self.contributions, self._case, nb_features) |
| 1073 | + # We clip large approximations to 100% |
| 1074 | + distance_reached = np.clip(distance_reached, 0, 1) |
| 1075 | + |
| 1076 | + self.features_compacity = {"features_needed": features_needed, "distance_reached": distance_reached} |
| 1077 | + |
1001 | 1078 | def init_app(self, settings: dict = None): |
1002 | 1079 | """ |
1003 | 1080 | Simple init of SmartApp in case of host smartapp by another way |
|
0 commit comments