Skip to content

Commit 382ea27

Browse files
authored
Merge branch 'master' into ernie
2 parents 0dd0271 + 4ca5c2c commit 382ea27

File tree

5 files changed

+109
-7
lines changed

5 files changed

+109
-7
lines changed

captum/_utils/typing.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TupleOrTensorOrBoolGeneric = TypeVar(
2525
"TupleOrTensorOrBoolGeneric", Tuple[Tensor, ...], Tensor, bool
2626
)
27+
PassThroughOutputType = TypeVar("PassThroughOutputType")
2728
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
2829
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
2930
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]

captum/attr/_utils/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def _find_output_mode_and_verify(
364364
"returns a scalar."
365365
)
366366
else:
367-
agg_output_mode = False
367+
agg_output_mode = perturbations_per_eval == 1
368368
if not allow_multi_outputs:
369369
assert (
370370
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1

captum/testing/helpers/basic_models.py

+77-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.nn as nn
99
import torch.nn.functional as F
10+
from captum._utils.typing import PassThroughOutputType
1011
from torch import Tensor
1112
from torch.futures import Future
1213

@@ -417,6 +418,76 @@ def forward(self, input1, input2, input3=None):
417418
return self.linear2(self.relu(self.linear1(embeddings))).sum(1)
418419

419420

421+
class GradientUnsupportedLayerOutput(nn.Module):
422+
"""
423+
This layer is used to test the case where the model returns a layer that
424+
is not supported by the gradient computation.
425+
"""
426+
427+
def __init__(self) -> None:
428+
super().__init__()
429+
430+
@no_type_check
431+
def forward(
432+
self, unsupported_layer_output: PassThroughOutputType
433+
) -> PassThroughOutputType:
434+
return unsupported_layer_output
435+
436+
437+
class BasicModel_GradientLayerAttribution(nn.Module):
438+
def __init__(
439+
self,
440+
inplace: bool = False,
441+
unsupported_layer_output: PassThroughOutputType = None,
442+
) -> None:
443+
super().__init__()
444+
# Linear 0 is simply identity transform
445+
self.unsupported_layer_output = unsupported_layer_output
446+
self.linear0 = nn.Linear(3, 3)
447+
self.linear0.weight = nn.Parameter(torch.eye(3))
448+
self.linear0.bias = nn.Parameter(torch.zeros(3))
449+
self.linear1 = nn.Linear(3, 4)
450+
self.linear1.weight = nn.Parameter(torch.ones(4, 3))
451+
self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
452+
453+
self.linear1_alt = nn.Linear(3, 4)
454+
self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3))
455+
self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))
456+
457+
self.relu = nn.ReLU(inplace=inplace)
458+
self.relu_alt = nn.ReLU(inplace=False)
459+
self.unsupportedLayer = GradientUnsupportedLayerOutput()
460+
461+
self.linear2 = nn.Linear(4, 2)
462+
self.linear2.weight = nn.Parameter(torch.ones(2, 4))
463+
self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
464+
465+
self.linear3 = nn.Linear(4, 2)
466+
self.linear3.weight = nn.Parameter(torch.ones(2, 4))
467+
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
468+
469+
@no_type_check
470+
def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor:
471+
input = x if add_input is None else x + add_input
472+
lin0_out = self.linear0(input)
473+
lin1_out = self.linear1(lin0_out)
474+
lin1_out_alt = self.linear1_alt(lin0_out)
475+
476+
if self.unsupported_layer_output is not None:
477+
self.unsupportedLayer(self.unsupported_layer_output)
478+
# unsupportedLayer is unused in the forward func.
479+
self.relu_alt(
480+
lin1_out_alt
481+
) # relu_alt's output is supported but it's unused in the forward func.
482+
483+
relu_out = self.relu(lin1_out)
484+
lin2_out = self.linear2(relu_out)
485+
486+
lin3_out = self.linear3(lin1_out_alt).to(torch.int64)
487+
488+
return torch.cat((lin2_out, lin3_out), dim=1)
489+
490+
420491
class MultiRelu(nn.Module):
421492
def __init__(self, inplace: bool = False) -> None:
422493
super().__init__()
@@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:
429500

