Skip to content

Commit 4a13876

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Make fully Bayesian imports lazy (#3265)
Summary: X-link: facebook/Ax#5148 `import botorch` eagerly loaded Pyro via two import paths: 1. `botorch.fit` → `from pyro.infer.mcmc import MCMC, NUTS` 2. `botorch.models.__init__` → `fully_bayesian.py` → Pyro imports This diff makes fully Bayesian imports lazy so that `import botorch` no longer eagerly loads Pyro. The imports are deferred until `fit_fully_bayesian_model_nuts` is called or `SaasFullyBayesianSingleTaskGP`/`SaasFullyBayesianMultiTaskGP` are accessed from `botorch.models`. Reviewed By: saitcakmak Differential Revision: D99687047
1 parent 1d5fd01 commit 4a13876

3 files changed

Lines changed: 31 additions & 7 deletions

File tree

botorch/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
test_functions,
1717
)
1818
from botorch.cross_validation import batch_cross_validation
19-
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
19+
from botorch.fit import fit_gpytorch_mll
2020
from botorch.generation.gen import (
2121
gen_candidates_scipy,
2222
gen_candidates_torch,
@@ -49,6 +49,14 @@
4949
gp_settings.max_eager_kernel_size._global_value = 4096
5050

5151

52+
def __getattr__(name: str):
53+
if name == "fit_fully_bayesian_model_nuts":
54+
from botorch.fit import fit_fully_bayesian_model_nuts
55+
56+
return fit_fully_bayesian_model_nuts
57+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
58+
59+
5260
__all__ = [
5361
"acquisition",
5462
"batch_cross_validation",

botorch/fit.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
from copy import deepcopy
1313
from functools import partial
1414
from itertools import filterfalse
15-
from typing import Any
15+
from typing import Any, TYPE_CHECKING
1616
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage
1717

1818
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
1919
from botorch.exceptions.warnings import OptimizationWarning
2020
from botorch.logging import logger
2121
from botorch.models import SingleTaskGP
22-
from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP
23-
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
2422
from botorch.models.map_saas import get_map_saas_model
2523
from botorch.models.model_list_gp_regression import ModelListGP
2624
from botorch.models.transforms.input import InputTransform
@@ -43,11 +41,14 @@
4341
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
4442
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
4543
from linear_operator.utils.errors import NotPSDError
46-
from pyro.infer.mcmc import MCMC, NUTS
4744
from torch import device, Tensor
4845
from torch.nn import Parameter
4946
from torch.utils.data import DataLoader
5047

48+
if TYPE_CHECKING:
49+
from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP
50+
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
51+
5152

5253
def _debug_warn(w: WarningMessage) -> bool:
5354
if _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)):
@@ -363,6 +364,8 @@ def fit_fully_bayesian_model_nuts(
363364
>>> gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
364365
>>> fit_fully_bayesian_model_nuts(gp)
365366
"""
367+
from pyro.infer.mcmc import MCMC, NUTS
368+
366369
model.train()
367370

368371
# Do inference with NUTS

botorch/models/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
GenericDeterministicModel,
1515
PosteriorMeanModel,
1616
)
17-
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
18-
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
1917
from botorch.models.gp_regression import SingleTaskGP
2018
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
2119
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
@@ -30,6 +28,21 @@
3028
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
3129
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
3230

31+
_FULLY_BAYESIAN_LAZY_IMPORTS = {
32+
"SaasFullyBayesianSingleTaskGP": "botorch.models.fully_bayesian",
33+
"SaasFullyBayesianMultiTaskGP": "botorch.models.fully_bayesian_multitask",
34+
}
35+
36+
37+
def __getattr__(name: str):
38+
if name in _FULLY_BAYESIAN_LAZY_IMPORTS:
39+
import importlib
40+
41+
module = importlib.import_module(_FULLY_BAYESIAN_LAZY_IMPORTS[name])
42+
return getattr(module, name)
43+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
44+
45+
3346
__all__ = [
3447
"add_saas_prior",
3548
"AdditiveMapSaasSingleTaskGP",

0 commit comments

Comments
 (0)