Skip to content

Commit 822385d

Browse files
jjunchofacebook-github-bot
authored andcommitted
Add test coverage for futures in ShapleyValueSampling 3/n
Summary: This diffs adds more testing coverage for attribute_future to ShapleyValueSampling unit tests that use the BasicModel_MultiLayer_MultiInput model. Since these tests are larger, they will not be parameterized for readability Reviewed By: csauper Differential Revision: D68346488
1 parent e6c51c9 commit 822385d

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

captum/testing/helpers/basic_models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,17 @@ def forward(self, x1: Tensor, x2: Tensor, x3: Tensor, scale: int):
576576
return self.model(scale * (x1 + x2 + x3))
577577

578578

579+
class BasicModel_MultiLayer_MultiInput_with_Future(nn.Module):
580+
def __init__(self) -> None:
581+
super().__init__()
582+
self.model = BasicModel_MultiLayer_with_Future()
583+
584+
@no_type_check
585+
# pyre-fixme[3]: Return type must be annotated.
586+
def forward(self, x1: Tensor, x2: Tensor, x3: Tensor, scale: int):
587+
return self.model(scale * (x1 + x2 + x3))
588+
589+
579590
class BasicModel_MultiLayer_TrueMultiInput(nn.Module):
580591
def __init__(self) -> None:
581592
super().__init__()

tests/attr/test_shapley.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from captum.testing.helpers.basic_models import (
1515
BasicModel_MultiLayer,
1616
BasicModel_MultiLayer_MultiInput,
17+
BasicModel_MultiLayer_MultiInput_with_Future,
1718
BasicModel_MultiLayer_with_Future,
1819
BasicModelBoolInput,
1920
BasicModelBoolInput_with_Future,
@@ -206,6 +207,25 @@ def test_multi_input_shapley_sampling_without_mask(self) -> None:
206207
test_true_shapley=False,
207208
)
208209

210+
def test_multi_input_shapley_sampling_without_mask_future(self) -> None:
211+
net = BasicModel_MultiLayer_MultiInput_with_Future()
212+
inp1 = torch.tensor([[23.0, 0.0, 0.0], [20.0, 50.0, 30.0]])
213+
inp2 = torch.tensor([[20.0, 0.0, 50.0], [0.0, 100.0, 0.0]])
214+
inp3 = torch.tensor([[0.0, 100.0, 10.0], [0.0, 10.0, 0.0]])
215+
expected = (
216+
[[90, 0, 0], [78.0, 198.0, 118.0]],
217+
[[78, 0, 198], [0.0, 398.0, 0.0]],
218+
[[0, 398, 38], [0.0, 38.0, 0.0]],
219+
)
220+
self._shapley_test_assert_future(
221+
net,
222+
(inp1, inp2, inp3),
223+
expected,
224+
additional_input=(1,),
225+
n_samples=200,
226+
test_true_shapley=False,
227+
)
228+
209229
def test_multi_input_shapley_sampling_with_mask(self) -> None:
210230
net = BasicModel_MultiLayer_MultiInput()
211231
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
@@ -241,6 +261,41 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None:
241261
perturbations_per_eval=(1, 2, 3),
242262
)
243263

264+
def test_multi_input_shapley_sampling_with_mask_future(self) -> None:
265+
net = BasicModel_MultiLayer_MultiInput_with_Future()
266+
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
267+
inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]])
268+
inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]])
269+
mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]])
270+
mask2 = torch.tensor([[0, 1, 2]])
271+
mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]])
272+
expected = (
273+
[[1088.6666, 1088.6666, 1088.6666], [255.0, 595.0, 255.0]],
274+
[[76.6666, 1088.6666, 156.6666], [255.0, 595.0, 0.0]],
275+
[[76.6666, 1088.6666, 156.6666], [255.0, 255.0, 255.0]],
276+
)
277+
self._shapley_test_assert_future(
278+
net,
279+
(inp1, inp2, inp3),
280+
expected,
281+
additional_input=(1,),
282+
feature_mask=(mask1, mask2, mask3),
283+
)
284+
expected_with_baseline = (
285+
[[1040, 1040, 1040], [184, 580.0, 184]],
286+
[[52, 1040, 132], [184, 580.0, -12.0]],
287+
[[52, 1040, 132], [184, 184, 184]],
288+
)
289+
self._shapley_test_assert_future(
290+
net,
291+
(inp1, inp2, inp3),
292+
expected_with_baseline,
293+
additional_input=(1,),
294+
feature_mask=(mask1, mask2, mask3),
295+
baselines=(2, 3.0, 4),
296+
perturbations_per_eval=(1, 2, 3),
297+
)
298+
244299
@parameterized.expand([True, False])
245300
def test_shapley_sampling_multi_task_output(self, use_future) -> None:
246301
# return shape (batch size, 2)

0 commit comments

Comments
 (0)