diff --git a/captum/_utils/exceptions.py b/captum/_utils/exceptions.py
index b952d3740..f548ba207 100644
--- a/captum/_utils/exceptions.py
+++ b/captum/_utils/exceptions.py
@@ -9,3 +9,11 @@ class FeatureAblationFutureError(Exception):
FeatureAblation attribution call"""
pass
+
+
+class ShapleyValueFutureError(Exception):
+ """This custom error is raised when an error
+ occurs within the callback chain of a
+ ShapleyValue attribution call"""
+
+ pass
diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py
index 83f1811ae..4ee00ce05 100644
--- a/captum/attr/_core/shapley_value.py
+++ b/captum/attr/_core/shapley_value.py
@@ -5,7 +5,7 @@
import itertools
import math
import warnings
-from typing import Callable, cast, Iterable, Optional, Sequence, Tuple, Union
+from typing import Callable, cast, Iterable, List, Optional, Sequence, Tuple, Union
import torch
from captum._utils.common import (
@@ -20,6 +20,7 @@
_is_tuple,
_run_forward,
)
+from captum._utils.exceptions import ShapleyValueFutureError
from captum._utils.progress import progress
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import PerturbationAttribution
@@ -29,7 +30,8 @@
_tensorize_baseline,
)
from captum.log import log_usage
-from torch import dtype, Tensor
+from torch import dtype, Size, Tensor
+from torch.futures import collect_all, Future
def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]:
@@ -394,7 +396,6 @@ def attribute(
)
if show_progress:
attr_progress.update()
-
if agg_output_mode:
eval_diff = modified_eval - prev_results
prev_results = modified_eval
@@ -438,7 +439,6 @@ def attribute(
# (*output_shape, *input_feature_shape)
total_attrib[j] += cur_attr
-
if show_progress:
attr_progress.close()
@@ -452,15 +452,304 @@ def attribute(
# `Tuple[Tensor, ...]`.
return formatted_attr
- # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
- def attribute_future(self) -> Callable:
+ def attribute_future(
+ self,
+ inputs: TensorOrTupleOfTensorsGeneric,
+ baselines: BaselineType = None,
+ target: TargetType = None,
+ additional_forward_args: Optional[Tuple[object, ...]] = None,
+ feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
+ n_samples: int = 25,
+ perturbations_per_eval: int = 1,
+ show_progress: bool = False,
+ ) -> Future[TensorOrTupleOfTensorsGeneric]:
r"""
This method is not implemented for ShapleyValueSampling.
"""
- raise NotImplementedError(
- "attribute_future is not implemented for ShapleyValueSampling"
+ is_inputs_tuple = _is_tuple(inputs)
+ inputs_tuple, baselines = _format_input_baseline(inputs, baselines)
+ additional_forward_args = _format_additional_forward_args(
+ additional_forward_args
+ )
+ formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple)
+ reshaped_feature_mask = _shape_feature_mask(
+ formatted_feature_mask, inputs_tuple
)
+ assert (
+ isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
+ ), "Ablations per evaluation must be at least 1."
+
+ with torch.no_grad():
+ baselines = _tensorize_baseline(inputs_tuple, baselines)
+ num_examples = inputs_tuple[0].shape[0]
+
+ total_features = _get_max_feature_index(reshaped_feature_mask) + 1
+
+ if show_progress:
+ attr_progress = progress(
+ desc=f"{self.get_name()} attribution",
+ total=self._get_n_evaluations(
+ total_features, n_samples, perturbations_per_eval
+ )
+ + 1, # add 1 for the initial eval
+ )
+ attr_progress.update(0)
+
+ initial_eval = self._strict_run_forward_future(
+ self.forward_func, baselines, target, additional_forward_args
+ )
+
+ if show_progress:
+ attr_progress.update()
+
+ prev_result_tuple = initial_eval.then(
+ lambda inp=initial_eval: self._initial_eval_to_prev_results_tuple(
+ inp,
+ num_examples,
+ perturbations_per_eval,
+ reshaped_feature_mask,
+ inputs_tuple,
+ )
+ )
+
+ iter_count = 0
+ # Iterate for number of samples, generate a permutation of the features
+ # and evalute the incremental increase for each feature.
+ for feature_permutation in self.permutation_generator(
+ total_features, n_samples
+ ):
+ prev_result_tuple = prev_result_tuple.then(
+ lambda inp=prev_result_tuple: self._setPrevResultsToInitialEval(inp)
+ )
+
+ iter_count += 1
+ for (
+ current_inputs,
+ current_add_args,
+ current_target,
+ current_masks,
+ ) in self._perturbation_generator(
+ inputs_tuple,
+ additional_forward_args,
+ target,
+ baselines,
+ reshaped_feature_mask,
+ feature_permutation,
+ perturbations_per_eval,
+ ):
+ if sum(torch.sum(mask).item() for mask in current_masks) == 0:
+ warnings.warn(
+ "Feature mask is missing some integers between 0 and "
+ "num_features, for optimal performance, make sure each"
+ " consecutive integer corresponds to a feature.",
+ stacklevel=1,
+ )
+ # modified_eval dimensions: 1D tensor with length
+ # equal to #num_examples * #features in batch
+ modified_eval = self._strict_run_forward_future(
+ self.forward_func,
+ current_inputs,
+ current_target,
+ current_add_args,
+ )
+ if show_progress:
+ attr_progress.update()
+
+ assert isinstance(modified_eval, torch.Future), (
+ "when using futures method, modified_eval should have "
+ f"Future type rather than {type(modified_eval)}"
+ )
+ eval_futs = collect_all([prev_result_tuple, modified_eval])
+ prev_result_tuple = eval_futs.then(
+ lambda evals=eval_futs, masks=current_masks: self._evalFutToPrevResultsTuple(
+ evals, num_examples, inputs_tuple, masks
+ )
+ )
+
+ if show_progress:
+ attr_progress.close()
+
+ # Divide total attributions by number of random permutations and return
+ # formatted attributions.
+ formatted_attr = prev_result_tuple.then(
+ lambda inp=prev_result_tuple: self._prev_result_tuple_to_formatted_attr(
+ inp, iter_count, is_inputs_tuple
+ )
+ )
+ # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
+ # `Tuple[Tensor, ...]`.
+ return formatted_attr
+
+ def _initial_eval_to_prev_results_tuple(
+ self,
+ initial_eval: Future[Tensor],
+ num_examples: int,
+ perturbations_per_eval: int,
+ reshaped_feature_mask: TensorOrTupleOfTensorsGeneric,
+ inputs_tuple: Tuple[Tensor, ...],
+ ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
+ """Since the initial eval is a Future, it is easier to bundle the prev_result,
+ agg_output_mode, output_shape, and total_attrib together
+ as Shapley Value Feature Attributions are being calculated"""
+ try:
+ initial_eval_processed = initial_eval.value()
+ prev_result = initial_eval_processed
+ if not isinstance(initial_eval_processed, Tensor):
+ raise AssertionError(
+ "initial_eval_to_processed_initial_eval_fut: "
+ "initial_eval should be a Tensor"
+ )
+ agg_output_mode = _find_output_mode_and_verify(
+ initial_eval_processed,
+ num_examples,
+ perturbations_per_eval,
+ reshaped_feature_mask,
+ allow_multi_outputs=True,
+ )
+ output_shape = initial_eval_processed.shape
+ total_attrib: List[Tensor] = [
+ torch.zeros(
+ tuple(output_shape) + tuple(input.shape[1:]),
+ dtype=torch.float,
+ device=inputs_tuple[0].device,
+ )
+ for input in inputs_tuple
+ ]
+ result = (
+ initial_eval_processed,
+ prev_result,
+ output_shape,
+ total_attrib,
+ agg_output_mode,
+ )
+ except ShapleyValueFutureError as e:
+ raise ShapleyValueFutureError(
+ "_initial_eval_to_prev_results_tuple func failed"
+ ) from e
+ return result
+
+ def _setPrevResultsToInitialEval(
+ self,
+ processed_initial_eval: Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]],
+ ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
+ """At the beginning of each feature permutation, the prev_results is
+ reset to the initial eval, and this method helps set that up"""
+ (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) = (
+ processed_initial_eval.value()
+ )
+ prev_results = initial_eval
+ return (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode)
+
+ def _evalFutToPrevResultsTuple(
+ self,
+ eval_futs: Future[
+ List[
+ Union[
+ Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]],
+ Future[Tensor],
+ ]
+ ]
+ ],
+ num_examples: int,
+ inputs_tuple: Tuple[Tensor, ...],
+ current_masks: Tuple[Tensor, ...],
+ ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
+ """Helper method responsible for calculating
+ eval differences between the modified eval and prev_results
+ Tensor and storing them in total_attrib. Returns prev_results_tuple
+ with modified total_attrib and prev_results"""
+ prev_results_tuple = eval_futs.value()[0].value()
+ modified_eval = eval_futs.value()[1].value()
+ if not isinstance(modified_eval, Tensor) or not isinstance(
+ prev_results_tuple, tuple
+ ):
+ raise ShapleyValueFutureError(
+ "_eval_fut_to_prev_results_tuple func failed due to type mismatch"
+ )
+ (
+ initial_eval,
+ prev_results,
+ output_shape,
+ total_attrib,
+ agg_output_mode,
+ ) = prev_results_tuple
+ if agg_output_mode:
+ eval_diff = modified_eval - prev_results
+ prev_results = modified_eval
+ else:
+ # when perturb_per_eval > 1, every num_examples stands for
+ # one perturb. Since the perturbs are from a consecutive
+ # perumuation, each diff of a perturb is its eval minus
+ # the eval of the previous perturb
+
+ all_eval = torch.cat((prev_results, modified_eval), dim=0)
+ eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
+ prev_results = all_eval[-num_examples:]
+
+ for j in range(len(total_attrib)):
+ # format eval_diff to shape
+ # (n_perturb, *output_shape, 1,.. 1)
+ # where n_perturb may not be perturb_per_eval
+ # Append n_input_feature dim of 1 to make the tensor
+ # have the same dim as the mask tensor.
+ formatted_eval_diff = eval_diff.reshape(
+ (-1,) + tuple(output_shape) + (len(inputs_tuple[j].shape) - 1) * (1,)
+ )
+
+ # mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
+ # reshape to
+ # (
+ # n_perturb,
+ # *broadcastable_to_output_shape
+ # *broadcastable_to_input_feature_shape
+ # )
+ cur_mask = current_masks[j]
+ cur_mask = cur_mask.reshape(
+ tuple(cur_mask.shape[:2])
+ + (len(output_shape) - 1) * (1,)
+ + tuple(cur_mask.shape[2:])
+ )
+
+ # aggregate n_perturb
+ cur_attr = (formatted_eval_diff * cur_mask.float()).sum(dim=0)
+ # (*output_shape, *input_feature_shape)
+ total_attrib[j] += cur_attr
+
+ result = (
+ initial_eval,
+ prev_results,
+ output_shape,
+ total_attrib,
+ agg_output_mode,
+ )
+ return result
+
+ def _prev_result_tuple_to_formatted_attr(
+ self,
+ prev_result_tuple: Future[
+ Tuple[Tensor, Tensor, Tuple[int], List[Tensor], bool]
+ ],
+ iter_count: int,
+ is_inputs_tuple: bool,
+ ) -> Union[Tensor, Tuple[Tensor, ...]]:
+ """Helper method to format total_attrib, which is a
+ list of tensors, into formatted attributions, which
+ are either a single tensor or a tuple of tensors"""
+
+ (
+ _,
+ _,
+ _,
+ total_attrib,
+ _,
+ ) = prev_result_tuple.value()
+ attrib = tuple(
+ tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib
+ )
+ formatted_attr = _format_output(is_inputs_tuple, attrib)
+ return formatted_attr
+
def _perturbation_generator(
self,
inputs: Tuple[Tensor, ...],
@@ -574,6 +863,38 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor:
# ref: https://github.com/pytorch/pytorch/pull/21215
return torch.tensor([forward_output], dtype=cast(dtype, output_type))
+ # pyre-fixme[2]: Parameter must be annotated.
+ def _strict_run_forward_future(self, *args, **kwargs) -> Future[Tensor]:
+ """
+ A temp wrapper for global _run_forward util to force
+ forward outputtype assertion & conversion, but takes
+ into account the Future tensor type
+ """
+
+ def process_strict_run_forward(fut: Future[Tensor]) -> Tensor:
+ output = fut.value()
+ if isinstance(output, Tensor):
+ # format scalar to shape (1) so we can always assume non-empty output_shape
+ if not output.shape:
+ output = output.reshape(1)
+ return output
+ output_type = type(output)
+ assert output_type is int or output_type is float, (
+ "the return of forward_func must be a Future of tensor, int, or float,"
+ f" received: {output_type}"
+ )
+ output = torch.tensor([output], dtype=cast(dtype, output_type))
+ return output
+
+ forward_output = _run_forward(*args, **kwargs)
+ assert isinstance(forward_output, torch.Future), (
+ "The return type of forward_func must be a Future"
+ f" received: {type(forward_output)}"
+ )
+
+ return_output = forward_output.then(process_strict_run_forward)
+ return return_output
+
class ShapleyValues(ShapleyValueSampling):
"""
diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py
index 8c4685f75..8d22adafc 100644
--- a/captum/testing/helpers/basic_models.py
+++ b/captum/testing/helpers/basic_models.py
@@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
+from torch.futures import Future
"""
@no_type_check annotation is applied to type-hinted models to avoid errors
@@ -477,6 +478,61 @@ def forward(
return lin2_out
+class BasicModel_MultiLayer_with_Future(nn.Module):
+ # This model is used to test the case where the model returns a future
+ def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None:
+ super().__init__()
+ # Linear 0 is simply identity transform
+ self.multi_input_module = multi_input_module
+ self.linear0 = nn.Linear(3, 3)
+ self.linear0.weight = nn.Parameter(torch.eye(3))
+ self.linear0.bias = nn.Parameter(torch.zeros(3))
+ self.linear1 = nn.Linear(3, 4)
+ self.linear1.weight = nn.Parameter(torch.ones(4, 3))
+ self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
+
+ self.linear1_alt = nn.Linear(3, 4)
+ self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3))
+ self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
+ self.multi_relu = MultiRelu(inplace=inplace)
+ self.relu = nn.ReLU(inplace=inplace)
+
+ self.linear2 = nn.Linear(4, 2)
+ self.linear2.weight = nn.Parameter(torch.ones(2, 4))
+ self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
+
+ @no_type_check
+ # pyre-fixme[3]: Return type must be annotated.
+ def forward(
+ self,
+ x: Tensor,
+ add_input: Optional[Tensor] = None,
+ multidim_output: bool = False,
+ ):
+ input = x if add_input is None else x + add_input
+ lin0_out = self.linear0(input)
+ lin1_out = self.linear1(lin0_out)
+ if self.multi_input_module:
+ relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input))
+ relu_out = relu_out1 + relu_out2
+ # relu is not used when multi_input_module set to True,
+ # so this is to set an unsued layer intentionally for testing
+ # and it won't be part of return
+ self.relu(lin1_out)
+ else:
+ relu_out = self.relu(lin1_out)
+ # pyre-fixme [29]: `typing.Type[Future]` is not a function
+ result = Future()
+ lin2_out = self.linear2(relu_out)
+ if multidim_output:
+ stack_mid = torch.stack((lin2_out, 2 * lin2_out), dim=2)
+ result.set_result(torch.stack((stack_mid, 4 * stack_mid), dim=3))
+ return result
+ else:
+ result.set_result(lin2_out)
+ return result
+
+
class BasicModelBoolInput(nn.Module):
def __init__(self) -> None:
super().__init__()
diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py
index 976adc55f..b88b478b6 100644
--- a/tests/attr/test_shapley.py
+++ b/tests/attr/test_shapley.py
@@ -14,6 +14,7 @@
from captum.testing.helpers.basic_models import (
BasicModel_MultiLayer,
BasicModel_MultiLayer_MultiInput,
+ BasicModel_MultiLayer_with_Future,
BasicModelBoolInput,
)
@@ -30,6 +31,17 @@ def test_simple_shapley_sampling(self) -> None:
n_samples=250,
)
+ def test_simple_shapley_sampling_future(self) -> None:
+ net = BasicModel_MultiLayer_with_Future()
+ inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)
+ self._shapley_test_assert_future(
+ net,
+ inp,
+ [[76.66666, 196.66666, 116.66666]],
+ perturbations_per_eval=(1, 2, 3),
+ n_samples=250,
+ )
+
def test_simple_shapley_sampling_with_mask(self) -> None:
net = BasicModel_MultiLayer()
inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)
@@ -388,15 +400,6 @@ def test_shapley_sampling_with_mask_and_show_progress(self, mock_stderr) -> None
mock_stderr.seek(0)
mock_stderr.truncate(0)
- def test_futures_not_implemented(self) -> None:
- net = BasicModel_MultiLayer()
-
- attributions = None
- shapley_samp = ShapleyValueSampling(net)
- with self.assertRaises(NotImplementedError):
- attributions = shapley_samp.attribute_future()
- self.assertEqual(attributions, None)
-
def _single_input_one_sample_batch_scalar_shapley_assert(
self, func: Callable
) -> None:
@@ -514,6 +517,53 @@ def _shapley_test_assert(
self, attributions, expected_attr, mode="max", delta=0.001
)
+ def _shapley_test_assert_future(
+ self,
+ model: Callable,
+ test_input: TensorOrTupleOfTensorsGeneric,
+ expected_attr,
+ feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
+ additional_input: Any = None,
+ perturbations_per_eval: Tuple[int, ...] = (1,),
+ baselines: BaselineType = None,
+ target: Union[None, int] = 0,
+ n_samples: int = 100,
+ delta: float = 1.0,
+ # leaving this false as it is not supported for future
+ test_true_shapley: bool = False,
+ show_progress: bool = False,
+ ) -> None:
+ for batch_size in perturbations_per_eval:
+ shapley_samp = ShapleyValueSampling(model)
+ attributions = shapley_samp.attribute_future(
+ test_input,
+ target=target,
+ feature_mask=feature_mask,
+ additional_forward_args=additional_input,
+ baselines=baselines,
+ perturbations_per_eval=batch_size,
+ n_samples=n_samples,
+ show_progress=show_progress,
+ )
+ attributions.wait()
+ assertTensorTuplesAlmostEqual(
+ self, attributions.value(), expected_attr, delta=delta, mode="max"
+ )
+ if test_true_shapley:
+ shapley_val = ShapleyValues(model)
+ attributions = shapley_val.attribute(
+ test_input,
+ target=target,
+ feature_mask=feature_mask,
+ additional_forward_args=additional_input,
+ baselines=baselines,
+ perturbations_per_eval=batch_size,
+ show_progress=show_progress,
+ )
+ assertTensorTuplesAlmostEqual(
+ self, attributions, expected_attr, mode="max", delta=0.001
+ )
+
if __name__ == "__main__":
unittest.main()