Skip to content

Commit c6781a5

Browse files
jjunchofacebook-github-bot
authored andcommitted
Adding async future functionality to ShapleyValues (#1487)
Summary: This diff implements the attribute_future method for the ShapleyValueSampling class. Reviewed By: cyrjano Differential Revision: D68158802
1 parent b487891 commit c6781a5

File tree

4 files changed

+446
-17
lines changed

4 files changed

+446
-17
lines changed

captum/_utils/exceptions.py

+8
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,11 @@ class FeatureAblationFutureError(Exception):
99
FeatureAblation attribution call"""
1010

1111
pass
12+
13+
14+
class ShapleyValueFutureError(Exception):
15+
"""This custom error is raised when an error
16+
occurs within the callback chain of a
17+
ShapleyValue attribution call"""
18+
19+
pass

captum/attr/_core/shapley_value.py

+323-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import itertools
66
import math
77
import warnings
8-
from typing import Callable, cast, Iterable, Optional, Sequence, Tuple, Union
8+
from typing import Callable, cast, Iterable, List, Optional, Sequence, Tuple, Union
99

1010
import torch
1111
from captum._utils.common import (
@@ -20,6 +20,7 @@
2020
_is_tuple,
2121
_run_forward,
2222
)
23+
from captum._utils.exceptions import ShapleyValueFutureError
2324
from captum._utils.progress import progress
2425
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
2526
from captum.attr._utils.attribution import PerturbationAttribution
@@ -29,7 +30,8 @@
2930
_tensorize_baseline,
3031
)
3132
from captum.log import log_usage
32-
from torch import dtype, Tensor
33+
from torch import dtype, Size, Tensor
34+
from torch.futures import collect_all, Future
3335

3436

3537
def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]:
@@ -394,7 +396,6 @@ def attribute(
394396
)
395397
if show_progress:
396398
attr_progress.update()
397-
398399
if agg_output_mode:
399400
eval_diff = modified_eval - prev_results
400401
prev_results = modified_eval
@@ -438,7 +439,6 @@ def attribute(
438439

439440
# (*output_shape, *input_feature_shape)
440441
total_attrib[j] += cur_attr
441-
442442
if show_progress:
443443
attr_progress.close()
444444

@@ -452,14 +452,298 @@ def attribute(
452452
# `Tuple[Tensor, ...]`.
453453
return formatted_attr
454454

455-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
456-
def attribute_future(self) -> Callable:
455+
def attribute_future(
456+
self,
457+
inputs: TensorOrTupleOfTensorsGeneric,
458+
baselines: BaselineType = None,
459+
target: TargetType = None,
460+
additional_forward_args: Optional[Tuple[object, ...]] = None,
461+
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
462+
n_samples: int = 25,
463+
perturbations_per_eval: int = 1,
464+
show_progress: bool = False,
465+
) -> Future[TensorOrTupleOfTensorsGeneric]:
457466
r"""
458467
This method is not implemented for ShapleyValueSampling.
459468
"""
460-
raise NotImplementedError(
461-
"attribute_future is not implemented for ShapleyValueSampling"
469+
is_inputs_tuple = _is_tuple(inputs)
470+
inputs_tuple, baselines = _format_input_baseline(inputs, baselines)
471+
additional_forward_args = _format_additional_forward_args(
472+
additional_forward_args
462473
)
474+
formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple)
475+
reshaped_feature_mask = _shape_feature_mask(
476+
formatted_feature_mask, inputs_tuple
477+
)
478+
479+
assert (
480+
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
481+
), "Ablations per evaluation must be at least 1."
482+
483+
with torch.no_grad():
484+
baselines = _tensorize_baseline(inputs_tuple, baselines)
485+
num_examples = inputs_tuple[0].shape[0]
486+
487+
total_features = _get_max_feature_index(reshaped_feature_mask) + 1
488+
489+
if show_progress:
490+
attr_progress = progress(
491+
desc=f"{self.get_name()} attribution",
492+
total=self._get_n_evaluations(
493+
total_features, n_samples, perturbations_per_eval
494+
)
495+
+ 1, # add 1 for the initial eval
496+
)
497+
attr_progress.update(0)
498+
499+
initial_eval = self._strict_run_forward_future(
500+
self.forward_func, baselines, target, additional_forward_args
501+
)
502+
503+
if show_progress:
504+
attr_progress.update()
505+
506+
prev_result_tuple = initial_eval.then(
507+
lambda initial_eval=initial_eval: self._initial_eval_to_prev_results_tuple(
508+
initial_eval,
509+
num_examples,
510+
perturbations_per_eval,
511+
reshaped_feature_mask,
512+
inputs_tuple,
513+
)
514+
)
515+
516+
iter_count = 0
517+
# Iterate for number of samples, generate a permutation of the features
518+
# and evalute the incremental increase for each feature.
519+
for feature_permutation in self.permutation_generator(
520+
total_features, n_samples
521+
):
522+
prev_result_tuple = prev_result_tuple.then(
523+
lambda prev_result_tuple=prev_result_tuple: self._set_prev_results_to_initial_eval(
524+
prev_result_tuple
525+
)
526+
)
527+
528+
iter_count += 1
529+
for (
530+
current_inputs,
531+
current_add_args,
532+
current_target,
533+
current_masks,
534+
) in self._perturbation_generator(
535+
inputs_tuple,
536+
additional_forward_args,
537+
target,
538+
baselines,
539+
reshaped_feature_mask,
540+
feature_permutation,
541+
perturbations_per_eval,
542+
):
543+
if sum(torch.sum(mask).item() for mask in current_masks) == 0:
544+
warnings.warn(
545+
"Feature mask is missing some integers between 0 and "
546+
"num_features, for optimal performance, make sure each"
547+
" consecutive integer corresponds to a feature.",
548+
stacklevel=1,
549+
)
550+
# modified_eval dimensions: 1D tensor with length
551+
# equal to #num_examples * #features in batch
552+
modified_eval = self._strict_run_forward_future(
553+
self.forward_func,
554+
current_inputs,
555+
current_target,
556+
current_add_args,
557+
)
558+
if show_progress:
559+
attr_progress.update()
560+
561+
assert isinstance(modified_eval, torch.Future), (
562+
"when using futures method, modified_eval should have "
563+
f"Future type rather than {type(modified_eval)}"
564+
)
565+
eval_futs = collect_all([prev_result_tuple, modified_eval])
566+
prev_result_tuple = eval_futs.then(
567+
lambda eval_futs=eval_futs, num_examples=num_examples, inputs_tuple=inputs_tuple, current_masks=current_masks: self._eval_fut_to_prev_results_tuple(
568+
eval_futs, num_examples, inputs_tuple, current_masks
569+
)
570+
)
571+
572+
if show_progress:
573+
attr_progress.close()
574+
575+
# Divide total attributions by number of random permutations and return
576+
# formatted attributions.
577+
formatted_attr = prev_result_tuple.then(
578+
lambda prev_result_tuple=prev_result_tuple: self._prev_result_tuple_to_formatted_attr(
579+
prev_result_tuple, iter_count, is_inputs_tuple
580+
)
581+
)
582+
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
583+
# `Tuple[Tensor, ...]`.
584+
return formatted_attr
585+
586+
def _initial_eval_to_prev_results_tuple(
587+
self,
588+
initial_eval: Future[Tensor],
589+
num_examples: int,
590+
perturbations_per_eval: int,
591+
reshaped_feature_mask: TensorOrTupleOfTensorsGeneric,
592+
inputs_tuple: Tuple[Tensor, ...],
593+
) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
594+
"""Since the initial eval is a Future, it is easier to bundle the prev_result, agg_output_mode, output_shape, and total_attrib together
595+
as Shapley Value Feature Attributions are being calculated"""
596+
try:
597+
initial_eval_processed = initial_eval.value()
598+
prev_result = initial_eval_processed
599+
if not isinstance(initial_eval_processed, Tensor):
600+
raise AssertionError(
601+
"initial_eval_to_processed_initial_eval_fut: "
602+
"initial_eval should be a Tensor"
603+
)
604+
agg_output_mode = _find_output_mode_and_verify(
605+
initial_eval_processed,
606+
num_examples,
607+
perturbations_per_eval,
608+
reshaped_feature_mask,
609+
allow_multi_outputs=True,
610+
)
611+
output_shape = initial_eval_processed.shape
612+
total_attrib: List[Tensor] = [
613+
torch.zeros(
614+
tuple(output_shape) + tuple(input.shape[1:]),
615+
dtype=torch.float,
616+
device=inputs_tuple[0].device,
617+
)
618+
for input in inputs_tuple
619+
]
620+
result = (
621+
initial_eval_processed,
622+
prev_result,
623+
output_shape,
624+
total_attrib,
625+
agg_output_mode,
626+
)
627+
except ShapleyValueFutureError as e:
628+
raise ShapleyValueFutureError(
629+
"_initial_eval_to_prev_results_tuple func failed"
630+
) from e
631+
return result
632+
633+
def _set_prev_results_to_initial_eval(
634+
self,
635+
processed_initial_eval: Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]],
636+
) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
637+
"""At the beginning of each feature permutation, the prev_results is reset to the initial eval, and this method helps set that up"""
638+
(initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) = (
639+
processed_initial_eval.value()
640+
)
641+
prev_results = initial_eval
642+
return (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode)
643+
644+
def _eval_fut_to_prev_results_tuple(
645+
self,
646+
eval_futs: Future[
647+
List[
648+
Union[
649+
Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]],
650+
Future[Tensor],
651+
]
652+
]
653+
],
654+
num_examples: int,
655+
inputs_tuple: Tuple[Tensor, ...],
656+
current_masks: Tuple[Tensor, ...],
657+
) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]:
658+
"""Helper method responsible for calculating eval differences between the modified eval and prev_results Tensor and storing them in total_attrib. Returns prev_results_tuple with modified total_attrib and prev_results"""
659+
prev_results_tuple = eval_futs.value()[0].value()
660+
modified_eval = eval_futs.value()[1].value()
661+
if not isinstance(modified_eval, Tensor) or not isinstance(
662+
prev_results_tuple, tuple
663+
):
664+
raise ShapleyValueFutureError(
665+
"_eval_fut_to_prev_results_tuple func failed due to type mismatch"
666+
)
667+
(
668+
initial_eval,
669+
prev_results,
670+
output_shape,
671+
total_attrib,
672+
agg_output_mode,
673+
) = prev_results_tuple
674+
if agg_output_mode:
675+
eval_diff = modified_eval - prev_results
676+
prev_results = modified_eval
677+
else:
678+
# when perturb_per_eval > 1, every num_examples stands for
679+
# one perturb. Since the perturbs are from a consecutive
680+
# perumuation, each diff of a perturb is its eval minus
681+
# the eval of the previous perturb
682+
683+
all_eval = torch.cat((prev_results, modified_eval), dim=0)
684+
eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
685+
prev_results = all_eval[-num_examples:]
686+
687+
for j in range(len(total_attrib)):
688+
# format eval_diff to shape
689+
# (n_perturb, *output_shape, 1,.. 1)
690+
# where n_perturb may not be perturb_per_eval
691+
# Append n_input_feature dim of 1 to make the tensor
692+
# have the same dim as the mask tensor.
693+
formatted_eval_diff = eval_diff.reshape(
694+
(-1,) + tuple(output_shape) + (len(inputs_tuple[j].shape) - 1) * (1,)
695+
)
696+
697+
# mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
698+
# reshape to
699+
# (
700+
# n_perturb,
701+
# *broadcastable_to_output_shape
702+
# *broadcastable_to_input_feature_shape
703+
# )
704+
cur_mask = current_masks[j]
705+
cur_mask = cur_mask.reshape(
706+
tuple(cur_mask.shape[:2])
707+
+ (len(output_shape) - 1) * (1,)
708+
+ tuple(cur_mask.shape[2:])
709+
)
710+
711+
# aggregate n_perturb
712+
cur_attr = (formatted_eval_diff * cur_mask.float()).sum(dim=0)
713+
# (*output_shape, *input_feature_shape)
714+
total_attrib[j] += cur_attr
715+
716+
result = (
717+
initial_eval,
718+
prev_results,
719+
output_shape,
720+
total_attrib,
721+
agg_output_mode,
722+
)
723+
return result
724+
725+
def _prev_result_tuple_to_formatted_attr(
726+
self,
727+
prev_result_tuple: Future[
728+
Tuple[Tensor, Tensor, Tuple[int], List[Tensor], bool]
729+
],
730+
iter_count: int,
731+
is_inputs_tuple: bool,
732+
) -> Union[Tensor, Tuple[Tensor, ...]]:
733+
"""Helper method to format total_attrib, which is a list of tensors, into formatted attributions, which are either a single tensor or a tuple of tensors"""
734+
735+
(
736+
_,
737+
_,
738+
_,
739+
total_attrib,
740+
_,
741+
) = prev_result_tuple.value()
742+
attrib = tuple(
743+
tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib
744+
)
745+
formatted_attr = _format_output(is_inputs_tuple, attrib)
746+
return formatted_attr
463747

464748
def _perturbation_generator(
465749
self,
@@ -574,6 +858,37 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor:
574858
# ref: https://github.com/pytorch/pytorch/pull/21215
575859
return torch.tensor([forward_output], dtype=cast(dtype, output_type))
576860

861+
# pyre-fixme[2]: Parameter must be annotated.
862+
def _strict_run_forward_future(self, *args, **kwargs) -> Future[Tensor]:
863+
"""
864+
A temp wrapper for global _run_forward util to force forward output
865+
type assertion & conversion, but takes into account the Future tensor type
866+
"""
867+
868+
def process_strict_run_forward(fut: Future[Tensor]) -> Tensor:
869+
output = fut.value()
870+
if isinstance(output, Tensor):
871+
# format scalar to shape (1) so we can always assume non-empty output_shape
872+
if not output.shape:
873+
output = output.reshape(1)
874+
return output
875+
output_type = type(output)
876+
assert output_type is int or output_type is float, (
877+
"the return of forward_func must be a Future of tensor, int, or float,"
878+
f" received: {output_type}"
879+
)
880+
output = torch.tensor([output], dtype=cast(dtype, output_type))
881+
return output
882+
883+
forward_output = _run_forward(*args, **kwargs)
884+
assert isinstance(forward_output, torch.Future), (
885+
"The return type of forward_func must be a Future"
886+
f" received: {type(forward_output)}"
887+
)
888+
889+
return_output = forward_output.then(process_strict_run_forward)
890+
return return_output
891+
577892

578893
class ShapleyValues(ShapleyValueSampling):
579894
"""

0 commit comments

Comments
 (0)