Skip to content

Commit 0d66aa0

Browse files
esantorellafacebook-github-bot
authored andcommitted
Unit tests for integration of acqfs with their input constructors, with LearnedObjective and constraints (#2112)
Summary: Pull Request resolved: #2112 Additional test using acqf with LearnedObjective Reviewed By: ItsMrLin Differential Revision: D51269389 fbshipit-source-id: 6a1603dec06ab0dc51c2389ef894ea9110066e89
1 parent 6bb9e31 commit 0d66aa0

7 files changed

+244
-15
lines changed

botorch/acquisition/input_constructors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1363,7 +1363,7 @@ def get_best_f_mc(
13631363
objective=objective,
13641364
posterior_transform=posterior_transform,
13651365
X_baseline=X_baseline,
1366-
)
1366+
).squeeze()
13671367

13681368

13691369
def optimize_objective(

botorch/acquisition/objective.py

+3
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,9 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
605605
)
606606
samples = samples.to(torch.float64)
607607

608+
if samples.ndim < 3:
609+
raise ValueError("samples should have at least 3 dimensions.")
610+
608611
posterior = self.pref_model.posterior(samples)
609612
if isinstance(self.pref_model, DeterministicModel):
610613
# return preference posterior mean

botorch/acquisition/utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def repeat_to_match_aug_dim(target_tensor: Tensor, reference_tensor: Tensor) ->
5656
matches that of `reference_tensor`.
5757
The shape will be `(augmented_sample * sample_size) x batch_shape x q x m`.
5858
59-
Example:
59+
Examples:
6060
>>> import torch
6161
>>> target_tensor = torch.arange(3).repeat(2, 1).T
6262
>>> target_tensor
@@ -71,7 +71,6 @@ def repeat_to_match_aug_dim(target_tensor: Tensor, reference_tensor: Tensor) ->
7171
[1, 1],
7272
[2, 2]])
7373
"""
74-
7574
augmented_sample_num, remainder = divmod(
7675
reference_tensor.shape[0], target_tensor.shape[0]
7776
)

test/acquisition/test_input_constructors.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,21 @@ def test_get_best_f_mc(self) -> None:
182182
best_f = get_best_f_mc(training_data=self.blockX_blockY)
183183
self.assertEqual(best_f, get_best_f_mc(self.blockX_blockY[0]))
184184

185-
best_f_expected = self.blockX_blockY[0].Y.max(dim=0).values
185+
best_f_expected = self.blockX_blockY[0].Y.max()
186186
self.assertAllClose(best_f, best_f_expected)
187187
with self.assertRaisesRegex(UnsupportedError, "require an objective"):
188188
get_best_f_mc(training_data=self.blockX_multiY)
189189
obj = LinearMCObjective(weights=torch.rand(2))
190190
best_f = get_best_f_mc(training_data=self.blockX_multiY, objective=obj)
191191

192192
multi_Y = torch.cat([d.Y for d in self.blockX_multiY.values()], dim=-1)
193-
best_f_expected = (multi_Y @ obj.weights).amax(dim=-1, keepdim=True)
193+
best_f_expected = (multi_Y @ obj.weights).max()
194194
self.assertAllClose(best_f, best_f_expected)
195195
post_tf = ScalarizedPosteriorTransform(weights=torch.ones(2))
196196
best_f = get_best_f_mc(
197197
training_data=self.blockX_multiY, posterior_transform=post_tf
198198
)
199-
best_f_expected = (multi_Y.sum(dim=-1)).amax(dim=-1, keepdim=True)
199+
best_f_expected = multi_Y.sum(dim=-1).max()
200200
self.assertAllClose(best_f, best_f_expected)
201201

202202
@mock.patch("botorch.acquisition.input_constructors.optimize_acqf")

test/acquisition/test_integration.py

+224
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from itertools import product
8+
from typing import Dict
9+
from warnings import catch_warnings, simplefilter
10+
11+
import torch
12+
from botorch.acquisition.input_constructors import get_acqf_input_constructor
13+
from botorch.acquisition.logei import (
14+
qLogExpectedImprovement,
15+
qLogNoisyExpectedImprovement,
16+
)
17+
from botorch.acquisition.monte_carlo import (
18+
qExpectedImprovement,
19+
qNoisyExpectedImprovement,
20+
qProbabilityOfImprovement,
21+
)
22+
from botorch.acquisition.objective import LearnedObjective
23+
from botorch.exceptions.warnings import InputDataWarning
24+
from botorch.fit import fit_gpytorch_mll
25+
from botorch.models import SingleTaskGP
26+
from botorch.sampling.normal import SobolQMCNormalSampler
27+
from botorch.utils.datasets import SupervisedDataset
28+
from botorch.utils.testing import BotorchTestCase
29+
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
30+
31+
32+
class TestObjectiveAndConstraintIntegration(BotorchTestCase):
33+
def setUp(self) -> None:
34+
self.q = 3
35+
self.d = 2
36+
self.tkwargs = {"device": self.device, "dtype": torch.double}
37+
38+
def _get_acqf_inputs(self, train_batch_shape: torch.Size, m: int) -> Dict:
39+
40+
train_x = torch.rand((*train_batch_shape, 5, self.d), **self.tkwargs)
41+
y = torch.rand((*train_batch_shape, 5, m), **self.tkwargs)
42+
43+
training_data = SupervisedDataset(
44+
X=train_x,
45+
Y=y,
46+
feature_names=[f"x{i}" for i in range(self.d)],
47+
outcome_names=[f"y{i}" for i in range(m)],
48+
)
49+
utility = y.sum(-1).unsqueeze(-1)
50+
51+
with catch_warnings():
52+
simplefilter("ignore", category=InputDataWarning)
53+
model = SingleTaskGP(train_x, y)
54+
mll = ExactMarginalLogLikelihood(model.likelihood, model)
55+
fit_gpytorch_mll(mll=mll)
56+
57+
with catch_warnings():
58+
simplefilter("ignore", category=InputDataWarning)
59+
pref_model = SingleTaskGP(y, utility)
60+
pref_mll = ExactMarginalLogLikelihood(pref_model.likelihood, pref_model)
61+
fit_gpytorch_mll(mll=pref_mll)
62+
return {
63+
"training_data": training_data,
64+
"model": model,
65+
"pref_model": pref_model,
66+
"train_x": train_x,
67+
}
68+
69+
def _base_test_with_learned_objective(
70+
self,
71+
train_batch_shape: torch.Size,
72+
prune_baseline: bool,
73+
test_batch_shape: torch.Size,
74+
) -> None:
75+
acq_inputs = self._get_acqf_inputs(train_batch_shape=train_batch_shape, m=4)
76+
77+
pref_sample_shapes = [1, 8]
78+
test_acqf_classes_and_kws = [
79+
# Not yet working
80+
# (qExpectedImprovement, {}),
81+
# (qProbabilityOfImprovement, {}),
82+
# (qLogExpectedImprovement, {}),
83+
(qNoisyExpectedImprovement, {"prune_baseline": prune_baseline}),
84+
(qLogNoisyExpectedImprovement, {"prune_baseline": prune_baseline}),
85+
]
86+
87+
for (acqf_cls, kws), pref_sample_shape in product(
88+
test_acqf_classes_and_kws, pref_sample_shapes
89+
):
90+
with self.subTest(
91+
train_batch_shape=train_batch_shape,
92+
test_batch_shape=test_batch_shape,
93+
prune_baseline=prune_baseline,
94+
acqf_cls=acqf_cls,
95+
pref_sample_shape=pref_sample_shape,
96+
):
97+
objective = LearnedObjective(
98+
pref_model=acq_inputs["pref_model"],
99+
sample_shape=torch.Size([pref_sample_shape]),
100+
)
101+
test_x = torch.rand(
102+
(*test_batch_shape, *train_batch_shape, self.q, self.d),
103+
**self.tkwargs,
104+
)
105+
input_constructor = get_acqf_input_constructor(acqf_cls=acqf_cls)
106+
107+
inputs = input_constructor(
108+
objective=objective,
109+
model=acq_inputs["model"],
110+
training_data=acq_inputs["training_data"],
111+
X_baseline=acq_inputs["train_x"],
112+
sampler=SobolQMCNormalSampler(torch.Size([4])),
113+
**kws,
114+
)
115+
acqf = acqf_cls(**inputs)
116+
acq_val = acqf(test_x)
117+
self.assertEqual(acq_val.shape.numel(), test_x.shape[:-2].numel())
118+
119+
def test_with_learned_objective_train_data_not_batched(self) -> None:
120+
train_batch_shape = []
121+
test_batch_shapes = [[], [1], [2]]
122+
for test_batch_shape in test_batch_shapes:
123+
self._base_test_with_learned_objective(
124+
train_batch_shape=torch.Size(train_batch_shape),
125+
prune_baseline=True,
126+
test_batch_shape=torch.Size(test_batch_shape),
127+
)
128+
129+
def test_with_learned_objective_train_data_1d_batch(self) -> None:
130+
train_batch_shape = [1]
131+
test_batch_shapes = [[], [1], [2]]
132+
for test_batch_shape in test_batch_shapes:
133+
self._base_test_with_learned_objective(
134+
train_batch_shape=torch.Size(train_batch_shape),
135+
# Batched inputs `X_baseline` are currently unsupported by
136+
# prune_inferior_points
137+
prune_baseline=False,
138+
test_batch_shape=torch.Size(test_batch_shape),
139+
)
140+
141+
def test_with_learned_objective_train_data_batched(self) -> None:
142+
train_batch_shape = [3]
143+
test_batch_shapes = [[], [1], [2]]
144+
for test_batch_shape in test_batch_shapes:
145+
self._base_test_with_learned_objective(
146+
train_batch_shape=torch.Size(train_batch_shape),
147+
# Batched inputs `X_baseline` are currently unsupported by
148+
# prune_inferior_points
149+
prune_baseline=False,
150+
test_batch_shape=torch.Size(test_batch_shape),
151+
)
152+
153+
def _base_test_without_learned_objective(
154+
self,
155+
train_batch_shape: torch.Size,
156+
prune_baseline: bool,
157+
test_batch_shape: torch.Size,
158+
) -> None:
159+
inputs = self._get_acqf_inputs(train_batch_shape=train_batch_shape, m=1)
160+
constraints = [lambda y: y[..., 0]]
161+
test_x = torch.rand(
162+
(*test_batch_shape, *train_batch_shape, self.q, self.d), **self.tkwargs
163+
)
164+
165+
input_constructor_kwargs = {
166+
"model": inputs["model"],
167+
"training_data": inputs["training_data"],
168+
"X_baseline": inputs["train_x"],
169+
"sampler": SobolQMCNormalSampler(torch.Size([4])),
170+
}
171+
172+
for acqf_cls, kws in [
173+
(qNoisyExpectedImprovement, {"prune_baseline": prune_baseline}),
174+
(qLogNoisyExpectedImprovement, {"prune_baseline": prune_baseline}),
175+
(qExpectedImprovement, {}),
176+
(qProbabilityOfImprovement, {}),
177+
(qLogExpectedImprovement, {}),
178+
]:
179+
# Not working.
180+
if train_batch_shape.numel() > 1 and acqf_cls == qLogExpectedImprovement:
181+
continue
182+
input_constructor = get_acqf_input_constructor(acqf_cls=acqf_cls)
183+
184+
with self.subTest(
185+
"no objective or constraints",
186+
train_batch_shape=train_batch_shape,
187+
prune_baseline=prune_baseline,
188+
test_batch_shape=test_batch_shape,
189+
acqf_cls=acqf_cls,
190+
):
191+
acqf = acqf_cls(**input_constructor(**input_constructor_kwargs, **kws))
192+
acq_val = acqf(test_x)
193+
self.assertEqual(acq_val.shape.numel(), test_x.shape[:-2].numel())
194+
195+
with self.subTest(
196+
"constrained",
197+
train_batch_shape=train_batch_shape,
198+
prune_baseline=prune_baseline,
199+
test_batch_shape=test_batch_shape,
200+
acqf_cls=acqf_cls,
201+
):
202+
acqf = acqf_cls(
203+
**input_constructor(
204+
constraints=constraints, **input_constructor_kwargs, **kws
205+
)
206+
)
207+
self.assertEqual(acq_val.shape.numel(), test_x.shape[:-2].numel())
208+
acq_val = acqf(test_x)
209+
210+
def test_without_learned_objective(self) -> None:
211+
train_batch_shapes = [[], [1], [2]]
212+
test_batch_shapes = [[], [1], [3]]
213+
for train_batch_shape, test_batch_shape in product(
214+
train_batch_shapes, test_batch_shapes
215+
):
216+
# Batched inputs `X_baseline` are currently unsupported by
217+
# prune_inferior_points
218+
prune_baseline_ = [False] if len(train_batch_shape) > 0 else [False, True]
219+
for prune_baseline in prune_baseline_:
220+
self._base_test_without_learned_objective(
221+
train_batch_shape=torch.Size(train_batch_shape),
222+
prune_baseline=prune_baseline,
223+
test_batch_shape=torch.Size(test_batch_shape),
224+
)

test/acquisition/test_monte_carlo.py

-2
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,6 @@ def test_q_expected_improvement_batch(self):
221221
acqf(X)
222222
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
223223

224-
# TODO: Test different objectives (incl. constraints)
225-
226224

227225
class TestQNoisyExpectedImprovement(BotorchTestCase):
228226
def test_q_noisy_expected_improvement(self):

test/acquisition/test_objective.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -459,13 +459,13 @@ def test_learned_preference_objective(self) -> None:
459459
og_sample_shape = 3
460460
large_sample_shape = 256
461461
batch_size = 2
462-
n = 8
462+
q = 8
463463
test_X = torch.rand(
464-
torch.Size((og_sample_shape, batch_size, n, self.x_dim)),
464+
torch.Size((og_sample_shape, batch_size, q, self.x_dim)),
465465
dtype=torch.float64,
466466
)
467467
large_X = torch.rand(
468-
torch.Size((large_sample_shape, batch_size, n, self.x_dim)),
468+
torch.Size((large_sample_shape, batch_size, q, self.x_dim)),
469469
dtype=torch.float64,
470470
)
471471

@@ -476,19 +476,24 @@ def test_learned_preference_objective(self) -> None:
476476
first_call_output = pref_obj(test_X)
477477
self.assertEqual(
478478
first_call_output.shape,
479-
torch.Size([og_sample_shape * DEFAULT_NUM_PREF_SAMPLES, batch_size, n]),
479+
torch.Size([og_sample_shape * DEFAULT_NUM_PREF_SAMPLES, batch_size, q]),
480480
)
481481
# Making sure the sampler has correct base_samples shape
482482
self.assertEqual(
483483
pref_obj.sampler.base_samples.shape,
484-
torch.Size([DEFAULT_NUM_PREF_SAMPLES, og_sample_shape, 1, n]),
484+
torch.Size([DEFAULT_NUM_PREF_SAMPLES, og_sample_shape, 1, q]),
485485
)
486486
# Passing through a same-shaped X again shouldn't change the base sample
487487
previous_base_samples = pref_obj.sampler.base_samples
488488
another_test_X = torch.rand_like(test_X)
489489
pref_obj(another_test_X)
490490
self.assertIs(pref_obj.sampler.base_samples, previous_base_samples)
491491

492+
with self.assertRaisesRegex(
493+
ValueError, "samples should have at least 3 dimensions."
494+
):
495+
pref_obj(torch.rand(q, self.x_dim))
496+
492497
# test when sampler has multiple preference samples
493498
with self.subTest("Multiple samples"):
494499
num_samples = 256
@@ -498,7 +503,7 @@ def test_learned_preference_objective(self) -> None:
498503
)
499504
self.assertEqual(
500505
pref_obj(test_X).shape,
501-
torch.Size([num_samples * og_sample_shape, batch_size, n]),
506+
torch.Size([num_samples * og_sample_shape, batch_size, q]),
502507
)
503508

504509
avg_obj_val = pref_obj(large_X).mean(dim=0)
@@ -513,7 +518,7 @@ def test_learned_preference_objective(self) -> None:
513518
pref_obj = LearnedObjective(pref_model=mean_pref_model)
514519
self.assertEqual(
515520
pref_obj(test_X).shape,
516-
torch.Size([og_sample_shape, batch_size, n]),
521+
torch.Size([og_sample_shape, batch_size, q]),
517522
)
518523

519524
# the order of samples shouldn't matter

0 commit comments

Comments
 (0)