Skip to content

Commit e1069b5

Browse files
committed
Extended tests for explanation tasks and added test for config space searcher.
1 parent 5a8c1f9 commit e1069b5

3 files changed

Lines changed: 222 additions & 19 deletions

File tree

tests/fixtures/simple_setup.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pytest
66
from ConfigSpace import Configuration, ConfigurationSpace, UniformFloatHyperparameter
77

8+
from hypershap import ExplanationTask
9+
810

911
@pytest.fixture(scope="session")
1012
def simple_config_space() -> ConfigurationSpace:
@@ -38,7 +40,17 @@ def evaluate(self, x: Configuration) -> float:
3840
Returns: The value of the configuration.
3941
4042
"""
41-
return self.a_coeff * x["a"] + self.b_coeff * x["b"]
43+
return self.value(x["a"], x["b"])
44+
45+
def value(self, a: float, b: float) -> float:
46+
"""Evaluate the value of a configuration.
47+
48+
Args:
49+
a: The value for hyperparameter a.
50+
b: The value for hyperparameter b.
51+
52+
"""
53+
return self.a_coeff * a + self.b_coeff * b
4254

4355

4456
@pytest.fixture(scope="session")
@@ -49,3 +61,12 @@ def simple_blackbox_function() -> SimpleBlackboxFunction:
4961
5062
"""
5163
return SimpleBlackboxFunction(0.7, 2.0)
64+
65+
66+
@pytest.fixture(scope="session")
67+
def simple_base_et(
68+
simple_config_space: ConfigurationSpace,
69+
simple_blackbox_function: SimpleBlackboxFunction,
70+
) -> ExplanationTask:
71+
"""Return a base explanation task for the simple setup."""
72+
return ExplanationTask.from_function(simple_config_space, simple_blackbox_function.evaluate)

tests/test_explanation_task.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from typing import TYPE_CHECKING
66

7-
import pytest
87
from sklearn.ensemble import RandomForestRegressor
98

109
if TYPE_CHECKING:
@@ -13,7 +12,13 @@
1312
from tests.fixtures.simple_setup import SimpleBlackboxFunction
1413

1514
from hypershap import ExplanationTask
16-
from hypershap.task import BaselineExplanationTask, TunabilityExplanationTask
15+
from hypershap.task import (
16+
BaselineExplanationTask,
17+
MistunabilityExplanationTask,
18+
MultiBaselineExplanationTask,
19+
SensitivityExplanationTask,
20+
TunabilityExplanationTask,
21+
)
1722

1823

