Skip to content

Commit 93d6fd6

Browse files
Balandatfacebook-github-bot
authored andcommitted
Undo dependency on botorch master
Summary: Ax released a version that depended on changes on botorch master: #159 This changes things back so it should work with botorch 0.1.3. Reviewed By: lena-kashtelyan Differential Revision: D17093727 fbshipit-source-id: 7d08d4205743c8135fb4adde73befa45a0204816
1 parent 564c6fc commit 93d6fd6

File tree

3 files changed

+20
-19
lines changed

3 files changed

+20
-19
lines changed

ax/models/tests/test_botorch_model.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,11 @@ def test_BotorchModel(self, dtype=torch.float, cuda=False):
148148
n = 3
149149

150150
X_dummy = torch.tensor([[[1.0, 2.0, 3.0]]], dtype=dtype, device=device)
151-
acq_dummy = torch.tensor(0.0, dtype=dtype, device=device)
152151
model_gen_options = {}
153152
# test sequential optimize
154153
with mock.patch(
155-
"ax.models.torch.botorch_defaults.optimize_acqf",
156-
return_value=(X_dummy, acq_dummy),
154+
"ax.models.torch.botorch_defaults.sequential_optimize", return_value=X_dummy
157155
) as mock_optimize_acqf:
158-
159156
Xgen, wgen = model.gen(
160157
n=n,
161158
bounds=bounds,
@@ -173,8 +170,7 @@ def test_BotorchModel(self, dtype=torch.float, cuda=False):
173170

174171
# test joint optimize
175172
with mock.patch(
176-
"ax.models.torch.botorch_defaults.optimize_acqf",
177-
return_value=(X_dummy, acq_dummy),
173+
"ax.models.torch.botorch_defaults.joint_optimize", return_value=X_dummy
178174
) as mock_optimize_acqf:
179175
Xgen, wgen = model.gen(
180176
n=n,

ax/models/torch/botorch.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
Optional[Callable[[Tensor], Tensor]],
5656
Any,
5757
],
58-
Tuple[Tensor, Tensor],
58+
Tensor,
5959
]
6060

6161

@@ -146,16 +146,15 @@ class BotorchModel(TorchModel):
146146
fixed_features,
147147
rounding_func,
148148
**kwargs,
149-
) -> (candidates, acq_values)
149+
) -> candidates
150150
151151
Here `acq_function` is a BoTorch `AcquisitionFunction`, `bounds` is a
152152
tensor containing bounds on the parameters, `n` is the number of
153153
candidates to be generated, `inequality_constraints` are inequality
154154
constraints on parameter values, `fixed_features` specifies features that
155155
should be fixed during generation, and `rounding_func` is a callback
156156
that rounds an optimization result appropriately. `candidates` is
157-
a tensor of generated candidates, and `acq_values` are the acquisition
158-
values associated with the candidates. For additional details on the
157+
a tensor of generated candidates. For additional details on the
159158
arguments, see `scipy_optimizer`.
160159
"""
161160

@@ -316,7 +315,7 @@ def gen(
316315

317316
botorch_rounding_func = get_rounding_func(rounding_func)
318317

319-
candidates, _ = self.acqf_optimizer( # pyre-ignore: [28]
318+
candidates = self.acqf_optimizer( # pyre-ignore: [28]
320319
acq_function=checked_cast(AcquisitionFunction, acquisition_function),
321320
bounds=bounds_,
322321
n=n,

ax/models/torch/botorch_defaults.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from botorch.models.model import Model
1919
from botorch.models.model_list_gp_regression import ModelListGP
2020
from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP
21-
from botorch.optim.optimize import optimize_acqf
21+
from botorch.optim.optimize import joint_optimize, sequential_optimize
2222
from botorch.utils import (
2323
get_objective_weights_transform,
2424
get_outcome_constraint_transforms,
@@ -204,7 +204,7 @@ def scipy_optimizer(
204204
fixed_features: Optional[Dict[int, float]] = None,
205205
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
206206
**kwargs: Any,
207-
) -> Tuple[Tensor, Tensor]:
207+
) -> Tensor:
208208
r"""Optimizer using scipy's minimize module on a numpy-adpator.
209209
210210
Args:
@@ -233,12 +233,15 @@ def scipy_optimizer(
233233
num_restarts: int = kwargs.get("num_restarts", 20)
234234
raw_samples: int = kwargs.get("num_raw_samples", 50 * num_restarts)
235235

236-
sequential = not kwargs.get("joint_optimization", False)
237-
# use SLSQP by default for small problems since it yields faster wall times
238-
if sequential and "method" not in kwargs:
239-
kwargs["method"] = "SLSQP"
236+
if kwargs.get("joint_optimization", False):
237+
optimize = joint_optimize
238+
else:
239+
optimize = sequential_optimize
240+
# use SLSQP by default for small problems since it yields faster wall times
241+
if "method" not in kwargs:
242+
kwargs["method"] = "SLSQP"
240243

241-
return optimize_acqf(
244+
X = optimize(
242245
acq_function=acq_function,
243246
bounds=bounds,
244247
q=n,
@@ -248,8 +251,11 @@ def scipy_optimizer(
248251
inequality_constraints=inequality_constraints,
249252
fixed_features=fixed_features,
250253
post_processing_func=rounding_func,
251-
sequential=not kwargs.get("joint_optimization", False),
252254
)
255+
# TODO: Un-hack this once botorch #234 is part of a stable release
256+
if isinstance(X, tuple):
257+
X, _ = X
258+
return X
253259

254260

255261
def _get_model(

0 commit comments

Comments
 (0)