Skip to content

Commit a05d643

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Core infrastructure and analytic tests
Summary: Adds test harness infrastructure for acquisition function testing, including factory functions (make_mock_model, make_X, make_trained_gp), a spec dataclass (AcquisitionSpec), and test mixins. Analytic acquisition functions (EI, LogEI, PI, LogPI, UCB, PosteriorMean, PosteriorStandardDeviation) now use a shared AnalyticAcquisitionTestMixin for dtype, batch-shape, and output-shape tests. Differential Revision: D93691051
1 parent 082c8a2 commit a05d643

5 files changed

Lines changed: 364 additions & 341 deletions

File tree

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Test harness for acquisition function testing."""
8+
9+
from .factories import make_trained_gp, make_X
10+
from .mixins import (
11+
AcquisitionSpec,
12+
AcquisitionTestMixin,
13+
AnalyticAcquisitionTestMixin,
14+
loop_filtered_specs,
15+
)
16+
17+
18+
__all__ = [
19+
"AcquisitionSpec",
20+
"AcquisitionTestMixin",
21+
"AnalyticAcquisitionTestMixin",
22+
"loop_filtered_specs",
23+
"make_trained_gp",
24+
"make_X",
25+
]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Factory functions for creating test fixtures."""
8+
9+
from __future__ import annotations
10+
11+
import torch
12+
from botorch.models import SingleTaskGP
13+
from torch import Tensor
14+
15+
16+
def make_X(
17+
batch_shape: list[int] | None = None,
18+
q: int = 1,
19+
d: int = 2,
20+
dtype: torch.dtype = torch.double,
21+
device: torch.device | None = None,
22+
) -> Tensor:
23+
"""Create a random input tensor for testing.
24+
25+
Args:
26+
batch_shape: The batch shape for the tensor. Defaults to [].
27+
q: The number of candidates. Defaults to 1.
28+
d: The dimension of the input space. Defaults to 2.
29+
dtype: The dtype of the tensor. Defaults to torch.double.
30+
device: The device for the tensor. Defaults to None (CPU).
31+
32+
Returns:
33+
A tensor of shape (*batch_shape, q, d) with random values in [0, 1).
34+
"""
35+
if batch_shape is None:
36+
batch_shape = []
37+
return torch.rand(*batch_shape, q, d, dtype=dtype, device=device)
38+
39+
40+
def make_trained_gp(
41+
n_train: int = 5,
42+
d: int = 2,
43+
m: int = 1,
44+
dtype: torch.dtype = torch.double,
45+
device: torch.device | None = None,
46+
with_known_noise: bool = False,
47+
) -> SingleTaskGP:
48+
"""Create a SingleTaskGP with random training data for testing.
49+
50+
Args:
51+
n_train: The number of training points. Defaults to 5.
52+
d: The dimension of the input space. Defaults to 2.
53+
m: The number of outputs. Defaults to 1.
54+
dtype: The dtype of the tensors. Defaults to torch.double.
55+
device: The device for the tensors. Defaults to None (CPU).
56+
with_known_noise: If True, include train_Yvar. Defaults to False.
57+
58+
Returns:
59+
A SingleTaskGP fitted with random training data where train_X has shape
60+
(n_train, d) and train_Y has shape (n_train, m).
61+
"""
62+
train_X = torch.rand(n_train, d, dtype=dtype, device=device)
63+
train_Y = torch.rand(n_train, m, dtype=dtype, device=device)
64+
if with_known_noise:
65+
train_Yvar = torch.full_like(train_Y, 0.25)
66+
return SingleTaskGP(train_X, train_Y, train_Yvar=train_Yvar)
67+
return SingleTaskGP(train_X, train_Y)

