Skip to content

Commit

Permalink
Add test coverage for futures in ShapleyValueSampling 2/n
Browse files Browse the repository at this point in the history
Summary: This diffs adds more testing coverage for attribute_future to ShapleyValueSampling unit tests that handle shapley sampling with boolean inputs

Reviewed By: cyrjano

Differential Revision: D68230069
  • Loading branch information
jjuncho authored and facebook-github-bot committed Jan 18, 2025
1 parent f1a7a34 commit e6c51c9
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 19 deletions.
16 changes: 16 additions & 0 deletions captum/testing/helpers/basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,22 @@ def forward(
return result


class BasicModelBoolInput_with_Future(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod = BasicModel_MultiLayer_with_Future()

# pyre-fixme[3]: Return type must be annotated.
def forward(
self,
x: Tensor,
add_input: Optional[Tensor] = None,
mult: float = 10.0,
):
assert x.dtype is torch.bool, "Input must be boolean"
return self.mod(x.float() * mult, add_input)


class BasicModelBoolInput(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down
67 changes: 48 additions & 19 deletions tests/attr/test_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BasicModel_MultiLayer_MultiInput,
BasicModel_MultiLayer_with_Future,
BasicModelBoolInput,
BasicModelBoolInput_with_Future,
)
from parameterized import parameterized
from torch.futures import Future
Expand Down Expand Up @@ -66,28 +67,56 @@ def test_simple_shapley_sampling_with_mask(self, use_future) -> None:
perturbations_per_eval=(1, 2, 3),
)

def test_simple_shapley_sampling_boolean(self) -> None:
net = BasicModelBoolInput()
@parameterized.expand([True, False])
def test_simple_shapley_sampling_boolean(self, use_future) -> None:
if use_future:
net = BasicModelBoolInput_with_Future()
else:
net = BasicModelBoolInput()
inp = torch.tensor([[True, False, True]])
self._shapley_test_assert(
net,
inp,
[[35.0, 35.0, 35.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
perturbations_per_eval=(1, 2, 3),
)
if use_future:
self._shapley_test_assert_future(
net,
inp,
[[35.0, 35.0, 35.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
perturbations_per_eval=(1, 2, 3),
)
else:
self._shapley_test_assert(
net,
inp,
[[35.0, 35.0, 35.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
perturbations_per_eval=(1, 2, 3),
)

def test_simple_shapley_sampling_boolean_with_baseline(self) -> None:
net = BasicModelBoolInput()
@parameterized.expand([True, False])
def test_simple_shapley_sampling_boolean_with_baseline(self, use_future) -> None:
if use_future:
net = BasicModelBoolInput_with_Future()
else:
net = BasicModelBoolInput()
inp = torch.tensor([[True, False, True]])
self._shapley_test_assert(
net,
inp,
[[-40.0, -40.0, 0.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
baselines=True,
perturbations_per_eval=(1, 2, 3),
)
if use_future:
self._shapley_test_assert_future(
net,
inp,
[[-40.0, -40.0, 0.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
baselines=True,
perturbations_per_eval=(1, 2, 3),
)
else:

self._shapley_test_assert(
net,
inp,
[[-40.0, -40.0, 0.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
baselines=True,
perturbations_per_eval=(1, 2, 3),
)

@parameterized.expand([True, False])
def test_simple_shapley_sampling_with_baselines(self, use_future) -> None:
Expand Down

0 comments on commit e6c51c9

Please sign in to comment.