Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f0e5e5f
feat: implement AmortizedVIPosterior and corresponding tests (WIP).
janfb Jan 13, 2026
3178b0b
feat: enhance AmortizedVIPosterior with vectorization and early stopp…
janfb Jan 13, 2026
e3dcb59
refactoring: input validation and error handling; enhance test for ac…
janfb Jan 13, 2026
4808920
feat: add FlowType enum for flow architectures in AmortizedVIPosterio…
janfb Jan 13, 2026
72e2bd5
feat: align amortized VI flow types and MAP behavior
janfb Jan 17, 2026
5bd06a1
refactoring: improve API consistency and input validation
janfb Feb 2, 2026
ea51371
feat(vi): add Zuko unconditional flow builder for VI
janfb Feb 2, 2026
7af45c0
refactor(vi): adapt DivergenceOptimizer for Zuko flow support
janfb Feb 2, 2026
4a1e186
refactor(vi): extend Zuko support in divergence optimizer subclasses
janfb Feb 2, 2026
4468d0d
fix(vi): address code review issues in VI posterior
janfb Feb 2, 2026
6982637
feat(vi): integrate Zuko unconditional flows into VIPosterior
janfb Feb 2, 2026
bc13be7
feat(vi): add train_amortized() method for amortized variational infe…
janfb Feb 2, 2026
5922057
fix(vi): fix serialization for Zuko flows and update tests
janfb Feb 2, 2026
fb837dd
refactor(vi): migrate amortized VI tests to unified VIPosterior
janfb Feb 2, 2026
e20a19e
refactor(vi): remove AmortizedVIPosterior from public API exports
janfb Feb 2, 2026
9875bfe
wip: fix zuko 1D issue, replace pyro flows with gaussian in 1D case
janfb Feb 4, 2026
14abaf2
device fixes and combining into one test file
janfb Feb 4, 2026
c53a83a
refactor tests: add fast amortized test; remove redundant.
janfb Feb 6, 2026
8b48387
clean up pyro; fix types; fix posterior params cls
janfb Feb 6, 2026
15206dd
Merge main into add-amortized-vip
janfb Feb 6, 2026
a040502
linting
janfb Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
_abc_family = ["ABC", "MCABC", "SMC", "SMCABC"]


__all__ = _npe_family + _nre_family + _nle_family + _abc_family + ["FMPE", "NPSE"]

from sbi.inference.posteriors import (
DirectPosterior,
EnsemblePosterior,
Expand All @@ -53,4 +51,27 @@
)
from sbi.utils.simulation_utils import simulate_for_sbi

__all__ = ["FMPE", "MarginalTrainer", "NLE", "NPE", "NPSE", "NRE", "simulate_for_sbi"]
__all__ = (
_npe_family
+ _nre_family
+ _nle_family
+ _abc_family
+ [
"FMPE",
"MarginalTrainer",
"NPSE",
"DirectPosterior",
"EnsemblePosterior",
"ImportanceSamplingPosterior",
"MCMCPosterior",
"RejectionPosterior",
"VIPosterior",
"VectorFieldPosterior",
"simulate_for_sbi",
"likelihood_estimator_based_potential",
"mixed_likelihood_estimator_based_potential",
"posterior_estimator_based_potential",
"ratio_estimator_based_potential",
"vector_field_estimator_based_potential",
]
)
91 changes: 51 additions & 40 deletions sbi/inference/posteriors/posterior_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Any,
Callable,
Dict,
Iterable,
Literal,
Optional,
Union,
Expand All @@ -17,7 +16,7 @@
)

