Skip to content

Commit

Permalink
Adding async future functionality to ShapleyValues
Browse files Browse the repository at this point in the history
Summary: This diff implements the attribute_future method for the ShapleyValueSampling class.

Reviewed By: cyrjano

Differential Revision: D68158802
  • Loading branch information
jjuncho authored and facebook-github-bot committed Jan 17, 2025
1 parent b487891 commit 691e477
Show file tree
Hide file tree
Showing 4 changed files with 446 additions and 17 deletions.
8 changes: 8 additions & 0 deletions captum/_utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@ class FeatureAblationFutureError(Exception):
FeatureAblation attribution call"""

pass


class ShapleyValueFutureError(Exception):
"""This custom error is raised when an error
occurs within the callback chain of a
ShapleyValue attribution call"""

pass
331 changes: 323 additions & 8 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
import math
import warnings
from typing import Callable, cast, Iterable, Optional, Sequence, Tuple, Union
from typing import Callable, cast, Iterable, List, Optional, Sequence, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -20,6 +20,7 @@
_is_tuple,
_run_forward,
)
from captum._utils.exceptions import ShapleyValueFutureError
from captum._utils.progress import progress
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import PerturbationAttribution
Expand All @@ -29,7 +30,8 @@
_tensorize_baseline,
)
from captum.log import log_usage
from torch import dtype, Tensor
from torch import dtype, Size, Tensor
from torch.futures import collect_all, Future


def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]:
Expand Down Expand Up @@ -394,7 +396,6 @@ def attribute(
)
if show_progress:
attr_progress.update()

if agg_output_mode:
eval_diff = modified_eval - prev_results
prev_results = modified_eval
Expand Down Expand Up @@ -438,7 +439,6 @@ def attribute(

# (*output_shape, *input_feature_shape)
total_attrib[j] += cur_attr

if show_progress:
attr_progress.close()

Expand All @@ -452,14 +452,298 @@ def attribute(
# `Tuple[Tensor, ...]`.
return formatted_attr

# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
def attribute_future(self) -> Callable:
def attribute_future(
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Optional[Tuple[object, ...]] = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
show_progress: bool = False,
) -> Future[TensorOrTupleOfTensorsGeneric]:
r"""
This method is not implemented for ShapleyValueSampling.
"""
raise NotImplementedError(
"attribute_future is not implemented for ShapleyValueSampling"
is_inputs_tuple = _is_tuple(inputs)
inputs_tuple, baselines = _format_input_baseline(inputs, baselines)
additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple)
reshaped_feature_mask = _shape_feature_mask(
formatted_feature_mask, inputs_tuple
)

assert (
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
), "Ablations per evaluation must be at least 1."

with torch.no_grad():
baselines = _tensorize_baseline(inputs_tuple, baselines)
num_examples = inputs_tuple[0].shape[0]

total_features = _get_max_feature_index(reshaped_feature_mask) + 1

if show_progress:
attr_progress = progress(
desc=f"{self.get_name()} attribution",
total=self._get_n_evaluations(
total_features, n_samples, perturbations_per_eval
)
+ 1, # add 1 for the initial eval
)
attr_progress.update(0)

initial_eval = self._strict_run_forward_future(
self.forward_func, baselines, target, additional_forward_args
)

if show_progress:
attr_progress.update()

prev_result_tuple = initial_eval.then(
lambda initial_eval=initial_eval: self._initial_eval_to_prev_results_tuple(
initial_eval,
num_examples,
perturbations_per_eval,
reshaped_feature_mask,
inputs_tuple,
)
)

iter_count = 0
# Iterate for number of samples, generate a permutation of the features
# and evalute the incremental increase for each feature.
for feature_permutation in self.permutation_generator(
total_features, n_samples
):
prev_result_tuple = prev_result_tuple.then(
lambda prev_result_tuple=prev_result_tuple: self._set_prev_results_to_initial_eval(
prev_result_tuple
)
)

iter_count += 1
for (
current_inputs,
current_add_args,
current_target,
current_masks,
) in self._perturbation_generator(
inputs_tuple,
additional_forward_args,
target,
baselines,
reshaped_feature_mask,
feature_permutation,
perturbations_per_eval,
):
if sum(torch.sum(mask).item() for mask in current_masks) == 0:
warnings.warn(
"Feature mask is missing some integers between 0 and "
"num_features, for optimal performance, make sure each"
" consecutive integer corresponds to a feature.",
stacklevel=1,
)
# modified_eval dimensions: 1D tensor with length
# equal to #num_examples * #features in batch
modified_eval = self._strict_run_forward_future(
self.forward_func,
current_inputs,
current_target,
current_add_args,
)
if show_progress:
attr_progress.update()

assert isinstance(modified_eval, torch.Future), (
"when using futures method, modified_eval should have "
f"Future type rather than {type(modified_eval)}"
)
eval_futs = collect_all([prev_result_tuple, modified_eval])
prev_result_tuple = eval_futs.then(
lambda eval_futs=eval_futs, num_examples=num_examples, inputs_tuple=inputs_tuple, current_masks=current_masks: self._eval_fut_to_prev_results_tuple(
eval_futs, num_examples, inputs_tuple, current_masks
)
)

if show_progress:
attr_progress.close()

# Divide total attributions by number of random permutations and return
# formatted attributions.
formatted_attr = prev_result_tuple.then(
lambda prev_result_tuple=prev_result_tuple: self._prev_result_tuple_to_formatted_attr(
prev_result_tuple, iter_count, is_inputs_tuple
)
)
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
# `Tuple[Tensor, ...]`.
return formatted_attr

def _initial_eval_to_prev_results_tuple(
self,
initial_eval: Future[Tensor],
num_examples: int,
perturbations_per_eval: int,
reshaped_feature_mask: TensorOrTupleOfTensorsGeneric,
inputs_tuple: Tuple[Tensor, ...],
) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
"""Since the initial eval is a Future, it is easier to bundle the prev_result, agg_output_mode, output_shape, and total_attrib together
as Shapley Value Feature Attributions are being calculated"""
try:
initial_eval_processed = initial_eval.value()
prev_result = initial_eval_processed
if not isinstance(initial_eval_processed, Tensor):
raise AssertionError(
"initial_eval_to_processed_initial_eval_fut: "
"initial_eval should be a Tensor"
)
agg_output_mode = _find_output_mode_and_verify(
initial_eval_processed,
num_examples,
perturbations_per_eval,
reshaped_feature_mask,
allow_multi_outputs=True,
)
output_shape = initial_eval_processed.shape
total_attrib: List[Tensor] = [
torch.zeros(
tuple(output_shape) + tuple(input.shape[1:]),
dtype=torch.float,
device=inputs_tuple[0].device,
)
for input in inputs_tuple
]
result = (
initial_eval_processed,
prev_result,
output_shape,
total_attrib,
agg_output_mode,
)
except ShapleyValueFutureError as e:
raise ShapleyValueFutureError(
"_initial_eval_to_prev_results_tuple func failed"
) from e
return result

def _set_prev_results_to_initial_eval(
self,
processed_initial_eval: Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]],
) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
"""At the beginning of each feature permutation, the prev_results is reset to the initial eval, and this method helps set that up"""
(initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) = (
processed_initial_eval.value()
)
prev_results = initial_eval
return (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode)

def _eval_fut_to_prev_results_tuple(
self,
eval_futs: Future[
List[
Union[
Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]],
Future[Tensor],
]
]
],
num_examples: int,
inputs_tuple: Tuple[Tensor, ...],
current_masks: Tuple[Tensor, ...],
) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
"""Helper method responsible for calculating eval differences between the modified eval and prev_results Tensor and storing them in total_attrib. Returns prev_results_tuple with modified total_attrib and prev_results"""
prev_results_tuple = eval_futs.value()[0].value()
modified_eval = eval_futs.value()[1].value()
if not isinstance(modified_eval, Tensor) or not isinstance(
prev_results_tuple, tuple
):
raise ShapleyValueFutureError(
"_eval_fut_to_prev_results_tuple func failed due to type mismatch"
)
(
initial_eval,
prev_results,
output_shape,
total_attrib,
agg_output_mode,
) = prev_results_tuple
if agg_output_mode:
eval_diff = modified_eval - prev_results
prev_results = modified_eval
else:
# when perturb_per_eval > 1, every num_examples stands for
# one perturb. Since the perturbs are from a consecutive
# perumuation, each diff of a perturb is its eval minus
# the eval of the previous perturb

all_eval = torch.cat((prev_results, modified_eval), dim=0)
eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
prev_results = all_eval[-num_examples:]

for j in range(len(total_attrib)):
# format eval_diff to shape
# (n_perturb, *output_shape, 1,.. 1)
# where n_perturb may not be perturb_per_eval
# Append n_input_feature dim of 1 to make the tensor
# have the same dim as the mask tensor.
formatted_eval_diff = eval_diff.reshape(
(-1,) + tuple(output_shape) + (len(inputs_tuple[j].shape) - 1) * (1,)
)

# mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
# reshape to
# (
# n_perturb,
# *broadcastable_to_output_shape
# *broadcastable_to_input_feature_shape
# )
cur_mask = current_masks[j]
cur_mask = cur_mask.reshape(
tuple(cur_mask.shape[:2])
+ (len(output_shape) - 1) * (1,)
+ tuple(cur_mask.shape[2:])
)

# aggregate n_perturb
cur_attr = (formatted_eval_diff * cur_mask.float()).sum(dim=0)
# (*output_shape, *input_feature_shape)
total_attrib[j] += cur_attr

result = (
initial_eval,
prev_results,
output_shape,
total_attrib,
agg_output_mode,
)
return result

def _prev_result_tuple_to_formatted_attr(
self,
prev_result_tuple: Future[
Tuple[Tensor, Tensor, Tuple[int], List[Tensor], bool]
],
iter_count: int,
is_inputs_tuple: bool,
) -> Union[Tensor, Tuple[Tensor, ...]]:
"""Helper method to format total_attrib, which is a list of tensors, into formatted attributions, which are either a single tensor or a tuple of tensors"""

(
_,
_,
_,
total_attrib,
_,
) = prev_result_tuple.value()
attrib = tuple(
tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib
)
formatted_attr = _format_output(is_inputs_tuple, attrib)
return formatted_attr

def _perturbation_generator(
self,
Expand Down Expand Up @@ -574,6 +858,37 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor:
# ref: https://github.com/pytorch/pytorch/pull/21215
return torch.tensor([forward_output], dtype=cast(dtype, output_type))

# pyre-fixme[2]: Parameter must be annotated.
def _strict_run_forward_future(self, *args, **kwargs) -> Future[Tensor]:
"""
A temp wrapper for global _run_forward util to force forward output
type assertion & conversion, but takes into account the Future tensor type
"""

def process_strict_run_forward(fut: Future[Tensor]) -> Tensor:
output = fut.value()
if isinstance(output, Tensor):
# format scalar to shape (1) so we can always assume non-empty output_shape
if not output.shape:
output = output.reshape(1)
return output
output_type = type(output)
assert output_type is int or output_type is float, (
"the return of forward_func must be a Future of tensor, int, or float,"
f" received: {output_type}"
)
output = torch.tensor([output], dtype=cast(dtype, output_type))
return output

forward_output = _run_forward(*args, **kwargs)
assert isinstance(forward_output, torch.Future), (
"The return type of forward_func must be a Future"
f" received: {type(forward_output)}"
)

return_output = forward_output.then(process_strict_run_forward)
return return_output


class ShapleyValues(ShapleyValueSampling):
"""
Expand Down
Loading

0 comments on commit 691e477

Please sign in to comment.