Skip to content

Commit e143fd2

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add e2e tests with HSS (facebook#2345)
Summary: Pull Request resolved: facebook#2345 Adds tests for candidate generation, prediction & cross validation using HSS with MBM defaults. Notably, the tests do not involve any mocks, except for `fast_botorch_optimize`. The tests also highlight a weakness of current setup as it does not allow making predictions using valid parameterizations, when those parameterizations lack the inactive parameters (i.e. they're not full parmeterizations). Reviewed By: Balandat Differential Revision: D55952289 fbshipit-source-id: 90dbb923d387d05ef05af4e013e6cec38174f8e7
1 parent 203557c commit e143fd2

File tree

1 file changed

+241
-0
lines changed

1 file changed

+241
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from contextlib import ExitStack
9+
from random import random
10+
from typing import List
11+
12+
from ax.core.experiment import Experiment
13+
from ax.core.objective import Objective
14+
from ax.core.observation import ObservationFeatures
15+
from ax.core.optimization_config import OptimizationConfig
16+
from ax.core.parameter import (
17+
ChoiceParameter,
18+
FixedParameter,
19+
ParameterType,
20+
RangeParameter,
21+
)
22+
from ax.core.search_space import HierarchicalSearchSpace
23+
from ax.core.trial import Trial
24+
from ax.metrics.noisy_function import GenericNoisyFunctionMetric
25+
from ax.modelbridge.cross_validation import cross_validate
26+
from ax.modelbridge.registry import Models
27+
from ax.runners.synthetic import SyntheticRunner
28+
from ax.utils.common.constants import Keys
29+
from ax.utils.common.testutils import TestCase
30+
from ax.utils.common.typeutils import checked_cast, not_none
31+
from ax.utils.testing.mock import fast_botorch_optimize
32+
33+
34+
class TestHierarchicalSearchSpace(TestCase):
35+
"""Tests for various modelbridge functionality with commonly used transforms
36+
using hierarchical search spaces (HSS).
37+
"""
38+
39+
def setUp(self) -> None:
40+
int_range = RangeParameter(
41+
name="int_range",
42+
parameter_type=ParameterType.INT,
43+
lower=0,
44+
upper=10,
45+
)
46+
str_choice = ChoiceParameter(
47+
name="str_choice",
48+
parameter_type=ParameterType.STRING,
49+
values=["a", "b", "c"],
50+
)
51+
fixed_root = FixedParameter(
52+
name="root",
53+
parameter_type=ParameterType.STRING,
54+
value="root",
55+
dependents={"root": ["int_range", "str_choice"]},
56+
)
57+
# This HSS does not have a real hierarchy.
58+
self.non_hierarchical_hss = HierarchicalSearchSpace(
59+
parameters=[
60+
fixed_root,
61+
int_range,
62+
str_choice,
63+
]
64+
)
65+
choice_root = ChoiceParameter(
66+
name="root",
67+
parameter_type=ParameterType.STRING,
68+
values=["range", "choice"],
69+
dependents={"range": ["int_range"], "choice": ["str_choice"]},
70+
)
71+
# This HSS has a simple hierarchy -- one parameter on each branch.
72+
self.simple_hss = HierarchicalSearchSpace(
73+
parameters=[choice_root, int_range, str_choice]
74+
)
75+
fixed_leaf = FixedParameter(
76+
name="fixed_leaf",
77+
parameter_type=ParameterType.STRING,
78+
value="leaf",
79+
)
80+
middle_choice = ChoiceParameter(
81+
name="middle_choice",
82+
parameter_type=ParameterType.INT,
83+
values=[0, 1],
84+
dependents={0: ["fixed_leaf"], 1: ["int_range", "str_choice"]},
85+
)
86+
int_choice = ChoiceParameter(
87+
name="int_choice",
88+
parameter_type=ParameterType.INT,
89+
values=[0, 1, 2, 3],
90+
is_ordered=False,
91+
)
92+
float_range = RangeParameter(
93+
name="float_range",
94+
parameter_type=ParameterType.FLOAT,
95+
lower=0.0,
96+
upper=5.0,
97+
)
98+
choice_root2 = ChoiceParameter(
99+
name="root2",
100+
parameter_type=ParameterType.BOOL,
101+
values=[True, False],
102+
dependents={True: ["middle_choice", "float_range"], False: ["int_choice"]},
103+
)
104+
# This HSS has a more complex, multi-level hierarchy.
105+
self.complex_hss = HierarchicalSearchSpace(
106+
parameters=[
107+
choice_root2,
108+
int_choice,
109+
middle_choice,
110+
float_range,
111+
fixed_leaf,
112+
int_range,
113+
str_choice,
114+
]
115+
)
116+
117+
@fast_botorch_optimize
118+
def _test_gen_base(
119+
self,
120+
hss: HierarchicalSearchSpace,
121+
expected_num_candidate_params: List[int],
122+
num_sobol_trials: int = 5,
123+
num_bo_trials: int = 5,
124+
) -> Experiment:
125+
"""Test Sobol & MBM candidate generation with HSS using default transforms.
126+
127+
Args:
128+
hss: The hierarchical search space to test.
129+
expected_num_candidate_params: The expected number of parameters in each
130+
candidate. This list should include all possible values, since different
131+
branches of HSS may have different numbers of parameters.
132+
num_sobol_trials: The number of Sobol trials to run.
133+
num_bo_trials: The number of BO trials to run.
134+
135+
Returns:
136+
The experiment with the generated candidates. This can be used to chain
137+
tests for other functionality that requires data.
138+
"""
139+
experiment = Experiment(
140+
name="test_experiment",
141+
search_space=hss,
142+
optimization_config=OptimizationConfig(
143+
objective=Objective(
144+
metric=GenericNoisyFunctionMetric(
145+
name="random", f=lambda _: random()
146+
),
147+
minimize=True,
148+
)
149+
),
150+
runner=SyntheticRunner(),
151+
)
152+
153+
sobol = Models.SOBOL(search_space=hss)
154+
for _ in range(num_sobol_trials):
155+
trial = experiment.new_trial(generator_run=sobol.gen(n=1))
156+
trial.run().mark_completed()
157+
158+
for _ in range(num_bo_trials):
159+
mbm = Models.BOTORCH_MODULAR(
160+
experiment=experiment, data=experiment.fetch_data()
161+
)
162+
trial = experiment.new_trial(generator_run=mbm.gen(n=1))
163+
trial.run().mark_completed()
164+
165+
for t in experiment.trials.values():
166+
trial = checked_cast(Trial, t)
167+
arm = not_none(trial.arm)
168+
self.assertIn(len(arm.parameters), expected_num_candidate_params)
169+
# Check that the trials have the full parameterization recorded.
170+
full_parameterization = not_none(
171+
trial._get_candidate_metadata(arm_name=arm.name)
172+
)[Keys.FULL_PARAMETERIZATION]
173+
self.assertEqual(full_parameterization.keys(), hss.parameters.keys())
174+
175+
return experiment
176+
177+
@fast_botorch_optimize
178+
def _base_test_predict_and_cv(
179+
self,
180+
experiment: Experiment,
181+
expect_errors_with_final_parameterization: bool = False,
182+
) -> None:
183+
"""Test predict and cross validation with a given experiment.
184+
The predict tests are done using the full parameterization, the
185+
final parameterization with the full parameterization recorded in
186+
metadata, and with the final parameterization only. When the final
187+
parameterization lacks some parameters, this may error out.
188+
`expect_errors_with_final_parameterization` arg is used to handle
189+
the `KeyError` that is expected (but should be fixed) in this setting.
190+
"""
191+
mbm = Models.BOTORCH_MODULAR(
192+
experiment=experiment, data=experiment.fetch_data()
193+
)
194+
for t in experiment.trials.values():
195+
trial = checked_cast(Trial, t)
196+
arm = not_none(trial.arm)
197+
final_parameterization = arm.parameters
198+
full_parameterization = not_none(
199+
trial._get_candidate_metadata(arm_name=arm.name)
200+
)[Keys.FULL_PARAMETERIZATION]
201+
# Predict with full parameterization -- this should always work.
202+
mbm.predict([ObservationFeatures(parameters=full_parameterization)])
203+
# Predict with final parameterization -- this may error out :(.
204+
with ExitStack() as es:
205+
if expect_errors_with_final_parameterization:
206+
es.enter_context(self.assertRaises(KeyError))
207+
mbm.predict([ObservationFeatures(parameters=final_parameterization)])
208+
# Predict with final parameterization but include the full parameterization
209+
# in the metadata. This is similar to what happens inside cross_validate.
210+
mbm.predict(
211+
[
212+
ObservationFeatures(
213+
parameters=final_parameterization,
214+
metadata={Keys.FULL_PARAMETERIZATION: full_parameterization},
215+
)
216+
]
217+
)
218+
cv_res = cross_validate(model=mbm)
219+
self.assertEqual(len(cv_res), len(experiment.trials))
220+
221+
def test_with_non_hierarchical_hss(self) -> None:
222+
experiment = self._test_gen_base(
223+
hss=self.non_hierarchical_hss, expected_num_candidate_params=[3]
224+
)
225+
self._base_test_predict_and_cv(experiment=experiment)
226+
227+
def test_with_simple_hss(self) -> None:
228+
experiment = self._test_gen_base(
229+
hss=self.simple_hss, expected_num_candidate_params=[2]
230+
)
231+
self._base_test_predict_and_cv(
232+
experiment=experiment, expect_errors_with_final_parameterization=True
233+
)
234+
235+
def test_with_complex_hss(self) -> None:
236+
experiment = self._test_gen_base(
237+
hss=self.complex_hss, expected_num_candidate_params=[2, 4, 5]
238+
)
239+
self._base_test_predict_and_cv(
240+
experiment=experiment, expect_errors_with_final_parameterization=True
241+
)

0 commit comments

Comments
 (0)