Skip to content

Commit bf300db

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Add TorchAdapter.botorch_model property (#4827)
Summary: Pull Request resolved: #4827 This is a convenience method that should save us from doing calls like `client._generation_strategy.adapter.generator.surrogate.model` and instead simplify it to `client._generation_strategy.adapter.botorch_model`. Reviewed By: Balandat Differential Revision: D91588495 fbshipit-source-id: 2cc66a3dbe847ca37ae2b5ccf50b8eb813bfd54c
1 parent 4b89d30 commit bf300db

File tree

5 files changed

+53
-15
lines changed

5 files changed

+53
-15
lines changed

ax/adapter/tests/test_torch_adapter.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ax.core.search_space import SearchSpace, SearchSpaceDigest
4444
from ax.core.types import ComparisonOp
4545
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
46+
from ax.exceptions.model import ModelError
4647
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
4748
from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
4849
from ax.generators.torch.botorch_modular.utils import ModelConfig
@@ -1391,6 +1392,37 @@ def test_moo_with_derived_parameter(self) -> None:
13911392
)
13921393
self.assertEqual(len(objective_thresholds), len(experiment.metrics))
13931394

1395+
def test_botorch_model_property(self) -> None:
1396+
experiment = get_branin_experiment(with_completed_trial=True)
1397+
# Case: Invalid generator.
1398+
adapter = TorchAdapter(
1399+
generator=TorchGenerator(),
1400+
experiment=experiment,
1401+
transforms=Cont_X_trans,
1402+
)
1403+
with self.assertRaisesRegex(UnsupportedError, "BoTorchGenerator"):
1404+
adapter.botorch_model
1405+
1406+
# Case: Model not fitted yet.
1407+
adapter = TorchAdapter(
1408+
generator=BoTorchGenerator(),
1409+
experiment=experiment,
1410+
transforms=Cont_X_trans,
1411+
fit_on_init=False,
1412+
)
1413+
with self.assertRaisesRegex(ModelError, "has not yet been constructed"):
1414+
adapter.botorch_model
1415+
1416+
# Case: Model fitted.
1417+
generator = BoTorchGenerator()
1418+
adapter = TorchAdapter(
1419+
generator=generator,
1420+
experiment=experiment,
1421+
transforms=Cont_X_trans,
1422+
)
1423+
self.assertIs(adapter.botorch_model, generator.surrogate.model)
1424+
self.assertIsInstance(adapter.botorch_model, SingleTaskGP)
1425+
13941426

13951427
class AdapterWithPLBOTest(TestCase):
13961428
"""Test the PLBO (Preference-Learning-guided BO) step in BOPE (Bayesian

ax/adapter/torch.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from ax.generators.types import TConfig
8787
from ax.utils.common.constants import Keys
8888
from ax.utils.common.logger import get_logger
89+
from botorch.models.model import Model
8990
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
9091
from pyre_extensions import none_throws
9192
from torch import Tensor
@@ -216,6 +217,16 @@ def feature_importances(self, metric_signature: str) -> dict[str, float]:
216217
importances_arr = importances_dict[metric_signature].flatten()
217218
return dict(zip(self.parameters, importances_arr, strict=True))
218219

220+
@property
221+
def botorch_model(self) -> Model:
222+
"""Returns the underlying BoTorch model for BoTorchGenerator."""
223+
if not isinstance(self.generator, BoTorchGenerator):
224+
raise UnsupportedError(
225+
"Generator must be a BoTorchGenerator to "
226+
f"access botorch_model. Found {type(self.generator)}."
227+
)
228+
return self.generator.surrogate.model
229+
219230
def infer_objective_thresholds(
220231
self,
221232
search_space: SearchSpace | None = None,
@@ -254,7 +265,7 @@ def infer_objective_thresholds(
254265
)
255266
# Infer objective thresholds.
256267
if isinstance(self.generator, BoTorchGenerator):
257-
model = self.generator.surrogate.model
268+
model = self.botorch_model
258269
Xs = self.generator.surrogate.Xs
259270
else:
260271
raise UnsupportedError(

ax/benchmark/benchmark_test_functions/surrogate.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,7 @@ def wrap_surrogate_in_deterministic_model(self) -> None:
104104
def surrogate(self) -> TorchAdapter:
105105
if self._surrogate is None:
106106
self._surrogate = none_throws(self.get_surrogate)()
107-
if not isinstance(
108-
# pyre-ignore[16]: `ax.generators.torch_base.TorchGenerator` has no
109-
# attribute `surrogate`.
110-
self._surrogate.generator.surrogate.model,
111-
DeterministicModel,
112-
):
107+
if not isinstance(self._surrogate.botorch_model, DeterministicModel):
113108
self.wrap_surrogate_in_deterministic_model()
114109
return none_throws(self._surrogate)
115110

ax/benchmark/tests/benchmark_test_functions/test_surrogate_test_function.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from botorch.models.deterministic import PosteriorMeanModel
2626
from botorch.sampling.pathwise.posterior_samplers import MatheronPathModel
27+
from pyre_extensions import assert_is_instance
2728

2829

2930
class TestSurrogateTestFunction(TestCase):
@@ -237,11 +238,11 @@ def test_ensemble_sampling(self) -> None:
237238

238239
# Access surrogate to trigger wrapping
239240
surrogate = test_function.surrogate
240-
# pyre-ignore[16]: Access base_model through deterministic wrapper
241-
wrapped_model = surrogate.generator.surrogate.model
241+
# Access base_model through deterministic wrapper
242+
wrapped_model = assert_is_instance(surrogate.botorch_model, PosteriorMeanModel)
242243

243244
# Check that exactly one model has weight 1.0 and others have weight 0.0
244-
weights = wrapped_model.ensemble_weights
245+
weights = assert_is_instance(wrapped_model.ensemble_weights, torch.Tensor)
245246
self.assertEqual(weights.sum().item(), 1.0)
246247
self.assertEqual((weights == 1.0).sum().item(), 1)
247248
self.assertEqual((weights == 0.0).sum().item(), len(weights) - 1)
@@ -262,5 +263,4 @@ def test_ensemble_no_sampling(self) -> None:
262263

263264
# Access surrogate to trigger wrapping
264265
surrogate = test_function.surrogate
265-
# pyre-ignore[16]: Access base_model through deterministic wrapper
266-
self.assertIsNone(surrogate.generator.surrogate.model.ensemble_weights)
266+
self.assertIsNone(surrogate.botorch_model.ensemble_weights)

ax/utils/sensitivity/tests/test_sensitivity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ class SensitivityAnalysisTest(TestCase):
7373
def setUp(self) -> None:
7474
super().setUp()
7575
set_rng_seed(0)
76-
self.model = get_adapter().generator.surrogate.model
77-
self.saas_model = get_adapter(saasbo=True).generator.surrogate.model
76+
self.model = get_adapter().botorch_model
77+
self.saas_model = get_adapter(saasbo=True).botorch_model
7878

7979
def test_DgsmGpMean(self) -> None:
8080
bounds = torch.tensor([(0.0, 1.0) for _ in range(2)]).t()
@@ -418,7 +418,7 @@ def test_SobolGPMean_SAASBO_Ax_utils(self) -> None:
418418
**sobol_kwargs,
419419
)
420420
ind_deriv = compute_derivatives_from_model_list(
421-
model_list=[adapter.generator.surrogate.model],
421+
model_list=[adapter.botorch_model],
422422
bounds=torch.tensor(adapter.generator.search_space_digest.bounds).T,
423423
discrete_features=discrete_features,
424424
fixed_features=None,

0 commit comments

Comments
 (0)