Skip to content

Commit cd699ba

Browse files
esantorellafacebook-github-bot
authored andcommitted
Introduce SurrogateTestFunction
Summary: **Context**: This next diff will cut over uses of `SurrogateRunner` to use `ParamBasedTestProblemRunner` with a `test_problem` that is the newly introduced `SurrogateTestFunction`, and the following diff after that will bring us down to only one runner class for benchmarking by merging `ParamBasedTestProblemRunner` into `BenchmarkRunner`. Having only one runner will make it easier to enable asynchronous benchmarks. Currently, SurrogateRunner had its own logic for tracking when trials are completed, which would make it difficult to work in with asynchronicity. **Note on naming**: Some names have become non-intuitive in the process of benchmarking. To accord with some future changes I hope to make, I called a new class SurrogateTestFunction, whereas SurrogateParamBasedTestProblem would be more in line with the old naming. The name changes I hope to make: * ParamBasedTestProblemRunner -> nothing, absorbed into BenchmarkRunner * ParamBasedTestProblem -> TestFunction, to emphasize that all it does is generate data (rather than more generally specify the problem we are solving) and that it is deterministic, and to differentiate it from BenchmarkProblem. BenchmarkTestFunction would also be a candidate. * BoTorchTestProblem -> BoTorchTestFunction **Changes in this diff**: * Introduces SurrogateTestFunction, a ParamBasedTestProblem for surrogates, giving it the surrogate-related logic from SurrogateRunner Differential Revision: D64899032
1 parent 03b8c5d commit cd699ba

File tree

4 files changed

+270
-33
lines changed

4 files changed

+270
-33
lines changed

ax/benchmark/runners/surrogate.py

+80
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111

1212
import torch
1313
from ax.benchmark.runners.base import BenchmarkRunner
14+
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
1415
from ax.core.base_trial import BaseTrial, TrialStatus
1516
from ax.core.observation import ObservationFeatures
1617
from ax.core.search_space import SearchSpaceDigest
18+
from ax.core.types import TParamValue
1719
from ax.modelbridge.torch import TorchModelBridge
1820
from ax.utils.common.base import Base
1921
from ax.utils.common.equality import equality_typechecker
@@ -22,6 +24,84 @@
2224
from torch import Tensor
2325

2426

