Skip to content

Commit 28aada1

Browse files
authored
Merge branch 'master' into ablation
2 parents 9eaa367 + b917b2a commit 28aada1

File tree

7 files changed

+160
-33
lines changed

7 files changed

+160
-33
lines changed

captum/_utils/typing.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TupleOrTensorOrBoolGeneric = TypeVar(
2525
"TupleOrTensorOrBoolGeneric", Tuple[Tensor, ...], Tensor, bool
2626
)
27+
PassThroughOutputType = TypeVar("PassThroughOutputType")
2728
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
2829
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
2930
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]

captum/attr/_core/feature_ablation.py

+38-18
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
# pyre-strict
44

55
import math
6-
from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union
6+
from typing import (
7+
Any,
8+
Callable,
9+
cast,
10+
Dict,
11+
Generator,
12+
List,
13+
Optional,
14+
Tuple,
15+
TypeVar,
16+
Union,
17+
)
718

819
import torch
920
from captum._utils.common import (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465476
attrib_type: dtype,
466477
**kwargs: Any,
467478
) -> Tuple[List[Tensor], List[Tensor]]:
479+
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
480+
for i, mask in enumerate(formatted_feature_mask):
481+
for feature_idx in torch.unique(mask):
482+
if feature_idx.item() not in feature_idx_to_tensor_idx:
483+
feature_idx_to_tensor_idx[feature_idx.item()] = []
484+
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
485+
468486
for (
469487
current_inputs,
470488
current_mask,
471489
) in self._ablation_generator(
472490
formatted_inputs,
473491
baselines,
474492
formatted_feature_mask,
493+
feature_idx_to_tensor_idx,
475494
**kwargs,
476495
):
477496
# modified_eval has (n_feature_perturbed * n_outputs) elements
@@ -511,27 +530,28 @@ def _ablation_generator(
511530
inputs: Tuple[Tensor, ...],
512531
baselines: BaselineType,
513532
input_mask: Tuple[Tensor, ...],
533+
feature_idx_to_tensor_idx: Dict[int, List[int]],
514534
**kwargs: Any,
515535
) -> Generator[
516536
Tuple[
517537
Tuple[Tensor, ...],
518-
Tuple[Tensor, ...],
538+
Tuple[Optional[Tensor], ...],
519539
],
520540
None,
521541
None,
522542
]:
523-
unique_feature_ids = torch.unique(
524-
torch.cat([mask.flatten() for mask in input_mask])
525-
).tolist()
526-
527543
if isinstance(baselines, torch.Tensor):
528544
baselines = baselines.reshape((1,) + tuple(baselines.shape))
529545

