Skip to content

Commit b1ce148

Browse files
authored
Merge pull request #50 from salesforce/new_features
New features
2 parents 72989c2 + d3ebd0e commit b1ce148

16 files changed

Lines changed: 534 additions & 22 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ We will continue improving this library to make it more comprehensive in the fut
6363
| Partial dependence plots | Black box | Global | || | | |
6464
| Accumulated local effects | Black box | Global | || | | |
6565
| Sensitivity analysis | Black box | Global | || | | |
66+
| Permutation explanation | Black box | Global | || | | |
6667
| Feature visualization | Torch or TF | Global | | || | |
6768
| Feature maps | Torch or TF | Local | | || | |
6869
| LIME | Black box | Local | |||| |

docs/omnixai.explainers.tabular.agnostic.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,19 @@ omnixai.explainers.tabular.agnostic.L2X.l2x module
6060
:members:
6161
:undoc-members:
6262
:show-inheritance:
63+
64+
omnixai.explainers.tabular.agnostic.permutation module
65+
------------------------------------------------------
66+
67+
.. automodule:: omnixai.explainers.tabular.agnostic.permutation
68+
:members:
69+
:undoc-members:
70+
:show-inheritance:
71+
72+
omnixai.explainers.tabular.agnostic.shap_global module
73+
------------------------------------------------------
74+
75+
.. automodule:: omnixai.explainers.tabular.agnostic.shap_global
76+
:members:
77+
:undoc-members:
78+
:show-inheritance:

omnixai/explainers/tabular/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
from .agnostic.ale import ALE
1212
from .agnostic.sensitivity import SensitivityAnalysisTabular
1313
from .agnostic.L2X.l2x import L2XTabular
14+
from .agnostic.permutation import PermutationImportance
15+
from .agnostic.shap_global import GlobalShapTabular
1416
from .counterfactual.mace.mace import MACEExplainer
1517
from .counterfactual.ce import CounterfactualExplainer
18+
from .counterfactual.knn import KNNCounterfactualExplainer
1619
from .specific.ig import IntegratedGradientTabular
1720
from .specific.linear import LinearRegression
1821
from .specific.linear import LogisticRegression
@@ -29,8 +32,11 @@
2932
"ALE",
3033
"SensitivityAnalysisTabular",
3134
"L2XTabular",
35+
"PermutationImportance",
36+
"GlobalShapTabular",
3237
"MACEExplainer",
3338
"CounterfactualExplainer",
39+
"KNNCounterfactualExplainer",
3440
"LinearRegression",
3541
"LogisticRegression",
3642
"TreeRegressor",
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#
2+
# Copyright (c) 2022 salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
"""
8+
The permutation feature importance explanation for tabular data.
9+
"""
10+
import numpy as np
11+
import pandas as pd
12+
from typing import Callable, Union
13+
from sklearn.metrics import log_loss
14+
from sklearn.inspection import permutation_importance
15+
16+
from ..base import ExplainerBase, TabularExplainerMixin
17+
from ....data.tabular import Tabular
18+
from ....explanations.tabular.feature_importance import GlobalFeatureImportance
19+
20+
21+
class _Estimator:
22+
def fit(self):
23+
pass
24+
25+
26+
class PermutationImportance(ExplainerBase, TabularExplainerMixin):
27+
"""
28+
The permutation feature importance explanations for tabular data. The permutation feature
29+
importance is defined to be the decrease in a model score when a single feature value
30+
is randomly shuffled.
31+
"""
32+
33+
explanation_type = "global"
34+
alias = ["permutation"]
35+
36+
def __init__(self, training_data: Tabular, predict_function, mode="classification", **kwargs):
37+
"""
38+
:param training_data: The training dataset for training the machine learning model.
39+
:param predict_function: The prediction function corresponding to the model to explain.
40+
When the model is for classification, the outputs of the ``predict_function``
41+
are the class probabilities. When the model is for regression, the outputs of
42+
the ``predict_function`` are the estimated values.
43+
:param mode: The task type, e.g., `classification` or `regression`.
44+
"""
45+
super().__init__()
46+
assert isinstance(training_data, Tabular), \
47+
"training_data should be an instance of Tabular."
48+
assert mode in ["classification", "regression"], \
49+
"`mode` can only be `classification` or `regression`."
50+
51+
self.categorical_columns = training_data.categorical_columns
52+
self.predict_function = predict_function
53+
self.mode = mode
54+
55+
def _build_score_function(self, score_func=None):
56+
if score_func is not None:
57+
def _score(estimator, x, y):
58+
z = self.predict_function(
59+
Tabular(x, categorical_columns=self.categorical_columns)
60+
)
61+
return score_func(y, z)
62+
elif self.mode == "classification":
63+
def _score(estimator, x, y):
64+
z = self.predict_function(
65+
Tabular(x, categorical_columns=self.categorical_columns)
66+
)
67+
return -log_loss(y, z)
68+
else:
69+
def _score(estimator, x, y):
70+
z = self.predict_function(
71+
Tabular(x, categorical_columns=self.categorical_columns)
72+
)
73+
return -np.mean((z - y) ** 2)
74+
return _score
75+
76+
def explain(
77+
self,
78+
X: Tabular,
79+
y: Union[np.ndarray, pd.DataFrame],
80+
n_repeats: int = 30,
81+
score_func: Callable = None
82+
) -> GlobalFeatureImportance:
83+
"""
84+
Generate permutation feature importance scores.
85+
86+
:param X: Data on which permutation importance will be computed.
87+
:param y: Targets or labels.
88+
:param n_repeats: The number of times a feature is randomly shuffled.
89+
:param score_func: The score function measuring the difference between
90+
ground-truth targets and predictions, e.g., -sklearn.metrics.log_loss(y_true, y_pred).
91+
:return: The permutation feature importance explanations.
92+
"""
93+
assert X is not None and y is not None, \
94+
"The test data `X` and target `y` cannot be None."
95+
y = y.values if isinstance(y, pd.DataFrame) else np.array(y)
96+
if y.ndim > 1:
97+
y = y.flatten()
98+
assert X.shape[0] == len(y), \
99+
"The numbers of samples in `X` and `y` are different."
100+
X = X.remove_target_column()
101+
102+
results = permutation_importance(
103+
estimator=_Estimator(),
104+
X=X.to_pd(copy=False),
105+
y=y,
106+
scoring=self._build_score_function(score_func)
107+
)
108+
explanations = GlobalFeatureImportance()
109+
explanations.add(
110+
feature_names=list(X.columns),
111+
importance_scores=results["importances_mean"]
112+
)
113+
return explanations

