Skip to content

Commit e6c51c9

Browse files
jjunchofacebook-github-bot
authored andcommitted
Add test coverage for futures in ShapleyValueSampling 2/n
Summary: This diffs adds more testing coverage for attribute_future to ShapleyValueSampling unit tests that handle shapley sampling with boolean inputs Reviewed By: cyrjano Differential Revision: D68230069
1 parent f1a7a34 commit e6c51c9

File tree

2 files changed

+64
-19
lines changed

2 files changed

+64
-19
lines changed

captum/testing/helpers/basic_models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,22 @@ def forward(
533533
return result
534534

535535

536+
class BasicModelBoolInput_with_Future(nn.Module):
537+
def __init__(self) -> None:
538+
super().__init__()
539+
self.mod = BasicModel_MultiLayer_with_Future()
540+
541+
# pyre-fixme[3]: Return type must be annotated.
542+
def forward(
543+
self,
544+
x: Tensor,
545+
add_input: Optional[Tensor] = None,
546+
mult: float = 10.0,
547+
):
548+
assert x.dtype is torch.bool, "Input must be boolean"
549+
return self.mod(x.float() * mult, add_input)
550+
551+
536552
class BasicModelBoolInput(nn.Module):
537553
def __init__(self) -> None:
538554
super().__init__()

tests/attr/test_shapley.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
BasicModel_MultiLayer_MultiInput,
1717
BasicModel_MultiLayer_with_Future,
1818
BasicModelBoolInput,
19+
BasicModelBoolInput_with_Future,
1920
)
2021
from parameterized import parameterized
2122
from torch.futures import Future
@@ -66,28 +67,56 @@ def test_simple_shapley_sampling_with_mask(self, use_future) -> None:
6667
perturbations_per_eval=(1, 2, 3),
6768
)
6869

69-
def test_simple_shapley_sampling_boolean(self) -> None:
70-
net = BasicModelBoolInput()
70+
@parameterized.expand([True, False])
71+
def test_simple_shapley_sampling_boolean(self, use_future) -> None:
72+
if use_future:
73+
net = BasicModelBoolInput_with_Future()
74+
else:
75+
net = BasicModelBoolInput()
7176
inp = torch.tensor([[True, False, True]])
72-
self._shapley_test_assert(
73-
net,
74-
inp,
75-
[[35.0, 35.0, 35.0]],
76-
feature_mask=torch.tensor([[0, 0, 1]]),
77-
perturbations_per_eval=(1, 2, 3),
78-
)
77+
if use_future:
78+
self._shapley_test_assert_future(
79+
net,
80+
inp,
81+
[[35.0, 35.0, 35.0]],
82+
feature_mask=torch.tensor([[0, 0, 1]]),
83+
perturbations_per_eval=(1, 2, 3),
84+
)
85+
else:
86+
self._shapley_test_assert(
87+
net,
88+
inp,
89+
[[35.0, 35.0, 35.0]],
90+
feature_mask=torch.tensor([[0, 0, 1]]),
91+
perturbations_per_eval=(1, 2, 3),
92+
)
7993

80-
def test_simple_shapley_sampling_boolean_with_baseline(self) -> None:
81-
net = BasicModelBoolInput()
94+
@parameterized.expand([True, False])
95+
def test_simple_shapley_sampling_boolean_with_baseline(self, use_future) -> None:
96+
if use_future:
97+
net = BasicModelBoolInput_with_Future()
98+
else:
99+
net = BasicModelBoolInput()
82100
inp = torch.tensor([[True, False, True]])
83-
self._shapley_test_assert(
84-
net,
85-
inp,
86-
[[-40.0, -40.0, 0.0]],
87-
feature_mask=torch.tensor([[0, 0, 1]]),
88-
baselines=True,
89-
perturbations_per_eval=(1, 2, 3),
90-
)
101+
if use_future:
102+
self._shapley_test_assert_future(
103+
net,
104+
inp,
105+
[[-40.0, -40.0, 0.0]],
106+
feature_mask=torch.tensor([[0, 0, 1]]),
107+
baselines=True,
108+
perturbations_per_eval=(1, 2, 3),
109+
)
110+
else:
111+
112+
self._shapley_test_assert(
113+
net,
114+
inp,
115+
[[-40.0, -40.0, 0.0]],
116+
feature_mask=torch.tensor([[0, 0, 1]]),
117+
baselines=True,
118+
perturbations_per_eval=(1, 2, 3),
119+
)
91120

92121
@parameterized.expand([True, False])
93122
def test_simple_shapley_sampling_with_baselines(self, use_future) -> None:

0 commit comments

Comments
 (0)