530546
# Process one feature per time, rather than processing every input tensor
531-
for feature_idx in unique_feature_ids:
547+
for feature_idx in feature_idx_to_tensor_idx.keys():
532548
ablated_inputs, current_masks = (
533549
self._construct_ablated_input_across_tensors(
534-
inputs, input_mask, baselines, feature_idx
550+
inputs,
551+
input_mask,
552+
baselines,
553+
feature_idx,
554+
feature_idx_to_tensor_idx[feature_idx],
535555
)
536556
)
537557
yield ablated_inputs, current_masks
@@ -542,18 +562,17 @@ def _construct_ablated_input_across_tensors(
542562
input_mask: Tuple[Tensor, ...],
543563
baselines: BaselineType,
544564
feature_idx: int,
545-
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
565+
tensor_idxs: List[int],
566+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
546567

547568
ablated_inputs = []
548-
current_masks = []
569+
current_masks: List[Optional[Tensor]] = []
549570
for i, input_tensor in enumerate(inputs):
550-
mask = input_mask[i]
551-
tensor_mask = mask == feature_idx
552-
if not tensor_mask.any():
571+
if i not in tensor_idxs:
553572
ablated_inputs.append(input_tensor)
554-
current_masks.append(torch.zeros_like(tensor_mask))
573+
current_masks.append(None)
555574
continue
556-
tensor_mask = tensor_mask.to(input_tensor.device).long()
575+
tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
557576
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
558577
if isinstance(baseline, torch.Tensor):
559578
baseline = baseline.reshape(
@@ -1173,7 +1192,7 @@ def _process_ablated_out(
11731192
def _process_ablated_out_full(
11741193
self,
11751194
modified_eval: Tensor,
1176-
current_mask: Tuple[Tensor, ...],
1195+
current_mask: Tuple[Optional[Tensor], ...],
11771196
flattened_initial_eval: Tensor,
11781197
inputs: TensorOrTupleOfTensorsGeneric,
11791198
n_outputs: int,
@@ -1195,9 +1214,10 @@ def _process_ablated_out_full(
11951214

11961215
if self.use_weights:
11971216
for weight, mask in zip(weights, current_mask):
1198-
weight += mask.float().sum(dim=0)
1217+
if mask is not None:
1218+
weight += mask.float().sum(dim=0)
11991219
for i, mask in enumerate(current_mask):
1200-
if inputs[i].numel() == 0:
1220+
if mask is None or inputs[i].numel() == 0:
12011221
continue
12021222
eval_diff = eval_diff.reshape(
12031223
eval_diff_shape + (inputs[i].dim() - 1) * (1,)

captum/attr/_core/feature_permutation.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Optional, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
@@ -26,15 +26,15 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
2626

2727

2828
def _permute_features_across_tensors(
29-
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Tensor, ...]
29+
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Optional[Tensor], ...]
3030
) -> Tuple[Tensor, ...]:
3131
"""
3232
Permutes features across multiple input tensors using the corresponding
3333
feature masks.
3434
"""
3535
permuted_outputs = []
3636
for input_tensor, feature_mask in zip(inputs, feature_masks):
37-
if not feature_mask.any():
37+
if feature_mask is None or not feature_mask.any():
3838
permuted_outputs.append(input_tensor)
3939
continue
4040
n = input_tensor.size(0)
@@ -103,7 +103,7 @@ def __init__(
103103
forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]],
104104
perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature,
105105
perm_func_cross_tensor: Callable[
106-
[Tuple[Tensor, ...], Tuple[Tensor, ...]], Tuple[Tensor, ...]
106+
[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]], Tuple[Tensor, ...]
107107
] = _permute_features_across_tensors,
108108
) -> None:
109109
r"""
@@ -392,9 +392,14 @@ def _construct_ablated_input_across_tensors(
392392
input_mask: Tuple[Tensor, ...],
393393
baselines: BaselineType,
394394
feature_idx: int,
395-
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
396-
feature_masks = tuple(
397-
(mask == feature_idx).to(inputs[0].device) for mask in input_mask
398-
)
395+
tensor_idxs: List[int],
396+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
397+
current_masks: List[Optional[Tensor]] = []
398+
for i, mask in enumerate(input_mask):
399+
if i in tensor_idxs:
400+
current_masks.append((mask == feature_idx).to(inputs[0].device))
401+
else:
402+
current_masks.append(None)
403+
feature_masks = tuple(current_masks)
399404
permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks)
400405
return permuted_outputs, feature_masks

captum/attr/_utils/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def _find_output_mode_and_verify(
364364
"returns a scalar."
365365
)
366366
else:
367-
agg_output_mode = False
367+
agg_output_mode = perturbations_per_eval == 1
368368
if not allow_multi_outputs:
369369
assert (
370370
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1

captum/testing/helpers/basic_models.py

+77-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.nn as nn
99
import torch.nn.functional as F
10+
from captum._utils.typing import PassThroughOutputType
1011
from torch import Tensor
1112
from torch.futures import Future
1213

@@ -417,6 +418,76 @@ def forward(self, input1, input2, input3=None):
417418
return self.linear2(self.relu(self.linear1(embeddings))).sum(1)
418419

419420

421+
class GradientUnsupportedLayerOutput(nn.Module):
422+
"""
423+
This layer is used to test the case where the model returns a layer that
424+
is not supported by the gradient computation.
425+
"""
426+
427+
def __init__(self) -> None:
428+
super().__init__()
429+
430+
@no_type_check
431+
def forward(
432+
self, unsupported_layer_output: PassThroughOutputType
433+
) -> PassThroughOutputType:
434+
return unsupported_layer_output
435+
436+
437+
class BasicModel_GradientLayerAttribution(nn.Module):
438+
def __init__(
439+
self,
440+
inplace: bool = False,
441+
unsupported_layer_output: PassThroughOutputType = None,
442+
) -> None:
443+
super().__init__()
444+
# Linear 0 is simply identity transform
445+
self.unsupported_layer_output = unsupported_layer_output
446+
self.linear0 = nn.Linear(3, 3)
447+
self.linear0.weight = nn.Parameter(torch.eye(3))
448+
self.linear0.bias = nn.Parameter(torch.zeros(3))
449+
self.linear1 = nn.Linear(3, 4)
450+
self.linear1.weight = nn.Parameter(torch.ones(4, 3))
451+
self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
452+
453+
self.linear1_alt = nn.Linear(3, 4)
454+
self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3))
455+
self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
456+
457+
self.relu = nn.ReLU(inplace=inplace)
458+
self.relu_alt = nn.ReLU(inplace=False)
459+
self.unsupportedLayer = GradientUnsupportedLayerOutput()
460+
461+
self.linear2 = nn.Linear(4, 2)
462+
self.linear2.weight = nn.Parameter(torch.ones(2, 4))
463+
self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
464+
465+
self.linear3 = nn.Linear(4, 2)
466+
self.linear3.weight = nn.Parameter(torch.ones(2, 4))
467+
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
468+
469+
@no_type_check
470+
def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor:
471+
input = x if add_input is None else x + add_input
472+
lin0_out = self.linear0(input)
473+
lin1_out = self.linear1(lin0_out)
474+
lin1_out_alt = self.linear1_alt(lin0_out)
475+
476+
if self.unsupported_layer_output is not None:
477+
self.unsupportedLayer(self.unsupported_layer_output)
478+
# unsupportedLayer is unused in the forward func.
479+
self.relu_alt(
480+
lin1_out_alt
481+
) # relu_alt's output is supported but it's unused in the forward func.
482+
483+
relu_out = self.relu(lin1_out)
484+
lin2_out = self.linear2(relu_out)
485+
486+
lin3_out = self.linear3(lin1_out_alt).to(torch.int64)
487+
488+
return torch.cat((lin2_out, lin3_out), dim=1)
489+
490+
420491
class MultiRelu(nn.Module):
421492
def __init__(self, inplace: bool = False) -> None:
422493
super().__init__()
@@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:
429500

430501

431502
class BasicModel_MultiLayer(nn.Module):
432-
def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None:
503+
def __init__(
504+
self,
505+
inplace: bool = False,
506+
multi_input_module: bool = False,
507+
) -> None:
433508
super().__init__()
434509
# Linear 0 is simply identity transform
435510
self.multi_input_module = multi_input_module
@@ -461,6 +536,7 @@ def forward(
461536
input = x if add_input is None else x + add_input
462537
lin0_out = self.linear0(input)
463538
lin1_out = self.linear1(lin0_out)
539+
464540
if self.multi_input_module:
465541
relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input))
466542
relu_out = relu_out1 + relu_out2

tests/attr/test_data_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"""
4242

4343
# Distributed Data Parallel env setup
44-
os.environ["MASTER_ADDR"] = "127.0.0.1"
44+
os.environ["MASTER_ADDR"] = "localhost"
4545
os.environ["MASTER_PORT"] = "29500"
4646
dist.init_process_group(backend="gloo", rank=0, world_size=1)
4747

tests/attr/test_shapley.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,30 @@ def func_future(*inp):
806806
lambda *inp: func_to_use(*inp), use_future=use_future
807807
)
808808

