|
8 | 8 | from unittest import mock
|
9 | 9 |
|
10 | 10 | import torch
|
11 |
| -from botorch.acquisition import monte_carlo |
| 11 | +from botorch.acquisition import analytic, monte_carlo, multi_objective |
| 12 | +from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction |
12 | 13 | from botorch.acquisition.multi_objective import (
|
13 | 14 | MCMultiOutputObjective,
|
14 | 15 | monte_carlo as moo_monte_carlo,
|
|
18 | 19 | MCAcquisitionObjective,
|
19 | 20 | ScalarizedPosteriorTransform,
|
20 | 21 | )
|
| 22 | +from botorch.acquisition.proximal import ProximalAcquisitionFunction |
21 | 23 | from botorch.acquisition.utils import (
|
22 | 24 | expand_trace_observations,
|
23 | 25 | get_acquisition_function,
|
24 | 26 | get_infeasible_cost,
|
| 27 | + is_nonnegative, |
| 28 | + isinstance_af, |
25 | 29 | project_to_sample_points,
|
26 | 30 | project_to_target_fidelity,
|
27 | 31 | prune_inferior_points,
|
@@ -606,6 +610,61 @@ def test_get_infeasible_cost(self):
|
606 | 610 | self.assertAllClose(M4, torch.tensor([1.0], **tkwargs))
|
607 | 611 |
|
608 | 612 |
|
| 613 | +class TestIsNonnegative(BotorchTestCase): |
| 614 | + def test_is_nonnegative(self): |
| 615 | + nonneg_afs = ( |
| 616 | + analytic.ExpectedImprovement, |
| 617 | + analytic.ConstrainedExpectedImprovement, |
| 618 | + analytic.ProbabilityOfImprovement, |
| 619 | + analytic.NoisyExpectedImprovement, |
| 620 | + monte_carlo.qExpectedImprovement, |
| 621 | + monte_carlo.qNoisyExpectedImprovement, |
| 622 | + monte_carlo.qProbabilityOfImprovement, |
| 623 | + multi_objective.analytic.ExpectedHypervolumeImprovement, |
| 624 | + multi_objective.monte_carlo.qExpectedHypervolumeImprovement, |
| 625 | + multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement, |
| 626 | + ) |
| 627 | + mm = MockModel( |
| 628 | + MockPosterior( |
| 629 | + mean=torch.rand(1, 1, device=self.device), |
| 630 | + variance=torch.ones(1, 1, device=self.device), |
| 631 | + ) |
| 632 | + ) |
| 633 | + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) |
| 634 | + with mock.patch( |
| 635 | + "botorch.acquisition.utils.isinstance_af", return_value=True |
| 636 | + ) as mock_isinstance_af: |
| 637 | + self.assertTrue(is_nonnegative(acq_function=acq_func)) |
| 638 | + mock_isinstance_af.assert_called_once() |
| 639 | + cargs, _ = mock_isinstance_af.call_args |
| 640 | + self.assertIs(cargs[0], acq_func) |
| 641 | + self.assertEqual(cargs[1], nonneg_afs) |
| 642 | + acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0) |
| 643 | + self.assertFalse(is_nonnegative(acq_function=acq_func)) |
| 644 | + |
| 645 | + |
| 646 | +class TestIsinstanceAf(BotorchTestCase): |
| 647 | + def test_isinstance_af(self): |
| 648 | + mm = MockModel( |
| 649 | + MockPosterior( |
| 650 | + mean=torch.rand(1, 1, device=self.device), |
| 651 | + variance=torch.ones(1, 1, device=self.device), |
| 652 | + ) |
| 653 | + ) |
| 654 | + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) |
| 655 | + self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement)) |
| 656 | + self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound)) |
| 657 | + wrapped_af = FixedFeatureAcquisitionFunction( |
| 658 | + acq_function=acq_func, d=2, columns=[1], values=[0.0] |
| 659 | + ) |
| 660 | + # test base af class |
| 661 | + self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement)) |
| 662 | + self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound)) |
| 663 | + # test wrapper class |
| 664 | + self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction)) |
| 665 | + self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction)) |
| 666 | + |
| 667 | + |
609 | 668 | class TestPruneInferiorPoints(BotorchTestCase):
|
610 | 669 | def test_prune_inferior_points(self):
|
611 | 670 | for dtype in (torch.float, torch.double):
|
|
0 commit comments