Skip to content

Commit 94766a7

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Back out "Replace Pyro with NumPyro for fully Bayesian NUTS inference (96% reduction in fit time)" (meta-pytorch#3263)
Summary: Pull Request resolved: meta-pytorch#3263 X-link: facebook/Ax#5136 Reviewed By: Balandat Differential Revision: D99446023 fbshipit-source-id: 76efd2e595d48589ee3252efe407e31521044468
1 parent 8857df9 commit 94766a7

12 files changed

Lines changed: 1913 additions & 2072 deletions

File tree

botorch/fit.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from typing import Any
1616
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage
1717

18-
import jax
1918
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
2019
from botorch.exceptions.warnings import OptimizationWarning
2120
from botorch.logging import logger
@@ -44,7 +43,7 @@
4443
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
4544
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
4645
from linear_operator.utils.errors import NotPSDError
47-
from numpyro.infer import MCMC, NUTS
46+
from pyro.infer.mcmc import MCMC, NUTS
4847
from torch import device, Tensor
4948
from torch.nn import Parameter
5049
from torch.utils.data import DataLoader
@@ -343,11 +342,9 @@ def fit_fully_bayesian_model_nuts(
343342
thinning: int = 16,
344343
disable_progbar: bool = False,
345344
jit_compile: bool = False,
346-
seed: int = 0,
347345
) -> None:
348346
r"""Fit a fully Bayesian model using the No-U-Turn-Sampler (NUTS)
349347
350-
Uses NumPyro's NUTS implementation (backed by JAX) for MCMC inference.
351348
352349
Args:
353350
model: Fully Bayesian GP to be fitted.
@@ -360,8 +357,7 @@ def fit_fully_bayesian_model_nuts(
360357
bar and diagnostics during MCMC.
361358
jit_compile: Whether to use jit. Using jit may be ~2X faster (rough estimate),
362359
but it will also increase the memory usage and sometimes result in runtime
363-
errors.
364-
seed: Random seed for JAX PRNG.
360+
errors, e.g., https://github.com/pyro-ppl/pyro/issues/3136.
365361
366362
Example:
367363
>>> gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
@@ -372,18 +368,20 @@ def fit_fully_bayesian_model_nuts(
372368
# Do inference with NUTS
373369
nuts = NUTS(
374370
model.pyro_model.sample,
375-
dense_mass=True,
371+
jit_compile=jit_compile,
372+
full_mass=True,
373+
ignore_jit_warnings=True,
376374
max_tree_depth=max_tree_depth,
377375
)
378376
mcmc = MCMC(
379377
nuts,
380-
num_warmup=warmup_steps,
378+
warmup_steps=warmup_steps,
381379
num_samples=num_samples,
382-
progress_bar=not disable_progbar,
380+
disable_progbar=disable_progbar,
383381
)
384-
mcmc.run(jax.random.PRNGKey(seed))
382+
mcmc.run()
385383

386-
# Get final MCMC samples from the NumPyro model
384+
# Get final MCMC samples from the Pyro model
387385
mcmc_samples = model.pyro_model.postprocess_mcmc_samples(
388386
mcmc_samples=mcmc.get_samples()
389387
)

0 commit comments

Comments
 (0)