809+
@parameterized.expand([True, False])
810+
def test_mutli_inp_shapley_batch_scalar_tensor_expanded(self, use_future) -> None:
811+
def func(*inp):
812+
sum_val = torch.sum(net(*inp)).item()
813+
return torch.tensor([sum_val, sum_val + 2.0, sum_val + 3.0])
814+
815+
def func_future(*inp):
816+
temp = net_fut(*inp)
817+
temp.wait()
818+
sum_val = torch.sum(temp.value()).item()
819+
fut = Future()
820+
fut.set_result(torch.tensor([sum_val, sum_val + 2.0, sum_val + 3.0]))
821+
return fut
822+
823+
if use_future:
824+
net_fut = BasicModel_MultiLayer_MultiInput_with_Future()
825+
func_to_use = func_future
826+
else:
827+
net = BasicModel_MultiLayer_MultiInput()
828+
func_to_use = func
829+
self._multi_input_batch_scalar_shapley_assert(
830+
lambda *inp: func_to_use(*inp), use_future=use_future, expanded_output=True
831+
)
832+
809833
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
810834
def test_shapley_sampling_with_show_progress(self, mock_stderr) -> None:
811835
net = BasicModel_MultiLayer()
@@ -947,18 +971,19 @@ def _single_int_input_multi_sample_batch_scalar_shapley_assert(
947971
)
948972

949973
def _multi_input_batch_scalar_shapley_assert(
950-
self, func: Callable, use_future: bool = False
974+
self, func: Callable, use_future: bool = False, expanded_output: bool = False
951975
) -> None:
952976
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
953977
inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]])
954978
inp3 = torch.tensor([[0.0, 100.0, 10.0], [20.0, 10.0, 13.0]])
955979
mask1 = torch.tensor([[1, 1, 1]])
956980
mask2 = torch.tensor([[0, 1, 2]])
957981
mask3 = torch.tensor([[0, 1, 2]])
982+
out_mult = 3 if expanded_output else 1
958983
expected = (
959-
[[3850.6666, 3850.6666, 3850.6666]],
960-
[[306.6666, 3850.6666, 410.6666]],
961-
[[306.6666, 3850.6666, 410.6666]],
984+
[[3850.6666, 3850.6666, 3850.6666]] * out_mult,
985+
[[306.6666, 3850.6666, 410.6666]] * out_mult,
986+
[[306.6666, 3850.6666, 410.6666]] * out_mult,
962987
)
963988
if use_future:
964989
self._shapley_test_assert_future(

0 commit comments

Comments
 (0)