Skip to content

Commit 6f96f36

Browse files
esantorellafacebook-github-bot
authored andcommitted
Fit model instead of mocking out model fitting in test_botorch_moo_defaults.FrontierEvaluatorTest (#3241)
Summary: * Fit a model in which fitting had been mocked out before * Move some logic out of the `setUp` so that we don't fit a model in tests where it isn't needed * In `test_pareto_frontier_raise_error_when_missing_data`, use an unfit model Reviewed By: saitcakmak Differential Revision: D68214262
1 parent 3c68bd3 commit 6f96f36

File tree

1 file changed

+55
-39
lines changed

1 file changed

+55
-39
lines changed

ax/models/tests/test_botorch_moo_defaults.py

+55-39
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from contextlib import ExitStack
1010
from typing import Any, cast
1111
from unittest import mock
12+
from warnings import catch_warnings, simplefilter
1213

1314
import numpy as np
1415
import torch
1516
from ax.core.search_space import SearchSpaceDigest
1617
from ax.models.torch.botorch_defaults import NO_OBSERVED_POINTS_MESSAGE
18+
from ax.models.torch.botorch_modular.model import BoTorchModel
1719
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
1820
from ax.models.torch.botorch_moo_defaults import (
1921
get_outcome_constraint_transforms,
@@ -24,12 +26,15 @@
2426
pareto_frontier_evaluator,
2527
)
2628
from ax.models.torch.utils import _get_X_pending_and_observed
29+
from ax.models.torch_base import TorchModel
2730
from ax.utils.common.random import with_rng_seed
2831
from ax.utils.common.testutils import TestCase
32+
from ax.utils.testing.mock import mock_botorch_optimize_context_manager
2933
from botorch.models.gp_regression import SingleTaskGP
3034
from botorch.utils.datasets import SupervisedDataset
3135
from botorch.utils.multi_objective.hypervolume import infer_reference_point
3236
from botorch.utils.testing import MockModel, MockPosterior
37+
from gpytorch.utils.warnings import NumericalWarning
3338
from torch._tensor import Tensor
3439

3540

@@ -42,12 +47,23 @@
4247
FIT_MODEL_MO_PATH = "ax.models.torch.botorch_defaults.fit_gpytorch_mll"
4348

4449

45-
# pyre-fixme[2]: Parameter must be annotated.
46-
def dummy_predict(model, X) -> tuple[Tensor, Tensor]:
47-
# Add column to X that is a product of previous elements.
48-
mean = torch.cat([X, torch.prod(X, dim=1).reshape(-1, 1)], dim=1)
49-
cov = torch.zeros(mean.shape[0], mean.shape[1], mean.shape[1])
50-
return mean, cov
50+
def _fit_model(
51+
model: TorchModel, X: torch.Tensor, Y: torch.Tensor, Yvar: torch.Tensor
52+
) -> None:
53+
bounds = [(0.0, 4.0), (0.0, 4.0)]
54+
datasets = [
55+
SupervisedDataset(
56+
X=X,
57+
Y=Y,
58+
Yvar=Yvar,
59+
feature_names=["x1", "x2"],
60+
outcome_names=["a", "b", "c"],
61+
)
62+
]
63+
search_space_digest = SearchSpaceDigest(feature_names=["x1", "x2"], bounds=bounds)
64+
with mock_botorch_optimize_context_manager(), catch_warnings():
65+
simplefilter(action="ignore", category=NumericalWarning)
66+
model.fit(datasets=datasets, search_space_digest=search_space_digest)
5167

5268

5369
class FrontierEvaluatorTest(TestCase):
@@ -66,45 +82,24 @@ def setUp(self) -> None:
6682
]
6783
)
6884
self.Yvar = torch.zeros(5, 3)
69-
self.outcome_constraints = (
70-
torch.tensor([[0.0, 0.0, 1.0]]),
71-
torch.tensor([[3.5]]),
72-
)
7385
self.objective_thresholds = torch.tensor([0.5, 1.5])
7486
self.objective_weights = torch.tensor([1.0, 1.0])
75-
bounds = [(0.0, 4.0), (0.0, 4.0)]
76-
self.model = MultiObjectiveBotorchModel(model_predictor=dummy_predict)
77-
with mock.patch(FIT_MODEL_MO_PATH) as _mock_fit_model:
78-
self.model.fit(
79-
datasets=[
80-
SupervisedDataset(
81-
X=self.X,
82-
Y=self.Y,
83-
Yvar=self.Yvar,
84-
feature_names=["x1", "x2"],
85-
outcome_names=["a", "b", "c"],
86-
)
87-
],
88-
search_space_digest=SearchSpaceDigest(
89-
feature_names=["x1", "x2"],
90-
bounds=bounds,
91-
),
92-
)
93-
_mock_fit_model.assert_called_once()
9487