430501

431502
class BasicModel_MultiLayer(nn.Module):
432-
def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None:
503+
def __init__(
504+
self,
505+
inplace: bool = False,
506+
multi_input_module: bool = False,
507+
) -> None:
433508
super().__init__()
434509
# Linear 0 is simply identity transform
435510
self.multi_input_module = multi_input_module
@@ -461,6 +536,7 @@ def forward(
461536
input = x if add_input is None else x + add_input
462537
lin0_out = self.linear0(input)
463538
lin1_out = self.linear1(lin0_out)
539+
464540
if self.multi_input_module:
465541
relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input))
466542
relu_out = relu_out1 + relu_out2

tests/attr/test_data_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"""
4242

4343
# Distributed Data Parallel env setup
44-
os.environ["MASTER_ADDR"] = "127.0.0.1"
44+
os.environ["MASTER_ADDR"] = "localhost"
4545
os.environ["MASTER_PORT"] = "29500"
4646
dist.init_process_group(backend="gloo", rank=0, world_size=1)
4747

tests/attr/test_shapley.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,30 @@ def func_future(*inp):
806806
lambda *inp: func_to_use(*inp), use_future=use_future
807807
)
808808

809+
@parameterized.expand([True, False])
810+
def test_mutli_inp_shapley_batch_scalar_tensor_expanded(self, use_future) -> None:
811+
def func(*inp):
812+
sum_val = torch.sum(net(*inp)).item()
813+
return torch.tensor([sum_val, sum_val + 2.0, sum_val + 3.0])
814+
815+
def func_future(*inp):
816+
temp = net_fut(*inp)
817+
temp.wait()
818+
sum_val = torch.sum(temp.value()).item()
819+
fut = Future()
820+
fut.set_result(torch.tensor([sum_val, sum_val + 2.0, sum_val + 3.0]))
821+
return fut
822+
823+
if use_future:
824+
net_fut = BasicModel_MultiLayer_MultiInput_with_Future()
825+
func_to_use = func_future
826+
else:
827+
net = BasicModel_MultiLayer_MultiInput()
828+
func_to_use = func
829+
self._multi_input_batch_scalar_shapley_assert(
830+
lambda *inp: func_to_use(*inp), use_future=use_future, expanded_output=True
831+
)
832+
809833
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
810834
def test_shapley_sampling_with_show_progress(self, mock_stderr) -> None:
811835
net = BasicModel_MultiLayer()
@@ -947,18 +971,19 @@ def _single_int_input_multi_sample_batch_scalar_shapley_assert(
947971
)
948972

949973
def _multi_input_batch_scalar_shapley_assert(
950-
self, func: Callable, use_future: bool = False
974+
self, func: Callable, use_future: bool = False, expanded_output: bool = False
951975
) -> None:
952976
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
953977
inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]])
954978
inp3 = torch.tensor([[0.0, 100.0, 10.0], [20.0, 10.0, 13.0]])
955979
mask1 = torch.tensor([[1, 1, 1]])
956980
mask2 = torch.tensor([[0, 1, 2]])
957981
mask3 = torch.tensor([[0, 1, 2]])
982+
out_mult = 3 if expanded_output else 1
958983
expected = (
959-
[[3850.6666, 3850.6666, 3850.6666]],
960-
[[306.6666, 3850.6666, 410.6666]],
961-
[[306.6666, 3850.6666, 410.6666]],
984+
[[3850.6666, 3850.6666, 3850.6666]] * out_mult,
985+
[[306.6666, 3850.6666, 410.6666]] * out_mult,
986+
[[306.6666, 3850.6666, 410.6666]] * out_mult,
962987
)
963988
if use_future:
964989
self._shapley_test_assert_future(

0 commit comments

Comments
 (0)