1515from typing import Any
1616from warnings import catch_warnings , simplefilter , warn_explicit , WarningMessage
1717
18- import jax
1918from botorch .exceptions .errors import ModelFittingError , UnsupportedError
2019from botorch .exceptions .warnings import OptimizationWarning
2120from botorch .logging import logger
4443from gpytorch .mlls .marginal_log_likelihood import MarginalLogLikelihood
4544from gpytorch .mlls .sum_marginal_log_likelihood import SumMarginalLogLikelihood
4645from linear_operator .utils .errors import NotPSDError
47- from numpyro .infer import MCMC , NUTS
46+ from pyro .infer . mcmc import MCMC , NUTS
4847from torch import device , Tensor
4948from torch .nn import Parameter
5049from torch .utils .data import DataLoader
@@ -343,11 +342,9 @@ def fit_fully_bayesian_model_nuts(
343342 thinning : int = 16 ,
344343 disable_progbar : bool = False ,
345344 jit_compile : bool = False ,
346- seed : int = 0 ,
347345) -> None :
348346 r"""Fit a fully Bayesian model using the No-U-Turn-Sampler (NUTS)
349347
350- Uses NumPyro's NUTS implementation (backed by JAX) for MCMC inference.
351348
352349 Args:
353350 model: Fully Bayesian GP to be fitted.
@@ -360,8 +357,7 @@ def fit_fully_bayesian_model_nuts(
360357 bar and diagnostics during MCMC.
361358 jit_compile: Whether to use jit. Using jit may be ~2X faster (rough estimate),
362359 but it will also increase the memory usage and sometimes result in runtime
363- errors.
364- seed: Random seed for JAX PRNG.
360+ errors, e.g., https://github.com/pyro-ppl/pyro/issues/3136.
365361
366362 Example:
367363 >>> gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
@@ -372,18 +368,20 @@ def fit_fully_bayesian_model_nuts(
372368 # Do inference with NUTS
373369 nuts = NUTS (
374370 model .pyro_model .sample ,
375- dense_mass = True ,
371+ jit_compile = jit_compile ,
372+ full_mass = True ,
373+ ignore_jit_warnings = True ,
376374 max_tree_depth = max_tree_depth ,
377375 )
378376 mcmc = MCMC (
379377 nuts ,
380- num_warmup = warmup_steps ,
378+ warmup_steps = warmup_steps ,
381379 num_samples = num_samples ,
382- progress_bar = not disable_progbar ,
380+ disable_progbar = disable_progbar ,
383381 )
384- mcmc .run (jax . random . PRNGKey ( seed ) )
382+ mcmc .run ()
385383
386- # Get final MCMC samples from the NumPyro model
384+ # Get final MCMC samples from the Pyro model
387385 mcmc_samples = model .pyro_model .postprocess_mcmc_samples (
388386 mcmc_samples = mcmc .get_samples ()
389387 )
0 commit comments