Skip to content

Commit

Permalink
Add test coverage for futures in ShapleyValueSampling 3/n (#1489)
Browse files Browse the repository at this point in the history
Summary:

This diffs adds more testing coverage for attribute_future to ShapleyValueSampling unit tests that use the BasicModel_MultiLayer_MultiInput model. Since these tests are larger, they will not be parameterized for readability

Reviewed By: csauper

Differential Revision: D68346488
  • Loading branch information
jjuncho authored and facebook-github-bot committed Jan 18, 2025
1 parent d47943f commit e63c640
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
11 changes: 11 additions & 0 deletions captum/testing/helpers/basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,17 @@ def forward(self, x1: Tensor, x2: Tensor, x3: Tensor, scale: int):
return self.model(scale * (x1 + x2 + x3))


class BasicModel_MultiLayer_MultiInput_with_Future(nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = BasicModel_MultiLayer_with_Future()

@no_type_check
# pyre-fixme[3]: Return type must be annotated.
def forward(self, x1: Tensor, x2: Tensor, x3: Tensor, scale: int):
return self.model(scale * (x1 + x2 + x3))


class BasicModel_MultiLayer_TrueMultiInput(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down
55 changes: 55 additions & 0 deletions tests/attr/test_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from captum.testing.helpers.basic_models import (
BasicModel_MultiLayer,
BasicModel_MultiLayer_MultiInput,
BasicModel_MultiLayer_MultiInput_with_Future,
BasicModel_MultiLayer_with_Future,
BasicModelBoolInput,
BasicModelBoolInput_with_Future,
Expand Down Expand Up @@ -206,6 +207,25 @@ def test_multi_input_shapley_sampling_without_mask(self) -> None:
test_true_shapley=False,
)

def test_multi_input_shapley_sampling_without_mask_future(self) -> None:
net = BasicModel_MultiLayer_MultiInput_with_Future()
inp1 = torch.tensor([[23.0, 0.0, 0.0], [20.0, 50.0, 30.0]])
inp2 = torch.tensor([[20.0, 0.0, 50.0], [0.0, 100.0, 0.0]])
inp3 = torch.tensor([[0.0, 100.0, 10.0], [0.0, 10.0, 0.0]])
expected = (
[[90, 0, 0], [78.0, 198.0, 118.0]],
[[78, 0, 198], [0.0, 398.0, 0.0]],
[[0, 398, 38], [0.0, 38.0, 0.0]],
)
self._shapley_test_assert_future(
net,
(inp1, inp2, inp3),
expected,
additional_input=(1,),
n_samples=200,
test_true_shapley=False,
)

def test_multi_input_shapley_sampling_with_mask(self) -> None:
net = BasicModel_MultiLayer_MultiInput()
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
Expand Down Expand Up @@ -241,6 +261,41 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None:
perturbations_per_eval=(1, 2, 3),
)

def test_multi_input_shapley_sampling_with_mask_future(self) -> None:
net = BasicModel_MultiLayer_MultiInput_with_Future()
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]])
inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]])
mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]])
mask2 = torch.tensor([[0, 1, 2]])
mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]])
expected = (
[[1088.6666, 1088.6666, 1088.6666], [255.0, 595.0, 255.0]],
[[76.6666, 1088.6666, 156.6666], [255.0, 595.0, 0.0]],
[[76.6666, 1088.6666, 156.6666], [255.0, 255.0, 255.0]],
)
self._shapley_test_assert_future(
net,
(inp1, inp2, inp3),
expected,
additional_input=(1,),
feature_mask=(mask1, mask2, mask3),
)
expected_with_baseline = (
[[1040, 1040, 1040], [184, 580.0, 184]],
[[52, 1040, 132], [184, 580.0, -12.0]],
[[52, 1040, 132], [184, 184, 184]],
)
self._shapley_test_assert_future(
net,
(inp1, inp2, inp3),
expected_with_baseline,
additional_input=(1,),
feature_mask=(mask1, mask2, mask3),
baselines=(2, 3.0, 4),
perturbations_per_eval=(1, 2, 3),
)

@parameterized.expand([True, False])
def test_shapley_sampling_multi_task_output(self, use_future) -> None:
# return shape (batch size, 2)
Expand Down

0 comments on commit e63c640

Please sign in to comment.