Skip to content

Commit 7d974aa

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix missing posterior_transform in gen_one_shot_kg_initial_conditions (#1187)
Summary: Pull Request resolved: #1187 Fixes #1186. Reviewed By: Balandat Differential Revision: D35757067 fbshipit-source-id: 31dc24fc1aed8b43c6b3385f367d690bd21bf0aa
1 parent 824a4a9 commit 7d974aa

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

botorch/optim/initializers.py

+1
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def gen_one_shot_kg_initial_conditions(
293293
value_function = _get_value_function(
294294
model=acq_function.model,
295295
objective=acq_function.objective,
296+
posterior_transform=acq_function.posterior_transform,
296297
sampler=acq_function.inner_sampler,
297298
project=getattr(acq_function, "project", None),
298299
)

test/acquisition/test_knowledge_gradient.py

+43
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from botorch.acquisition.utils import project_to_sample_points
2727
from botorch.exceptions.errors import UnsupportedError
2828
from botorch.models import SingleTaskGP
29+
from botorch.optim.optimize import optimize_acqf
2930
from botorch.posteriors.gpytorch import GPyTorchPosterior
3031
from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
3132
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
@@ -559,6 +560,48 @@ def test_fixed_evaluation_qMFKG(self):
559560
raw_samples=1,
560561
)
561562

563+
def test_optimize_w_posterior_transform(self):
564+
# This is mainly testing that we can optimize without errors.
565+
for dtype in (torch.float, torch.double):
566+
tkwargs = {"dtype": dtype, "device": self.device}
567+
mean = torch.tensor([1.0, 0.5], **tkwargs).expand(2, 1, 2)
568+
cov = torch.tensor([[1.0, 0.1], [0.1, 0.5]], **tkwargs).expand(2, 2, 2)
569+
posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov))
570+
model = MockModel(posterior)
571+
n_f = 4
572+
mean = torch.tensor([1.0, 0.5], **tkwargs).expand(n_f, 2, 1, 2)
573+
cov = torch.tensor([[1.0, 0.1], [0.1, 0.5]], **tkwargs).expand(n_f, 2, 2, 2)
574+
posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov))
575+
mfm = MockModel(posterior)
576+
bounds = torch.zeros(2, 2, **tkwargs)
577+
bounds[1] = 1
578+
options = {"num_inner_restarts": 2, "raw_inner_samples": 2}
579+
with mock.patch.object(MockModel, "fantasize", return_value=mfm):
580+
kg = qMultiFidelityKnowledgeGradient(
581+
model=model,
582+
num_fantasies=n_f,
583+
posterior_transform=ScalarizedPosteriorTransform(
584+
weights=torch.rand(2, **tkwargs)
585+
),
586+
)
587+
# Mocking this to get around grad issues.
588+
with mock.patch(
589+
f"{optimize_acqf.__module__}.gen_candidates_scipy",
590+
return_value=(
591+
torch.zeros(2, n_f + 1, 2, **tkwargs),
592+
torch.zeros(2, **tkwargs),
593+
),
594+
):
595+
candidate, value = optimize_acqf(
596+
acq_function=kg,
597+
bounds=bounds,
598+
q=1,
599+
num_restarts=2,
600+
raw_samples=2,
601+
options=options,
602+
)
603+
self.assertTrue(torch.equal(candidate, torch.zeros(1, 2, **tkwargs)))
604+
562605

563606
class TestKGUtils(BotorchTestCase):
564607
def test_get_value_function(self):

0 commit comments

Comments
 (0)