Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding LIME and Kernel SHAP #468

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b95a23d
Initial lime
vivekmig Aug 19, 2020
a065208
Adding lime changes
vivekmig Aug 26, 2020
e4c9847
Kernel shap
vivekmig Aug 30, 2020
ff4db0c
Shapley update
vivekmig Aug 30, 2020
3f5bbe7
Common updates
vivekmig Aug 30, 2020
868adaa
Merge branch 'master' of https://github.com/pytorch/captum into Lime
vivekmig Aug 30, 2020
b4bf962
Adding tests
vivekmig Sep 2, 2020
2e82a9a
Kernel shap test updates
vivekmig Sep 3, 2020
2fd37c1
Updates to lime
vivekmig Sep 9, 2020
0d32b57
Merging
vivekmig Sep 9, 2020
5b3b93f
Lint fixes
vivekmig Sep 10, 2020
9ef658a
Update docs
vivekmig Sep 14, 2020
3224e2a
Fixing formatting
vivekmig Sep 14, 2020
21778bf
Tests
vivekmig Sep 15, 2020
186fa00
Fixes
vivekmig Sep 15, 2020
a0d9dae
Config updates
vivekmig Sep 15, 2020
fd03aa0
Config formatting
vivekmig Sep 15, 2020
52dd38d
Config fixes
vivekmig Sep 15, 2020
31f9b7a
Fixes
vivekmig Sep 15, 2020
4af72a6
CI fixes
vivekmig Sep 15, 2020
6cd3d97
CI updates
vivekmig Sep 15, 2020
1926eb1
Fix formatting
vivekmig Sep 16, 2020
a6aec45
Fixes
vivekmig Sep 16, 2020
3cf25bc
sklearn fix
vivekmig Sep 16, 2020
a62df2a
Formatting
vivekmig Sep 16, 2020
faa76ee
Fix for torch 1.2
vivekmig Sep 16, 2020
e887280
Starting to address comments
vivekmig Oct 2, 2020
30dff8a
Addressing comments
vivekmig Oct 14, 2020
603f617
Fixes
vivekmig Oct 14, 2020
234b72f
Fixes
vivekmig Oct 14, 2020
d15b801
Fixing default similarity
vivekmig Oct 21, 2020
1eaa3ba
fixes
vivekmig Oct 23, 2020
24db7b0
fixes
vivekmig Oct 23, 2020
c7d6382
Fixing literal
vivekmig Oct 23, 2020
0032fe3
Fixes
vivekmig Oct 23, 2020
202037a
Merge conflicts
vivekmig Oct 23, 2020
8f5e80a
formatting
vivekmig Oct 23, 2020
e8f1b2a
Merge conflicts
vivekmig Nov 6, 2020
98fb170
Fixes
vivekmig Nov 6, 2020
044ba11
Fixes
vivekmig Nov 6, 2020
0c25e66
Fixing typo
vivekmig Nov 7, 2020
f252668
Test skip fix
vivekmig Nov 9, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
TupleOrTensorOrBoolGeneric,
)

