Skip to content

Commit 44b6400

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
Improve Documentation of Testing Base Classes in BoTorch (#2770)
Summary: Pull Request resolved: #2770 I added doc strings (and some abstract methods). I came across the following issues doing so: - There are multiple private functions, like `_get_random_data`. These are used outside, though, thus the editor shows them as unused. Should they maybe not be private instead? - The double naming of `BaseTestProblem` and `functions` is confusing for me. Is there a reason for this? Should we maybe change it to only use one of the two? - Only `TestCorruptedProblemsMixin.rosenbrock_problem` seems to be used from that Mixin. Should we maybe change this to not be a Mixin, but just the rosenbrock variable instad, or not have the `outlier_generator` inside the Mixin? - I think we should delete `rsample_from_base_samples`. It is not used (if my command+click works here properly) and hard to understand. Reviewed By: esantorella Differential Revision: D71198064 fbshipit-source-id: 9dd660045aaa3424f09063be0efb747b5cfc321d
1 parent f612ad9 commit 44b6400

13 files changed

+252
-84
lines changed

botorch/utils/testing.py

+192-20
Large diffs are not rendered by default.

test/acquisition/test_objective.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from botorch.models.transforms.input import Normalize
2929
from botorch.posteriors import GPyTorchPosterior
3030
from botorch.utils import apply_constraints
31-
from botorch.utils.testing import _get_test_posterior, BotorchTestCase
31+
from botorch.utils.testing import BotorchTestCase, get_test_posterior
3232
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
3333
from linear_operator.operators.dense_linear_operator import to_linear_operator
3434

@@ -67,7 +67,7 @@ def test_scalarized_posterior_transform(self):
6767
offset = torch.rand(1).item()
6868
weights = torch.randn(m, device=self.device, dtype=dtype)
6969
obj = ScalarizedPosteriorTransform(weights=weights, offset=offset)
70-
posterior = _get_test_posterior(
70+
posterior = get_test_posterior(
7171
batch_shape, m=m, device=self.device, dtype=dtype
7272
)
7373
mean, covar = (

test/models/test_gp_regression.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from botorch.sampling import SobolQMCNormalSampler
2020
from botorch.utils.datasets import SupervisedDataset
2121
from botorch.utils.test_helpers import get_pvar_expected
22-
from botorch.utils.testing import _get_random_data, BotorchTestCase
22+
from botorch.utils.testing import BotorchTestCase, get_random_data
2323
from gpytorch.kernels import RBFKernel
2424
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood
2525
from gpytorch.means import ConstantMean, ZeroMean
@@ -38,7 +38,7 @@ def _get_model_and_data(
3838
**tkwargs,
3939
):
4040
extra_model_kwargs = extra_model_kwargs or {}
41-
train_X, train_Y = _get_random_data(batch_shape=batch_shape, m=m, **tkwargs)
41+
train_X, train_Y = get_random_data(batch_shape=batch_shape, m=m, **tkwargs)
4242
model_kwargs = {
4343
"train_X": train_X,
4444
"train_Y": train_Y,
@@ -182,7 +182,7 @@ def test_default_transforms(self):
182182
("Default", "None", "Log"), # Outcome transform
183183
):
184184
tkwargs = {"device": self.device, "dtype": dtype}
185-
train_X, train_Y = _get_random_data(batch_shape=batch_shape, m=m, **tkwargs)
185+
train_X, train_Y = get_random_data(batch_shape=batch_shape, m=m, **tkwargs)
186186

187187
model_kwargs = {}
188188
if octf == "None":
@@ -241,7 +241,7 @@ def test_condition_on_observations(self):
241241
# test condition_on_observations
242242
fant_shape = torch.Size([2])
243243
# fantasize at different input points
244-
X_fant, Y_fant = _get_random_data(
244+
X_fant, Y_fant = get_random_data(
245245
batch_shape=fant_shape + batch_shape, m=m, n=3, **tkwargs
246246
)
247247
c_kwargs = (
@@ -459,7 +459,7 @@ def _get_model_and_data(
459459
**tkwargs,
460460
):
461461
extra_model_kwargs = extra_model_kwargs or {}
462-
train_X, train_Y = _get_random_data(batch_shape=batch_shape, m=m, **tkwargs)
462+
train_X, train_Y = get_random_data(batch_shape=batch_shape, m=m, **tkwargs)
463463
model_kwargs = {
464464
"train_X": train_X,
465465
"train_Y": train_Y,

test/models/test_gp_regression_fidelity.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from botorch.posteriors import GPyTorchPosterior
1717
from botorch.sampling import SobolQMCNormalSampler
1818
from botorch.utils.datasets import SupervisedDataset
19-
from botorch.utils.testing import _get_random_data, BotorchTestCase
19+
from botorch.utils.testing import BotorchTestCase, get_random_data
2020
from gpytorch.kernels.scale_kernel import ScaleKernel
2121
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
2222
from gpytorch.means import ConstantMean
@@ -30,7 +30,7 @@ def _get_random_data_with_fidelity(
3030
r"""Construct test data.
3131
For this test, by convention the trailing dimensions are the fidelity dimensions
3232
"""
33-
train_x, train_y = _get_random_data(
33+
train_x, train_y = get_random_data(
3434
batch_shape=batch_shape, m=m, d=d, n=n, **tkwargs
3535
)
3636
s = torch.rand(n, n_fidelity, **tkwargs).repeat(batch_shape + torch.Size([1, 1]))

test/models/test_gp_regression_mixed.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from botorch.sampling import SobolQMCNormalSampler
1919
from botorch.utils.datasets import SupervisedDataset
2020
from botorch.utils.test_helpers import get_pvar_expected
21-
from botorch.utils.testing import _get_random_data, BotorchTestCase
21+
from botorch.utils.testing import BotorchTestCase, get_random_data
2222
from gpytorch.kernels.kernel import AdditiveKernel, ProductKernel
2323
from gpytorch.kernels.rbf_kernel import RBFKernel
2424
from gpytorch.kernels.scale_kernel import ScaleKernel
@@ -44,7 +44,7 @@ def test_gp(self):
4444
# to test without that transform we need to explicitly pass in `None`.
4545
outcome_transform_kwargs = {} if use_octf else {"outcome_transform": None}
4646

47-
train_X, train_Y = _get_random_data(
47+
train_X, train_Y = get_random_data(
4848
batch_shape=batch_shape, m=m, d=d, **tkwargs
4949
)
5050
cat_dims = list(range(ncat))
@@ -151,7 +151,7 @@ def test_condition_on_observations__(self):
151151
(torch.Size([2]), 1, 2, torch.double, False),
152152
):
153153
tkwargs = {"device": self.device, "dtype": dtype}
154-
train_X, train_Y = _get_random_data(
154+
train_X, train_Y = get_random_data(
155155
batch_shape=batch_shape, m=m, d=d, **tkwargs
156156
)
157157
cat_dims = list(range(ncat))
@@ -169,7 +169,7 @@ def test_condition_on_observations__(self):
169169
fant_shape = torch.Size([2])
170170

171171
# fantasize at different input points
172-
X_fant, Y_fant = _get_random_data(
172+
X_fant, Y_fant = get_random_data(
173173
fant_shape + batch_shape, m=m, d=d, n=3, **tkwargs
174174
)
175175
additional_kwargs = (
@@ -255,7 +255,7 @@ def test_fantasize(self):
255255
(torch.Size([2]), 1, 2, torch.double, False),
256256
):
257257
tkwargs = {"device": self.device, "dtype": dtype}
258-
train_X, train_Y = _get_random_data(
258+
train_X, train_Y = get_random_data(
259259
batch_shape=batch_shape, m=m, d=d, **tkwargs
260260
)
261261
train_Yvar = torch.full_like(train_Y, 0.1) if observed_noise else None
@@ -281,7 +281,7 @@ def test_subset_model(self):
281281
(torch.float, torch.double),
282282
):
283283
tkwargs = {"device": self.device, "dtype": dtype}
284-
train_X, train_Y = _get_random_data(
284+
train_X, train_Y = get_random_data(
285285
batch_shape=batch_shape, m=m, d=d, **tkwargs
286286
)
287287
cat_dims = list(range(ncat))
@@ -307,7 +307,7 @@ def test_construct_inputs(self):
307307
(torch.Size(), torch.Size([2])), (1, 2), (torch.float, torch.double)
308308
):
309309
tkwargs = {"device": self.device, "dtype": dtype}
310-
X, Y = _get_random_data(batch_shape=batch_shape, m=1, d=d, **tkwargs)
310+
X, Y = get_random_data(batch_shape=batch_shape, m=1, d=d, **tkwargs)
311311
cat_dims = list(range(ncat))
312312
training_data = SupervisedDataset(
313313
X,

test/models/test_gpytorch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from botorch.posteriors.gpytorch import GPyTorchPosterior
3030
from botorch.sampling.normal import SobolQMCNormalSampler
3131
from botorch.utils.test_helpers import SimpleGPyTorchModel
32-
from botorch.utils.testing import _get_random_data, BotorchTestCase
32+
from botorch.utils.testing import BotorchTestCase, get_random_data
3333
from gpytorch import ExactMarginalLogLikelihood
3434
from gpytorch.distributions import MultivariateNormal
3535
from gpytorch.kernels import RBFKernel, ScaleKernel
@@ -492,7 +492,7 @@ def test_model_list_gpytorch_model(self):
492492
self.assertIsInstance(posterior, GPyTorchPosterior)
493493
self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
494494
# test multioutput
495-
train_x_raw, train_y = _get_random_data(
495+
train_x_raw, train_y = get_random_data(
496496
batch_shape=torch.Size(), m=1, n=10, **tkwargs
497497
)
498498
task_idx = torch.cat(

test/models/test_latent_kronecker_gp.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from botorch.fit import fit_gpytorch_mll
1515
from botorch.models.latent_kronecker_gp import LatentKroneckerGP, MinMaxStandardize
1616
from botorch.models.transforms import Normalize
17-
from botorch.utils.testing import _get_random_data, BotorchTestCase
17+
from botorch.utils.testing import BotorchTestCase, get_random_data
1818
from botorch.utils.types import DEFAULT
1919
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
2020
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood
@@ -27,7 +27,7 @@
2727
def _get_data_with_missing_entries(
2828
n_train: int, d: int, m: int, batch_shape: torch.Size, tkwargs: dict
2929
):
30-
train_X, train_Y = _get_random_data(
30+
train_X, train_Y = get_random_data(
3131
batch_shape=batch_shape, m=m, d=d, n=n_train, **tkwargs
3232
)
3333

@@ -586,7 +586,7 @@ def test_iterative_methods(self):
586586
(torch.float, torch.double), # dtype
587587
):
588588
tkwargs = {"device": self.device, "dtype": dtype}
589-
train_X, train_Y = _get_random_data(
589+
train_X, train_Y = get_random_data(
590590
batch_shape=batch_shape, m=m, d=d, n=n_train, **tkwargs
591591
)
592592

@@ -619,7 +619,7 @@ def test_iterative_methods(self):
619619
def test_not_implemented(self):
620620
batch_shape = torch.Size([])
621621
tkwargs = {"device": self.device, "dtype": torch.double}
622-
train_X, train_Y = _get_random_data(batch_shape=batch_shape, m=1, **tkwargs)
622+
train_X, train_Y = get_random_data(batch_shape=batch_shape, m=1, **tkwargs)
623623

624624
model = LatentKroneckerGP(
625625
train_X=train_X,

test/models/test_model_list_gp_regression.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from botorch.sampling.base import MCSampler
2424
from botorch.sampling.list_sampler import ListSampler
2525
from botorch.sampling.normal import IIDNormalSampler
26-
from botorch.utils.testing import _get_random_data, BotorchTestCase
26+
from botorch.utils.testing import BotorchTestCase, get_random_data
2727
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
2828
from gpytorch.kernels import RBFKernel
2929
from gpytorch.likelihoods import LikelihoodList
@@ -41,13 +41,9 @@
4141
def _get_model(
4242
fixed_noise=False, outcome_transform: str = "None", use_intf=False, **tkwargs
4343
) -> ModelListGP:
44-
train_x1, train_y1 = _get_random_data(
45-
batch_shape=torch.Size(), m=1, n=10, **tkwargs
46-
)
44+
train_x1, train_y1 = get_random_data(batch_shape=torch.Size(), m=1, n=10, **tkwargs)
4745
train_y1 = torch.exp(train_y1)
48-
train_x2, train_y2 = _get_random_data(
49-
batch_shape=torch.Size(), m=1, n=11, **tkwargs
50-
)
46+
train_x2, train_y2 = get_random_data(batch_shape=torch.Size(), m=1, n=11, **tkwargs)
5147
if outcome_transform == "Standardize":
5248
octfs = [Standardize(m=1), Standardize(m=1)]
5349
elif outcome_transform == "Log":
@@ -280,7 +276,7 @@ def test_ModelListGP_fixed_noise(self) -> None:
280276

281277
def test_ModelListGP_single(self):
282278
tkwargs = {"device": self.device, "dtype": torch.float}
283-
train_x1, train_y1 = _get_random_data(
279+
train_x1, train_y1 = get_random_data(
284280
batch_shape=torch.Size(), m=1, n=10, **tkwargs
285281
)
286282
model1 = SingleTaskGP(train_X=train_x1, train_Y=train_y1)
@@ -296,7 +292,7 @@ def test_ModelListGP_multi_task(self, use_outcome_transform: bool = False):
296292
outcome_transform_kwargs = (
297293
{} if use_outcome_transform else {"outcome_transform": None}
298294
)
299-
train_x_raw, train_y = _get_random_data(
295+
train_x_raw, train_y = get_random_data(
300296
batch_shape=torch.Size(), m=1, n=10, **tkwargs
301297
)
302298
task_idx = torch.cat(

test/optim/test_initializers.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@
4848
from botorch.sampling.normal import IIDNormalSampler
4949
from botorch.utils.sampling import manual_seed, unnormalize
5050
from botorch.utils.testing import (
51-
_get_max_violation_of_bounds,
52-
_get_max_violation_of_constraints,
5351
BotorchTestCase,
52+
get_max_violation_of_bounds,
53+
get_max_violation_of_constraints,
5454
MockAcquisitionFunction,
5555
MockModel,
5656
MockPosterior,
@@ -61,11 +61,11 @@ class TestBoundsAndConstraintCheckers(BotorchTestCase):
6161
def test_bounds_check(self) -> None:
6262
bounds = torch.tensor([[1, 2], [3, 4]], device=self.device)
6363
samples = torch.tensor([[2, 3], [2, 3.1]], device=self.device)[None, :, :]
64-
result = _get_max_violation_of_bounds(samples, bounds)
64+
result = get_max_violation_of_bounds(samples, bounds)
6565
self.assertAlmostEqual(result, -0.9, delta=1e-6)
6666

6767
samples = torch.tensor([[2, 3], [2, 4.1]], device=self.device)[None, :, :]
68-
result = _get_max_violation_of_bounds(samples, bounds)
68+
result = get_max_violation_of_bounds(samples, bounds)
6969
self.assertAlmostEqual(result, 0.1, delta=1e-6)
7070

7171
def test_constraint_check(self) -> None:
@@ -77,10 +77,10 @@ def test_constraint_check(self) -> None:
7777
)
7878
]
7979
samples = torch.tensor([[2, 3], [2, 3.1]], device=self.device)[None, :, :]
80-
result = _get_max_violation_of_constraints(samples, constraints, equality=True)
80+
result = get_max_violation_of_constraints(samples, constraints, equality=True)
8181
self.assertAlmostEqual(result, 0.1, delta=1e-6)
8282

83-
result = _get_max_violation_of_constraints(samples, constraints, equality=False)
83+
result = get_max_violation_of_constraints(samples, constraints, equality=False)
8484
self.assertAlmostEqual(result, 0.0, delta=1e-6)
8585

8686

@@ -268,7 +268,7 @@ def test_gen_batch_initial_conditions(self):
268268
self.assertEqual(batch_initial_conditions.device, bounds.device)
269269
self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
270270
self.assertLess(
271-
_get_max_violation_of_bounds(batch_initial_conditions, bounds),
271+
get_max_violation_of_bounds(batch_initial_conditions, bounds),
272272
1e-6,
273273
)
274274
batch_shape = (
@@ -347,7 +347,7 @@ def test_gen_batch_initial_conditions_topn(self):
347347
self.assertEqual(batch_initial_conditions.device, bounds.device)
348348
self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
349349
self.assertLess(
350-
_get_max_violation_of_bounds(batch_initial_conditions, bounds),
350+
get_max_violation_of_bounds(batch_initial_conditions, bounds),
351351
1e-6,
352352
)
353353
batch_shape = (
@@ -409,7 +409,7 @@ def test_gen_batch_initial_conditions_highdim(self):
409409
self.assertEqual(batch_initial_conditions.device, bounds.device)
410410
self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
411411
self.assertLess(
412-
_get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6
412+
get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6
413413
)
414414
if ffs is not None:
415415
for idx, val in ffs.items():
@@ -637,18 +637,18 @@ def _to_self_device(
637637
return None if x is None else x.to(device=self.device)
638638

639639
self.assertLess(
640-
_get_max_violation_of_bounds(_to_self_device(samples), bounds), tol
640+
get_max_violation_of_bounds(_to_self_device(samples), bounds), tol
641641
)
642642

643643
self.assertLess(
644-
_get_max_violation_of_constraints(
644+
get_max_violation_of_constraints(
645645
_to_self_device(samples), constraints=equalities, equality=True
646646
),
647647
tol,
648648
)
649649

650650
self.assertLess(
651-
_get_max_violation_of_constraints(
651+
get_max_violation_of_constraints(
652652
_to_self_device(samples),
653653
constraints=inequalities,
654654
equality=False,
@@ -708,19 +708,19 @@ def test_gen_batch_initial_conditions_constraints(self):
708708
self.assertEqual(batch_initial_conditions.device, bounds.device)
709709
self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
710710
self.assertLess(
711-
_get_max_violation_of_bounds(batch_initial_conditions, bounds),
711+
get_max_violation_of_bounds(batch_initial_conditions, bounds),
712712
1e-6,
713713
)
714714
self.assertLess(
715-
_get_max_violation_of_constraints(
715+
get_max_violation_of_constraints(
716716
batch_initial_conditions,
717717
inequality_constraints,
718718
equality=False,
719719
),
720720
1e-6,
721721
)
722722
self.assertLess(
723-
_get_max_violation_of_constraints(
723+
get_max_violation_of_constraints(
724724
batch_initial_conditions,
725725
equality_constraints,
726726
equality=True,
@@ -821,7 +821,7 @@ def test_gen_batch_initial_conditions_interpoint_constraints(self):
821821
batch_initial_conditions[1, 2, 0],
822822
)
823823
self.assertLess(
824-
_get_max_violation_of_constraints(
824+
get_max_violation_of_constraints(
825825
batch_initial_conditions,
826826
inequality_constraints,
827827
equality=False,
@@ -886,7 +886,7 @@ def generator(n: int, q: int, seed: int | None):
886886
self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
887887
self.assertTrue((batch_initial_conditions[..., -1] == 0.42).all())
888888
self.assertLess(
889-
_get_max_violation_of_bounds(batch_initial_conditions, bounds),
889+
get_max_violation_of_bounds(batch_initial_conditions, bounds),
890890
1e-6,
891891
)
892892
if ffs is not None:
@@ -981,7 +981,7 @@ def test_gen_batch_initial_conditions_fixed_X_fantasies(self):
981981
self.assertEqual(batch_initial_conditions.device, bounds.device)
982982
self.assertEqual(batch_initial_conditions.dtype, bounds.dtype)
983983
self.assertLess(
984-
_get_max_violation_of_bounds(batch_initial_conditions, bounds),
984+
get_max_violation_of_bounds(batch_initial_conditions, bounds),
985985
1e-6,
986986
)
987987
batch_shape = (

0 commit comments

Comments
 (0)