|
26 | 26 | from botorch.acquisition.utils import project_to_sample_points
|
27 | 27 | from botorch.exceptions.errors import UnsupportedError
|
28 | 28 | from botorch.models import SingleTaskGP
|
| 29 | +from botorch.optim.optimize import optimize_acqf |
29 | 30 | from botorch.posteriors.gpytorch import GPyTorchPosterior
|
30 | 31 | from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
|
31 | 32 | from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
|
@@ -559,6 +560,48 @@ def test_fixed_evaluation_qMFKG(self):
|
559 | 560 | raw_samples=1,
|
560 | 561 | )
|
561 | 562 |
|
| 563 | + def test_optimize_w_posterior_transform(self): |
| 564 | + # This is mainly testing that we can optimize without errors. |
| 565 | + for dtype in (torch.float, torch.double): |
| 566 | + tkwargs = {"dtype": dtype, "device": self.device} |
| 567 | + mean = torch.tensor([1.0, 0.5], **tkwargs).expand(2, 1, 2) |
| 568 | + cov = torch.tensor([[1.0, 0.1], [0.1, 0.5]], **tkwargs).expand(2, 2, 2) |
| 569 | + posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov)) |
| 570 | + model = MockModel(posterior) |
| 571 | + n_f = 4 |
| 572 | + mean = torch.tensor([1.0, 0.5], **tkwargs).expand(n_f, 2, 1, 2) |
| 573 | + cov = torch.tensor([[1.0, 0.1], [0.1, 0.5]], **tkwargs).expand(n_f, 2, 2, 2) |
| 574 | + posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov)) |
| 575 | + mfm = MockModel(posterior) |
| 576 | + bounds = torch.zeros(2, 2, **tkwargs) |
| 577 | + bounds[1] = 1 |
| 578 | + options = {"num_inner_restarts": 2, "raw_inner_samples": 2} |
| 579 | + with mock.patch.object(MockModel, "fantasize", return_value=mfm): |
| 580 | + kg = qMultiFidelityKnowledgeGradient( |
| 581 | + model=model, |
| 582 | + num_fantasies=n_f, |
| 583 | + posterior_transform=ScalarizedPosteriorTransform( |
| 584 | + weights=torch.rand(2, **tkwargs) |
| 585 | + ), |
| 586 | + ) |
| 587 | + # Mocking this to get around grad issues. |
| 588 | + with mock.patch( |
| 589 | + f"{optimize_acqf.__module__}.gen_candidates_scipy", |
| 590 | + return_value=( |
| 591 | + torch.zeros(2, n_f + 1, 2, **tkwargs), |
| 592 | + torch.zeros(2, **tkwargs), |
| 593 | + ), |
| 594 | + ): |
| 595 | + candidate, value = optimize_acqf( |
| 596 | + acq_function=kg, |
| 597 | + bounds=bounds, |
| 598 | + q=1, |
| 599 | + num_restarts=2, |
| 600 | + raw_samples=2, |
| 601 | + options=options, |
| 602 | + ) |
| 603 | + self.assertTrue(torch.equal(candidate, torch.zeros(1, 2, **tkwargs))) |
| 604 | + |
562 | 605 |
|
563 | 606 | class TestKGUtils(BotorchTestCase):
|
564 | 607 | def test_get_value_function(self):
|
|
0 commit comments