diff --git a/captum/_utils/exceptions.py b/captum/_utils/exceptions.py index b952d3740..f548ba207 100644 --- a/captum/_utils/exceptions.py +++ b/captum/_utils/exceptions.py @@ -9,3 +9,11 @@ class FeatureAblationFutureError(Exception): FeatureAblation attribution call""" pass + + +class ShapleyValueFutureError(Exception): + """This custom error is raised when an error + occurs within the callback chain of a + ShapleyValue attribution call""" + + pass diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index 83f1811ae..ca7f6f7e9 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -5,7 +5,7 @@ import itertools import math import warnings -from typing import Callable, cast, Iterable, Optional, Sequence, Tuple, Union +from typing import Callable, cast, Iterable, List, Optional, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -20,6 +20,7 @@ _is_tuple, _run_forward, ) +from captum._utils.exceptions import ShapleyValueFutureError from captum._utils.progress import progress from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import PerturbationAttribution @@ -29,7 +30,8 @@ _tensorize_baseline, ) from captum.log import log_usage -from torch import dtype, Tensor +from torch import dtype, Size, Tensor +from torch.futures import collect_all, Future def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]: @@ -394,7 +396,6 @@ def attribute( ) if show_progress: attr_progress.update() - if agg_output_mode: eval_diff = modified_eval - prev_results prev_results = modified_eval @@ -438,7 +439,6 @@ def attribute( # (*output_shape, *input_feature_shape) total_attrib[j] += cur_attr - if show_progress: attr_progress.close() @@ -452,15 +452,318 @@ def attribute( # `Tuple[Tensor, ...]`. return formatted_attr - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Optional[Tuple[object, ...]] = None, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + n_samples: int = 25, + perturbations_per_eval: int = 1, + show_progress: bool = False, + ) -> Future[TensorOrTupleOfTensorsGeneric]: r""" This method is not implemented for ShapleyValueSampling. """ - raise NotImplementedError( - "attribute_future is not implemented for ShapleyValueSampling" + is_inputs_tuple = _is_tuple(inputs) + inputs_tuple, baselines = _format_input_baseline(inputs, baselines) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple) + reshaped_feature_mask = _shape_feature_mask( + formatted_feature_mask, inputs_tuple ) + assert ( + isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 + ), "Ablations per evaluation must be at least 1." + + with torch.no_grad(): + baselines = _tensorize_baseline(inputs_tuple, baselines) + num_examples = inputs_tuple[0].shape[0] + + total_features = _get_max_feature_index(reshaped_feature_mask) + 1 + + if show_progress: + attr_progress = progress( + desc=f"{self.get_name()} attribution", + total=self._get_n_evaluations( + total_features, n_samples, perturbations_per_eval + ) + + 1, # add 1 for the initial eval + ) + attr_progress.update(0) + + initial_eval: Future[Tensor] = self._strict_run_forward_future( + self.forward_func, baselines, target, additional_forward_args + ) + + if show_progress: + attr_progress.update() + + prev_result_tuple: Future[ + Tuple[Tensor, Tensor, Size, List[Tensor], bool] + ] = initial_eval.then( + lambda inp=initial_eval: self._initialEvalToPrevResultsTuple( # type: ignore # noqa: E501 line too long + inp, + num_examples, + perturbations_per_eval, + reshaped_feature_mask, + inputs_tuple, + ) + ) + + iter_count = 0 + # Iterate for number of samples, generate a permutation of the features + # and evalute the incremental increase for each feature. + for feature_permutation in self.permutation_generator( + total_features, n_samples + ): + prev_result_tuple = prev_result_tuple.then( + lambda inp=prev_result_tuple: self._setPrevResultsToInitialEval(inp) # type: ignore # noqa: E501 line too long + ) + + iter_count += 1 + for ( + current_inputs, + current_add_args, + current_target, + current_masks, + ) in self._perturbation_generator( + inputs_tuple, + additional_forward_args, + target, + baselines, + reshaped_feature_mask, + feature_permutation, + perturbations_per_eval, + ): + if sum(torch.sum(mask).item() for mask in current_masks) == 0: + warnings.warn( + "Feature mask is missing some integers between 0 and " + "num_features, for optimal performance, make sure each" + " consecutive integer corresponds to a feature.", + stacklevel=1, + ) + # modified_eval dimensions: 1D tensor with length + # equal to #num_examples * #features in batch + modified_eval = self._strict_run_forward_future( + self.forward_func, + current_inputs, + current_target, + current_add_args, + ) + if show_progress: + attr_progress.update() + + assert isinstance(modified_eval, torch.Future), ( + "when using futures method, modified_eval should have " + f"Future type rather than {type(modified_eval)}" + ) + eval_futs: Future[ + List[ + Future[ + Union[ + Tuple[Tensor, Tensor, Size, List[Tensor], bool], + Tensor, + ] + ] + ] + ] = collect_all([prev_result_tuple, modified_eval]) + + prev_result_tuple = eval_futs.then( + lambda evals=eval_futs, masks=current_masks: self._evalFutToPrevResultsTuple( # type: ignore # noqa: E501 line too long + evals, num_examples, inputs_tuple, masks + ) + ) + + if show_progress: + attr_progress.close() + + # Divide total attributions by number of random permutations and return + # formatted attributions. + formatted_attr: Future[Union[Tensor, tuple[Tensor, ...]]] = ( + prev_result_tuple.then( + lambda inp=prev_result_tuple: self._prevResultTupleToFormattedAttr( # type: ignore # noqa: E501 line too long + inp, iter_count, is_inputs_tuple + ) + ) + ) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. + return formatted_attr # type: ignore + + def _initialEvalToPrevResultsTuple( + self, + initial_eval: Future[Tensor], + num_examples: int, + perturbations_per_eval: int, + reshaped_feature_mask: TensorOrTupleOfTensorsGeneric, + inputs_tuple: Tuple[Tensor, ...], + ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: + """Since the initial eval is a Future, it is easier to bundle the prev_result, + agg_output_mode, output_shape, and total_attrib together + as Shapley Value Feature Attributions are being calculated""" + try: + initial_eval_processed = initial_eval.value() + prev_result = initial_eval_processed + if not isinstance(initial_eval_processed, Tensor): + raise AssertionError( + "initial_eval_to_processed_initial_eval_fut: " + "initial_eval should be a Tensor" + ) + agg_output_mode = _find_output_mode_and_verify( + initial_eval_processed, + num_examples, + perturbations_per_eval, + reshaped_feature_mask, + allow_multi_outputs=True, + ) + output_shape = initial_eval_processed.shape + total_attrib: List[Tensor] = [ + torch.zeros( + tuple(output_shape) + tuple(input.shape[1:]), + dtype=torch.float, + device=inputs_tuple[0].device, + ) + for input in inputs_tuple + ] + result = ( + initial_eval_processed, + prev_result, + output_shape, + total_attrib, + agg_output_mode, + ) + except ShapleyValueFutureError as e: + raise ShapleyValueFutureError( + "_initial_eval_to_prev_results_tuple func failed" + ) from e + return result + + def _setPrevResultsToInitialEval( + self, + processed_initial_eval: Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]], + ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: + """At the beginning of each feature permutation, the prev_results is + reset to the initial eval, and this method helps set that up""" + (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) = ( + processed_initial_eval.value() + ) + prev_results = initial_eval + return (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) + + def _evalFutToPrevResultsTuple( + self, + eval_futs: Future[ + List[ + Union[ + Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]], + Future[Tensor], + ] + ] + ], + num_examples: int, + inputs_tuple: Tuple[Tensor, ...], + current_masks: Tuple[Tensor, ...], + ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: + """Helper method responsible for calculating + eval differences between the modified eval and prev_results + Tensor and storing them in total_attrib. Returns prev_results_tuple + with modified total_attrib and prev_results""" + prev_results_tuple = eval_futs.value()[0].value() + modified_eval = eval_futs.value()[1].value() + if not isinstance(modified_eval, Tensor) or not isinstance( + prev_results_tuple, tuple + ): + raise ShapleyValueFutureError( + "_eval_fut_to_prev_results_tuple func failed due to type mismatch" + ) + ( + initial_eval, + prev_results, + output_shape, + total_attrib, + agg_output_mode, + ) = prev_results_tuple + if agg_output_mode: + eval_diff = modified_eval - prev_results + prev_results = modified_eval + else: + # when perturb_per_eval > 1, every num_examples stands for + # one perturb. Since the perturbs are from a consecutive + # perumuation, each diff of a perturb is its eval minus + # the eval of the previous perturb + + all_eval = torch.cat((prev_results, modified_eval), dim=0) + eval_diff = all_eval[num_examples:] - all_eval[:-num_examples] + prev_results = all_eval[-num_examples:] + + for j in range(len(total_attrib)): + # format eval_diff to shape + # (n_perturb, *output_shape, 1,.. 1) + # where n_perturb may not be perturb_per_eval + # Append n_input_feature dim of 1 to make the tensor + # have the same dim as the mask tensor. + formatted_eval_diff = eval_diff.reshape( + (-1,) + tuple(output_shape) + (len(inputs_tuple[j].shape) - 1) * (1,) + ) + + # mask in shape (n_perturb, *mask_shape_broadcastable_to_input) + # reshape to + # ( + # n_perturb, + # *broadcastable_to_output_shape + # *broadcastable_to_input_feature_shape + # ) + cur_mask = current_masks[j] + cur_mask = cur_mask.reshape( + tuple(cur_mask.shape[:2]) + + (len(output_shape) - 1) * (1,) + + tuple(cur_mask.shape[2:]) + ) + + # aggregate n_perturb + cur_attr = (formatted_eval_diff * cur_mask.float()).sum(dim=0) + # (*output_shape, *input_feature_shape) + total_attrib[j] += cur_attr + + result = ( + initial_eval, + prev_results, + output_shape, + total_attrib, + agg_output_mode, + ) + return result + + def _prevResultTupleToFormattedAttr( + self, + prev_result_tuple: Future[ + Tuple[Tensor, Tensor, Tuple[int], List[Tensor], bool] + ], + iter_count: int, + is_inputs_tuple: bool, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + """Helper method to format total_attrib, which is a + list of tensors, into formatted attributions, which + are either a single tensor or a tuple of tensors""" + + ( + _, + _, + _, + total_attrib, + _, + ) = prev_result_tuple.value() + attrib = tuple( + tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib + ) + formatted_attr = _format_output(is_inputs_tuple, attrib) + return formatted_attr + def _perturbation_generator( self, inputs: Tuple[Tensor, ...], @@ -574,6 +877,39 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor: # ref: https://github.com/pytorch/pytorch/pull/21215 return torch.tensor([forward_output], dtype=cast(dtype, output_type)) + # pyre-fixme[2]: Parameter must be annotated. + def _strict_run_forward_future(self, *args, **kwargs) -> Future[Tensor]: + """ + A temp wrapper for global _run_forward util to force + forward outputtype assertion & conversion, but takes + into account the Future tensor type + """ + + def process_strict_run_forward(fut: Future[Tensor]) -> Tensor: + output = fut.value() + if isinstance(output, Tensor): + # format scalar to shape (1) so we can always + # assume non-empty output_shape + if not output.shape: + output = output.reshape(1) + return output + output_type = type(output) + assert output_type is int or output_type is float, ( + "the return of forward_func must be a Future of tensor, int, or float," + f" received: {output_type}" + ) + output = torch.tensor([output], dtype=cast(dtype, output_type)) + return output + + forward_output = _run_forward(*args, **kwargs) + assert isinstance(forward_output, torch.Future), ( + "The return type of forward_func must be a Future" + f" received: {type(forward_output)}" + ) + + return_output = forward_output.then(process_strict_run_forward) + return return_output + class ShapleyValues(ShapleyValueSampling): """ diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 8c4685f75..6eaf58e5d 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -8,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.futures import Future """ @no_type_check annotation is applied to type-hinted models to avoid errors @@ -477,6 +478,77 @@ def forward( return lin2_out +class BasicModel_MultiLayer_with_Future(nn.Module): + # This model is used to test the case where the model returns a future + def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None: + super().__init__() + # Linear 0 is simply identity transform + self.multi_input_module = multi_input_module + self.linear0 = nn.Linear(3, 3) + self.linear0.weight = nn.Parameter(torch.eye(3)) + self.linear0.bias = nn.Parameter(torch.zeros(3)) + self.linear1 = nn.Linear(3, 4) + self.linear1.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + + self.linear1_alt = nn.Linear(3, 4) + self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + self.multi_relu = MultiRelu(inplace=inplace) + self.relu = nn.ReLU(inplace=inplace) + + self.linear2 = nn.Linear(4, 2) + self.linear2.weight = nn.Parameter(torch.ones(2, 4)) + self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) + + @no_type_check + # pyre-fixme[3]: Return type must be annotated. + def forward( + self, + x: Tensor, + add_input: Optional[Tensor] = None, + multidim_output: bool = False, + ): + input = x if add_input is None else x + add_input + lin0_out = self.linear0(input) + lin1_out = self.linear1(lin0_out) + if self.multi_input_module: + relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input)) + relu_out = relu_out1 + relu_out2 + # relu is not used when multi_input_module set to True, + # so this is to set an unsued layer intentionally for testing + # and it won't be part of return + self.relu(lin1_out) + else: + relu_out = self.relu(lin1_out) + # pyre-fixme [29]: `typing.Type[Future]` is not a function + result = Future() + lin2_out = self.linear2(relu_out) + if multidim_output: + stack_mid = torch.stack((lin2_out, 2 * lin2_out), dim=2) + result.set_result(torch.stack((stack_mid, 4 * stack_mid), dim=3)) + return result + else: + result.set_result(lin2_out) + return result + + +class BasicModelBoolInput_with_Future(nn.Module): + def __init__(self) -> None: + super().__init__() + self.mod = BasicModel_MultiLayer_with_Future() + + # pyre-fixme[3]: Return type must be annotated. + def forward( + self, + x: Tensor, + add_input: Optional[Tensor] = None, + mult: float = 10.0, + ): + assert x.dtype is torch.bool, "Input must be boolean" + return self.mod(x.float() * mult, add_input) + + class BasicModelBoolInput(nn.Module): def __init__(self) -> None: super().__init__() @@ -504,6 +576,17 @@ def forward(self, x1: Tensor, x2: Tensor, x3: Tensor, scale: int): return self.model(scale * (x1 + x2 + x3)) +class BasicModel_MultiLayer_MultiInput_with_Future(nn.Module): + def __init__(self) -> None: + super().__init__() + self.model = BasicModel_MultiLayer_with_Future() + + @no_type_check + # pyre-fixme[3]: Return type must be annotated. + def forward(self, x1: Tensor, x2: Tensor, x3: Tensor, scale: int): + return self.model(scale * (x1 + x2 + x3)) + + class BasicModel_MultiLayer_TrueMultiInput(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index 976adc55f..dbe781303 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -14,90 +14,174 @@ from captum.testing.helpers.basic_models import ( BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, + BasicModel_MultiLayer_MultiInput_with_Future, + BasicModel_MultiLayer_with_Future, BasicModelBoolInput, + BasicModelBoolInput_with_Future, ) +from parameterized import parameterized +from torch.futures import Future class Test(BaseTest): - def test_simple_shapley_sampling(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert( - net, - inp, - [[76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=250, - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) - def test_simple_shapley_sampling_with_mask(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_with_mask(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert( - net, - inp, - [[275.0, 275.0, 115.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[275.0, 275.0, 115.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[275.0, 275.0, 115.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) - def test_simple_shapley_sampling_boolean(self) -> None: - net = BasicModelBoolInput() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_boolean(self, use_future) -> None: inp = torch.tensor([[True, False, True]]) - self._shapley_test_assert( - net, - inp, - [[35.0, 35.0, 35.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModelBoolInput_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[35.0, 35.0, 35.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModelBoolInput() + self._shapley_test_assert( + net, + inp, + [[35.0, 35.0, 35.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) - def test_simple_shapley_sampling_boolean_with_baseline(self) -> None: - net = BasicModelBoolInput() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_boolean_with_baseline(self, use_future) -> None: inp = torch.tensor([[True, False, True]]) - self._shapley_test_assert( - net, - inp, - [[-40.0, -40.0, 0.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - baselines=True, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModelBoolInput_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[-40.0, -40.0, 0.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=True, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModelBoolInput() + self._shapley_test_assert( + net, + inp, + [[-40.0, -40.0, 0.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=True, + perturbations_per_eval=(1, 2, 3), + ) - def test_simple_shapley_sampling_with_baselines(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_with_baselines(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]]) - self._shapley_test_assert( - net, - inp, - [[248.0, 248.0, 104.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - baselines=4, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[248.0, 248.0, 104.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=4, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[248.0, 248.0, 104.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=4, + perturbations_per_eval=(1, 2, 3), + ) - def test_multi_sample_shapley_sampling(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_multi_sample_shapley_sampling(self, use_future) -> None: inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]]) - self._shapley_test_assert( - net, - inp, - [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=200, - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=200, + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=200, + ) - def test_multi_sample_shapley_sampling_with_mask(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_multi_sample_shapley_sampling_with_mask(self, use_future) -> None: inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]], requires_grad=True) mask = torch.tensor([[0, 0, 1], [1, 1, 0]]) - self._shapley_test_assert( - net, - inp, - [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], - feature_mask=mask, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) def test_multi_input_shapley_sampling_without_mask(self) -> None: net = BasicModel_MultiLayer_MultiInput() @@ -118,6 +202,25 @@ def test_multi_input_shapley_sampling_without_mask(self) -> None: test_true_shapley=False, ) + def test_multi_input_shapley_sampling_without_mask_future(self) -> None: + net = BasicModel_MultiLayer_MultiInput_with_Future() + inp1 = torch.tensor([[23.0, 0.0, 0.0], [20.0, 50.0, 30.0]]) + inp2 = torch.tensor([[20.0, 0.0, 50.0], [0.0, 100.0, 0.0]]) + inp3 = torch.tensor([[0.0, 100.0, 10.0], [0.0, 10.0, 0.0]]) + expected = ( + [[90, 0, 0], [78.0, 198.0, 118.0]], + [[78, 0, 198], [0.0, 398.0, 0.0]], + [[0, 398, 38], [0.0, 38.0, 0.0]], + ) + self._shapley_test_assert_future( + net, + (inp1, inp2, inp3), + expected, + additional_input=(1,), + n_samples=200, + test_true_shapley=False, + ) + def test_multi_input_shapley_sampling_with_mask(self) -> None: net = BasicModel_MultiLayer_MultiInput() inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) @@ -153,86 +256,201 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_shapley_sampling_multi_task_output(self) -> None: - # return shape (batch size, 2) - net1 = BasicModel_MultiLayer() - - # return shape (batch size, 4) - def forward_func(*args, **kwargs): - net_output = net1(*args, **kwargs) - batch_size = net_output.size(0) - constant = torch.ones(batch_size, 2) - output = torch.cat( - [ - net_output, - constant, - ], - dim=-1, - ) - return output - - inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - - self._shapley_test_assert( - forward_func, - inp, - [ - [ - [76.66666, 196.66666, 116.66666], - [76.66666, 196.66666, 116.66666], - [0, 0, 0], - [0, 0, 0], - ] - ], - target=None, # no target, multi-task output for all classes + def test_multi_input_shapley_sampling_with_mask_future(self) -> None: + net = BasicModel_MultiLayer_MultiInput_with_Future() + inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) + inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) + inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) + mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) + mask2 = torch.tensor([[0, 1, 2]]) + mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]]) + expected = ( + [[1088.6666, 1088.6666, 1088.6666], [255.0, 595.0, 255.0]], + [[76.6666, 1088.6666, 156.6666], [255.0, 595.0, 0.0]], + [[76.6666, 1088.6666, 156.6666], [255.0, 255.0, 255.0]], + ) + self._shapley_test_assert_future( + net, + (inp1, inp2, inp3), + expected, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + ) + expected_with_baseline = ( + [[1040, 1040, 1040], [184, 580.0, 184]], + [[52, 1040, 132], [184, 580.0, -12.0]], + [[52, 1040, 132], [184, 184, 184]], + ) + self._shapley_test_assert_future( + net, + (inp1, inp2, inp3), + expected_with_baseline, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + baselines=(2, 3.0, 4), perturbations_per_eval=(1, 2, 3), - n_samples=150, - test_true_shapley=True, ) - def test_shapley_sampling_multi_task_output_with_mask(self) -> None: + @parameterized.expand([True, False]) + def test_shapley_sampling_multi_task_output(self, use_future) -> None: # return shape (batch size, 2) - net1 = BasicModel_MultiLayer() + inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) + if use_future: + net1_fut = BasicModel_MultiLayer_with_Future() + + def forward_func(*args, **kwargs): + net_output = net1_fut(*args, **kwargs) + net_output.wait() + batch_size = net_output.value().size(0) + constant = torch.ones(batch_size, 2) + output = torch.cat( + [ + net_output.value(), + constant, + ], + dim=-1, + ) + fut = Future() + fut.set_result(output) + return fut - # return shape (batch size, 4) - def forward_func(*args, **kwargs): - net_output = net1(*args, **kwargs) - batch_size = net_output.size(0) - constant = torch.ones(batch_size, 1) + self._shapley_test_assert_future( + forward_func, + inp, + [ + [ + [76.66666, 196.66666, 116.66666], + [76.66666, 196.66666, 116.66666], + [0, 0, 0], + [0, 0, 0], + ] + ], + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + ) + else: + net1 = BasicModel_MultiLayer() + + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + batch_size = net_output.size(0) + constant = torch.ones(batch_size, 2) + output = torch.cat( + [ + net_output, + constant, + ], + dim=-1, + ) + return output - output = torch.cat( + # return shape (batch size, 4) + self._shapley_test_assert( + forward_func, + inp, [ - net_output, - constant, + [ + [76.66666, 196.66666, 116.66666], + [76.66666, 196.66666, 116.66666], + [0, 0, 0], + [0, 0, 0], + ] ], - dim=-1, + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, ) - return output + @parameterized.expand([True, False]) + def test_shapley_sampling_multi_task_output_with_mask(self, use_future) -> None: + # return shape (batch size, 2) inp = torch.tensor([[20.0, 50.0, 30.0], [20.0, 50.0, 30.0]], requires_grad=True) mask = torch.tensor([[1, 1, 0], [0, 1, 1]]) + if use_future: + net1_fut = BasicModel_MultiLayer_with_Future() + + # return shape (batch size, 4) + def forward_func(*args, **kwargs): + net_output = net1_fut(*args, **kwargs) + net_output.wait() + batch_size = net_output.value().size(0) + constant = torch.ones(batch_size, 1) + + output = torch.cat( + [ + net_output.value(), + constant, + ], + dim=-1, + ) + fut = Future() + fut.set_result(output) + return fut - self._shapley_test_assert( - forward_func, - inp, - [ + self._shapley_test_assert_future( + forward_func, + inp, [ - [275.0, 275.0, 115.0], - [275.0, 275.0, 115.0], - [0, 0, 0], + [ + [275.0, 275.0, 115.0], + [275.0, 275.0, 115.0], + [0, 0, 0], + ], + [ + [75.0, 315.0, 315.0], + [75.0, 315.0, 315.0], + [0, 0, 0], + ], ], + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + feature_mask=mask, + ) + else: + + net1 = BasicModel_MultiLayer() + + # return shape (batch size, 4) + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + batch_size = net_output.size(0) + constant = torch.ones(batch_size, 1) + + output = torch.cat( + [ + net_output, + constant, + ], + dim=-1, + ) + return output + + self._shapley_test_assert( + forward_func, + inp, [ - [75.0, 315.0, 315.0], - [75.0, 315.0, 315.0], - [0, 0, 0], + [ + [275.0, 275.0, 115.0], + [275.0, 275.0, 115.0], + [0, 0, 0], + ], + [ + [75.0, 315.0, 315.0], + [75.0, 315.0, 315.0], + [0, 0, 0], + ], ], - ], - target=None, # no target, multi-task output for all classes - perturbations_per_eval=(1, 2, 3), - n_samples=150, - test_true_shapley=True, - feature_mask=mask, - ) + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + feature_mask=mask, + ) # Remaining tests are for cases where forward function returns a scalar # per batch, as either a float, integer, 0d tensor or 1d tensor. @@ -388,15 +606,6 @@ def test_shapley_sampling_with_mask_and_show_progress(self, mock_stderr) -> None mock_stderr.seek(0) mock_stderr.truncate(0) - def test_futures_not_implemented(self) -> None: - net = BasicModel_MultiLayer() - - attributions = None - shapley_samp = ShapleyValueSampling(net) - with self.assertRaises(NotImplementedError): - attributions = shapley_samp.attribute_future() - self.assertEqual(attributions, None) - def _single_input_one_sample_batch_scalar_shapley_assert( self, func: Callable ) -> None: @@ -514,6 +723,54 @@ def _shapley_test_assert( self, attributions, expected_attr, mode="max", delta=0.001 ) + def _shapley_test_assert_future( + self, + model: Callable, + test_input: TensorOrTupleOfTensorsGeneric, + expected_attr, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + additional_input: Any = None, + perturbations_per_eval: Tuple[int, ...] = (1,), + baselines: BaselineType = None, + target: Union[None, int] = 0, + n_samples: int = 100, + delta: float = 1.0, + # leaving this false as it is not supported for future + test_true_shapley: bool = False, + show_progress: bool = False, + ) -> None: + for batch_size in perturbations_per_eval: + shapley_samp = ShapleyValueSampling(model) + attributions = shapley_samp.attribute_future( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + n_samples=n_samples, + show_progress=show_progress, + ) + attributions.wait() + assertTensorTuplesAlmostEqual( + self, attributions.value(), expected_attr, delta=delta, mode="max" + ) + if test_true_shapley: + shapley_val = ShapleyValues(model) + attributions = shapley_val.attribute_future( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + show_progress=show_progress, + ) + attributions.wait() + assertTensorTuplesAlmostEqual( + self, attributions.value(), expected_attr, mode="max", delta=0.001 + ) + if __name__ == "__main__": unittest.main()