Skip to content

Commit f8eec90

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Make fully Bayesian imports lazy (facebook#5148)
Summary: X-link: meta-pytorch/botorch#3265 `import botorch` eagerly loaded Pyro via two import paths: 1. `botorch.fit` → `from pyro.infer.mcmc import MCMC, NUTS` 2. `botorch.models.__init__` → `fully_bayesian.py` → Pyro imports This diff makes fully Bayesian imports lazy so that `import botorch` no longer eagerly loads Pyro. The imports are deferred until `fit_fully_bayesian_model_nuts` is called or `SaasFullyBayesianSingleTaskGP`/`SaasFullyBayesianMultiTaskGP` are accessed from `botorch.models`. Differential Revision: D99687047
1 parent 96cf12b commit f8eec90

2 files changed

Lines changed: 3 additions & 2 deletions

File tree

ax/generators/torch/tests/test_surrogate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@
5151
from ax.utils.testing.torch_stubs import get_torch_test_data
5252
from ax.utils.testing.utils import generic_equals
5353
from botorch.exceptions.errors import ModelFittingError
54-
from botorch.models import ModelListGP, SaasFullyBayesianSingleTaskGP, SingleTaskGP
54+
from botorch.models import ModelListGP, SingleTaskGP
5555
from botorch.models.deterministic import GenericDeterministicModel
56+
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
5657
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
5758
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
5859
from botorch.models.model import Model, ModelList # noqa: F401 -- used in Mocks.

ax/utils/testing/tests/test_mock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_botorch_mocks(self) -> None:
4545

4646
def test_fully_bayesian_mocks(self) -> None:
4747
experiment = get_branin_experiment(with_completed_batch=True)
48-
with patch("botorch.fit.MCMC", wraps=MCMC) as mock_mcmc:
48+
with patch("pyro.infer.mcmc.MCMC", wraps=MCMC) as mock_mcmc:
4949
with mock_botorch_optimize_context_manager():
5050
Generators.SAASBO(experiment=experiment, data=experiment.lookup_data())
5151
mock_mcmc.assert_called_once()

0 commit comments

Comments
 (0)