omnixai/explainers/tabular/agnostic/shap.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313

1414
from ..base import TabularExplainer
1515
from ....data.tabular import Tabular
16-
from ....explanations.tabular.feature_importance import FeatureImportance
16+
from ....explanations.tabular.feature_importance import \
17+
FeatureImportance
1718

1819

1920
class ShapTabular(TabularExplainer):
2021
"""
2122
The SHAP explainer for tabular data.
2223
If using this explainer, please cite the original work: https://github.com/slundberg/shap.
2324
"""
24-
2525
explanation_type = "local"
2626
alias = ["shap"]
2727

@@ -47,19 +47,22 @@ def __init__(
4747
Please refer to the doc of `shap.KernelExplainer`.
4848
"""
4949
super().__init__(training_data=training_data, predict_function=predict_function, mode=mode, **kwargs)
50+
self.link = kwargs.get("link", None)
51+
if self.link is None:
52+
self.link = "logit" if self.mode == "classification" else "identity"
53+
5054
self.ignored_features = set(ignored_features) if ignored_features is not None else set()
5155
if self.target_column is not None:
5256
assert self.target_column not in self.ignored_features, \
5357
f"The target column {self.target_column} cannot be in the ignored feature list."
5458
self.valid_indices = [i for i, f in enumerate(self.feature_columns) if f not in self.ignored_features]
5559

56-
if "nsamples" not in kwargs:
57-
kwargs["nsamples"] = 100
58-
self.background_data = shap.sample(self.data, nsamples=kwargs["nsamples"])
60+
self.background_data = shap.sample(self.data, nsamples=kwargs.get("nsamples", 100))
61+
self.explainer = shap.KernelExplainer(self.predict_fn, self.background_data, link=self.link, **kwargs)
5962

6063
def explain(self, X, y=None, **kwargs) -> FeatureImportance:
6164
"""
62-
Generates the feature-importance explanations for the input instances.
65+
Generates the local SHAP explanations for the input instances.
6366
6467
:param X: A batch of input instances. When ``X`` is `pd.DataFrame`
6568
or `np.ndarray`, ``X`` will be converted into `Tabular` automatically.
@@ -68,7 +71,7 @@ def explain(self, X, y=None, **kwargs) -> FeatureImportance:
6871
when ``y = None``.
6972
:param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`,
7073
e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction.
71-
:return: The feature-importance explanations for all the input instances.
74+
:return: The feature importance explanations.
7275
"""
7376
X = self._to_tabular(X).remove_target_column()
7477
explanations = FeatureImportance(self.mode)
@@ -90,12 +93,7 @@ def explain(self, X, y=None, **kwargs) -> FeatureImportance:
9093
y = None
9194

9295
if len(self.ignored_features) == 0:
93-
explainer = shap.KernelExplainer(
94-
self.predict_fn, self.background_data,
95-
link="logit" if self.mode == "classification" else "identity", **kwargs
96-
)
97-
shap_values = explainer.shap_values(instances, **kwargs)
98-
96+
shap_values = self.explainer.shap_values(instances, **kwargs)
9997
for i, instance in enumerate(instances):
10098
df = X.iloc(i).to_pd()
10199
feature_values = \
@@ -120,12 +118,12 @@ def _predict(_x):
120118
_y = np.tile(instance, (_x.shape[0], 1))
121119
_y[:, self.valid_indices] = _x
122120
return self.predict_fn(_y)
121+
123122
predict_function = _predict
124123
test_x = instance[self.valid_indices]
125-
126124
explainer = shap.KernelExplainer(
127125
predict_function, self.background_data[:, self.valid_indices],
128-
link="logit" if self.mode == "classification" else "identity", **kwargs
126+
link=self.link, **kwargs
129127
)
130128
shap_values = explainer.shap_values(np.expand_dims(test_x, axis=0), **kwargs)
131129

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#
2+
# Copyright (c) 2022 salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
"""
8+
The SHAP explainer for global feature importance.
9+
"""
10+
import shap
11+
import numpy as np
12+
from typing import Callable, List
13+
14+
from ..base import TabularExplainer
15+
from ....data.tabular import Tabular
16+
from ....explanations.tabular.feature_importance import GlobalFeatureImportance
17+
18+
19+
class GlobalShapTabular(TabularExplainer):
20+
"""
21+
The SHAP explainer for global feature importance.
22+
If using this explainer, please cite the original work: https://github.com/slundberg/shap.
23+
"""
24+
25+
explanation_type = "global"
26+
alias = ["shap_global"]
27+
28+
def __init__(
29+
self,
30+
training_data: Tabular,
31+
predict_function: Callable,
32+
mode: str = "classification",
33+
ignored_features: List = None,
34+
**kwargs
35+
):
36+
"""
37+
:param training_data: The data used to initialize a SHAP explainer. ``training_data``
38+
can be the training dataset for training the machine learning model. If the training
39+
dataset is large, please set parameter ``nsamples``, e.g., ``nsamples = 100``.
40+
:param predict_function: The prediction function corresponding to the model to explain.
41+
When the model is for classification, the outputs of the ``predict_function``
42+
are the class probabilities. When the model is for regression, the outputs of
43+
the ``predict_function`` are the estimated values.
44+
:param mode: The task type, e.g., `classification` or `regression`.
45+
:param ignored_features: The features ignored in computing feature importance scores.
46+
:param kwargs: Additional parameters to initialize `shap.KernelExplainer`, e.g., ``nsamples``.
47+
Please refer to the doc of `shap.KernelExplainer`.
48+
"""
49+
super().__init__(training_data=training_data, predict_function=predict_function, mode=mode, **kwargs)
50+
self.ignored_features = set(ignored_features) if ignored_features is not None else set()
51+
if self.target_column is not None:
52+
assert self.target_column not in self.ignored_features, \
53+
f"The target column {self.target_column} cannot be in the ignored feature list."
54+
self.valid_indices = [i for i, f in enumerate(self.feature_columns) if f not in self.ignored_features]
55+
56+
if "nsamples" not in kwargs:
57+
kwargs["nsamples"] = 100
58+
self.background_data = shap.sample(self.data, nsamples=kwargs["nsamples"])
59+
self.sampled_data = shap.sample(self.data, nsamples=kwargs["nsamples"])
60+
61+
def _explain_global(self, X, **kwargs) -> GlobalFeatureImportance:
62+
if "nsamples" not in kwargs:
63+
kwargs["nsamples"] = 100
64+
instances = self.sampled_data if X is None else \
65+
self.transformer.transform(X.remove_target_column())
66+
67+
explanations = GlobalFeatureImportance()
68+
explainer = shap.KernelExplainer(
69+
self.predict_fn, self.background_data,
70+
link="logit" if self.mode == "classification" else "identity", **kwargs
71+
)
72+
shap_values = explainer.shap_values(instances, **kwargs)
73+
74+
if self.mode == "classification":
75+
values = 0
76+
for v in shap_values:
77+
values += np.abs(v)
78+
values /= len(shap_values)
79+
shap_values = values
80+
81+
importance_scores = np.mean(np.abs(shap_values), axis=0)
82+
explanations.add(
83+
feature_names=self.feature_columns,
84+
importance_scores=importance_scores,
85+
sort=True
86+
)
87+
return explanations
88+
89+
def explain(
90+
self,
91+
X: Tabular = None,
92+
**kwargs
93+
):
94+
"""
95+
Generates the global SHAP explanations.
96+
97+
:param X: The data will be used to compute global SHAP values, i.e., the mean of the absolute
98+
SHAP value for each feature. If `X` is None, a set of training samples will be used.
99+
:param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`,
100+
e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction.
101+
:return: The global feature importance explanations.
102+
"""
103+
return self._explain_global(X=X, **kwargs)
104+
105+
def save(
106+
self,
107+
directory: str,
108+
filename: str = None,
109+
**kwargs
110+
):
111+
"""
112+
Saves the initialized explainer.
113+
114+
:param directory: The folder for the dumped explainer.
115+
:param filename: The filename (the explainer class name if it is None).
116+
"""
117+
super().save(
118+
directory=directory,
119+
filename=filename,
120+
ignored_attributes=["data"],
121+
**kwargs
122+
)

0 commit comments

Comments
 (0)