Skip to content

Commit 5e736f8

Browse files
committed
Added a function to the random config space searcher that caches already evaluated coalitions and ensures monotonicity for min/max search settings.
Furthermore, a fallback solution was introduced to cater for cases where all samples configs become invalid according to conditions after blinding.
1 parent c767532 commit 5e736f8

4 files changed

Lines changed: 108 additions & 3 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# v0.0.6
2+
3+
## Improvements
4+
- Added fallback for configuration spaces with conditions resulting in all configurations being filtered out.
5+
- Added caching and a function in RandomConfigSpaceSearcher to ensure monotonicity of the value function.
6+
17
# v0.0.5
28

39
## Features

src/hypershap/utils.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,77 @@ def search(self, coalition: np.ndarray) -> float:
154154
validity = np.apply_along_axis(self._is_valid, axis=1, arr=temp_random_sample)
155155
filtered_samples = temp_random_sample[validity]
156156

157+
if coalition.any():
158+
filtered_samples = []
159+
157160
if len(filtered_samples) < 0.05 * len(temp_random_sample): # pragma: no cover
158161
logger.warning(
159162
"WARNING: Due to blinding less than 5% of the samples in the random search remain valid. "
160163
"Consider increasing the sampling budget of the random search.",
161164
)
162165

163166
# predict performance values with the help of the surrogate model for the filtered configurations
164-
vals: np.ndarray = np.array(self.explanation_task.get_single_surrogate_model().evaluate(filtered_samples))
167+
if len(filtered_samples) > 0:
168+
vals: np.ndarray = np.array(
169+
self.explanation_task.get_single_surrogate_model().evaluate(filtered_samples),
170+
)
171+
else:
172+
logger.warning(
173+
"WARNING: After filtering for conditions, no configurations were left, thus, using baseline value.",
174+
)
175+
vals = np.array([self.search(np.array([False] * len(coalition)))])
165176
else:
166177
vals: np.ndarray = np.array(self.explanation_task.get_single_surrogate_model().evaluate(temp_random_sample))
167178

168-
return evaluate_aggregation(self.mode, vals)
179+
# determine the final, aggregated value of the coalition
180+
res = evaluate_aggregation(self.mode, vals)
181+
182+
# in case we are maximizing or minimizing, ensure that the value function is monotone
183+
if self.mode in (Aggregation.MAX, Aggregation.MIN):
184+
res = self._ensure_monotonicity(coalition, res)
185+
186+
# cache the coalition's value
187+
self.coalition_cache[str(coalition.tolist())] = res
188+
189+
return res
190+
191+
def _ensure_monotonicity(self, coalition: np.ndarray, value: float) -> float:
192+
"""Ensure that the value function is monotonically increasing/decreasing depending on whether we want to maximize or minimize respectively.
193+
194+
Args:
195+
coalition: The current coalition.
196+
value: The value of the coalition as determined by searching.
197+
198+
Returns: The monotonicity-ensured value of the coalition.
199+
200+
"""
201+
monotone_value = value
202+
checked_one = False
203+
204+
for i in range(len(coalition)):
205+
if coalition[i]: # check whether the entry is True and set it to False to check for a cached result
206+
temp_coalition = coalition.copy()
207+
temp_coalition[i] = False
208+
if str(temp_coalition.tolist()) in self.coalition_cache:
209+
checked_one = True
210+
monotone_value = evaluate_aggregation(
211+
self.mode,
212+
np.array([
213+
monotone_value,
214+
self.coalition_cache[str(temp_coalition.tolist())],
215+
]),
216+
)
217+
218+
if not checked_one and coalition.any():
219+
logger.warning(
220+
"Could not ensure monotonicity as none of the coalitions with one player less has been cached so far.",
221+
)
222+
223+
if value < monotone_value:
224+
logger.debug(
225+
"Ensured monotonicity with a sub-coalition's value. Increased the value of the current coalition from %s to %s.",
226+
value,
227+
monotone_value,
228+
)
229+
230+
return monotone_value

tests/fixtures/simple_setup.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
from __future__ import annotations
44

55
import pytest
6-
from ConfigSpace import Configuration, ConfigurationSpace, LessThanCondition, UniformFloatHyperparameter
6+
from ConfigSpace import (
7+
Configuration,
8+
ConfigurationSpace,
9+
GreaterThanCondition,
10+
LessThanCondition,
11+
UniformFloatHyperparameter,
12+
)
713

814
from hypershap import ExplanationTask
915

@@ -88,10 +94,34 @@ def simple_cond_config_space() -> ConfigurationSpace:
8894
return config_space
8995

9096

97+
@pytest.fixture(scope="session")
98+
def simple_act_config_space() -> ConfigurationSpace:
99+
"""Return a simple config space with activation structure for testing."""
100+
config_space = ConfigurationSpace()
101+
config_space.seed(42)
102+
103+
a = UniformFloatHyperparameter("a", 0, 1, 0)
104+
b = UniformFloatHyperparameter("b", 0, 1, 0)
105+
config_space.add(a)
106+
config_space.add(b)
107+
108+
config_space.add(GreaterThanCondition(b, a, 0.3))
109+
return config_space
110+
111+
91112
@pytest.fixture(scope="session")
92113
def simple_cond_base_et(
93114
simple_cond_config_space: ConfigurationSpace,
94115
simple_blackbox_function: SimpleBlackboxFunction,
95116
) -> ExplanationTask:
96117
"""Return a base explanation task for the simple setup with conditions."""
97118
return ExplanationTask.from_function(simple_cond_config_space, simple_blackbox_function.evaluate)
119+
120+
121+
@pytest.fixture(scope="session")
122+
def simple_act_base_et(
123+
simple_act_config_space: ConfigurationSpace,
124+
simple_blackbox_function: SimpleBlackboxFunction,
125+
) -> ExplanationTask:
126+
"""Return a base explanation task for the simple setup with conditions."""
127+
return ExplanationTask.from_function(simple_act_config_space, simple_blackbox_function.evaluate)

tests/test_extended_settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,10 @@ def test_tunability_with_conditions(simple_cond_base_et: ExplanationTask) -> Non
7979
hypershap = HyperSHAP(simple_cond_base_et)
8080
iv = hypershap.tunability(simple_cond_base_et.config_space.get_default_configuration())
8181
assert iv is not None, "Interaction values should not be none."
82+
83+
84+
def test_tunability_with_activation_structures(simple_act_base_et: ExplanationTask) -> None:
85+
"""Test the tunability task with a configuration space that has conditions."""
86+
hypershap = HyperSHAP(simple_act_base_et)
87+
iv = hypershap.tunability(simple_act_base_et.config_space.get_default_configuration())
88+
assert iv is not None, "Interaction values should not be none."

0 commit comments

Comments
 (0)