Skip to content

Commit 436138e

Browse files
npielawskifacebook-github-bot
authored andcommitted
Implements Multi-Fidelity GIBBON (Lower Bound MES) acquisition. (#1185)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation Since `qLowerBoundMaxValueEntropy` provides a cheap approximation to `qMaxValueEntropy`, this PR implements the multi-fidelity version of `qMultiFidelityMaxValueEntropy` to be able to use the approximation in a multi-fidelity setting as well. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #1185 Test Plan: Copied the unit test for `qMultiFidelityMaxValueEntropy` without changes (besides the class name). The `qMultiFidelityLowerBoundMaxValueEntropy` class is the same as `qMultiFidelityMaxValueEntropy` with a different compute_information_gain function, so doesn't require a different set of tests. Reviewed By: dme65 Differential Revision: D35746956 Pulled By: Balandat fbshipit-source-id: f937399cc2c83d28cd7ab3b6f44533ae6013a061
1 parent 7d974aa commit 436138e

File tree

3 files changed

+65
-3
lines changed

3 files changed

+65
-3
lines changed

botorch/acquisition/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from botorch.acquisition.max_value_entropy_search import (
3535
MaxValueBase,
3636
qLowerBoundMaxValueEntropy,
37+
qMultiFidelityLowerBoundMaxValueEntropy,
3738
qMaxValueEntropy,
3839
qMultiFidelityMaxValueEntropy,
3940
)
@@ -81,6 +82,7 @@
8182
"MaxValueBase",
8283
"qMultiFidelityKnowledgeGradient",
8384
"qMaxValueEntropy",
85+
"qMultiFidelityLowerBoundMaxValueEntropy",
8486
"qLowerBoundMaxValueEntropy",
8587
"qMultiFidelityMaxValueEntropy",
8688
"qMultiStepLookahead",

botorch/acquisition/max_value_entropy_search.py

+50
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,56 @@ def forward(self, X: Tensor) -> Tensor:
795795
return ig.mean(dim=0) # average over the fantasies
796796

797797

798+
class qMultiFidelityLowerBoundMaxValueEntropy(qMultiFidelityMaxValueEntropy):
799+
r"""Multi-fidelity acquisition function for General-purpose Information-Based
800+
Bayesian optimization (GIBBON).
801+
802+
The acquisition function for multi-fidelity max-value entropy search
803+
with support for trace observations. See [Takeno2020mfmves]_
804+
for a detailed discussion of the basic ideas on multi-fidelity MES
805+
(note that this implementation is somewhat different). This acquisition function
806+
is similar to `qMultiFidelityMaxValueEntropy` but computes the information gain
807+
from the lower bound described in [Moss2021gibbon].
808+
809+
The model must be single-outcome, unless using a PosteriorTransform.
810+
The batch case `q > 1` is supported through cyclic optimization and fantasies.
811+
812+
Example:
813+
>>> model = SingleTaskGP(train_X, train_Y)
814+
>>> candidate_set = torch.rand(1000, bounds.size(1))
815+
>>> candidate_set = bounds[0] + (bounds[1] - bounds[0]) * candidate_set
816+
>>> MF_qGIBBON = qMultiFidelityLowerBoundMaxValueEntropy(model, candidate_set)
817+
>>> mf_gibbon = MF_qGIBBON(test_X)
818+
"""
819+
820+
def _compute_information_gain(
821+
self, X: Tensor, mean_M: Tensor, variance_M: Tensor, covar_mM: Tensor
822+
) -> Tensor:
823+
r"""Compute GIBBON's approximation of information gain at the design points `X`.
824+
825+
When using GIBBON for batch optimization (i.e `q > 1`), we calculate the
826+
additional information provided by adding a new candidate point to the current
827+
batch of design points (`X_pending`), rather than calculating the information
828+
provided by the whole batch. This allows a modest computational saving.
829+
830+
Args:
831+
X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
832+
with `1` `d`-dim design point each.
833+
mean_M: A `batch_shape x 1`-dim Tensor of means.
834+
variance_M: A `batch_shape x 1`-dim Tensor of variances
835+
consisting of `batch_shape` t-batches with `num_fantasies` fantasies.
836+
covar_mM: A `batch_shape x num_fantasies x (1 + num_trace_observations)`
837+
-dim Tensor of covariances.
838+
839+
Returns:
840+
A `num_fantasies x batch_shape`-dim Tensor of information gains at the
841+
given design points `X`.
842+
"""
843+
return qLowerBoundMaxValueEntropy._compute_information_gain(
844+
self, X=X, mean_M=mean_M, variance_M=variance_M, covar_mM=covar_mM
845+
)
846+
847+
798848
def _sample_max_value_Thompson(
799849
model: Model,
800850
candidate_set: Tensor,

test/acquisition/test_max_value_entropy_search.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
qLowerBoundMaxValueEntropy,
1616
qMaxValueEntropy,
1717
qMultiFidelityMaxValueEntropy,
18+
qMultiFidelityLowerBoundMaxValueEntropy,
1819
)
1920
from botorch.acquisition.objective import (
2021
PosteriorTransform,
@@ -241,14 +242,16 @@ def test_q_lower_bound_max_value_entropy(self):
241242
with self.assertRaisesRegex(UnsupportedError, "X_pending is not None"):
242243
qGIBBON(X)
243244

244-
def test_q_multi_fidelity_max_value_entropy(self):
245+
def test_q_multi_fidelity_max_value_entropy(
246+
self, acqf_class=qMultiFidelityMaxValueEntropy
247+
):
245248
for dtype in (torch.float, torch.double):
246249
torch.manual_seed(7)
247250
mm = MESMockModel()
248251
train_inputs = torch.rand(10, 2, device=self.device, dtype=dtype)
249252
mm.train_inputs = (train_inputs,)
250253
candidate_set = torch.rand(10, 2, device=self.device, dtype=dtype)
251-
qMF_MVE = qMultiFidelityMaxValueEntropy(
254+
qMF_MVE = acqf_class(
252255
model=mm, candidate_set=candidate_set, num_mv_samples=10
253256
)
254257

@@ -277,7 +280,7 @@ def test_q_multi_fidelity_max_value_entropy(self):
277280
pt = ScalarizedPosteriorTransform(
278281
weights=torch.ones(2, device=self.device, dtype=dtype)
279282
)
280-
qMF_MVE = qMultiFidelityMaxValueEntropy(
283+
qMF_MVE = acqf_class(
281284
model=mm,
282285
candidate_set=candidate_set,
283286
num_mv_samples=10,
@@ -286,6 +289,13 @@ def test_q_multi_fidelity_max_value_entropy(self):
286289
X = torch.rand(1, 2, device=self.device, dtype=dtype)
287290
self.assertEqual(qMF_MVE(X).shape, torch.Size([1]))
288291

292+
def test_q_multi_fidelity_lower_bound_max_value_entropy(self):
293+
# Same test as for MF-MES since GIBBON only changes in the way it computes the
294+
# information gain.
295+
self.test_q_multi_fidelity_max_value_entropy(
296+
acqf_class=qMultiFidelityLowerBoundMaxValueEntropy
297+
)
298+
289299
def test_sample_max_value_Gumbel(self):
290300
for dtype in (torch.float, torch.double):
291301
torch.manual_seed(7)

0 commit comments

Comments
 (0)