Skip to content

Commit 279d0be

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Adding LIME and Kernel SHAP (pytorch#468)
Summary: Adding LIME and Kernel SHAP to Captum. This includes all documentation, type hints, and data parallel / JIT tests. Pull Request resolved: pytorch#468 Reviewed By: NarineK Differential Revision: D23733322 Pulled By: vivekmig fbshipit-source-id: 1ecc21306493ce4bd84ce175d4e08c21aaa49083
1 parent eadab0e commit 279d0be

16 files changed

+2494
-31
lines changed

captum/_utils/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
BaselineType,
1414
Literal,
1515
TargetType,
16+
TensorOrTupleOfTensorsGeneric,
1617
TupleOrTensorOrBoolGeneric,
1718
)
1819

@@ -547,3 +548,9 @@ def _sort_key_list(
547548
"devices with computed tensors."
548549

549550
return out_list
551+
552+
553+
def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
554+
if isinstance(inp, Tensor):
555+
return inp.flatten()
556+
return torch.cat([single_inp.flatten() for single_inp in inp])

captum/attr/_core/kernel_shap.py

Lines changed: 324 additions & 0 deletions
Large diffs are not rendered by default.

captum/attr/_core/lime.py

Lines changed: 1148 additions & 0 deletions
Large diffs are not rendered by default.

captum/attr/_core/shapley_value.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
2020
from captum.attr._utils.attribution import PerturbationAttribution
2121
from captum.attr._utils.common import (
22+
_construct_default_feature_mask,
2223
_find_output_mode_and_verify,
2324
_format_input_baseline,
2425
_tensorize_baseline,
@@ -278,7 +279,7 @@ def attribute(
278279
num_examples = inputs[0].shape[0]
279280

280281
if feature_mask is None:
281-
feature_mask, total_features = self.construct_feature_mask(inputs)
282+
feature_mask, total_features = _construct_default_feature_mask(inputs)
282283
else:
283284
total_features = int(
284285
max(torch.max(single_mask).item() for single_mask in feature_mask)
@@ -445,25 +446,6 @@ def _perturbation_generator(
445446
combined_masks,
446447
)
447448

448-
def construct_feature_mask(
449-
self, inputs: Tuple[Tensor, ...]
450-
) -> Tuple[Tuple[Tensor, ...], int]:
451-
feature_mask = []
452-
current_num_features = 0
453-
for i in range(len(inputs)):
454-
num_features = torch.numel(inputs[i][0])
455-
feature_mask.append(
456-
current_num_features
457-
+ torch.reshape(
458-
torch.arange(num_features, device=inputs[i].device),
459-
inputs[i][0:1].shape,
460-
)
461-
)
462-
current_num_features += num_features
463-
total_features = current_num_features
464-
feature_mask = tuple(feature_mask)
465-
return feature_mask, total_features
466-
467449

468450
class ShapleyValues(PerturbationAttribution):
469451
"""

captum/attr/_utils/batching.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from captum._utils.common import (
1010
_format_additional_forward_args,
1111
_format_input,
12+
_format_output,
1213
_reduce_list,
1314
)
1415
from captum._utils.typing import (
@@ -196,3 +197,27 @@ def _batched_operator(
196197
)
197198
]
198199
return _reduce_list(all_outputs)
200+
201+
202+
def _select_example(curr_arg: Any, index: int, bsz: int) -> Any:
203+
if curr_arg is None:
204+
return None
205+
is_tuple = isinstance(curr_arg, tuple)
206+
if not is_tuple:
207+
curr_arg = (curr_arg,)
208+
selected_arg = []
209+
for i in range(len(curr_arg)):
210+
if isinstance(curr_arg[i], (Tensor, list)) and len(curr_arg[i]) == bsz:
211+
selected_arg.append(curr_arg[i][index : index + 1])
212+
else:
213+
selected_arg.append(curr_arg[i])
214+
return _format_output(is_tuple, tuple(selected_arg))
215+
216+
217+
def _batch_example_iterator(bsz: int, *args) -> Iterator:
218+
"""
219+
Batches the provided argument.
220+
"""
221+
for i in range(bsz):
222+
curr_args = [_select_example(args[j], i, bsz) for j in range(len(args))]
223+
yield tuple(curr_args)

captum/attr/_utils/common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,26 @@ def _find_output_mode_and_verify(
347347
return agg_output_mode
348348

349349

350+
def _construct_default_feature_mask(
351+
inputs: Tuple[Tensor, ...]
352+
) -> Tuple[Tuple[Tensor, ...], int]:
353+
feature_mask = []
354+
current_num_features = 0
355+
for i in range(len(inputs)):
356+
num_features = torch.numel(inputs[i][0])
357+
feature_mask.append(
358+
current_num_features
359+
+ torch.reshape(
360+
torch.arange(num_features, device=inputs[i].device),
361+
inputs[i][0:1].shape,
362+
)
363+
)
364+
current_num_features += num_features
365+
total_features = current_num_features
366+
feature_mask = tuple(feature_mask)
367+
return feature_mask, total_features
368+
369+
350370
def neuron_index_deprecation_decorator(func):
351371
r"""
352372
Decorator to deprecate neuron_index parameter for Neuron Attribution methods.

scripts/install_via_conda.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ else
3434
fi
3535

3636
# install other deps
37-
conda install -y numpy sphinx pytest flake8 ipywidgets ipython
37+
conda install -y numpy sphinx pytest flake8 ipywidgets ipython scikit-learn
3838
conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort
3939

4040
# install node/yarn for insights build

setup.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,15 @@ def report(*args):
6666
DEV_REQUIRES = (
6767
INSIGHTS_REQUIRES
6868
+ TEST_REQUIRES
69-
+ ["black", "flake8", "sphinx", "sphinx-autodoc-typehints", "mypy>=0.760", "isort"]
69+
+ [
70+
"black",
71+
"flake8",
72+
"sphinx",
73+
"sphinx-autodoc-typehints",
74+
"mypy>=0.760",
75+
"isort",
76+
"scikit-learn",
77+
]
7078
)
7179

7280
# get version string from module

tests/attr/helpers/gen_test_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from torch.nn import Module
77

8+
from captum.attr._core.lime import Lime
89
from captum.attr._models.base import _get_deep_layer_name
910
from captum.attr._utils.attribution import Attribution
1011

@@ -39,6 +40,21 @@ def parse_test_config(
3940
return algorithms, model, args, layer, noise_tunnel, baseline_distr
4041

4142

43+
def should_create_generated_test(algorithm: Type[Attribution]) -> bool:
44+
if issubclass(algorithm, Lime):
45+
try:
46+
import sklearn # noqa: F401
47+
48+
assert (
49+
sklearn.__version__ >= "0.23.0"
50+
), "Must have sklearn version 0.23.0 or higher to use "
51+
"sample_weight in Lasso regression."
52+
return True
53+
except (ImportError, AssertionError):
54+
return False
55+
return True
56+
57+
4258
@typing.overload
4359
def get_target_layer(model: Module, layer_name: str) -> Module:
4460
...

tests/attr/helpers/test_config.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from captum.attr._core.guided_grad_cam import GuidedGradCam
1111
from captum.attr._core.input_x_gradient import InputXGradient
1212
from captum.attr._core.integrated_gradients import IntegratedGradients
13+
from captum.attr._core.kernel_shap import KernelShap
1314
from captum.attr._core.layer.grad_cam import LayerGradCam
1415
from captum.attr._core.layer.internal_influence import InternalInfluence
1516
from captum.attr._core.layer.layer_activation import LayerActivation
@@ -19,6 +20,7 @@
1920
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
2021
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
2122
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
23+
from captum.attr._core.lime import Lime
2224
from captum.attr._core.neuron.neuron_conductance import NeuronConductance
2325
from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap
2426
from captum.attr._core.neuron.neuron_feature_ablation import NeuronFeatureAblation
@@ -98,6 +100,8 @@
98100
Deconvolution,
99101
ShapleyValueSampling,
100102
FeaturePermutation,
103+
Lime,
104+
KernelShap,
101105
],
102106
"model": BasicModel_MultiLayer(),
103107
"attribute_args": {"inputs": torch.randn(4, 3), "target": 1},
@@ -114,6 +118,8 @@
114118
Deconvolution,
115119
ShapleyValueSampling,
116120
FeaturePermutation,
121+
Lime,
122+
KernelShap,
117123
],
118124
"model": BasicModel_MultiLayer_MultiInput(),
119125
"attribute_args": {
@@ -135,6 +141,8 @@
135141
Deconvolution,
136142
ShapleyValueSampling,
137143
FeaturePermutation,
144+
Lime,
145+
KernelShap,
138146
],
139147
"model": BasicModel_MultiLayer(),
140148
"attribute_args": {"inputs": torch.randn(4, 3), "target": [0, 1, 1, 0]},
@@ -151,6 +159,8 @@
151159
Deconvolution,
152160
ShapleyValueSampling,
153161
FeaturePermutation,
162+
Lime,
163+
KernelShap,
154164
],
155165
"model": BasicModel_MultiLayer_MultiInput(),
156166
"attribute_args": {
@@ -171,6 +181,8 @@
171181
Deconvolution,
172182
ShapleyValueSampling,
173183
FeaturePermutation,
184+
Lime,
185+
KernelShap,
174186
],
175187
"model": BasicModel_MultiLayer(),
176188
"attribute_args": {
@@ -191,6 +203,8 @@
191203
Deconvolution,
192204
ShapleyValueSampling,
193205
FeaturePermutation,
206+
Lime,
207+
KernelShap,
194208
],
195209
"model": BasicModel_MultiLayer(),
196210
"attribute_args": {"inputs": torch.randn(4, 3), "target": torch.tensor([0])},
@@ -207,6 +221,8 @@
207221
Deconvolution,
208222
ShapleyValueSampling,
209223
FeaturePermutation,
224+
Lime,
225+
KernelShap,
210226
],
211227
"model": BasicModel_MultiLayer(),
212228
"attribute_args": {
@@ -222,6 +238,8 @@
222238
FeatureAblation,
223239
DeepLift,
224240
ShapleyValueSampling,
241+
Lime,
242+
KernelShap,
225243
],
226244
"model": BasicModel_MultiLayer(),
227245
"attribute_args": {
@@ -238,6 +256,8 @@
238256
FeatureAblation,
239257
DeepLift,
240258
ShapleyValueSampling,
259+
Lime,
260+
KernelShap,
241261
],
242262
"model": BasicModel_MultiLayer(),
243263
"attribute_args": {
@@ -508,17 +528,30 @@
508528
# Perturbation-Specific Configs
509529
{
510530
"name": "conv_with_perturbations_per_eval",
511-
"algorithms": [FeatureAblation, ShapleyValueSampling, FeaturePermutation],
531+
"algorithms": [
532+
FeatureAblation,
533+
ShapleyValueSampling,
534+
FeaturePermutation,
535+
Lime,
536+
KernelShap,
537+
],
512538
"model": BasicModel_ConvNet(),
513539
"attribute_args": {
514540
"inputs": torch.arange(400).view(4, 1, 10, 10).float(),
515541
"target": 0,
516542
"perturbations_per_eval": 20,
517543
},
544+
"dp_delta": 0.008,
518545
},
519546
{
520547
"name": "basic_multiple_tuple_target_with_perturbations_per_eval",
521-
"algorithms": [FeatureAblation, ShapleyValueSampling, FeaturePermutation],
548+
"algorithms": [
549+
FeatureAblation,
550+
ShapleyValueSampling,
551+
FeaturePermutation,
552+
Lime,
553+
KernelShap,
554+
],
522555
"model": BasicModel_MultiLayer(),
523556
"attribute_args": {
524557
"inputs": torch.randn(4, 3),

tests/attr/test_data_parallel.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from captum.attr._utils.attribution import Attribution, InternalAttribution
2121

2222
from ..helpers.basic import BaseTest, assertTensorTuplesAlmostEqual, deep_copy_args
23-
from .helpers.gen_test_utils import gen_test_name, get_target_layer, parse_test_config
23+
from .helpers.gen_test_utils import (
24+
gen_test_name,
25+
get_target_layer,
26+
parse_test_config,
27+
should_create_generated_test,
28+
)
2429
from .helpers.test_config import config
2530

2631
"""
@@ -66,6 +71,8 @@ def __new__(cls, name: str, bases: Tuple, attrs: Dict):
6671
dp_delta = test_config["dp_delta"] if "dp_delta" in test_config else 0.0001
6772

6873
for algorithm in algorithms:
74+
if not should_create_generated_test(algorithm):
75+
continue
6976
for mode in DataParallelCompareMode:
7077
# Creates test case corresponding to each algorithm and
7178
# DataParallelCompareMode

tests/attr/test_hook_removal.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from captum.attr._utils.attribution import Attribution, InternalAttribution
1111

1212
from ..helpers.basic import BaseTest, deep_copy_args
13-
from .helpers.gen_test_utils import gen_test_name, get_target_layer, parse_test_config
13+
from .helpers.gen_test_utils import (
14+
gen_test_name,
15+
get_target_layer,
16+
parse_test_config,
17+
should_create_generated_test,
18+
)
1419
from .helpers.test_config import config
1520

1621
"""
@@ -63,6 +68,8 @@ def __new__(cls, name: str, bases: Tuple, attrs: Dict):
6368
) = parse_test_config(test_config)
6469

6570
for algorithm in algorithms:
71+
if not should_create_generated_test(algorithm):
72+
continue
6673
for mode in HookRemovalMode:
6774
if mode is HookRemovalMode.invalid_module and layer is None:
6875
continue

tests/attr/test_jit.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@
1313
from captum.attr._core.gradient_shap import GradientShap
1414
from captum.attr._core.input_x_gradient import InputXGradient
1515
from captum.attr._core.integrated_gradients import IntegratedGradients
16+
from captum.attr._core.kernel_shap import KernelShap
17+
from captum.attr._core.lime import Lime
1618
from captum.attr._core.noise_tunnel import NoiseTunnel
1719
from captum.attr._core.occlusion import Occlusion
1820
from captum.attr._core.saliency import Saliency
1921
from captum.attr._core.shapley_value import ShapleyValueSampling
2022
from captum.attr._utils.attribution import Attribution
2123

2224
from ..helpers.basic import BaseTest, assertTensorTuplesAlmostEqual, deep_copy_args
23-
from .helpers.gen_test_utils import gen_test_name, parse_test_config
25+
from .helpers.gen_test_utils import (
26+
gen_test_name,
27+
parse_test_config,
28+
should_create_generated_test,
29+
)
2430
from .helpers.test_config import config
2531

2632
JIT_SUPPORTED = [
@@ -32,6 +38,8 @@
3238
Occlusion,
3339
Saliency,
3440
ShapleyValueSampling,
41+
Lime,
42+
KernelShap,
3543
]
3644

3745
"""
@@ -75,6 +83,8 @@ def __new__(cls, name: str, bases: Tuple, attrs: Dict):
7583
baseline_distr,
7684
) = parse_test_config(test_config)
7785
for algorithm in algorithms:
86+
if not should_create_generated_test(algorithm):
87+
continue
7888
if algorithm in JIT_SUPPORTED:
7989
for mode in JITCompareMode:
8090
# Creates test case corresponding to each algorithm and

0 commit comments

Comments
 (0)