Skip to content

Commit d476ba8

Browse files
authored
Merge pull request #270 from MAIF/feature/explanation_metrics
Feature/explanation metrics
2 parents a089646 + 2684285 commit d476ba8

File tree

11 files changed

+2479
-2
lines changed

11 files changed

+2479
-2
lines changed

docs/autodocs/shapash.explainer.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ The Plot Methods
1414
----------------
1515

1616
.. autoclass:: shapash.explainer.smart_plotter.SmartPlotter
17-
:members: features_importance, contribution_plot, local_plot, compare_plot, top_interactions_plot
17+
:members: features_importance, contribution_plot, local_plot, compare_plot, top_interactions_plot, local_neighbors_plot, stability_plot, compacity_plot
18+
:undoc-members:
19+
:show-inheritance:
20+
21+
.. autoclass:: shapash.explainer.consistency.Consistency
22+
:members: consistency_plot
1823
:undoc-members:
1924
:show-inheritance:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
requirements = [
1212
'plotly==4.12.0',
13+
'matplotlib>=3.3.0',
1314
'numpy>1.18.0',
1415
'pandas>1.0.2',
1516
'shap>=0.36.0',
@@ -31,7 +32,6 @@
3132
extras['report'] = [
3233
'nbconvert==6.0.7',
3334
'papermill',
34-
'matplotlib',
3535
'seaborn<=0.11.1',
3636
'notebook',
3737
'Jinja2',

shapash/explainer/consistency.py

Lines changed: 465 additions & 0 deletions
Large diffs are not rendered by default.

shapash/explainer/smart_explainer.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import copy
66
import tempfile
77
import shutil
8+
import numpy as np
89
import pandas as pd
910
from shapash.webapp.smart_app import SmartApp
1011
from shapash.utils.io import save_pickle
@@ -26,6 +27,7 @@
2627
from .smart_plotter import SmartPlotter
2728
import shapash.explainer.smart_predictor
2829
from shapash.utils.model import predict_proba, predict
30+
from shapash.utils.explanation_metrics import find_neighbors, shap_neighbors, get_min_nb_features, get_distance
2931

3032
logging.basicConfig(level=logging.INFO)
3133

@@ -91,6 +93,13 @@ class SmartExplainer:
9193
Dictionary that references the numbers of feature values ​​in the x_pred
9294
features_imp: pandas.Series (regression) or list (classification)
9395
Features importance values
96+
local_neighbors: dict
97+
Dictionary of values to be displayed on the local_neighbors plot.
98+
The key is "norm_shap (normalized contributions values of instance and neighbors)
99+
features_stability: dict
100+
Dictionary of arrays to be displayed on the stability plot.
101+
The keys are "amplitude" (average contributions values for selected instances) and
102+
"stability" (stability metric across neighborhood)
94103
preprocessing : category_encoders, ColumnTransformer, list or dict
95104
The processing apply to the original data.
96105
postprocessing : dict
@@ -278,6 +287,9 @@ def compile(self, x, model, explainer=None, contributions=None, y_pred=None,
278287
self.features_groups = features_groups
279288
if features_groups:
280289
self._compile_features_groups(features_groups)
290+
self.local_neighbors = None
291+
self.features_stability = None
292+
self.features_compacity = None
281293

282294
def _compile_features_groups(self, features_groups):
283295
"""
@@ -998,6 +1010,71 @@ def compute_features_import(self, force=False):
9981010
if self.features_imp is None or force:
9991011
self.features_imp = self.state.compute_features_import(self.contributions)
10001012

1013+
def compute_features_stability(self, selection):
1014+
"""
1015+
For a selection of instances, compute features stability metrics used in
1016+
methods `local_neighbors_plot` and `local_stability_plot`.
1017+
- If selection is a single instance, the method returns the (normalized) contribution values
1018+
of instance and corresponding neighbors.
1019+
- If selection represents multiple instances, the method returns the average (normalized) contribution values
1020+
of instances and neighbors (=amplitude), as well as the variability of those values in the neighborhood (=variability)
1021+
1022+
Parameters
1023+
----------
1024+
selection: list
1025+
Indices of rows to be displayed on the stability plot
1026+
1027+
Returns
1028+
-------
1029+
Dictionary
1030+
Values that will be displayed on the graph. Keys are "amplitude", "variability" and "norm_shap"
1031+
"""
1032+
if (self._case == "classification") and (len(self._classes) > 2):
1033+
raise AssertionError("Multi-class classification is not supported")
1034+
1035+
all_neighbors = find_neighbors(selection, self.x_init, self.model, self._case)
1036+
1037+
# Check if entry is a single instance or not
1038+
if len(selection) == 1:
1039+
# Compute explanations for instance and neighbors
1040+
norm_shap, _, _ = shap_neighbors(all_neighbors[0], self.x_init, self.contributions, self._case)
1041+
self.local_neighbors = {"norm_shap": norm_shap}
1042+
else:
1043+
numb_expl = len(selection)
1044+
amplitude = np.zeros((numb_expl, self.x_pred.shape[1]))
1045+
variability = np.zeros((numb_expl, self.x_pred.shape[1]))
1046+
# For each instance (+ neighbors), compute explanation
1047+
for i in range(numb_expl):
1048+
(_, variability[i, :], amplitude[i, :],) = shap_neighbors(all_neighbors[i], self.x_init, self.contributions, self._case)
1049+
self.features_stability = {"variability": variability, "amplitude": amplitude}
1050+
1051+
def compute_features_compacity(self, selection, distance, nb_features):
1052+
"""
1053+
For a selection of instances, compute features compacity metrics used in method `compacity_plot`.
1054+
1055+
The method returns :
1056+
* the minimum number of features needed for a given approximation level
1057+
* conversely, the approximation reached with a given number of features
1058+
1059+
Parameters
1060+
----------
1061+
selection: list
1062+
Indices of rows to be displayed on the stability plot
1063+
distance : float
1064+
How close we want to be from model with all features
1065+
nb_features : int
1066+
Number of features used
1067+
"""
1068+
if (self._case == "classification") and (len(self._classes) > 2):
1069+
raise AssertionError("Multi-class classification is not supported")
1070+
1071+
features_needed = get_min_nb_features(selection, self.contributions, self._case, distance)
1072+
distance_reached = get_distance(selection, self.contributions, self._case, nb_features)
1073+
# We clip large approximations to 100%
1074+
distance_reached = np.clip(distance_reached, 0, 1)
1075+
1076+
self.features_compacity = {"features_needed": features_needed, "distance_reached": distance_reached}
1077+
10011078
def init_app(self, settings: dict = None):
10021079
"""
10031080
Simple init of SmartApp in case of host smartapp by another way

0 commit comments

Comments
 (0)