Skip to content

Commit 25badc4

Browse files
sdaultonfacebook-github-bot
authored andcommitted
acquisition function wrapper (#1532)
Summary: Pull Request resolved: #1532 Add a wrapper for modifying inputs/outputs. This is useful for not only probabilistic reparameterization, but will also simplify other integrated AFs (e.g. MCMC) as well as fixed feature AFs and things like prior-guided AFs Differential Revision: D41629186 fbshipit-source-id: 7d77fee09746b10e6533372621d9c49df8d8ab5a
1 parent 4d4b47d commit 25badc4

File tree

8 files changed

+144
-47
lines changed

8 files changed

+144
-47
lines changed

botorch/acquisition/fixed_feature.py

+7-19
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
import torch
1818
from botorch.acquisition.acquisition import AcquisitionFunction
19+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
1920
from torch import Tensor
20-
from torch.nn import Module
2121

2222

23-
class FixedFeatureAcquisitionFunction(AcquisitionFunction):
23+
class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
2424
"""A wrapper around AquisitionFunctions to fix a subset of features.
2525
2626
Example:
@@ -56,8 +56,7 @@ def __init__(
5656
combination of `Tensor`s and numbers which can be broadcasted
5757
to form a tensor with trailing dimension size of `d_f`.
5858
"""
59-
Module.__init__(self)
60-
self.acq_func = acq_function
59+
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
6160
dtype = torch.float
6261
device = torch.device("cpu")
6362
self.d = d
@@ -126,24 +125,13 @@ def forward(self, X: Tensor):
126125
X_full = self._construct_X_full(X)
127126
return self.acq_func(X_full)
128127

129-
@property
130-
def X_pending(self):
131-
r"""Return the `X_pending` of the base acquisition function."""
132-
try:
133-
return self.acq_func.X_pending
134-
except (ValueError, AttributeError):
135-
raise ValueError(
136-
f"Base acquisition function {type(self.acq_func).__name__} "
137-
"does not have an `X_pending` attribute."
138-
)
139-
140-
@X_pending.setter
141-
def X_pending(self, X_pending: Optional[Tensor]):
128+
def set_X_pending(self, X_pending: Optional[Tensor]):
142129
r"""Sets the `X_pending` of the base acquisition function."""
143130
if X_pending is not None:
144-
self.acq_func.X_pending = self._construct_X_full(X_pending)
131+
full_X_pending = self._construct_X_full(X_pending)
145132
else:
146-
self.acq_func.X_pending = X_pending
133+
full_X_pending = None
134+
self.acq_func.set_X_pending(full_X_pending)
147135

148136
def _construct_X_full(self, X: Tensor) -> Tensor:
149137
r"""Constructs the full input for the base acquisition function.

botorch/acquisition/penalized.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515

1616
import torch
1717
from botorch.acquisition.acquisition import AcquisitionFunction
18-
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
1918
from botorch.acquisition.objective import GenericMCObjective
20-
from botorch.exceptions import UnsupportedError
19+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
2120
from torch import Tensor
2221

2322

@@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
139138
return regularization_term
140139

141140

142-
class PenalizedAcquisitionFunction(AcquisitionFunction):
141+
class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
143142
r"""Single-outcome acquisition function regularized by the given penalty.
144143
145144
The usage is similar to:
@@ -161,29 +160,16 @@ def __init__(
161160
penalty_func: The regularization function.
162161
regularization_parameter: Regularization parameter used in optimization.
163162
"""
164-
super().__init__(model=raw_acqf.model)
165-
self.raw_acqf = raw_acqf
163+
AcquisitionFunction.__init__(self, model=raw_acqf.model)
164+
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf)
166165
self.penalty_func = penalty_func
167166
self.regularization_parameter = regularization_parameter
168167

169168
def forward(self, X: Tensor) -> Tensor:
170-
raw_value = self.raw_acqf(X=X)
169+
raw_value = self.acq_func(X=X)
171170
penalty_term = self.penalty_func(X)
172171
return raw_value - self.regularization_parameter * penalty_term
173172

174-
@property
175-
def X_pending(self) -> Optional[Tensor]:
176-
return self.raw_acqf.X_pending
177-
178-
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
179-
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
180-
self.raw_acqf.set_X_pending(X_pending=X_pending)
181-
else:
182-
raise UnsupportedError(
183-
"The raw acquisition function is Analytic and does not account "
184-
"for X_pending yet."
185-
)
186-
187173

188174
def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
189175
r"""Computes the group lasso regularization function for the given point.

botorch/acquisition/proximal.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import torch
1717
from botorch.acquisition import AcquisitionFunction
18+
19+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
1820
from botorch.exceptions.errors import UnsupportedError
1921
from botorch.models import ModelListGP
2022
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
@@ -25,7 +27,7 @@
2527
from torch.nn import Module
2628

2729

28-
class ProximalAcquisitionFunction(AcquisitionFunction):
30+
class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
2931
"""A wrapper around AcquisitionFunctions to add proximal weighting of the
3032
acquisition function. The acquisition function is
3133
weighted via a squared exponential centered at the last training point,
@@ -70,17 +72,14 @@ def __init__(
7072
beta: If not None, apply a softplus transform to the base acquisition
7173
function, allows negative base acquisition function values.
7274
"""
73-
Module.__init__(self)
74-
75-
self.acq_func = acq_function
75+
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
7676
model = self.acq_func.model
7777

7878
if hasattr(acq_function, "X_pending"):
7979
if acq_function.X_pending is not None:
8080
raise UnsupportedError(
8181
"Proximal acquisition function requires `X_pending` to be None."
8282
)
83-
self.X_pending = acq_function.X_pending
8483

8584
self.register_buffer("proximal_weights", proximal_weights)
8685
self.register_buffer(
@@ -91,6 +90,12 @@ def __init__(
9190

9291
_validate_model(model, proximal_weights)
9392

93+
def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
94+
r"""Sets the `X_pending` of the base acquisition function."""
95+
raise UnsupportedError(
96+
"Proximal acquisition function does not support `X_pending`."
97+
)
98+
9499
@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
95100
def forward(self, X: Tensor) -> Tensor:
96101
r"""Evaluate base acquisition function with proximal weighting.

botorch/acquisition/wrapper.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
A wrapper classes around AcquisitionFunctions to modify inputs and outputs.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from abc import ABC, abstractmethod
14+
from typing import Optional
15+
16+
from botorch.acquisition.acquisition import AcquisitionFunction
17+
from torch import Tensor
18+
from torch.nn import Module
19+
20+
21+
class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC):
22+
r"""Abstract acquisition wrapper."""
23+
24+
def __init__(self, acq_function: AcquisitionFunction) -> None:
25+
Module.__init__(self)
26+
self.acq_func = acq_function
27+
28+
@property
29+
def X_pending(self) -> Optional[Tensor]:
30+
r"""Return the `X_pending` of the base acquisition function."""
31+
try:
32+
return self.acq_func.X_pending
33+
except (ValueError, AttributeError):
34+
raise ValueError(
35+
f"Base acquisition function {type(self.acq_func).__name__} "
36+
"does not have an `X_pending` attribute."
37+
)
38+
39+
def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
40+
r"""Sets the `X_pending` of the base acquisition function."""
41+
self.acq_func.set_X_pending(X_pending)
42+
43+
@abstractmethod
44+
def forward(self, X: Tensor) -> Tensor:
45+
r"""Evaluate the wrapped acquisition function on the candidate set X.
46+
47+
Args:
48+
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
49+
design points each.
50+
51+
Returns:
52+
A `(b)`-dim Tensor of acquisition function values at the given
53+
design points `X`.
54+
"""
55+
pass # pragma: no cover

sphinx/source/acquisition.rst

+7-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ Analytic Acquisition Function API
2121
.. autoclass:: AnalyticAcquisitionFunction
2222
:members:
2323

24+
Acquisition Function Wrapper API
25+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
26+
.. automodule:: botorch.acquisition.wrapper
27+
:members:
28+
2429
Cached Cholesky Acquisition Function API
2530
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2631
.. automodule:: botorch.acquisition.cached_cholesky
@@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions
6570
.. automodule:: botorch.acquisition.multi_objective.analytic
6671
:members:
6772
:exclude-members: MultiObjectiveAnalyticAcquisitionFunction
68-
73+
6974
Multi-Objective Joint Entropy Search Acquisition Functions
7075
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7176
.. automodule:: botorch.acquisition.multi_objective.joint_entropy_search
@@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions
8691
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8792
.. automodule:: botorch.acquisition.multi_objective.multi_fidelity
8893
:members:
89-
94+
9095
Multi-Objective Predictive Entropy Search Acquisition Functions
9196
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9297
.. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search

test/acquisition/test_fixed_feature.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_fixed_features(self):
8787
qEI_ff.set_X_pending(X_pending[..., :-1])
8888
self.assertAllClose(qEI.X_pending, X_pending)
8989
# test setting to None
90-
qEI_ff.X_pending = None
90+
qEI_ff.set_X_pending(None)
9191
self.assertIsNone(qEI_ff.X_pending)
9292

9393
# test gradient

test/acquisition/test_proximal.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,15 @@ def test_proximal(self):
209209

210210
# test for x_pending points
211211
pending_acq = DummyAcquisitionFunction(model)
212-
pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype))
212+
X_pending = torch.rand(3, 3, device=self.device, dtype=dtype)
213+
pending_acq.set_X_pending(X_pending)
213214
with self.assertRaises(UnsupportedError):
214215
ProximalAcquisitionFunction(pending_acq, proximal_weights)
216+
# test setting pending points
217+
pending_acq.set_X_pending(None)
218+
af = ProximalAcquisitionFunction(pending_acq, proximal_weights)
219+
with self.assertRaises(UnsupportedError):
220+
af.set_X_pending(X_pending)
215221

216222
# test model with multi-batch training inputs
217223
train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype)

test/acquisition/test_wrapper.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from botorch.acquisition.analytic import ExpectedImprovement
9+
from botorch.acquisition.monte_carlo import qExpectedImprovement
10+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
11+
from botorch.exceptions.errors import UnsupportedError
12+
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
13+
14+
15+
class DummyWrapper(AbstractAcquisitionFunctionWrapper):
16+
def forward(self, X):
17+
return self.acq_func(X)
18+
19+
20+
class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase):
21+
def test_abstract_acquisition_function_wrapper(self):
22+
for dtype in (torch.float, torch.double):
23+
mm = MockModel(
24+
MockPosterior(
25+
mean=torch.rand(1, 1, dtype=dtype, device=self.device),
26+
variance=torch.ones(1, 1, dtype=dtype, device=self.device),
27+
)
28+
)
29+
acq_func = ExpectedImprovement(model=mm, best_f=-1.0)
30+
wrapped_af = DummyWrapper(acq_function=acq_func)
31+
self.assertIs(wrapped_af.acq_func, acq_func)
32+
# test forward
33+
X = torch.rand(1, 1, dtype=dtype, device=self.device)
34+
with torch.no_grad():
35+
wrapped_val = wrapped_af(X)
36+
af_val = acq_func(X)
37+
self.assertEqual(wrapped_val.item(), af_val.item())
38+
39+
# test X_pending
40+
with self.assertRaises(ValueError):
41+
self.assertIsNone(wrapped_af.X_pending)
42+
with self.assertRaises(UnsupportedError):
43+
wrapped_af.set_X_pending(X)
44+
acq_func = qExpectedImprovement(model=mm, best_f=-1.0)
45+
wrapped_af = DummyWrapper(acq_function=acq_func)
46+
self.assertIsNone(wrapped_af.X_pending)
47+
wrapped_af.set_X_pending(X)
48+
self.assertTrue(torch.equal(X, wrapped_af.X_pending))
49+
self.assertTrue(torch.equal(X, acq_func.X_pending))
50+
wrapped_af.set_X_pending(None)
51+
self.assertIsNone(wrapped_af.X_pending)
52+
self.assertIsNone(acq_func.X_pending)

0 commit comments

Comments
 (0)