|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | +from botorch.acquisition.analytic import ExpectedImprovement |
| 9 | +from botorch.acquisition.monte_carlo import qExpectedImprovement |
| 10 | +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper |
| 11 | +from botorch.exceptions.errors import UnsupportedError |
| 12 | +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior |
| 13 | + |
| 14 | + |
| 15 | +class DummyWrapper(AbstractAcquisitionFunctionWrapper): |
| 16 | + def forward(self, X): |
| 17 | + return self.acq_func(X) |
| 18 | + |
| 19 | + |
| 20 | +class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase): |
| 21 | + def test_abstract_acquisition_function_wrapper(self): |
| 22 | + for dtype in (torch.float, torch.double): |
| 23 | + mm = MockModel( |
| 24 | + MockPosterior( |
| 25 | + mean=torch.rand(1, 1, dtype=dtype, device=self.device), |
| 26 | + variance=torch.ones(1, 1, dtype=dtype, device=self.device), |
| 27 | + ) |
| 28 | + ) |
| 29 | + acq_func = ExpectedImprovement(model=mm, best_f=-1.0) |
| 30 | + wrapped_af = DummyWrapper(acq_function=acq_func) |
| 31 | + self.assertIs(wrapped_af.acq_func, acq_func) |
| 32 | + # test forward |
| 33 | + X = torch.rand(1, 1, dtype=dtype, device=self.device) |
| 34 | + with torch.no_grad(): |
| 35 | + wrapped_val = wrapped_af(X) |
| 36 | + af_val = acq_func(X) |
| 37 | + self.assertEqual(wrapped_val.item(), af_val.item()) |
| 38 | + |
| 39 | + # test X_pending |
| 40 | + with self.assertRaises(ValueError): |
| 41 | + self.assertIsNone(wrapped_af.X_pending) |
| 42 | + with self.assertRaises(UnsupportedError): |
| 43 | + wrapped_af.set_X_pending(X) |
| 44 | + acq_func = qExpectedImprovement(model=mm, best_f=-1.0) |
| 45 | + wrapped_af = DummyWrapper(acq_function=acq_func) |
| 46 | + self.assertIsNone(wrapped_af.X_pending) |
| 47 | + wrapped_af.set_X_pending(X) |
| 48 | + self.assertTrue(torch.equal(X, wrapped_af.X_pending)) |
| 49 | + self.assertTrue(torch.equal(X, acq_func.X_pending)) |
| 50 | + wrapped_af.set_X_pending(None) |
| 51 | + self.assertIsNone(wrapped_af.X_pending) |
| 52 | + self.assertIsNone(acq_func.X_pending) |
0 commit comments