Skip to content

Commit

Permalink
Enable multi-task attribution for Shapley (#1173)
Browse files Browse the repository at this point in the history
Summary:

Support multi-task attribution in `ShapleyValues` and `ShapleyValueSampling`.

Assuming the return of `forward_fun` is in (*output_shape), the attribution result will be in (*output_shape, *input_shape[1:]). Existing use cases becomes just special cases where output_shape is (1,) or (batch_size,)

Reviewed By: vivekmig

Differential Revision: D48696578
  • Loading branch information
aobo-y authored and facebook-github-bot committed Oct 25, 2023
1 parent b8817e8 commit 9857837
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 21 deletions.
82 changes: 64 additions & 18 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,31 @@ def attribute(
)
attr_progress.update(0)

initial_eval = _run_forward(
initial_eval = self._strict_run_forward(
self.forward_func, baselines, target, additional_forward_args
)

if show_progress:
attr_progress.update()

agg_output_mode = _find_output_mode_and_verify(
initial_eval, num_examples, perturbations_per_eval, feature_mask
initial_eval,
num_examples,
perturbations_per_eval,
feature_mask,
allow_multi_outputs=True,
)

# Initialize attribution totals and counts
output_shape = initial_eval.shape
n_outputs = initial_eval.numel()

# attr shape (*output_shape, *input_feature_shape)
total_attrib = [
torch.zeros_like(
input[0:1] if agg_output_mode else input, dtype=torch.float
torch.zeros(
(*output_shape, *input.shape[1:]),
dtype=torch.float,
device=inputs[0].device,
)
for input in inputs
]
Expand Down Expand Up @@ -349,7 +359,7 @@ def attribute(
)
# modified_eval dimensions: 1D tensor with length
# equal to #num_examples * #features in batch
modified_eval = _run_forward(
modified_eval = self._strict_run_forward(
self.forward_func,
current_inputs,
current_target,
Expand All @@ -362,23 +372,34 @@ def attribute(
eval_diff = modified_eval - prev_results
prev_results = modified_eval
else:
# when perturb_per_eval > 1, every num_examples stands for
# one perturb. Since the perturbs are from a consecutive
# perumuation, each diff of a perturb is its eval minus
# the eval of the previous perturb
all_eval = torch.cat((prev_results, modified_eval), dim=0)
eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
prev_results = all_eval[-num_examples:]

for j in range(len(total_attrib)):
current_eval_diff = eval_diff
if not agg_output_mode:
# current_eval_diff dimensions:
# (#features in batch, #num_examples, 1,.. 1)
# (contains 1 more dimension than inputs). This adds extra
# dimensions of 1 to make the tensor broadcastable with the
# inputs tensor.
current_eval_diff = current_eval_diff.reshape(
(-1, num_examples) + (len(inputs[j].shape) - 1) * (1,)
)
total_attrib[j] += (
current_eval_diff * current_masks[j].float()
).sum(dim=0)
# format eval_diff to shape
# (n_perturb, n_outputs, 1,.. 1)
# where n_perturb may not be perturb_per_eval
# Append n_input_feature dim of 1 to make the tensor
# have the same dim as the mask tensor.
formatted_eval_diff = eval_diff.reshape(
(-1, n_outputs) + (len(inputs[j].shape) - 1) * (1,)
)

# mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
# aggregate n_perturb
cur_attr = (formatted_eval_diff * current_masks[j].float()).sum(
dim=0
)

# (n_outputs, *input_feature_shape) -< (*output_shape, *input_feature_shape)
total_attrib[j] += cur_attr.reshape(
(*output_shape, *cur_attr.shape[1:])
)

if show_progress:
attr_progress.close()
Expand Down Expand Up @@ -476,6 +497,31 @@ def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval):
"""return the total number of forward evaluations needed"""
return math.ceil(total_features / perturbations_per_eval) * n_samples

def _strict_run_forward(self, *args, **kwargs) -> Tensor:
"""
A temp wrapper for global _run_forward util to force forward output
type assertion & conversion.
Remove after the strict logic is supported by all attr classes
"""
forward_output = _run_forward(*args, **kwargs)
if isinstance(forward_output, Tensor):
# format scalar to shape (1) so we can always assume non-empty output_shape
if not forward_output.shape:
forward_output = forward_output.reshape(1)

return forward_output

output_type = type(forward_output)
assert output_type is int or output_type is float, (
"the return of forward_func must be a tensor, int, or float,"
f" received: {forward_output}"
)

# using python built-in type as torch dtype
# int -> torch.int64, float -> torch.float64
# ref: https://github.com/pytorch/pytorch/pull/21215
return torch.tensor([forward_output], dtype=output_type)


class ShapleyValues(ShapleyValueSampling):
"""
Expand Down
8 changes: 5 additions & 3 deletions captum/attr/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def _find_output_mode_and_verify(
num_examples: int,
perturbations_per_eval: int,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric],
allow_multi_outputs: bool = False,
) -> bool:
"""
This method identifies whether the model outputs a single output for a batch
Expand Down Expand Up @@ -346,9 +347,10 @@ def _find_output_mode_and_verify(
)
else:
agg_output_mode = False
assert (
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1
), "Target should identify a single element in the model output."
if not allow_multi_outputs:
assert (
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1
), "Target should identify a single element in the model output."
return agg_output_mode


Expand Down
37 changes: 37 additions & 0 deletions tests/attr/test_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,43 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None:
perturbations_per_eval=(1, 2, 3),
)

def test_shapley_sampling_multi_task_output(self) -> None:
# return shape (batch size, 2)
net1 = BasicModel_MultiLayer()

# return shape (batch size, 4)
def forward_func(*args, **kwargs):
net_output = net1(*args, **kwargs)
batch_size = net_output.size(0)
constant = torch.ones(batch_size, 2)
output = torch.cat(
[
net_output,
constant,
],
dim=-1,
)
return output

inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)

self._shapley_test_assert(
forward_func,
inp,
[
[
[76.66666, 196.66666, 116.66666],
[76.66666, 196.66666, 116.66666],
[0, 0, 0],
[0, 0, 0],
]
],
target=None, # no target, multi-task output for all classes
perturbations_per_eval=(1, 2, 3),
n_samples=150,
test_true_shapley=True,
)

# Remaining tests are for cases where forward function returns a scalar
# per batch, as either a float, integer, 0d tensor or 1d tensor.
def test_single_shapley_batch_scalar_float(self) -> None:
Expand Down

0 comments on commit 9857837

Please sign in to comment.