Skip to content

Commit 0253ca3

Browse files
David Erikssonfacebook-github-bot
David Eriksson
authored andcommitted
Change list surrogate construction (#1247)
Summary: Pull Request resolved: #1247 Change `submodel_outcome_transforms`, `submodel_input_transforms`, `submodel_covar_module_class`, `submodel_covar_module_options`, `submodel_likelihood_class`, `submodel_likelihood_options` to only accept one input for all models. Reviewed By: lena-kashtelyan Differential Revision: D40164298 fbshipit-source-id: 2799f300069600a7cfa7ebee3d6783c7ae03d346
1 parent 6912b80 commit 0253ca3

File tree

2 files changed

+147
-104
lines changed

2 files changed

+147
-104
lines changed

ax/models/torch/botorch_modular/list_surrogate.py

+60-41
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import inspect
10+
from copy import deepcopy
1011

1112
from logging import Logger
1213
from typing import Any, Dict, List, Optional, Type
@@ -49,16 +50,22 @@ class ListSurrogate(Surrogate):
4950
``submodel_outions_per_outcome[submodel_outcome]`` (individual).
5051
mll_class: ``MarginalLogLikelihood`` class to use for model-fitting.
5152
mll_options: Dictionary of options / kwargs for the MLL.
52-
submodel_outcome_transforms: A dictionary mapping each outcome to a
53-
BoTorch outcome transform. Gets passed down to the BoTorch ``Model``s.
53+
submodel_outcome_transforms: An outcome transform that will be used
54+
by all outcomes. Gets passed down to the BoTorch ``Model``s.
5455
To use multiple outcome transforms on a submodel, chain them
5556
together using ``ChainedOutcomeTransform``.
56-
submodel_input_transforms: A dictionary mapping each outcome to a
57-
BoTorch input transform. Gets passed down to the BoTorch ``Model``.
57+
submodel_input_transforms: An input transform that will be used
58+
by all outcomes. Gets passed down to the BoTorch ``Model``.
5859
If sharing a single ``InputTransform`` object across submodels is
5960
preferred, pass in a dictionary where each outcome key references the
6061
same ``InputTransform`` object. To use multiple input transfroms on
6162
a submodel, chain them together using ``ChainedInputTransform``.
63+
submodel_covar_module_class: A covar module that will be used by all outcomes.
64+
submodel_covar_module_options: Options for a BoTorch covar module or options
65+
that will be used by all outcomes.
66+
submodel_likelihood_class: A likelihood that will be used by all outcomes.
67+
submodel_likelihood_options: Options for a BoTorch likelihood or options that
68+
will be used by all outcomes.
6269
"""
6370

6471
botorch_submodel_class_per_outcome: Dict[str, Type[Model]]
@@ -67,12 +74,12 @@ class ListSurrogate(Surrogate):
6774
submodel_options: Dict[str, Any]
6875
mll_class: Type[MarginalLogLikelihood]
6976
mll_options: Dict[str, Any]
70-
submodel_outcome_transforms: Dict[str, OutcomeTransform]
71-
submodel_input_transforms: Dict[str, InputTransform]
72-
submodel_covar_module_class: Dict[str, Type[Kernel]]
73-
submodel_covar_module_options: Dict[str, Dict[str, Any]]
74-
submodel_likelihood_class: Dict[str, Type[Likelihood]]
75-
submodel_likelihood_options: Dict[str, Dict[str, Any]]
77+
submodel_outcome_transforms: Optional[OutcomeTransform]
78+
submodel_input_transforms: Optional[InputTransform]
79+
submodel_covar_module_class: Optional[Type[Kernel]]
80+
submodel_covar_module_options: Dict[str, Any]
81+
submodel_likelihood_class: Optional[Type[Likelihood]]
82+
submodel_likelihood_options: Dict[str, Any]
7683
_model: Optional[Model] = None
7784
# Special setting for surrogates instantiated via `Surrogate.from_botorch`,
7885
# to avoid re-constructing the underlying BoTorch model on `Surrogate.fit`
@@ -87,12 +94,12 @@ def __init__(
8794
submodel_options: Optional[Dict[str, Any]] = None,
8895
mll_class: Type[MarginalLogLikelihood] = ExactMarginalLogLikelihood,
8996
mll_options: Optional[Dict[str, Any]] = None,
90-
submodel_outcome_transforms: Optional[Dict[str, OutcomeTransform]] = None,
91-
submodel_input_transforms: Optional[Dict[str, InputTransform]] = None,
92-
submodel_covar_module_class: Optional[Dict[str, Type[Kernel]]] = None,
93-
submodel_covar_module_options: Optional[Dict[str, Dict[str, Any]]] = None,
94-
submodel_likelihood_class: Optional[Dict[str, Type[Likelihood]]] = None,
95-
submodel_likelihood_options: Optional[Dict[str, Dict[str, Any]]] = None,
97+
submodel_outcome_transforms: Optional[OutcomeTransform] = None,
98+
submodel_input_transforms: Optional[InputTransform] = None,
99+
submodel_covar_module_class: Optional[Type[Kernel]] = None,
100+
submodel_covar_module_options: Optional[Dict[str, Any]] = None,
101+
submodel_likelihood_class: Optional[Type[Likelihood]] = None,
102+
submodel_likelihood_options: Optional[Dict[str, Any]] = None,
96103
) -> None:
97104
if not bool(botorch_submodel_class_per_outcome) ^ bool(botorch_submodel_class):
98105
raise ValueError( # pragma: no cover
@@ -106,11 +113,11 @@ def __init__(
106113
self.botorch_submodel_class = botorch_submodel_class
107114
self.submodel_options_per_outcome = submodel_options_per_outcome or {}
108115
self.submodel_options = submodel_options or {}
109-
self.submodel_outcome_transforms = submodel_outcome_transforms or {}
110-
self.submodel_input_transforms = submodel_input_transforms or {}
111-
self.submodel_covar_module_class = submodel_covar_module_class or {}
116+
self.submodel_outcome_transforms = submodel_outcome_transforms
117+
self.submodel_input_transforms = submodel_input_transforms
118+
self.submodel_covar_module_class = submodel_covar_module_class
112119
self.submodel_covar_module_options = submodel_covar_module_options or {}
113-
self.submodel_likelihood_class = submodel_likelihood_class or {}
120+
self.submodel_likelihood_class = submodel_likelihood_class
114121
self.submodel_likelihood_options = submodel_likelihood_options or {}
115122
super().__init__(
116123
botorch_model_class=ModelListGP,
@@ -159,7 +166,6 @@ def construct(
159166
# Construct input perturbation if doing robust optimization.
160167
# NOTE: Doing this here rather than in `_set_formatted_inputs` to make sure
161168
# we use the same perturbations for each sub-model.
162-
submodel_input_transforms = self.submodel_input_transforms.copy()
163169
robust_digest: Optional[Dict[str, Any]] = kwargs.get("robust_digest", None)
164170
if robust_digest is not None:
165171
if len(robust_digest["environmental_variables"]):
@@ -176,15 +182,15 @@ def construct(
176182
perturbation_set=samples, multiplicative=robust_digest["multiplicative"]
177183
)
178184

179-
for m in metric_names:
180-
if submodel_input_transforms.get(m) is not None:
181-
# TODO: Support mixing with user supplied transforms.
182-
raise NotImplementedError(
183-
"User supplied input transforms are not supported "
184-
"in robust optimization."
185-
)
186-
else:
187-
submodel_input_transforms[m] = perturbation
185+
if self.submodel_input_transforms is not None:
186+
# TODO: Support mixing with user supplied transforms.
187+
raise NotImplementedError(
188+
"User supplied input transforms are not supported "
189+
"in robust optimization."
190+
)
191+
submodel_input_transforms = perturbation
192+
else:
193+
submodel_input_transforms = self.submodel_input_transforms
188194

189195
submodels = []
190196
for m, dataset in zip(metric_names, datasets):
@@ -218,20 +224,33 @@ def construct(
218224
# way to filter the arguments. See the comment in `Surrogate.construct`
219225
# regarding potential use of a `ModelFactory` in the future.
220226
model_cls_args = inspect.getfullargspec(model_cls).args
221-
covar_module_class = self.submodel_covar_module_class.get(m)
222-
covar_module_options = self.submodel_covar_module_options.get(m)
223-
likelihood_class = self.submodel_likelihood_class.get(m)
224-
likelihood_options = self.submodel_likelihood_options.get(m)
225-
outcome_transform = self.submodel_outcome_transforms.get(m)
226-
input_transform = submodel_input_transforms.get(m)
227-
228227
self._set_formatted_inputs(
229228
formatted_model_inputs=formatted_model_inputs,
230229
inputs=[
231-
["covar_module", covar_module_class, covar_module_options, None],
232-
["likelihood", likelihood_class, likelihood_options, None],
233-
["outcome_transform", None, None, outcome_transform],
234-
["input_transform", None, None, input_transform],
230+
[
231+
"covar_module",
232+
self.submodel_covar_module_class,
233+
self.submodel_covar_module_options,
234+
None,
235+
],
236+
[
237+
"likelihood",
238+
self.submodel_likelihood_class,
239+
self.submodel_likelihood_options,
240+
None,
241+
],
242+
[
243+
"outcome_transform",
244+
None,
245+
None,
246+
deepcopy(self.submodel_outcome_transforms),
247+
],
248+
[
249+
"input_transform",
250+
None,
251+
None,
252+
deepcopy(submodel_input_transforms),
253+
],
235254
],
236255
dataset=dataset,
237256
botorch_model_class_args=model_cls_args,

ax/models/torch/tests/test_list_surrogate.py

+87-63
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
from unittest.mock import Mock, patch
89

910
import numpy as np
10-
1111
import torch
1212
from ax.core.search_space import SearchSpaceDigest
1313
from ax.exceptions.core import UserInputError
@@ -37,7 +37,8 @@
3737
GaussianLikelihood,
3838
Likelihood, # noqa: F401
3939
)
40-
from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood
40+
from gpytorch.mlls import ExactMarginalLogLikelihood
41+
4142

4243
SURROGATE_PATH = f"{Surrogate.__module__}"
4344
UTILS_PATH = f"{choose_model_class.__module__}"
@@ -58,6 +59,9 @@ def setUp(self) -> None:
5859
Xs1, Ys1, Yvars1, bounds, _, _, _ = get_torch_test_data(
5960
dtype=self.dtype, task_features=self.search_space_digest.task_features
6061
)
62+
# Change the inputs/outputs a bit so the data isn't identical
63+
Xs1[0] *= 2
64+
Ys1[0] += 1
6165
Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data(
6266
dtype=self.dtype, task_features=self.search_space_digest.task_features
6367
)
@@ -352,19 +356,12 @@ def test_fit(
352356
)
353357

354358
def test_with_botorch_transforms(self) -> None:
355-
input_transforms = {"outcome_1": Normalize(d=3), "outcome_2": Normalize(d=3)}
356-
outcome_transforms = {
357-
"outcome_1": Standardize(m=1),
358-
"outcome_2": Standardize(m=1),
359-
}
359+
input_transforms = Normalize(d=3)
360+
outcome_transforms = Standardize(m=1)
360361
surrogate = ListSurrogate(
361362
botorch_submodel_class=SingleTaskGPWithDifferentConstructor,
362363
mll_class=ExactMarginalLogLikelihood,
363-
# pyre-fixme[6]: For 3rd param expected `Optional[Dict[str,
364-
# OutcomeTransform]]` but got `Dict[str, Standardize]`.
365364
submodel_outcome_transforms=outcome_transforms,
366-
# pyre-fixme[6]: For 4th param expected `Optional[Dict[str,
367-
# InputTransform]]` but got `Dict[str, Normalize]`.
368365
submodel_input_transforms=input_transforms,
369366
)
370367
with self.assertRaisesRegex(UserInputError, "The BoTorch model class"):
@@ -375,23 +372,34 @@ def test_with_botorch_transforms(self) -> None:
375372
surrogate = ListSurrogate(
376373
botorch_submodel_class=SingleTaskGP,
377374
mll_class=ExactMarginalLogLikelihood,
378-
# pyre-fixme[6]: For 3rd param expected `Optional[Dict[str,
379-
# OutcomeTransform]]` but got `Dict[str, Standardize]`.
380375
submodel_outcome_transforms=outcome_transforms,
381-
# pyre-fixme[6]: For 4th param expected `Optional[Dict[str,
382-
# InputTransform]]` but got `Dict[str, Normalize]`.
383376
submodel_input_transforms=input_transforms,
384377
)
385378
surrogate.construct(
386379
datasets=self.supervised_training_data,
387380
metric_names=self.outcomes,
388381
)
389-
models = surrogate.model.models
390-
for i, outcome in enumerate(("outcome_1", "outcome_2")):
391-
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
392-
self.assertIs(models[i].outcome_transform, outcome_transforms[outcome])
393-
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
394-
self.assertIs(models[i].input_transform, input_transforms[outcome])
382+
# pyre-ignore [9]
383+
models: torch.nn.modules.container.ModuleList = surrogate.model.models
384+
for i in range(2):
385+
self.assertIsInstance(models[i].outcome_transform, Standardize)
386+
self.assertIsInstance(models[i].input_transform, Normalize)
387+
self.assertEqual(models[0].outcome_transform.means.item(), 4.5)
388+
self.assertEqual(models[1].outcome_transform.means.item(), 3.5)
389+
self.assertAlmostEqual(
390+
models[0].outcome_transform.stdvs.item(), 1 / math.sqrt(2)
391+
)
392+
self.assertAlmostEqual(
393+
models[1].outcome_transform.stdvs.item(), 1 / math.sqrt(2)
394+
)
395+
self.assertTrue(
396+
torch.all(
397+
torch.isclose(
398+
models[0].input_transform.bounds,
399+
2 * models[1].input_transform.bounds, # pyre-ignore
400+
)
401+
)
402+
)
395403

396404
def test_serialize_attributes_as_kwargs(self) -> None:
397405
expected = self.surrogate.__dict__
@@ -411,48 +419,64 @@ def test_serialize_attributes_as_kwargs(self) -> None:
411419
self.assertEqual(self.surrogate._serialize_attributes_as_kwargs(), expected)
412420

413421
def test_construct_custom_model(self) -> None:
414-
noise_con1, noise_con2 = Interval(1e-6, 1e-1), GreaterThan(1e-4)
415-
surrogate = ListSurrogate(
416-
botorch_submodel_class=SingleTaskGP,
417-
mll_class=LeaveOneOutPseudoLikelihood,
418-
submodel_covar_module_class={
419-
"outcome_1": RBFKernel,
420-
"outcome_2": MaternKernel,
421-
},
422-
submodel_covar_module_options={
423-
"outcome_1": {"ard_num_dims": 1},
424-
"outcome_2": {"ard_num_dims": 3},
425-
},
426-
submodel_likelihood_class={
427-
"outcome_1": GaussianLikelihood,
428-
"outcome_2": GaussianLikelihood,
429-
},
430-
submodel_likelihood_options={
431-
"outcome_1": {"noise_constraint": noise_con1},
432-
"outcome_2": {"noise_constraint": noise_con2},
433-
},
434-
)
435-
surrogate.construct(
436-
datasets=self.supervised_training_data,
437-
metric_names=self.outcomes,
438-
)
439-
# pyre-fixme[16]: Optional type has no attribute `models`.
440-
self.assertEqual(len(surrogate._model.models), 2)
441-
self.assertEqual(surrogate.mll_class, LeaveOneOutPseudoLikelihood)
442-
for i, m in enumerate(surrogate._model.models):
443-
self.assertEqual(type(m.likelihood), GaussianLikelihood)
444-
if i == 0:
445-
self.assertEqual(type(m.covar_module), RBFKernel)
446-
self.assertEqual(m.covar_module.ard_num_dims, 1)
447-
self.assertEqual(
448-
m.likelihood.noise_covar.raw_noise_constraint, noise_con1
449-
)
450-
else:
422+
noise_constraint = Interval(1e-4, 10.0)
423+
for submodel_covar_module_options, submodel_likelihood_options in [
424+
[{"ard_num_dims": 3}, {"noise_constraint": noise_constraint}],
425+
[{}, {}],
426+
]:
427+
surrogate = ListSurrogate(
428+
botorch_submodel_class=SingleTaskGP,
429+
mll_class=ExactMarginalLogLikelihood,
430+
submodel_covar_module_class=MaternKernel,
431+
submodel_covar_module_options=submodel_covar_module_options,
432+
submodel_likelihood_class=GaussianLikelihood,
433+
submodel_likelihood_options=submodel_likelihood_options,
434+
submodel_input_transforms=Normalize(d=3),
435+
submodel_outcome_transforms=Standardize(m=1),
436+
)
437+
surrogate.construct(
438+
datasets=self.supervised_training_data,
439+
metric_names=self.outcomes,
440+
)
441+
# pyre-fixme[16]: Optional type has no attribute `models`.
442+
self.assertEqual(len(surrogate._model.models), 2)
443+
self.assertEqual(surrogate.mll_class, ExactMarginalLogLikelihood)
444+
# Make sure we properly copied the transforms
445+
self.assertNotEqual(
446+
id(surrogate._model.models[0].input_transform),
447+
id(surrogate._model.models[1].input_transform),
448+
)
449+
self.assertNotEqual(
450+
id(surrogate._model.models[0].outcome_transform),
451+
id(surrogate._model.models[1].outcome_transform),
452+
)
453+
454+
for m in surrogate._model.models:
455+
self.assertEqual(type(m.likelihood), GaussianLikelihood)
451456
self.assertEqual(type(m.covar_module), MaternKernel)
452-
self.assertEqual(m.covar_module.ard_num_dims, 3)
453-
self.assertEqual(
454-
m.likelihood.noise_covar.raw_noise_constraint, noise_con2
455-
)
457+
if submodel_covar_module_options:
458+
self.assertEqual(m.covar_module.ard_num_dims, 3)
459+
else:
460+
self.assertEqual(m.covar_module.ard_num_dims, None)
461+
if submodel_likelihood_options:
462+
self.assertEqual(
463+
type(m.likelihood.noise_covar.raw_noise_constraint), Interval
464+
)
465+
self.assertEqual(
466+
m.likelihood.noise_covar.raw_noise_constraint.lower_bound,
467+
noise_constraint.lower_bound,
468+
)
469+
self.assertEqual(
470+
m.likelihood.noise_covar.raw_noise_constraint.upper_bound,
471+
noise_constraint.upper_bound,
472+
)
473+
else:
474+
self.assertEqual(
475+
type(m.likelihood.noise_covar.raw_noise_constraint), GreaterThan
476+
)
477+
self.assertEqual(
478+
m.likelihood.noise_covar.raw_noise_constraint.lower_bound, 1e-4
479+
)
456480

457481
def test_w_robust_digest(self) -> None:
458482
surrogate = ListSurrogate(
@@ -470,7 +494,7 @@ def test_w_robust_digest(self) -> None:
470494
"environmental_variables": [],
471495
"multiplicative": False,
472496
}
473-
surrogate.submodel_input_transforms = {self.outcomes[0]: Normalize(d=1)}
497+
surrogate.submodel_input_transforms = Normalize(d=1)
474498
with self.assertRaisesRegex(NotImplementedError, "input transforms"):
475499
surrogate.construct(
476500
datasets=self.supervised_training_data,

0 commit comments

Comments
 (0)