|
14 | 14 | from captum.testing.helpers.basic_models import (
|
15 | 15 | BasicModel_MultiLayer,
|
16 | 16 | BasicModel_MultiLayer_MultiInput,
|
| 17 | + BasicModel_MultiLayer_MultiInput_with_Future, |
17 | 18 | BasicModel_MultiLayer_with_Future,
|
18 | 19 | BasicModelBoolInput,
|
19 | 20 | BasicModelBoolInput_with_Future,
|
@@ -206,6 +207,25 @@ def test_multi_input_shapley_sampling_without_mask(self) -> None:
|
206 | 207 | test_true_shapley=False,
|
207 | 208 | )
|
208 | 209 |
|
| 210 | + def test_multi_input_shapley_sampling_without_mask_future(self) -> None: |
| 211 | + net = BasicModel_MultiLayer_MultiInput_with_Future() |
| 212 | + inp1 = torch.tensor([[23.0, 0.0, 0.0], [20.0, 50.0, 30.0]]) |
| 213 | + inp2 = torch.tensor([[20.0, 0.0, 50.0], [0.0, 100.0, 0.0]]) |
| 214 | + inp3 = torch.tensor([[0.0, 100.0, 10.0], [0.0, 10.0, 0.0]]) |
| 215 | + expected = ( |
| 216 | + [[90, 0, 0], [78.0, 198.0, 118.0]], |
| 217 | + [[78, 0, 198], [0.0, 398.0, 0.0]], |
| 218 | + [[0, 398, 38], [0.0, 38.0, 0.0]], |
| 219 | + ) |
| 220 | + self._shapley_test_assert_future( |
| 221 | + net, |
| 222 | + (inp1, inp2, inp3), |
| 223 | + expected, |
| 224 | + additional_input=(1,), |
| 225 | + n_samples=200, |
| 226 | + test_true_shapley=False, |
| 227 | + ) |
| 228 | + |
209 | 229 | def test_multi_input_shapley_sampling_with_mask(self) -> None:
|
210 | 230 | net = BasicModel_MultiLayer_MultiInput()
|
211 | 231 | inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])
|
@@ -241,6 +261,41 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None:
|
241 | 261 | perturbations_per_eval=(1, 2, 3),
|
242 | 262 | )
|
243 | 263 |
|
| 264 | + def test_multi_input_shapley_sampling_with_mask_future(self) -> None: |
| 265 | + net = BasicModel_MultiLayer_MultiInput_with_Future() |
| 266 | + inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) |
| 267 | + inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) |
| 268 | + inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) |
| 269 | + mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) |
| 270 | + mask2 = torch.tensor([[0, 1, 2]]) |
| 271 | + mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]]) |
| 272 | + expected = ( |
| 273 | + [[1088.6666, 1088.6666, 1088.6666], [255.0, 595.0, 255.0]], |
| 274 | + [[76.6666, 1088.6666, 156.6666], [255.0, 595.0, 0.0]], |
| 275 | + [[76.6666, 1088.6666, 156.6666], [255.0, 255.0, 255.0]], |
| 276 | + ) |
| 277 | + self._shapley_test_assert_future( |
| 278 | + net, |
| 279 | + (inp1, inp2, inp3), |
| 280 | + expected, |
| 281 | + additional_input=(1,), |
| 282 | + feature_mask=(mask1, mask2, mask3), |
| 283 | + ) |
| 284 | + expected_with_baseline = ( |
| 285 | + [[1040, 1040, 1040], [184, 580.0, 184]], |
| 286 | + [[52, 1040, 132], [184, 580.0, -12.0]], |
| 287 | + [[52, 1040, 132], [184, 184, 184]], |
| 288 | + ) |
| 289 | + self._shapley_test_assert_future( |
| 290 | + net, |
| 291 | + (inp1, inp2, inp3), |
| 292 | + expected_with_baseline, |
| 293 | + additional_input=(1,), |
| 294 | + feature_mask=(mask1, mask2, mask3), |
| 295 | + baselines=(2, 3.0, 4), |
| 296 | + perturbations_per_eval=(1, 2, 3), |
| 297 | + ) |
| 298 | + |
244 | 299 | @parameterized.expand([True, False])
|
245 | 300 | def test_shapley_sampling_multi_task_output(self, use_future) -> None:
|
246 | 301 | # return shape (batch size, 2)
|
|
0 commit comments