Skip to content

Commit

Permalink
Adding LIME and Kernel SHAP (pytorch#468)
Browse files Browse the repository at this point in the history
Summary:
Adding LIME and Kernel SHAP to Captum. This includes all documentation, type hints, and data parallel / JIT tests.

Pull Request resolved: pytorch#468

Reviewed By: NarineK

Differential Revision: D23733322

Pulled By: vivekmig

fbshipit-source-id: 1ecc21306493ce4bd84ce175d4e08c21aaa49083
  • Loading branch information
vivekmig authored and facebook-github-bot committed Nov 9, 2020
1 parent eadab0e commit 279d0be
Show file tree
Hide file tree
Showing 16 changed files with 2,494 additions and 31 deletions.
7 changes: 7 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
TupleOrTensorOrBoolGeneric,
)

Expand Down Expand Up @@ -547,3 +548,9 @@ def _sort_key_list(
"devices with computed tensors."

return out_list


def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
if isinstance(inp, Tensor):
return inp.flatten()
return torch.cat([single_inp.flatten() for single_inp in inp])
324 changes: 324 additions & 0 deletions captum/attr/_core/kernel_shap.py

Large diffs are not rendered by default.

1,148 changes: 1,148 additions & 0 deletions captum/attr/_core/lime.py

Large diffs are not rendered by default.

22 changes: 2 additions & 20 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.common import (
_construct_default_feature_mask,
_find_output_mode_and_verify,
_format_input_baseline,
_tensorize_baseline,
Expand Down Expand Up @@ -278,7 +279,7 @@ def attribute(
num_examples = inputs[0].shape[0]

if feature_mask is None:
feature_mask, total_features = self.construct_feature_mask(inputs)
feature_mask, total_features = _construct_default_feature_mask(inputs)
else:
total_features = int(
max(torch.max(single_mask).item() for single_mask in feature_mask)
Expand Down Expand Up @@ -445,25 +446,6 @@ def _perturbation_generator(
combined_masks,
)

def construct_feature_mask(
self, inputs: Tuple[Tensor, ...]
) -> Tuple[Tuple[Tensor, ...], int]:
feature_mask = []
current_num_features = 0
for i in range(len(inputs)):
num_features = torch.numel(inputs[i][0])
feature_mask.append(
current_num_features
+ torch.reshape(
torch.arange(num_features, device=inputs[i].device),
inputs[i][0:1].shape,
)
)
current_num_features += num_features
total_features = current_num_features
feature_mask = tuple(feature_mask)
return feature_mask, total_features


class ShapleyValues(PerturbationAttribution):
"""
Expand Down
25 changes: 25 additions & 0 deletions captum/attr/_utils/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from captum._utils.common import (
_format_additional_forward_args,
_format_input,
_format_output,
_reduce_list,
)
from captum._utils.typing import (
Expand Down Expand Up @@ -196,3 +197,27 @@ def _batched_operator(
)
]
return _reduce_list(all_outputs)


def _select_example(curr_arg: Any, index: int, bsz: int) -> Any:
if curr_arg is None:
return None
is_tuple = isinstance(curr_arg, tuple)
if not is_tuple:
curr_arg = (curr_arg,)
selected_arg = []
for i in range(len(curr_arg)):
if isinstance(curr_arg[i], (Tensor, list)) and len(curr_arg[i]) == bsz:
selected_arg.append(curr_arg[i][index : index + 1])
else:
selected_arg.append(curr_arg[i])
return _format_output(is_tuple, tuple(selected_arg))


def _batch_example_iterator(bsz: int, *args) -> Iterator:
"""
Batches the provided argument.
"""
for i in range(bsz):
curr_args = [_select_example(args[j], i, bsz) for j in range(len(args))]
yield tuple(curr_args)
20 changes: 20 additions & 0 deletions captum/attr/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,26 @@ def _find_output_mode_and_verify(
return agg_output_mode


def _construct_default_feature_mask(
inputs: Tuple[Tensor, ...]
) -> Tuple[Tuple[Tensor, ...], int]:
feature_mask = []
current_num_features = 0
for i in range(len(inputs)):
num_features = torch.numel(inputs[i][0])
feature_mask.append(
current_num_features
+ torch.reshape(
torch.arange(num_features, device=inputs[i].device),
inputs[i][0:1].shape,
)
)
current_num_features += num_features
total_features = current_num_features
feature_mask = tuple(feature_mask)
return feature_mask, total_features


def neuron_index_deprecation_decorator(func):
r"""
Decorator to deprecate neuron_index parameter for Neuron Attribution methods.
Expand Down
2 changes: 1 addition & 1 deletion scripts/install_via_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ else
fi

# install other deps
conda install -y numpy sphinx pytest flake8 ipywidgets ipython
conda install -y numpy sphinx pytest flake8 ipywidgets ipython scikit-learn
conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort

# install node/yarn for insights build
Expand Down
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,15 @@ def report(*args):
DEV_REQUIRES = (
INSIGHTS_REQUIRES
+ TEST_REQUIRES
+ ["black", "flake8", "sphinx", "sphinx-autodoc-typehints", "mypy>=0.760", "isort"]
+ [
"black",
"flake8",
"sphinx",
"sphinx-autodoc-typehints",
"mypy>=0.760",
"isort",
"scikit-learn",
]
)

# get version string from module
Expand Down
16 changes: 16 additions & 0 deletions tests/attr/helpers/gen_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from torch.nn import Module

from captum.attr._core.lime import Lime
from captum.attr._models.base import _get_deep_layer_name
from captum.attr._utils.attribution import Attribution

Expand Down Expand Up @@ -39,6 +40,21 @@ def parse_test_config(
return algorithms, model, args, layer, noise_tunnel, baseline_distr


def should_create_generated_test(algorithm: Type[Attribution]) -> bool:
if issubclass(algorithm, Lime):
try:
import sklearn # noqa: F401

assert (
sklearn.__version__ >= "0.23.0"
), "Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression."
return True
except (ImportError, AssertionError):
return False
return True


@typing.overload
def get_target_layer(model: Module, layer_name: str) -> Module:
...
Expand Down
37 changes: 35 additions & 2 deletions tests/attr/helpers/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from captum.attr._core.guided_grad_cam import GuidedGradCam
from captum.attr._core.input_x_gradient import InputXGradient
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.grad_cam import LayerGradCam
from captum.attr._core.layer.internal_influence import InternalInfluence
from captum.attr._core.layer.layer_activation import LayerActivation
Expand All @@ -19,6 +20,7 @@
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
from captum.attr._core.lime import Lime
from captum.attr._core.neuron.neuron_conductance import NeuronConductance
from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap
from captum.attr._core.neuron.neuron_feature_ablation import NeuronFeatureAblation
Expand Down Expand Up @@ -98,6 +100,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {"inputs": torch.randn(4, 3), "target": 1},
Expand All @@ -114,6 +118,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer_MultiInput(),
"attribute_args": {
Expand All @@ -135,6 +141,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {"inputs": torch.randn(4, 3), "target": [0, 1, 1, 0]},
Expand All @@ -151,6 +159,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer_MultiInput(),
"attribute_args": {
Expand All @@ -171,6 +181,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
Expand All @@ -191,6 +203,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {"inputs": torch.randn(4, 3), "target": torch.tensor([0])},
Expand All @@ -207,6 +221,8 @@
Deconvolution,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
Expand All @@ -222,6 +238,8 @@
FeatureAblation,
DeepLift,
ShapleyValueSampling,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
Expand All @@ -238,6 +256,8 @@
FeatureAblation,
DeepLift,
ShapleyValueSampling,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
Expand Down Expand Up @@ -508,17 +528,30 @@
# Perturbation-Specific Configs
{
"name": "conv_with_perturbations_per_eval",
"algorithms": [FeatureAblation, ShapleyValueSampling, FeaturePermutation],
"algorithms": [
FeatureAblation,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_ConvNet(),
"attribute_args": {
"inputs": torch.arange(400).view(4, 1, 10, 10).float(),
"target": 0,
"perturbations_per_eval": 20,
},
"dp_delta": 0.008,
},
{
"name": "basic_multiple_tuple_target_with_perturbations_per_eval",
"algorithms": [FeatureAblation, ShapleyValueSampling, FeaturePermutation],
"algorithms": [
FeatureAblation,
ShapleyValueSampling,
FeaturePermutation,
Lime,
KernelShap,
],
"model": BasicModel_MultiLayer(),
"attribute_args": {
"inputs": torch.randn(4, 3),
Expand Down
9 changes: 8 additions & 1 deletion tests/attr/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from captum.attr._utils.attribution import Attribution, InternalAttribution

from ..helpers.basic import BaseTest, assertTensorTuplesAlmostEqual, deep_copy_args
from .helpers.gen_test_utils import gen_test_name, get_target_layer, parse_test_config
from .helpers.gen_test_utils import (
gen_test_name,
get_target_layer,
parse_test_config,
should_create_generated_test,
)
from .helpers.test_config import config

"""
Expand Down Expand Up @@ -66,6 +71,8 @@ def __new__(cls, name: str, bases: Tuple, attrs: Dict):
dp_delta = test_config["dp_delta"] if "dp_delta" in test_config else 0.0001

for algorithm in algorithms:
if not should_create_generated_test(algorithm):
continue
for mode in DataParallelCompareMode:
# Creates test case corresponding to each algorithm and
# DataParallelCompareMode
Expand Down
9 changes: 8 additions & 1 deletion tests/attr/test_hook_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from captum.attr._utils.attribution import Attribution, InternalAttribution

from ..helpers.basic import BaseTest, deep_copy_args
from .helpers.gen_test_utils import gen_test_name, get_target_layer, parse_test_config
from .helpers.gen_test_utils import (
gen_test_name,
get_target_layer,
parse_test_config,
should_create_generated_test,
)
from .helpers.test_config import config

"""
Expand Down Expand Up @@ -63,6 +68,8 @@ def __new__(cls, name: str, bases: Tuple, attrs: Dict):
) = parse_test_config(test_config)

for algorithm in algorithms:
if not should_create_generated_test(algorithm):
continue
for mode in HookRemovalMode:
if mode is HookRemovalMode.invalid_module and layer is None:
continue
Expand Down
12 changes: 11 additions & 1 deletion tests/attr/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
from captum.attr._core.gradient_shap import GradientShap
from captum.attr._core.input_x_gradient import InputXGradient
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.lime import Lime
from captum.attr._core.noise_tunnel import NoiseTunnel
from captum.attr._core.occlusion import Occlusion
from captum.attr._core.saliency import Saliency
from captum.attr._core.shapley_value import ShapleyValueSampling
from captum.attr._utils.attribution import Attribution

from ..helpers.basic import BaseTest, assertTensorTuplesAlmostEqual, deep_copy_args
from .helpers.gen_test_utils import gen_test_name, parse_test_config
from .helpers.gen_test_utils import (
gen_test_name,
parse_test_config,
should_create_generated_test,
)
from .helpers.test_config import config

JIT_SUPPORTED = [
Expand All @@ -32,6 +38,8 @@
Occlusion,
Saliency,
ShapleyValueSampling,
Lime,
KernelShap,
]

"""
Expand Down Expand Up @@ -75,6 +83,8 @@ def __new__(cls, name: str, bases: Tuple, attrs: Dict):
baseline_distr,
) = parse_test_config(test_config)
for algorithm in algorithms:
if not should_create_generated_test(algorithm):
continue
if algorithm in JIT_SUPPORTED:
for mode in JITCompareMode:
# Creates test case corresponding to each algorithm and
Expand Down
Loading

0 comments on commit 279d0be

Please sign in to comment.