9588
def test_pareto_frontier_raise_error_when_missing_data(self) -> None:
9689
with self.assertRaises(ValueError):
9790
pareto_frontier_evaluator(
98-
model=self.model,
91+
model=MultiObjectiveBotorchModel(),
9992
objective_thresholds=self.objective_thresholds,
10093
objective_weights=self.objective_weights,
10194
Yvar=self.Yvar,
10295
)
10396

10497
def test_pareto_frontier_evaluator_raw(self) -> None:
98+
model = BoTorchModel()
99+
_fit_model(model=model, X=self.X, Y=self.Y, Yvar=self.Yvar)
105100
Yvar = torch.diag_embed(self.Yvar)
106101
Y, cov, indx = pareto_frontier_evaluator(
107-
model=self.model,
102+
model=model,
108103
objective_weights=self.objective_weights,
109104
objective_thresholds=self.objective_thresholds,
110105
Y=self.Y,
@@ -118,7 +113,7 @@ def test_pareto_frontier_evaluator_raw(self) -> None:
118113

119114
# Omit objective_thresholds
120115
Y, cov, indx = pareto_frontier_evaluator(
121-
model=self.model,
116+
model=model,
122117
objective_weights=self.objective_weights,
123118
Y=self.Y,
124119
Yvar=Yvar,
@@ -131,7 +126,7 @@ def test_pareto_frontier_evaluator_raw(self) -> None:
131126

132127
# Change objective_weights so goal is to minimize b
133128
Y, cov, indx = pareto_frontier_evaluator(
134-
model=self.model,
129+
model=model,
135130
objective_weights=torch.tensor([1.0, -1.0]),
136131
objective_thresholds=self.objective_thresholds,
137132
Y=self.Y,
@@ -146,7 +141,7 @@ def test_pareto_frontier_evaluator_raw(self) -> None:
146141

147142
# test no points better than reference point
148143
Y, cov, indx = pareto_frontier_evaluator(
149-
model=self.model,
144+
model=model,
150145
objective_weights=self.objective_weights,
151146
objective_thresholds=torch.full_like(self.objective_thresholds, 100.0),
152147
Y=self.Y,
@@ -157,8 +152,25 @@ def test_pareto_frontier_evaluator_raw(self) -> None:
157152
self.assertTrue(torch.equal(torch.tensor([], dtype=torch.long), indx))
158153

159154
def test_pareto_frontier_evaluator_predict(self) -> None:
160-
Y, cov, indx = pareto_frontier_evaluator(
161-
model=self.model,
155+
def dummy_predict(
156+
model: MultiObjectiveBotorchModel,
157+
X: Tensor,
158+
use_posterior_predictive: bool = False,
159+
) -> tuple[Tensor, Tensor]:
160+
# Add column to X that is a product of previous elements.
161+
mean = torch.cat([X, torch.prod(X, dim=1).reshape(-1, 1)], dim=1)
162+
cov = torch.zeros(mean.shape[0], mean.shape[1], mean.shape[1])
163+
return mean, cov
164+
165+
# pyre-fixme: Incompatible parameter type [6]: In call
166+
# `MultiObjectiveBotorchModel.__init__`, for argument `model_predictor`,
167+
# expected `typing.Callable[[Model, Tensor, bool], Tuple[Tensor,
168+
# Tensor]]` but got named arguments
169+
model = MultiObjectiveBotorchModel(model_predictor=dummy_predict)
170+
_fit_model(model=model, X=self.X, Y=self.Y, Yvar=self.Yvar)
171+
172+
Y, _, indx = pareto_frontier_evaluator(
173+
model=model,
162174
objective_weights=self.objective_weights,
163175
objective_thresholds=self.objective_thresholds,
164176
X=self.X,
@@ -170,13 +182,17 @@ def test_pareto_frontier_evaluator_predict(self) -> None:
170182
self.assertTrue(torch.equal(torch.arange(2, 4), indx))
171183

172184
def test_pareto_frontier_evaluator_with_outcome_constraints(self) -> None:
173-
Y, cov, indx = pareto_frontier_evaluator(
174-
model=self.model,
185+
model = MultiObjectiveBotorchModel()
186+
Y, _, indx = pareto_frontier_evaluator(
187+
model=model,
175188
objective_weights=self.objective_weights,
176189
objective_thresholds=self.objective_thresholds,
177190
Y=self.Y,
178191
Yvar=self.Yvar,
179-
outcome_constraints=self.outcome_constraints,
192+
outcome_constraints=(
193+
torch.tensor([[0.0, 0.0, 1.0]]),
194+
torch.tensor([[3.5]]),
195+
),
180196
)
181197
pred = self.Y[2, :]
182198
self.assertTrue(

0 commit comments

Comments
 (0)