1924
def _check_explanation_task(
@@ -83,24 +88,46 @@ def test_explanation_task_from_model(
8388
_check_explanation_task(explanation_task, simple_config_space, simple_blackbox_function)
8489

8590

86-
@pytest.fixture
87-
def base_et(
88-
simple_config_space: ConfigurationSpace,
89-
simple_blackbox_function: SimpleBlackboxFunction,
90-
) -> ExplanationTask:
91-
"""Return a base explanation task for the simple setup."""
92-
return ExplanationTask.from_function(simple_config_space, simple_blackbox_function.evaluate)
93-
94-
95-
def test_baseline_explanation_task(base_et: ExplanationTask) -> None:
91+
def test_baseline_explanation_task(simple_base_et: ExplanationTask) -> None:
9692
"""Test the baseline explanation task."""
97-
config = base_et.config_space.sample_configuration()
98-
et = BaselineExplanationTask(base_et.config_space, base_et.surrogate_model, baseline_config=config)
93+
config = simple_base_et.config_space.sample_configuration()
94+
et = BaselineExplanationTask(simple_base_et.config_space, simple_base_et.surrogate_model, baseline_config=config)
9995
assert et.baseline_config == config, "Baseline explanation task should have the proper baseline config."
10096

10197

102-
def test_tunability_explanation_task(base_et: ExplanationTask) -> None:
98+
def test_multibaseline_explanation_task(simple_base_et: ExplanationTask) -> None:
99+
"""Test the multibaseline explanation task."""
100+
baseline_configs = simple_base_et.config_space.sample_configuration(2)
101+
et = MultiBaselineExplanationTask(
102+
simple_base_et.config_space,
103+
simple_base_et.surrogate_model,
104+
baseline_configs=baseline_configs,
105+
)
106+
assert et.baseline_configs == baseline_configs, (
107+
"Multibaseline explanation task should have the proper baseline configs."
108+
)
109+
110+
111+
def test_tunability_explanation_task(simple_base_et: ExplanationTask) -> None:
103112
"""Test the tunability explanation task."""
104-
config = base_et.config_space.sample_configuration()
105-
et = TunabilityExplanationTask(base_et.config_space, base_et.surrogate_model, baseline_config=config)
106-
assert et.baseline_config == config, "Baseline explanation task should have the proper baseline config."
113+
config = simple_base_et.config_space.sample_configuration()
114+
et = TunabilityExplanationTask(simple_base_et.config_space, simple_base_et.surrogate_model, baseline_config=config)
115+
assert et.baseline_config == config, "Tunability explanation task should have the proper baseline config."
116+
117+
118+
def test_sensitivity_explanation_task(simple_base_et: ExplanationTask) -> None:
119+
"""Test the sensitivity explanation task."""
120+
config = simple_base_et.config_space.sample_configuration()
121+
et = SensitivityExplanationTask(simple_base_et.config_space, simple_base_et.surrogate_model, baseline_config=config)
122+
assert et.baseline_config == config, "Sensitivity explanation task should have the proper baseline config."
123+
124+
125+
def test_mistunability_explanation_task(simple_base_et: ExplanationTask) -> None:
126+
"""Test the mistunability explanation task."""
127+
config = simple_base_et.config_space.sample_configuration()
128+
et = MistunabilityExplanationTask(
129+
simple_base_et.config_space,
130+
simple_base_et.surrogate_model,
131+
baseline_config=config,
132+
)
133+
assert et.baseline_config == config, "Mistunability explanation task should have the proper baseline config."

tests/test_utils.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""Tests for the utils module."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
import pytest
9+
from ConfigSpace import UniformFloatHyperparameter
10+
11+
if TYPE_CHECKING:
12+
from hypershap import ExplanationTask
13+
from tests.fixtures.simple_setup import SimpleBlackboxFunction
14+
15+
from hypershap.task import BaselineExplanationTask
16+
from hypershap.utils import RandomConfigSpaceSearcher, UnknownModeError
17+
18+
DEFAULT_MODE = "max"
19+
N_SAMPLES = 100_000
20+
EPSILON = 0.1
21+
22+
23+
@pytest.fixture(scope="module")
24+
def random_cs(simple_base_et: ExplanationTask) -> RandomConfigSpaceSearcher:
25+
"""Fixture for creating a random config space searcher."""
26+
baseline_et = BaselineExplanationTask(
27+
simple_base_et.config_space,
28+
simple_base_et.surrogate_model,
29+
baseline_config=simple_base_et.config_space.get_default_configuration(),
30+
)
31+
32+
return RandomConfigSpaceSearcher(
33+
explanation_task=baseline_et,
34+
mode=DEFAULT_MODE,
35+
n_samples=N_SAMPLES,
36+
)
37+
38+
39+
def test_unavailable_mode(simple_base_et: ExplanationTask) -> None:
40+
"""Test that unavailable modes raise an exception."""
41+
baseline_et = BaselineExplanationTask(
42+
simple_base_et.config_space,
43+
simple_base_et.surrogate_model,
44+
baseline_config=simple_base_et.config_space.get_default_configuration(),
45+
)
46+
47+
try:
48+
RandomConfigSpaceSearcher(
49+
explanation_task=baseline_et,
50+
mode="abc",
51+
n_samples=N_SAMPLES,
52+
)
53+
except UnknownModeError:
54+
assert True, "Unknown mode error expected"
55+
else:
56+
pytest.fail("Unknown mode error expected")
57+
58+
59+
def test_n_samples(random_cs: RandomConfigSpaceSearcher) -> None:
60+
"""Test whether random config space searcher draws the given number of samples."""
61+
assert random_cs.random_sample.shape[0] == N_SAMPLES, (
62+
"Number of samples should be the same as the number of samples in the explanation task."
63+
)
64+
65+
66+
def test_empty_coalition_search(random_cs: RandomConfigSpaceSearcher) -> None:
67+
"""Test random config space searcher for an empty coalition."""
68+
et = random_cs.explanation_task
69+
res = random_cs.search(np.array([False] * random_cs.explanation_task.get_num_hyperparameters()))
70+
assert res == et.surrogate_model.evaluate_config(et.config_space.get_default_configuration()), (
71+
"If no hyperparameter is activated for searching, the resulting max performance should be equal to default performance."
72+
)
73+
74+
75+
def test_grand_coalition_max_search(
76+
random_cs: RandomConfigSpaceSearcher,
77+
simple_blackbox_function: SimpleBlackboxFunction,
78+
) -> None:
79+
"""Test random config space searcher for max aggregation."""
80+
et = random_cs.explanation_task
81+
res = random_cs.search(np.array([True] * random_cs.explanation_task.get_num_hyperparameters()))
82+
83+
if isinstance(et.config_space["a"], UniformFloatHyperparameter) and isinstance(
84+
et.config_space["b"],
85+
UniformFloatHyperparameter,
86+
):
87+
a: UniformFloatHyperparameter = et.config_space["a"]
88+
b: UniformFloatHyperparameter = et.config_space["b"]
89+
a_upper = a.upper
90+
b_upper = b.upper
91+
max_value = simple_blackbox_function.value(a_upper, b_upper)
92+
else:
93+
raise TypeError
94+
95+
assert abs(max_value - res < EPSILON), "The max performance should be equal to the upper boundaries value."
96+
97+
98+
def test_grand_coalition_min_search(
99+
random_cs: RandomConfigSpaceSearcher,
100+
simple_blackbox_function: SimpleBlackboxFunction,
101+
) -> None:
102+
"""Test random config space searcher for min aggregation."""
103+
et = random_cs.explanation_task
104+
random_cs.mode = "min"
105+
res = random_cs.search(np.array([True] * random_cs.explanation_task.get_num_hyperparameters()))
106+
107+
if isinstance(et.config_space["a"], UniformFloatHyperparameter) and isinstance(
108+
et.config_space["b"],
109+
UniformFloatHyperparameter,
110+
):
111+
a: UniformFloatHyperparameter = et.config_space["a"]
112+
b: UniformFloatHyperparameter = et.config_space["b"]
113+
a_lower = a.lower
114+
b_lower = b.lower
115+
min_value = simple_blackbox_function.value(a_lower, b_lower)
116+
else:
117+
raise TypeError
118+
119+
assert abs(res - min_value < EPSILON), "The min performance should be equal to the lower boundaries value."
120+
121+
122+
def test_grand_coalition_avg_search(
123+
random_cs: RandomConfigSpaceSearcher,
124+
simple_blackbox_function: SimpleBlackboxFunction,
125+
) -> None:
126+
"""Test random config space searcher for avg aggregation."""
127+
et = random_cs.explanation_task
128+
random_cs.mode = "avg"
129+
res = random_cs.search(np.array([True] * random_cs.explanation_task.get_num_hyperparameters()))
130+
131+
if isinstance(et.config_space["a"], UniformFloatHyperparameter) and isinstance(
132+
et.config_space["b"],
133+
UniformFloatHyperparameter,
134+
):
135+
a: UniformFloatHyperparameter = et.config_space["a"]
136+
b: UniformFloatHyperparameter = et.config_space["b"]
137+
a_middle = a.lower + (a.upper - a.lower) / 2
138+
b_middle = b.lower + (b.upper - b.lower) / 2
139+
avg_value = simple_blackbox_function.value(a_middle, b_middle)
140+
else:
141+
raise TypeError
142+
143+
assert abs(res - avg_value < EPSILON), "The avg aggregation should be equal to the middle performance."
144+
145+
146+
def test_baseline_coalition_var_search(
147+
random_cs: RandomConfigSpaceSearcher,
148+
) -> None:
149+
"""Test random config space searcher for avg aggregation."""
150+
random_cs.mode = "var"
151+
res = random_cs.search(np.array([False] * random_cs.explanation_task.get_num_hyperparameters()))
152+
expected_var = 0
153+
assert abs(res - expected_var < EPSILON), (
154+
"If no hyperparameter is activated for searching, the variance should be 0."
155+
)

0 commit comments

Comments
 (0)