Skip to content

Commit deeca64

Browse files
Balandatfacebook-github-bot
authored andcommitted
Short-cut t_batch_mode_transform decorator on non-tensor inputs (#991)
Summary: This essentially makes `t_batch_mode_tarnsform` a nullop in case the argument `X` passed to the acquisition function is not a `torch.Tensor` object. This allows using acquisition functions that use models with non non-standard input types such as strings, which is the case in some applications. Currently this just touches the decorator; in the future we should consider changing the types and signatures of the acquisition functions and models throughout to natively support this more generally. cc wjmaddox Pull Request resolved: #991 Reviewed By: dme65 Differential Revision: D32903859 Pulled By: Balandat fbshipit-source-id: c3abc8b40db307358807fe60014e2c0e1fe49c58
1 parent c6bc8f9 commit deeca64

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

botorch/utils/transforms.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212

1313
import warnings
1414
from functools import wraps
15-
from typing import Any, Callable, List, Optional
15+
from typing import Any, Callable, List, Optional, TypeVar
1616

1717
import torch
1818
from torch import Tensor
1919

20+
ACQF = TypeVar("AcquisitionFunction")
21+
2022

2123
def squeeze_last_dim(Y: Tensor) -> Tensor:
2224
r"""Squeeze the last dimension of a Tensor.
@@ -172,19 +174,19 @@ def _verify_output_shape(acqf: Any, X: Tensor, output: Tensor) -> bool:
172174
def t_batch_mode_transform(
173175
expected_q: Optional[int] = None,
174176
assert_output_shape: bool = True,
175-
) -> Callable[[Callable[[Any, Tensor], Any]], Callable[[Any, Tensor], Any]]:
176-
r"""Factory for decorators taking a t-batched `X` tensor.
177+
) -> Callable[[Callable[[ACQF, Any], Any]], Callable[[ACQF, Any], Any]]:
178+
r"""Factory for decorators enabling consistent t-batch behavior.
177179
178180
This method creates decorators for instance methods to transform an input tensor
179181
`X` to t-batch mode (i.e. with at least 3 dimensions). This assumes the tensor
180182
has a q-batch dimension. The decorator also checks the q-batch size if `expected_q`
181183
is provided, and the output shape if `assert_output_shape` is `True`.
182184
183185
Args:
184-
expected_q: The expected q-batch size of X. If specified, this will raise an
185-
AssertionError if X's q-batch size does not equal expected_q.
186+
expected_q: The expected q-batch size of `X`. If specified, this will raise an
187+
AssertionError if `X`'s q-batch size does not equal expected_q.
186188
assert_output_shape: If `True`, this will raise an AssertionError if the
187-
output shape does not match either the t-batch shape of X,
189+
output shape does not match either the t-batch shape of `X`,
188190
or the `acqf.model.batch_shape` for acquisition functions using
189191
batched models.
190192
@@ -202,9 +204,16 @@ def t_batch_mode_transform(
202204
>>> ...
203205
"""
204206

205-
def decorator(method: Callable[[Any, Tensor], Any]) -> Callable[[Any, Tensor], Any]:
207+
def decorator(
208+
method: Callable[[ACQF, Any], Any],
209+
) -> Callable[[ACQF, Any], Any]:
206210
@wraps(method)
207-
def decorated(acqf: Any, X: Tensor, *args: Any, **kwargs: Any) -> Any:
211+
def decorated(acqf: ACQF, X: Any, *args: Any, **kwargs: Any) -> Any:
212+
213+
# Allow using acquisition functions for other inputs (e.g. lists of strings)
214+
if not isinstance(X, Tensor):
215+
return method(acqf, X, *args, **kwargs)
216+
208217
if X.dim() < 2:
209218
raise ValueError(
210219
f"{type(acqf).__name__} requires X to have at least 2 dimensions,"

test/utils/test_transforms.py

+5
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,11 @@ def test_t_batch_mode_transform(self):
191191
Xout = c.broadcast_batch_shape_method(X)
192192
self.assertEqual(Xout.shape, c.model.batch_shape)
193193

194+
# test with non-tensor argument
195+
X = ((3, 4), {"foo": True})
196+
Xout = c.q_method(X)
197+
self.assertEqual(X, Xout)
198+
194199

195200
class TestConcatenatePendingPoints(BotorchTestCase):
196201
def test_concatenate_pending_points(self):

0 commit comments

Comments
 (0)