Skip to content

Commit 4ca5c2c

Browse files
Vivek Miglanifacebook-github-bot
Vivek Miglani
authored andcommitted
Enable aggregate mode for any case where perturbations_per_eval == 1 (#1525)
Summary: Pull Request resolved: #1525 Shapley Values currently have issues with per task importance, since aggregate mode returns more than 1 output with perturbations per eval = 1, which should apply aggregate mode for collating perturbation results. Updates logic to appropriately handle multiple outputs (not matching batch size) when perturbations per eval = 1 Reviewed By: MarcioPorto Differential Revision: D70832826 fbshipit-source-id: 52e1e40d599f662ac522eae4830560cf1338f7e1
1 parent c9688bb commit 4ca5c2c

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

captum/attr/_utils/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def _find_output_mode_and_verify(
364364
"returns a scalar."
365365
)
366366
else:
367-
agg_output_mode = False
367+
agg_output_mode = perturbations_per_eval == 1
368368
if not allow_multi_outputs:
369369
assert (
370370
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1

tests/attr/test_shapley.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,30 @@ def func_future(*inp):
806806
lambda *inp: func_to_use(*inp), use_future=use_future
807807
)
808808

809+
@parameterized.expand([True, False])
810+
def test_mutli_inp_shapley_batch_scalar_tensor_expanded(self, use_future) -> None:
811+
def func(*inp):
812+
sum_val = torch.sum(net(*inp)).item()
813+
return torch.tensor([sum_val, sum_val + 2.0, sum_val + 3.0])
814+
815+
def func_future(*inp):
816+
temp = net_fut(*inp)
817+
temp.wait()
818+
sum_val = torch.sum(temp.value()).item()
819+
fut = Future()
820+
fut.set_result(torch.tensor([sum_val, sum_val + 2.0, sum_val + 3.0]))
821+
return fut
822+
823+
if use_future:
824+
net_fut = BasicModel_MultiLayer_MultiInput_with_Future()
825+
func_to_use = func_future
826+
else:
827+
net = BasicModel_MultiLayer_MultiInput()
828+
func_to_use = func
829+
self._multi_input_batch_scalar_shapley_assert(
830+
lambda *inp: func_to_use(*inp), use_future=use_future, expanded_output=True
831+
)
832+
809833
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
810834
def test_shapley_sampling_with_show_progress(self, mock_stderr) -> None:
811835
net = BasicModel_MultiLayer()
@@ -947,18 +971,19 @@ def _single_int_input_multi_sample_batch_scalar_shapley_assert(
947971
)
948972

949973
def _multi_input_batch_scalar_shapley_assert(
950-
self, func: Callable, use_future: bool = False
974+
self, func: Callable, use_future: bool = False, expanded_output: bool = False
951975
) -> None:
952976
inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
953977
inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]])
954978
inp3 = torch.tensor([[0.0, 100.0, 10.0], [20.0, 10.0, 13.0]])
955979
mask1 = torch.tensor([[1, 1, 1]])
956980
mask2 = torch.tensor([[0, 1, 2]])
957981
mask3 = torch.tensor([[0, 1, 2]])
982+
out_mult = 3 if expanded_output else 1
958983
expected = (
959-
[[3850.6666, 3850.6666, 3850.6666]],
960-
[[306.6666, 3850.6666, 410.6666]],
961-
[[306.6666, 3850.6666, 410.6666]],
984+
[[3850.6666, 3850.6666, 3850.6666]] * out_mult,
985+
[[306.6666, 3850.6666, 410.6666]] * out_mult,
986+
[[306.6666, 3850.6666, 410.6666]] * out_mult,
962987
)
963988
if use_future:
964989
self._shapley_test_assert_future(

0 commit comments

Comments
 (0)