11# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33
4- from typing import Any , Callable , Dict , Optional , Union
4+ from typing import Any , Callable , Dict , Literal , Optional , Union
55
66from torch .distributions import Distribution
77
@@ -94,9 +94,19 @@ def build_posterior(
9494 self ,
9595 density_estimator : Optional [MixedDensityEstimator ] = None ,
9696 prior : Optional [Distribution ] = None ,
97- sample_with : str = "direct" ,
98- mcmc_method : str = "slice_np_vectorized" ,
99- vi_method : str = "rKL" ,
97+ sample_with : Literal [
98+ "mcmc" , "rejection" , "vi" , "importance" , "direct"
99+ ] = "direct" ,
100+ mcmc_method : Literal [
101+ "slice_np" ,
102+ "slice_np_vectorized" ,
103+ "hmc_pyro" ,
104+ "nuts_pyro" ,
105+ "slice_pymc" ,
106+ "hmc_pymc" ,
107+ "nuts_pymc" ,
108+ ] = "slice_np_vectorized" ,
109+ vi_method : Literal ["rKL" , "fKL" , "IW" , "alpha" ] = "rKL" ,
100110 direct_sampling_parameters : Optional [Dict [str , Any ]] = None ,
101111 mcmc_parameters : Optional [Dict [str , Any ]] = None ,
102112 vi_parameters : Optional [Dict [str , Any ]] = None ,
@@ -117,10 +127,14 @@ def build_posterior(
117127 prior: Prior distribution.
118128 sample_with: Method to use for sampling from the posterior. Must be one of
119129 [`direct` | `mcmc` | `rejection` | `vi` | `importance`].
120- mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
121- `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
122- implementation of slice sampling; select `hmc`, `nuts` or `slice` for
123- Pyro-based sampling.
130+ mcmc_method: Method used for MCMC sampling, one of `slice_np`,
131+ `slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`,
132+ `hmc_pymc`, `nuts_pymc`. `slice_np` is a custom
133+ numpy implementation of slice sampling. `slice_np_vectorized` is
134+ identical to `slice_np`, but if `num_chains>1`, the chains are
135+ vectorized for `slice_np_vectorized` whereas they are run sequentially
136+ for `slice_np`. The samplers ending on `_pyro` are using Pyro, and
137+ likewise the samplers ending on `_pymc` are using PyMC.
124138 vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`].
125139 direct_sampling_parameters: Additional kwargs passed to `DirectPosterior`.
126140 mcmc_parameters: Additional kwargs passed to `MCMCPosterior`.
0 commit comments