27+
@dataclass(kw_only=True)
28+
class SurrogateTestFunction(ParamBasedTestProblem):
29+
"""
30+
Data-generating function for surrogate benchmark problems.
31+
32+
Args:
33+
name: The name of the runner.
34+
outcome_names: Names of outcomes to return in `evaluate_true`, if the
35+
surrogate produces more outcomes than are needed.
36+
_surrogate: Either `None`, or a `TorchModelBridge` surrogate to use
37+
for generating observations. If `None`, `get_surrogate_and_datasets`
38+
must not be None and will be used to generate the surrogate when it
39+
is needed.
40+
_datasets: Either `None`, or the `SupervisedDataset`s used to fit
41+
the surrogate model. If `None`, `get_surrogate_and_datasets` must
42+
not be None and will be used to generate the datasets when they are
43+
needed.
44+
get_surrogate_and_datasets: Function that returns the surrogate and
45+
datasets, to allow for lazy construction. If
46+
`get_surrogate_and_datasets` is not provided, `surrogate` and
47+
`datasets` must be provided, and vice versa.
48+
"""
49+
50+
name: str
51+
outcome_names: list[str]
52+
_surrogate: TorchModelBridge | None = None
53+
_datasets: list[SupervisedDataset] | None = None
54+
get_surrogate_and_datasets: (
55+
None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]]
56+
) = None
57+
58+
def __post_init__(self) -> None:
59+
if self.get_surrogate_and_datasets is None and (
60+
self._surrogate is None or self._datasets is None
61+
):
62+
raise ValueError(
63+
"If `get_surrogate_and_datasets` is None, `_surrogate` "
64+
"and `_datasets` must not be None, and vice versa."
65+
)
66+
67+
def set_surrogate_and_datasets(self) -> None:
68+
self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)()
69+
70+
@property
71+
def surrogate(self) -> TorchModelBridge:
72+
if self._surrogate is None:
73+
self.set_surrogate_and_datasets()
74+
return none_throws(self._surrogate)
75+
76+
@property
77+
def datasets(self) -> list[SupervisedDataset]:
78+
if self._datasets is None:
79+
self.set_surrogate_and_datasets()
80+
return none_throws(self._datasets)
81+
82+
def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
83+
# We're ignoring the uncertainty predictions of the surrogate model here and
84+
# use the mean predictions as the outcomes (before potentially adding noise)
85+
means, _ = self.surrogate.predict(
86+
# pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict
87+
observation_features=[ObservationFeatures(params)]
88+
)
89+
means = [means[name][0] for name in self.outcome_names]
90+
return torch.tensor(
91+
means,
92+
device=self.surrogate.device,
93+
dtype=self.surrogate.dtype,
94+
)
95+
96+
@equality_typechecker
97+
def __eq__(self, other: Base) -> bool:
98+
if type(other) is not type(self):
99+
return False
100+
101+
# Don't check surrogate, datasets, or callable
102+
return self.name == other.name
103+
104+
25105
@dataclass
26106
class SurrogateRunner(BenchmarkRunner):
27107
"""Runner for surrogate benchmark problems.

ax/benchmark/tests/runners/test_botorch_test_problem.py

+59-26
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
# pyre-strict
88

99

10+
from contextlib import nullcontext
1011
from dataclasses import replace
1112
from itertools import product
12-
from unittest.mock import Mock
13+
from unittest.mock import Mock, patch
1314

1415
import numpy as np
1516

@@ -19,13 +20,17 @@
1920
BoTorchTestProblem,
2021
ParamBasedTestProblemRunner,
2122
)
23+
from ax.benchmark.runners.surrogate import SurrogateTestFunction
2224
from ax.core.arm import Arm
2325
from ax.core.base_trial import TrialStatus
2426
from ax.core.trial import Trial
2527
from ax.exceptions.core import UnsupportedError
2628
from ax.utils.common.testutils import TestCase
2729
from ax.utils.common.typeutils import checked_cast
28-
from ax.utils.testing.benchmark_stubs import TestParamBasedTestProblem
30+
from ax.utils.testing.benchmark_stubs import (
31+
get_soo_surrogate_test_function,
32+
TestParamBasedTestProblem,
33+
)
2934
from botorch.test_functions.multi_objective import BraninCurrin
3035
from botorch.test_functions.synthetic import Ackley, ConstrainedHartmann, Hartmann
3136
from botorch.utils.transforms import normalize
@@ -130,34 +135,35 @@ def test_synthetic_runner(self) -> None:
130135
for num_outcomes in (1, 2)
131136
for noise_std in (0.0, [float(i) for i in range(num_outcomes)])
132137
]
133-
for test_problem, noise_std, num_outcomes in botorch_cases + param_based_cases:
134-
is_constrained = isinstance(
135-
test_problem, BoTorchTestProblem
136-
) and isinstance(test_problem.botorch_problem, ConstrainedHartmann)
137-
num_constraints = 1 if is_constrained else 0
138-
outcome_names = [
139-
f"objective_{i}" for i in range(num_outcomes - num_constraints)
140-
] + ["constraint"] * num_constraints
138+
surrogate_cases = [
139+
(get_soo_surrogate_test_function(lazy=False), noise_std, 1)
140+
for noise_std in (0.0, 1.0, [0.0], [1.0])
141+
]
142+
for test_problem, noise_std, num_outcomes in (
143+
botorch_cases + param_based_cases + surrogate_cases
144+
):
145+
# Set up outcome names
146+
if isinstance(test_problem, BoTorchTestProblem):
147+
if isinstance(test_problem.botorch_problem, ConstrainedHartmann):
148+
outcome_names = ["objective_0", "constraint"]
149+
else:
150+
outcome_names = ["objective_0"]
151+
elif isinstance(test_problem, TestParamBasedTestProblem):
152+
outcome_names = [f"objective_{i}" for i in range(num_outcomes)]
153+
elif isinstance(test_problem, SurrogateTestFunction):
154+
outcome_names = ["branin"]
141155

156+
# Set up runner
142157
runner = ParamBasedTestProblemRunner(
143158
test_problem=test_problem,
144159
outcome_names=outcome_names,
145160
noise_std=noise_std,
146161
)
147-
modified_bounds = (
148-
test_problem.modified_bounds
149-
if isinstance(test_problem, BoTorchTestProblem)
150-
else None
151-
)
152-
153-
test_description: str = (
154-
f"test problem: {test_problem.__class__.__name__}, "
155-
f"modified_bounds: {modified_bounds}, "
156-
f"noise_std: {noise_std}."
157-
)
158-
is_botorch = isinstance(test_problem, BoTorchTestProblem)
159162

160-
with self.subTest(f"Test basic construction, {test_description}"):
163+
test_description = f"{test_problem=}, {noise_std=}"
164+
with self.subTest(
165+
f"Test basic construction, {test_problem=}, {noise_std=}"
166+
):
161167
self.assertIs(runner.test_problem, test_problem)
162168
self.assertEqual(runner.outcome_names, outcome_names)
163169
if isinstance(noise_std, list):
@@ -183,6 +189,7 @@ def test_synthetic_runner(self) -> None:
183189
test_problem.botorch_problem.bounds.dtype, torch.double
184190
)
185191

192+
is_botorch = isinstance(test_problem, BoTorchTestProblem)
186193
with self.subTest(f"test `get_Y_true()`, {test_description}"):
187194
dim = 6 if is_botorch else 9
188195
X = torch.rand(1, dim, dtype=torch.double)
@@ -195,7 +202,20 @@ def test_synthetic_runner(self) -> None:
195202
)
196203
params = dict(zip(param_names, (x.item() for x in X.unbind(-1))))
197204

198-
Y = runner.get_Y_true(params=params)
205+
with (
206+
nullcontext()
207+
if not isinstance(test_problem, SurrogateTestFunction)
208+
else patch.object(
209+
# pyre-fixme: ParamBasedTestProblem` has no attribute
210+
# `_surrogate`.
211+
runner.test_problem._surrogate,
212+
"predict",
213+
return_value=({"branin": [4.2]}, None),
214+
)
215+
):
216+
Y = runner.get_Y_true(params=params)
217+
oracle = runner.evaluate_oracle(parameters=params)
218+
199219
if (
200220
isinstance(test_problem, BoTorchTestProblem)
201221
and test_problem.modified_bounds is not None
@@ -221,12 +241,13 @@ def test_synthetic_runner(self) -> None:
221241
)
222242
else:
223243
expected_Y = obj
244+
elif isinstance(test_problem, SurrogateTestFunction):
245+
expected_Y = torch.tensor([4.2], dtype=torch.double)
224246
else:
225247
expected_Y = torch.full(
226248
torch.Size([2]), X.pow(2).sum().item(), dtype=torch.double
227249
)
228250
self.assertTrue(torch.allclose(Y, expected_Y))
229-
oracle = runner.evaluate_oracle(parameters=params)
230251
self.assertTrue(np.equal(Y.numpy(), oracle).all())
231252

232253
with self.subTest(f"test `run()`, {test_description}"):
@@ -237,7 +258,19 @@ def test_synthetic_runner(self) -> None:
237258
trial.arms = [arm]
238259
trial.arm = arm
239260
trial.index = 0
240-
res = runner.run(trial=trial)
261+
262+
with (
263+
nullcontext()
264+
if not isinstance(test_problem, SurrogateTestFunction)
265+
else patch.object(
266+
# pyre-fixme: ParamBasedTestProblem` has no attribute
267+
# `_surrogate`.
268+
runner.test_problem._surrogate,
269+
"predict",
270+
return_value=({"branin": [4.2]}, None),
271+
)
272+
):
273+
res = runner.run(trial=trial)
241274
self.assertEqual({"Ys", "Ystds", "outcome_names"}, res.keys())
242275
self.assertEqual({"0_0"}, res["Ys"].keys())
243276

ax/benchmark/tests/runners/test_surrogate_runner.py

+74-4
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,82 @@
88
from unittest.mock import MagicMock, patch
99

1010
import torch
11-
from ax.benchmark.runners.surrogate import SurrogateRunner
11+
from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction
1212
from ax.core.parameter import ParameterType, RangeParameter
1313
from ax.core.search_space import SearchSpace
1414
from ax.modelbridge.torch import TorchModelBridge
1515
from ax.utils.common.testutils import TestCase
16-
from ax.utils.testing.benchmark_stubs import get_soo_surrogate
16+
from ax.utils.testing.benchmark_stubs import (
17+
get_soo_surrogate_legacy,
18+
get_soo_surrogate_test_function,
19+
)
20+
21+
22+
class TestSurrogateTestFunction(TestCase):
23+
def test_surrogate_test_function(self) -> None:
24+
# Construct a search space with log-scale parameters.
25+
for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}):
26+
with self.subTest(noise_std=noise_std):
27+
surrogate = MagicMock()
28+
mock_mean = torch.tensor([[0.1234]], dtype=torch.double)
29+
surrogate.predict = MagicMock(return_value=(mock_mean, 0))
30+
surrogate.device = torch.device("cpu")
31+
surrogate.dtype = torch.double
32+
test_function = SurrogateTestFunction(
33+
name="test test function",
34+
outcome_names=["dummy metric"],
35+
_surrogate=surrogate,
36+
_datasets=[],
37+
)
38+
self.assertEqual(test_function.name, "test test function")
39+
self.assertIs(test_function.surrogate, surrogate)
40+
41+
def test_lazy_instantiation(self) -> None:
42+
test_function = get_soo_surrogate_test_function()
43+
44+
self.assertIsNone(test_function._surrogate)
45+
self.assertIsNone(test_function._datasets)
46+
47+
# Accessing `surrogate` sets datasets and surrogate
48+
self.assertIsInstance(test_function.surrogate, TorchModelBridge)
49+
self.assertIsInstance(test_function._surrogate, TorchModelBridge)
50+
self.assertIsInstance(test_function._datasets, list)
51+
52+
# Accessing `datasets` also sets datasets and surrogate
53+
test_function = get_soo_surrogate_test_function()
54+
self.assertIsInstance(test_function.datasets, list)
55+
self.assertIsInstance(test_function._surrogate, TorchModelBridge)
56+
self.assertIsInstance(test_function._datasets, list)
57+
58+
with patch.object(
59+
test_function,
60+
"get_surrogate_and_datasets",
61+
wraps=test_function.get_surrogate_and_datasets,
62+
) as mock_get_surrogate_and_datasets:
63+
test_function.surrogate
64+
mock_get_surrogate_and_datasets.assert_not_called()
65+
66+
def test_instantiation_raises_with_missing_args(self) -> None:
67+
with self.assertRaisesRegex(
68+
ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and "
69+
):
70+
SurrogateTestFunction(name="test runner", outcome_names=[])
71+
72+
def test_equality(self) -> None:
73+
def _construct_test_function(name: str) -> SurrogateTestFunction:
74+
return SurrogateTestFunction(
75+
name=name,
76+
_surrogate=MagicMock(),
77+
_datasets=[],
78+
outcome_names=["dummy_metric"],
79+
)
80+
81+
runner_1 = _construct_test_function("test 1")
82+
runner_2 = _construct_test_function("test 2")
83+
runner_1a = _construct_test_function("test 1")
84+
self.assertEqual(runner_1, runner_1a)
85+
self.assertNotEqual(runner_1, runner_2)
86+
self.assertNotEqual(runner_1, 1)
1787

1888

1989
class TestSurrogateRunner(TestCase):
@@ -49,7 +119,7 @@ def test_surrogate_runner(self) -> None:
49119
self.assertEqual(runner.noise_stds, noise_std)
50120

51121
def test_lazy_instantiation(self) -> None:
52-
runner = get_soo_surrogate().runner
122+
runner = get_soo_surrogate_legacy().runner
53123

54124
self.assertIsNone(runner._surrogate)
55125
self.assertIsNone(runner._datasets)
@@ -60,7 +130,7 @@ def test_lazy_instantiation(self) -> None:
60130
self.assertIsInstance(runner._datasets, list)
61131

62132
# Accessing `datasets` also sets datasets and surrogate
63-
runner = get_soo_surrogate().runner
133+
runner = get_soo_surrogate_legacy().runner
64134
self.assertIsInstance(runner.datasets, list)
65135
self.assertIsInstance(runner._surrogate, TorchModelBridge)
66136
self.assertIsInstance(runner._datasets, list)

0 commit comments

Comments
 (0)