from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.sbi_types import PyroTransformedDistribution, TorchTransform
from sbi.sbi_types import TorchTransform, VariationalDistribution
from sbi.utils.typechecks import (
is_nonnegative_int,
is_positive_float,
Expand Down Expand Up @@ -334,61 +333,73 @@ def validate(self):
@dataclass(frozen=True)
class VIPosteriorParameters(PosteriorParameters):
"""
Parameters for initializing VIPosterior.
Parameters for VIPosterior, supporting both single-x and amortized VI.

Fields:
q: Variational distribution, either string, `TransformedDistribution`, or a
`VIPosterior` object. This specifies a parametric class of distribution
over which the best possible posterior approximation is searched. For
string input, we currently support [nsf, scf, maf, mcf, gaussian,
gaussian_diag]. You can also specify your own variational family by
passing a pyro `TransformedDistribution`.
Additionally, we allow a `Callable`, which allows you the pass a
`builder` function, which if called returns a distribution. This may be
useful for setting the hyperparameters e.g. `num_transfroms` within the
`get_flow_builder` method specifying the number of transformations
within a normalizing flow. If q is already a `VIPosterior`, then the
arguments will be copied from it (relevant for multi-round training).
vi_method: This specifies the variational methods which are used to fit q to
the posterior. We currently support [rKL, fKL, IW, alpha]. Note that
some of the divergences are `mode seeking` i.e. they underestimate
variance and collapse on multimodal targets (`rKL`, `alpha` for alpha >
1) and some are `mass covering` i.e. they overestimate variance but
typically cover all modes (`fKL`, `IW`, `alpha` for alpha < 1).
parameters: List of parameters of the variational posterior. This is only
required for user-defined q i.e. if q does not have a `parameters`
attribute.
modules: List of modules of the variational posterior. This is only
required for user-defined q i.e. if q does not have a `modules`
attribute.
q: Variational distribution. Either a string specifying the flow type
[nsf, maf, naf, unaf, nice, sospf, gaussian, gaussian_diag], a
`TransformedDistribution`, a `VIPosterior` object, or a `Callable`
builder function. For amortized VI, use string flow types only.
If q is already a `VIPosterior`, arguments are copied from it
(relevant for multi-round training).
vi_method: Variational method for fitting q to the posterior. Options:
[rKL, fKL, IW, alpha]. Some are "mode seeking" (rKL, alpha > 1) and
some are "mass covering" (fKL, IW, alpha < 1). Currently only used
for single-x VI; amortized VI uses ELBO (rKL).
num_transforms: Number of transforms in the normalizing flow.
hidden_features: Hidden layer size in the flow networks.
z_score_theta: Method for z-scoring θ (the parameters being modeled).
One of "none", "independent", "structured". Use "structured" for
parameters with correlations.
z_score_x: Method for z-scoring x (the conditioning variable, amortized
VI only). One of "none", "independent", "structured". Use
"structured" for structured data like images.

Note:
For custom distributions that lack `parameters()` and `modules()` methods,
pass these via `VIPosterior.set_q(q, parameters=..., modules=...)` instead.
"""

q: Union[
Literal["nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"],
PyroTransformedDistribution,
Literal[
"nsf", "maf", "naf", "unaf", "nice", "sospf", "gaussian", "gaussian_diag"
],
VariationalDistribution,
"VIPosterior",
Callable,
] = "maf"
vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL"
parameters: Optional[Iterable] = None
modules: Optional[Iterable] = None
num_transforms: int = 5
hidden_features: int = 50
z_score_theta: Literal["none", "independent", "structured"] = "independent"
z_score_x: Literal["none", "independent", "structured"] = "independent"

def validate(self):
"""Validate VIPosteriorParameters fields."""

valid_q = {"nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"}
valid_q = {
"nsf",
"maf",
"naf",
"unaf",
"nice",
"sospf",
"gaussian",
"gaussian_diag",
}

if isinstance(self.q, str) and self.q not in valid_q:
raise ValueError(f"If `q` is a string, it must be one of {valid_q}")
elif not isinstance(
self.q, (PyroTransformedDistribution, VIPosterior, Callable, str)
self.q, (VariationalDistribution, VIPosterior, Callable, str)
):
raise TypeError(
"q must be either of typr PyroTransformedDistribution,"
" VIPosterioror or Callable"
"q must be either of type VariationalDistribution,"
" VIPosterior or Callable"
)

if self.parameters is not None and not isinstance(self.parameters, Iterable):
raise TypeError("parameters must be iterable or None.")
if self.modules is not None and not isinstance(self.modules, Iterable):
raise TypeError("modules must be iterable or None.")
if self.num_transforms < 1:
raise ValueError(f"num_transforms must be >= 1, got {self.num_transforms}")
if self.hidden_features < 1:
raise ValueError(
f"hidden_features must be >= 1, got {self.hidden_features}"
)
Loading
Loading