test/acquisition/harness/mixins.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Base specs and test mixins for acquisition function testing."""
8+
9+
from __future__ import annotations
10+
11+
from dataclasses import dataclass, field
12+
from functools import wraps
13+
from typing import Any, Callable
14+
15+
import torch
16+
from botorch.acquisition.acquisition import AcquisitionFunction
17+
18+
from .factories import make_trained_gp, make_X
19+
20+
21+
def loop_filtered_specs(test_method: Callable) -> Callable:
22+
"""Decorator that runs a test method for each acquisition spec.
23+
24+
Automatically skips specs that have the test name in their bypass_tests list.
25+
The decorated method receives `spec` as its first argument after `self`.
26+
27+
Usage:
28+
@loop_filtered_specs
29+
def test_something(self, spec: AcquisitionSpec) -> None:
30+
# Test code here - no need for manual iteration or bypass checks
31+
"""
32+
33+
@wraps(test_method)
34+
def wrapper(self) -> None:
35+
test_name = test_method.__name__
36+
for spec in self.acquisition_specs:
37+
if test_name in spec.bypass_tests:
38+
continue
39+
with self.subTest(cls=spec.cls.__name__):
40+
test_method(self, spec)
41+
42+
return wrapper
43+
44+
45+
@dataclass
46+
class AcquisitionSpec:
47+
"""Base spec for analytic and simple acquisition functions.
48+
49+
Attributes:
50+
cls: The acquisition function class to test
51+
required_kwargs: Dict of required constructor arguments
52+
requires_X_observed: If True, pass X_observed (model training inputs)
53+
to acquisition function constructor.
54+
requires_fixed_noise: If True, the acquisition function requires a model
55+
with fixed/known observation noise (FixedNoiseGaussianLikelihood).
56+
convert_tensor_kwargs: If True, convert tensor kwargs to the test's
57+
dtype and device. Defaults to True.
58+
bypass_tests: List of test names to skip for this acquisition function.
59+
Defaults to empty list (run all tests).
60+
"""
61+
62+
cls: type[AcquisitionFunction]
63+
required_kwargs: dict[str, Any] = field(default_factory=dict)
64+
requires_X_observed: bool = False
65+
requires_fixed_noise: bool = False
66+
convert_tensor_kwargs: bool = True
67+
bypass_tests: list[str] = field(default_factory=list)
68+
69+
def get_kwargs(self, dtype: torch.dtype, device: torch.device) -> dict[str, Any]:
70+
"""Get required_kwargs with tensors converted to the specified dtype/device.
71+
72+
Args:
73+
dtype: The target dtype for tensor conversion.
74+
device: The target device for tensor conversion.
75+
76+
Returns:
77+
A copy of required_kwargs with all Tensor values converted to the
78+
specified dtype and device if convert_tensor_kwargs is True.
79+
"""
80+
if not self.convert_tensor_kwargs:
81+
return dict(self.required_kwargs)
82+
kwargs = {}
83+
for key, value in self.required_kwargs.items():
84+
if isinstance(value, torch.Tensor):
85+
kwargs[key] = value.to(dtype=dtype, device=device)
86+
else:
87+
kwargs[key] = value
88+
return kwargs
89+
90+
91+
class AcquisitionTestMixin:
92+
"""Mixin providing standard tests for acquisition functions.
93+
94+
Subclasses should override `acquisition_specs` to return a list of
95+
AcquisitionSpec instances defining which acquisition functions to test.
96+
"""
97+
98+
@property
99+
def acquisition_specs(self) -> list[AcquisitionSpec]:
100+
"""Return the list of AcquisitionSpec instances to test."""
101+
return []
102+
103+
def _make_model(
104+
self,
105+
spec: AcquisitionSpec,
106+
dtype: torch.dtype,
107+
m: int = 1,
108+
):
109+
"""Create a model for testing.
110+
111+
Args:
112+
spec: The acquisition spec defining the test configuration.
113+
dtype: The dtype for the model tensors.
114+
m: The number of outputs. Defaults to 1.
115+
116+
Returns:
117+
A SingleTaskGP with random training data.
118+
"""
119+
return make_trained_gp(
120+
n_train=5,
121+
d=2,
122+
m=m,
123+
dtype=dtype,
124+
device=self.device,
125+
with_known_noise=spec.requires_fixed_noise,
126+
)
127+
128+
def _make_acquisition(
129+
self,
130+
spec: AcquisitionSpec,
131+
model,
132+
dtype: torch.dtype,
133+
):
134+
"""Create an acquisition function for testing.
135+
136+
Args:
137+
spec: The acquisition spec defining the test configuration.
138+
model: The model to use for the acquisition function.
139+
dtype: The dtype for tensors.
140+
141+
Returns:
142+
An instance of the acquisition function specified by the spec.
143+
"""
144+
kwargs = spec.get_kwargs(dtype=dtype, device=self.device)
145+
if spec.requires_X_observed:
146+
kwargs["X_observed"] = model.train_inputs[0]
147+
return spec.cls(model=model, **kwargs)
148+
149+
@loop_filtered_specs
150+
def test_dtype(self, spec: AcquisitionSpec) -> None:
151+
"""Test acquisition function with different dtypes."""
152+
for dtype in (torch.float, torch.double):
153+
with self.subTest(dtype=dtype):
154+
model = self._make_model(spec=spec, dtype=dtype)
155+
acqf = self._make_acquisition(spec=spec, model=model, dtype=dtype)
156+
X = make_X(batch_shape=[4], q=1, dtype=dtype, device=self.device)
157+
value = acqf(X)
158+
self.assertEqual(value.dtype, dtype)
159+
self.assertEqual(value.device.type, self.device.type)
160+
161+
@loop_filtered_specs
162+
def test_output_shapes(self, spec: AcquisitionSpec) -> None:
163+
"""Test acquisition function with different batch shapes."""
164+
model = self._make_model(spec=spec, dtype=torch.double)
165+
acqf = self._make_acquisition(spec=spec, model=model, dtype=torch.double)
166+
for batch_shape in [[5], [5, 3]]:
167+
with self.subTest(batch_shape=batch_shape):
168+
X = make_X(batch_shape=batch_shape, q=1, device=self.device)
169+
value = acqf(X)
170+
expected_shape = torch.Size(batch_shape)
171+
self.assertEqual(value.shape, expected_shape)
172+
173+
@loop_filtered_specs
174+
def test_fixed_noise(self, spec: AcquisitionSpec) -> None:
175+
"""Test acquisition function requiring X_observed with fixed noise model."""
176+
model = self._make_model(spec=spec, dtype=torch.double)
177+
acqf = self._make_acquisition(spec=spec, model=model, dtype=torch.double)
178+
X = make_X(batch_shape=[4], q=1, device=self.device)
179+
value = acqf(X)
180+
self.assertEqual(value.shape, torch.Size([4]))
181+
182+
183+
class AnalyticAcquisitionTestMixin(AcquisitionTestMixin):
184+
"""Mixin for analytic acquisition functions.
185+
186+
Inherits dtype/device and batch shape tests from AcquisitionTestMixin.
187+
"""
188+
189+
pass

0 commit comments

Comments
 (0)