|
16 | 16 | BasicModel_MultiLayer_MultiInput,
|
17 | 17 | BasicModel_MultiLayer_with_Future,
|
18 | 18 | BasicModelBoolInput,
|
| 19 | + BasicModelBoolInput_with_Future, |
19 | 20 | )
|
20 | 21 | from parameterized import parameterized
|
21 | 22 | from torch.futures import Future
|
@@ -66,28 +67,56 @@ def test_simple_shapley_sampling_with_mask(self, use_future) -> None:
|
66 | 67 | perturbations_per_eval=(1, 2, 3),
|
67 | 68 | )
|
68 | 69 |
|
69 |
| - def test_simple_shapley_sampling_boolean(self) -> None: |
70 |
| - net = BasicModelBoolInput() |
| 70 | + @parameterized.expand([True, False]) |
| 71 | + def test_simple_shapley_sampling_boolean(self, use_future) -> None: |
| 72 | + if use_future: |
| 73 | + net = BasicModelBoolInput_with_Future() |
| 74 | + else: |
| 75 | + net = BasicModelBoolInput() |
71 | 76 | inp = torch.tensor([[True, False, True]])
|
72 |
| - self._shapley_test_assert( |
73 |
| - net, |
74 |
| - inp, |
75 |
| - [[35.0, 35.0, 35.0]], |
76 |
| - feature_mask=torch.tensor([[0, 0, 1]]), |
77 |
| - perturbations_per_eval=(1, 2, 3), |
78 |
| - ) |
| 77 | + if use_future: |
| 78 | + self._shapley_test_assert_future( |
| 79 | + net, |
| 80 | + inp, |
| 81 | + [[35.0, 35.0, 35.0]], |
| 82 | + feature_mask=torch.tensor([[0, 0, 1]]), |
| 83 | + perturbations_per_eval=(1, 2, 3), |
| 84 | + ) |
| 85 | + else: |
| 86 | + self._shapley_test_assert( |
| 87 | + net, |
| 88 | + inp, |
| 89 | + [[35.0, 35.0, 35.0]], |
| 90 | + feature_mask=torch.tensor([[0, 0, 1]]), |
| 91 | + perturbations_per_eval=(1, 2, 3), |
| 92 | + ) |
79 | 93 |
|
80 |
| - def test_simple_shapley_sampling_boolean_with_baseline(self) -> None: |
81 |
| - net = BasicModelBoolInput() |
| 94 | + @parameterized.expand([True, False]) |
| 95 | + def test_simple_shapley_sampling_boolean_with_baseline(self, use_future) -> None: |
| 96 | + if use_future: |
| 97 | + net = BasicModelBoolInput_with_Future() |
| 98 | + else: |
| 99 | + net = BasicModelBoolInput() |
82 | 100 | inp = torch.tensor([[True, False, True]])
|
83 |
| - self._shapley_test_assert( |
84 |
| - net, |
85 |
| - inp, |
86 |
| - [[-40.0, -40.0, 0.0]], |
87 |
| - feature_mask=torch.tensor([[0, 0, 1]]), |
88 |
| - baselines=True, |
89 |
| - perturbations_per_eval=(1, 2, 3), |
90 |
| - ) |
| 101 | + if use_future: |
| 102 | + self._shapley_test_assert_future( |
| 103 | + net, |
| 104 | + inp, |
| 105 | + [[-40.0, -40.0, 0.0]], |
| 106 | + feature_mask=torch.tensor([[0, 0, 1]]), |
| 107 | + baselines=True, |
| 108 | + perturbations_per_eval=(1, 2, 3), |
| 109 | + ) |
| 110 | + else: |
| 111 | + |
| 112 | + self._shapley_test_assert( |
| 113 | + net, |
| 114 | + inp, |
| 115 | + [[-40.0, -40.0, 0.0]], |
| 116 | + feature_mask=torch.tensor([[0, 0, 1]]), |
| 117 | + baselines=True, |
| 118 | + perturbations_per_eval=(1, 2, 3), |
| 119 | + ) |
91 | 120 |
|
92 | 121 | @parameterized.expand([True, False])
|
93 | 122 | def test_simple_shapley_sampling_with_baselines(self, use_future) -> None:
|
|
0 commit comments