Expand Down Expand Up @@ -520,3 +521,9 @@ def _sort_key_list(
"devices with computed tensors."

return out_list


def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
if isinstance(inp, Tensor):
return inp.flatten()
return torch.cat([single_inp.flatten() for single_inp in inp])
324 changes: 324 additions & 0 deletions captum/attr/_core/kernel_shap.py

Large diffs are not rendered by default.

1,148 changes: 1,148 additions & 0 deletions captum/attr/_core/lime.py

Large diffs are not rendered by default.

22 changes: 2 additions & 20 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from .._utils.attribution import PerturbationAttribution
from .._utils.common import (
_construct_default_feature_mask,
_find_output_mode_and_verify,
_format_input_baseline,
_tensorize_baseline,
Expand Down Expand Up @@ -279,7 +280,7 @@ def attribute(
num_examples = inputs[0].shape[0]

if feature_mask is None:
feature_mask, total_features = self.construct_feature_mask(inputs)
feature_mask, total_features = _construct_default_feature_mask(inputs)
else:
total_features = int(
max(torch.max(single_mask).item() for single_mask in feature_mask)
Expand Down Expand Up @@ -446,25 +447,6 @@ def _perturbation_generator(
combined_masks,
)

def construct_feature_mask(
self, inputs: Tuple[Tensor, ...]
) -> Tuple[Tuple[Tensor, ...], int]:
feature_mask = []
current_num_features = 0
for i in range(len(inputs)):
num_features = torch.numel(inputs[i][0])
feature_mask.append(
current_num_features
+ torch.reshape(
torch.arange(num_features, device=inputs[i].device),
inputs[i][0:1].shape,
)
)
current_num_features += num_features
total_features = current_num_features
feature_mask = tuple(feature_mask)
return feature_mask, total_features


class ShapleyValues(PerturbationAttribution):
"""
Expand Down
25 changes: 25 additions & 0 deletions captum/attr/_utils/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..._utils.common import (
_format_additional_forward_args,
_format_input,
_format_output,
_reduce_list,
)
from ..._utils.typing import (
Expand Down Expand Up @@ -196,3 +197,27 @@ def _batched_operator(
)
]
return _reduce_list(all_outputs)


def _select_example(curr_arg: Any, index: int, bsz: int) -> Any:
if curr_arg is None:
return None
is_tuple = isinstance(curr_arg, tuple)
if not is_tuple:
curr_arg = (curr_arg,)
selected_arg = []
for i in range(len(curr_arg)):
if isinstance(curr_arg[i], (Tensor, list)) and len(curr_arg[i]) == bsz:
selected_arg.append(curr_arg[i][index : index + 1])
else:
selected_arg.append(curr_arg[i])
return _format_output(is_tuple, tuple(selected_arg))


def _batch_example_iterator(bsz: int, *args) -> Iterator:
"""
Batches the provided argument.
"""
for i in range(bsz):
curr_args = [_select_example(args[j], i, bsz) for j in range(len(args))]
yield tuple(curr_args)
20 changes: 20 additions & 0 deletions captum/attr/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,23 @@ def _find_output_mode_and_verify(
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1
), "Target should identify a single element in the model output."
return agg_output_mode


def _construct_default_feature_mask(
inputs: Tuple[Tensor, ...]
) -> Tuple[Tuple[Tensor, ...], int]:
feature_mask = []
current_num_features = 0
for i in range(len(inputs)):
num_features = torch.numel(inputs[i][0])
feature_mask.append(
current_num_features
+ torch.reshape(
torch.arange(num_features, device=inputs[i].device),
inputs[i][0:1].shape,
)
)
current_num_features += num_features
total_features = current_num_features
feature_mask = tuple(feature_mask)
return feature_mask, total_features
2 changes: 1 addition & 1 deletion scripts/install_via_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ else
fi

# install other deps
conda install -y numpy sphinx pytest flake8 ipywidgets ipython
conda install -y numpy sphinx pytest flake8 ipywidgets ipython scikit-learn
conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort

# install node/yarn for insights build
Expand Down
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,15 @@ def report(*args):
DEV_REQUIRES = (
INSIGHTS_REQUIRES
+ TEST_REQUIRES
+ ["black", "flake8", "sphinx", "sphinx-autodoc-typehints", "mypy>=0.760", "isort"]
+ [
"black",
"flake8",
"sphinx",
"sphinx-autodoc-typehints",
"mypy>=0.760",
"isort",
"scikit-learn",
]
)

# get version string from module
Expand Down
12 changes: 12 additions & 0 deletions tests/attr/helpers/gen_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from torch.nn import Module

from captum.attr._core.lime import Lime
from captum.attr._models.base import _get_deep_layer_name
from captum.attr._utils.attribution import Attribution

Expand Down Expand Up @@ -39,6 +40,17 @@ def parse_test_config(
return algorithms, model, args, layer, noise_tunnel, baseline_distr


def should_create_generated_test(algorithm: Type[Attribution]) -> bool:
if issubclass(algorithm, Lime):
try:
import sklearn # noqa: F401

return True
except ImportError:
return False
return True


@typing.overload
def get_target_layer(model: Module, layer_name: str) -> Module:
...
Expand Down
37 changes: 35 additions & 2 deletions tests/attr/helpers/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from captum.attr._core.guided_grad_cam import GuidedGradCam
from captum.attr._core.input_x_gradient import InputXGradient
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.grad_cam import LayerGradCam
from captum.attr._core.layer.internal_influence import InternalInfluence
from captum.attr._core.layer.layer_activation import LayerActivation
Expand All @@ -19,6 +20,7 @@
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
from captum.attr._core.lime import Lime
from captum.attr._core.neuron.neuron_conductance import NeuronConductance
from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap
from captum.attr._core.neuron.neuron_feature_ablation import NeuronFeatureAblation
Expand Down Expand Up @@ -98,6 +100,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {"inputs": torch.randn(4, 3), "target": 1},
Expand All @@ -114,6 +118,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer_MultiInput(),
"attribute_args": {
Expand All @@ -135,6 +141,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {"inputs": torch.randn(4, 3), "target": [0, 1, 1, 0]},
Expand All @@ -151,6 +159,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer_MultiInput(),
"attribute_args": {
Expand All @@ -171,6 +181,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
Expand All @@ -191,6 +203,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {"inputs": torch.randn(4, 3), "target": torch.tensor([0])},
Expand All @@ -207,6 +221,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
Expand All @@ -222,6 +238,8 @@
FeatureAblation,
DeepLift,
ShapleyValueSampling,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
Expand All @@ -238,6 +256,8 @@
FeatureAblation,
DeepLift,
ShapleyValueSampling,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
Expand Down Expand Up @@ -508,17 +528,30 @@
# Perturbation-Specific Configs
{
"name": "conv_with_perturbations_per_eval",
"algorithms": [FeatureAblation, ShapleyValueSampling, FeaturePermutation],
"algorithms": [
FeatureAblation,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_ConvNet(),
"attribute_args": {
"inputs": torch.arange(400).view(4, 1, 10, 10).float(),
"target": 0,
"perturbations_per_eval": 20,
},
"dp_delta": 0.008,
},
{
"name": "basic_multiple_tuple_target_with_perturbations_per_eval",
"algorithms": [FeatureAblation, ShapleyValueSampling, FeaturePermutation],
"algorithms": [
FeatureAblation,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
"inputs": torch.randn(4, 3),
Expand Down
9 changes: 8 additions & 1 deletion tests/attr/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from captum.attr._utils.attribution import Attribution, InternalAttribution

from ..helpers.basic import BaseTest, assertTensorTuplesAlmostEqual, deep_copy_args
from .helpers.gen_test_utils import gen_test_name, get_target_layer, parse_test_config
from .helpers.gen_test_utils import (
gen_test_name,
get_target_layer,
parse_test_config,
should_create_generated_test,
)
from .helpers.test_config import config

"""
Expand Down Expand Up @@ -66,6 +71,8 @@ def __new__(cls, name: str, bases: Tuple, attrs: Dict):
dp_delta = test_config["dp_delta"] if "dp_delta" in test_config else 0.0001

for algorithm in algorithms:
if not should_create_generated_test(algorithm):
continue
for mode in DataParallelCompareMode:
# Creates test case corresponding to each algorithm and
# DataParallelCompareMode
Expand Down
9 changes: 8 additions & 1 deletion tests/attr/test_hook_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from captum.attr._utils.attribution import Attribution, InternalAttribution

from ..helpers.basic import BaseTest, deep_copy_args
from .helpers.gen_test_utils import gen_test_name, get_target_layer, parse_test_config
from .helpers.gen_test_utils import (
gen_test_name,
get_target_layer,
parse_test_config,
should_create_generated_test,
)
from .helpers.test_config import config

"""
Expand Down Expand Up @@ -63,6 +68,8 @@ def __new__(cls, name: str, bases: Tuple, attrs: Dict):
) = parse_test_config(test_config)

for algorithm in algorithms:
if not should_create_generated_test(algorithm):
continue
for mode in HookRemovalMode:
if mode is HookRemovalMode.invalid_module and layer is None:
continue
Expand Down
12 changes: 11 additions & 1 deletion tests/attr/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
from captum.attr._core.gradient_shap import GradientShap
from captum.attr._core.input_x_gradient import InputXGradient
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.lime import Lime
from captum.attr._core.noise_tunnel import NoiseTunnel
from captum.attr._core.occlusion import Occlusion
from captum.attr._core.saliency import Saliency
from captum.attr._core.shapley_value import ShapleyValueSampling
from captum.attr._utils.attribution import Attribution

from ..helpers.basic import BaseTest, assertTensorTuplesAlmostEqual, deep_copy_args
from .helpers.gen_test_utils import gen_test_name, parse_test_config
from .helpers.gen_test_utils import (
gen_test_name,
parse_test_config,
should_create_generated_test,
)
from .helpers.test_config import config

JIT_SUPPORTED = [
Expand All @@ -32,6 +38,8 @@
Occlusion,
Saliency,
ShapleyValueSampling,
Lime,
KernelShap,
]

"""
Expand Down Expand Up @@ -75,6 +83,8 @@ def __new__(cls, name: str, bases: Tuple, attrs: Dict):
baseline_distr,
) = parse_test_config(test_config)
for algorithm in algorithms:
if not should_create_generated_test(algorithm):
continue
if algorithm in JIT_SUPPORTED:
for mode in JITCompareMode:
# Creates test case corresponding to each algorithm and
Expand Down
Loading