Skip to content

Commit 3a396b2

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Add isinstance_af
Summary: Creates a new helper method for checking both if a given AF is an instance of a class or if the given AF wraps a base AF that is an instance of a class Differential Revision: D43127722 fbshipit-source-id: 0ec0131cf1c7512c10ab14e5d0a0a20cf3025688
1 parent c5ad87b commit 3a396b2

File tree

2 files changed

+75
-3
lines changed

2 files changed

+75
-3
lines changed

botorch/acquisition/utils.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from __future__ import annotations
1212

1313
import math
14-
from typing import Callable, Dict, List, Optional, Union
14+
from typing import Any, Callable, Dict, List, Optional, Union
1515

1616
import torch
1717
from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401
@@ -22,6 +22,7 @@
2222
MCAcquisitionObjective,
2323
PosteriorTransform,
2424
)
25+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
2526
from botorch.exceptions.errors import UnsupportedError
2627
from botorch.models.fully_bayesian import MCMC_DIM
2728
from botorch.models.model import Model
@@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None):
253254
return -(lb.clamp_max(0.0))
254255

255256

257+
def isinstance_af(
258+
__obj: object,
259+
__class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]],
260+
) -> bool:
261+
r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions."""
262+
if isinstance(__obj, AbstractAcquisitionFunctionWrapper):
263+
isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple)
264+
else:
265+
isinstance_base_af = False
266+
return isinstance_base_af or isinstance(__obj, __class_or_tuple)
267+
268+
256269
def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
257270
r"""Determine whether a given acquisition function is non-negative.
258271
@@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
267280
>>> qEI = qExpectedImprovement(model, best_f=0.1)
268281
>>> is_nonnegative(qEI) # returns True
269282
"""
270-
return isinstance(
283+
return isinstance_af(
271284
acq_function,
272285
(
273286
analytic.ExpectedImprovement,

test/acquisition/test_utils.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from unittest import mock
99

1010
import torch
11-
from botorch.acquisition import monte_carlo
11+
from botorch.acquisition import analytic, monte_carlo, multi_objective
12+
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
1213
from botorch.acquisition.multi_objective import (
1314
MCMultiOutputObjective,
1415
monte_carlo as moo_monte_carlo,
@@ -18,10 +19,13 @@
1819
MCAcquisitionObjective,
1920
ScalarizedPosteriorTransform,
2021
)
22+
from botorch.acquisition.proximal import ProximalAcquisitionFunction
2123
from botorch.acquisition.utils import (
2224
expand_trace_observations,
2325
get_acquisition_function,
2426
get_infeasible_cost,
27+
is_nonnegative,
28+
isinstance_af,
2529
project_to_sample_points,
2630
project_to_target_fidelity,
2731
prune_inferior_points,
@@ -606,6 +610,61 @@ def test_get_infeasible_cost(self):
606610
self.assertAllClose(M4, torch.tensor([1.0], **tkwargs))
607611

608612

613+
class TestIsNonnegative(BotorchTestCase):
614+
def test_is_nonnegative(self):
615+
nonneg_afs = (
616+
analytic.ExpectedImprovement,
617+
analytic.ConstrainedExpectedImprovement,
618+
analytic.ProbabilityOfImprovement,
619+
analytic.NoisyExpectedImprovement,
620+
monte_carlo.qExpectedImprovement,
621+
monte_carlo.qNoisyExpectedImprovement,
622+
monte_carlo.qProbabilityOfImprovement,
623+
multi_objective.analytic.ExpectedHypervolumeImprovement,
624+
multi_objective.monte_carlo.qExpectedHypervolumeImprovement,
625+
multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement,
626+
)
627+
mm = MockModel(
628+
MockPosterior(
629+
mean=torch.rand(1, 1, device=self.device),
630+
variance=torch.ones(1, 1, device=self.device),
631+
)
632+
)
633+
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
634+
with mock.patch(
635+
"botorch.acquisition.utils.isinstance_af", return_value=True
636+
) as mock_isinstance_af:
637+
self.assertTrue(is_nonnegative(acq_function=acq_func))
638+
mock_isinstance_af.assert_called_once()
639+
cargs, _ = mock_isinstance_af.call_args
640+
self.assertIs(cargs[0], acq_func)
641+
self.assertEqual(cargs[1], nonneg_afs)
642+
acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0)
643+
self.assertFalse(is_nonnegative(acq_function=acq_func))
644+
645+
646+
class TestIsinstanceAf(BotorchTestCase):
647+
def test_isinstance_af(self):
648+
mm = MockModel(
649+
MockPosterior(
650+
mean=torch.rand(1, 1, device=self.device),
651+
variance=torch.ones(1, 1, device=self.device),
652+
)
653+
)
654+
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
655+
self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement))
656+
self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound))
657+
wrapped_af = FixedFeatureAcquisitionFunction(
658+
acq_function=acq_func, d=2, columns=[1], values=[0.0]
659+
)
660+
# test base af class
661+
self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement))
662+
self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound))
663+
# test wrapper class
664+
self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction))
665+
self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction))
666+
667+
609668
class TestPruneInferiorPoints(BotorchTestCase):
610669
def test_prune_inferior_points(self):
611670
for dtype in (torch.float, torch.double):

0 commit comments

Comments
 (0)