diff --git a/sbi/inference/__init__.py b/sbi/inference/__init__.py index 6a8dd00fe..eb159aa34 100644 --- a/sbi/inference/__init__.py +++ b/sbi/inference/__init__.py @@ -33,8 +33,6 @@ _abc_family = ["ABC", "MCABC", "SMC", "SMCABC"] -__all__ = _npe_family + _nre_family + _nle_family + _abc_family + ["FMPE", "NPSE"] - from sbi.inference.posteriors import ( DirectPosterior, EnsemblePosterior, @@ -53,4 +51,27 @@ ) from sbi.utils.simulation_utils import simulate_for_sbi -__all__ = ["FMPE", "MarginalTrainer", "NLE", "NPE", "NPSE", "NRE", "simulate_for_sbi"] +__all__ = ( + _npe_family + + _nre_family + + _nle_family + + _abc_family + + [ + "FMPE", + "MarginalTrainer", + "NPSE", + "DirectPosterior", + "EnsemblePosterior", + "ImportanceSamplingPosterior", + "MCMCPosterior", + "RejectionPosterior", + "VIPosterior", + "VectorFieldPosterior", + "simulate_for_sbi", + "likelihood_estimator_based_potential", + "mixed_likelihood_estimator_based_potential", + "posterior_estimator_based_potential", + "ratio_estimator_based_potential", + "vector_field_estimator_based_potential", + ] +) diff --git a/sbi/inference/posteriors/posterior_parameters.py b/sbi/inference/posteriors/posterior_parameters.py index 6f10d4246..cd3a01742 100644 --- a/sbi/inference/posteriors/posterior_parameters.py +++ b/sbi/inference/posteriors/posterior_parameters.py @@ -7,7 +7,6 @@ Any, Callable, Dict, - Iterable, Literal, Optional, Union, @@ -17,7 +16,7 @@ ) from sbi.inference.posteriors.vi_posterior import VIPosterior -from sbi.sbi_types import PyroTransformedDistribution, TorchTransform +from sbi.sbi_types import TorchTransform, VariationalDistribution from sbi.utils.typechecks import ( is_nonnegative_int, is_positive_float, @@ -334,61 +333,73 @@ def validate(self): @dataclass(frozen=True) class VIPosteriorParameters(PosteriorParameters): """ - Parameters for initializing VIPosterior. + Parameters for VIPosterior, supporting both single-x and amortized VI. Fields: - q: Variational distribution, either string, `TransformedDistribution`, or a - `VIPosterior` object. This specifies a parametric class of distribution - over which the best possible posterior approximation is searched. For - string input, we currently support [nsf, scf, maf, mcf, gaussian, - gaussian_diag]. You can also specify your own variational family by - passing a pyro `TransformedDistribution`. - Additionally, we allow a `Callable`, which allows you the pass a - `builder` function, which if called returns a distribution. This may be - useful for setting the hyperparameters e.g. `num_transfroms` within the - `get_flow_builder` method specifying the number of transformations - within a normalizing flow. If q is already a `VIPosterior`, then the - arguments will be copied from it (relevant for multi-round training). - vi_method: This specifies the variational methods which are used to fit q to - the posterior. We currently support [rKL, fKL, IW, alpha]. Note that - some of the divergences are `mode seeking` i.e. they underestimate - variance and collapse on multimodal targets (`rKL`, `alpha` for alpha > - 1) and some are `mass covering` i.e. they overestimate variance but - typically cover all modes (`fKL`, `IW`, `alpha` for alpha < 1). - parameters: List of parameters of the variational posterior. This is only - required for user-defined q i.e. if q does not have a `parameters` - attribute. - modules: List of modules of the variational posterior. This is only - required for user-defined q i.e. if q does not have a `modules` - attribute. + q: Variational distribution. Either a string specifying the flow type + [nsf, maf, naf, unaf, nice, sospf, gaussian, gaussian_diag], a + `TransformedDistribution`, a `VIPosterior` object, or a `Callable` + builder function. For amortized VI, use string flow types only. + If q is already a `VIPosterior`, arguments are copied from it + (relevant for multi-round training). + vi_method: Variational method for fitting q to the posterior. Options: + [rKL, fKL, IW, alpha]. Some are "mode seeking" (rKL, alpha > 1) and + some are "mass covering" (fKL, IW, alpha < 1). Currently only used + for single-x VI; amortized VI uses ELBO (rKL). + num_transforms: Number of transforms in the normalizing flow. + hidden_features: Hidden layer size in the flow networks. + z_score_theta: Method for z-scoring θ (the parameters being modeled). + One of "none", "independent", "structured". Use "structured" for + parameters with correlations. + z_score_x: Method for z-scoring x (the conditioning variable, amortized + VI only). One of "none", "independent", "structured". Use + "structured" for structured data like images. + + Note: + For custom distributions that lack `parameters()` and `modules()` methods, + pass these via `VIPosterior.set_q(q, parameters=..., modules=...)` instead. """ q: Union[ - Literal["nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"], - PyroTransformedDistribution, + Literal[ + "nsf", "maf", "naf", "unaf", "nice", "sospf", "gaussian", "gaussian_diag" + ], + VariationalDistribution, "VIPosterior", Callable, ] = "maf" vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL" - parameters: Optional[Iterable] = None - modules: Optional[Iterable] = None + num_transforms: int = 5 + hidden_features: int = 50 + z_score_theta: Literal["none", "independent", "structured"] = "independent" + z_score_x: Literal["none", "independent", "structured"] = "independent" def validate(self): """Validate VIPosteriorParameters fields.""" - - valid_q = {"nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"} + valid_q = { + "nsf", + "maf", + "naf", + "unaf", + "nice", + "sospf", + "gaussian", + "gaussian_diag", + } if isinstance(self.q, str) and self.q not in valid_q: raise ValueError(f"If `q` is a string, it must be one of {valid_q}") elif not isinstance( - self.q, (PyroTransformedDistribution, VIPosterior, Callable, str) + self.q, (VariationalDistribution, VIPosterior, Callable, str) ): raise TypeError( - "q must be either of typr PyroTransformedDistribution," - " VIPosterioror or Callable" + "q must be either of type VariationalDistribution," + " VIPosterior or Callable" ) - if self.parameters is not None and not isinstance(self.parameters, Iterable): - raise TypeError("parameters must be iterable or None.") - if self.modules is not None and not isinstance(self.modules, Iterable): - raise TypeError("modules must be iterable or None.") + if self.num_transforms < 1: + raise ValueError(f"num_transforms must be >= 1, got {self.num_transforms}") + if self.hidden_features < 1: + raise ValueError( + f"hidden_features must be >= 1, got {self.hidden_features}" + ) diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index ae06aa60b..8cd36f7fa 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -2,28 +2,41 @@ # under the Apache License Version 2.0, see import copy +import warnings from copy import deepcopy -from typing import Callable, Dict, Iterable, Literal, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterable, Literal, Optional, Union import numpy as np import torch -from torch import Tensor +from torch import Tensor, nn from torch.distributions import Distribution +from torch.optim import Adam +from torch.optim.lr_scheduler import ExponentialLR from tqdm.auto import tqdm from sbi.inference.posteriors.base_posterior import NeuralPosterior + +if TYPE_CHECKING: + from sbi.inference.posteriors.posterior_parameters import VIPosteriorParameters from sbi.inference.potentials.base_potential import BasePotential, CustomPotential +from sbi.neural_nets.estimators.base import ConditionalDensityEstimator +from sbi.neural_nets.estimators.zuko_flow import ZukoUnconditionalFlow +from sbi.neural_nets.factory import ZukoFlowType +from sbi.neural_nets.net_builders.flow import ( + build_zuko_flow, + build_zuko_unconditional_flow, +) from sbi.samplers.vi.vi_divergence_optimizers import get_VI_method -from sbi.samplers.vi.vi_pyro_flows import get_flow_builder from sbi.samplers.vi.vi_quality_control import get_quality_metric from sbi.samplers.vi.vi_utils import ( + LearnableGaussian, + TransformedZukoFlow, adapt_variational_distribution, check_variational_distribution, make_object_deepcopy_compatible, move_all_tensor_to_device, ) from sbi.sbi_types import ( - PyroTransformedDistribution, Shape, TorchDistribution, TorchTensor, @@ -32,6 +45,17 @@ from sbi.utils.sbiutils import mcmc_transform from sbi.utils.torchutils import atleast_2d_float32_tensor, ensure_theta_batched +# Supported Zuko flow types for VI (lowercase names) +_ZUKO_FLOW_TYPES = {"maf", "nsf", "naf", "unaf", "nice", "sospf"} + +# Type for supported variational family strings +VariationalFamily = Literal[ + "maf", "nsf", "naf", "unaf", "nice", "sospf", "gaussian", "gaussian_diag" +] + +# Type for the q parameter in VIPosterior +QType = Union[VariationalFamily, Distribution, "VIPosterior", Callable] + class VIPosterior(NeuralPosterior): r"""Provides VI (Variational Inference) to sample from the posterior. @@ -60,12 +84,7 @@ def __init__( self, potential_fn: Union[BasePotential, CustomPotential], prior: Optional[TorchDistribution] = None, # type: ignore - q: Union[ - Literal["nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"], - PyroTransformedDistribution, - "VIPosterior", - Callable, - ] = "maf", + q: QType = "maf", theta_transform: Optional[TorchTransform] = None, vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL", device: Union[str, torch.device] = "cpu", @@ -82,18 +101,19 @@ def __init__( quality metrics. Please make sure that this matches with the prior within the potential_fn. If `None` is given, we will try to infer it from potential_fn or q, if this fails we raise an Error. - q: Variational distribution, either string, `TransformedDistribution`, or a + q: Variational distribution, either string, `Distribution`, or a `VIPosterior` object. This specifies a parametric class of distribution over which the best possible posterior approximation is searched. For - string input, we currently support [nsf, scf, maf, mcf, gaussian, - gaussian_diag]. You can also specify your own variational family by - passing a pyro `TransformedDistribution`. - Additionally, we allow a `Callable`, which allows you the pass a - `builder` function, which if called returns a distribution. This may be - useful for setting the hyperparameters e.g. `num_transfroms` within the - `get_flow_builder` method specifying the number of transformations - within a normalizing flow. If q is already a `VIPosterior`, then the - arguments will be copied from it (relevant for multi-round training). + string input, we support normalizing flows [maf, nsf, naf, unaf, nice, + sospf] via Zuko, and Gaussian families [gaussian, gaussian_diag]. + You can also specify your own variational family by passing a + `torch.distributions.Distribution`. Additionally, we allow a `Callable` + with signature `(event_shape: torch.Size, link_transform: + TorchTransform, device: str) -> Distribution` for custom flow + configurations. The + callable should return a distribution with `sample()` and `log_prob()` + methods. If q is already a `VIPosterior`, then the arguments will be + copied from it (relevant for multi-round training). theta_transform: Maps form prior support to unconstrained space. The inverse is used here to ensure that the posterior support is equal to that of the prior. @@ -139,6 +159,12 @@ def __init__( move_all_tensor_to_device(self._prior, device) self._optimizer = None + # Mode tracking: None (not trained), "single_x", or "amortized" + self._mode: Optional[Literal["single_x", "amortized"]] = None + + # Amortized mode: conditional flow q(θ|x) + self._amortized_q: Optional[ConditionalDensityEstimator] = None + # In contrast to MCMC we want to project into constrained space. if theta_transform is None: self.link_transform = mcmc_transform(self._prior).inv @@ -162,7 +188,7 @@ def __init__( "can evaluate the _normalized_ posterior density with .log_prob()." ) - def to(self, device: Union[str, torch.device]) -> None: + def to(self, device: Union[str, torch.device]) -> "VIPosterior": """ Move potential_fn, _prior and x_o to device, and change the device attribute. @@ -170,6 +196,9 @@ def to(self, device: Union[str, torch.device]) -> None: Args: device: The device to move the posterior to. + + Returns: + self for method chaining. """ self.device = device self.potential_fn.to(device) # type: ignore @@ -189,44 +218,163 @@ def to(self, device: Union[str, torch.device]) -> None: else: self.link_transform = self.theta_transform.inv - @property - def q(self) -> Distribution: - """Returns the variational posterior.""" - return self._q + return self - @q.setter - def q( + def _build_zuko_flow( self, - q: Union[str, Distribution, "VIPosterior", Callable], - ) -> None: - """Sets the variational distribution. If the distribution does not admit access - through `parameters` and `modules` function, please use `set_q` if you want to - explicitly specify the parameters and modules. + flow_type: str, + num_transforms: int = 5, + hidden_features: int = 50, + z_score_x: Literal[ + "none", "independent", "structured", "transform_to_unconstrained" + ] = "independent", + ) -> TransformedZukoFlow: + """Build a Zuko unconditional flow for variational inference. + + The flow is wrapped with TransformedZukoFlow to handle the transformation + between unconstrained (flow) space and constrained (prior) space. This ensures + that samples from the flow match the prior's support and log_prob accounts + for the Jacobian of the transformation. + + Args: + flow_type: Type of flow, one of ["maf", "nsf", "naf", "unaf", "nice", + "sospf"]. For "gaussian" or "gaussian_diag", use LearnableGaussian. + num_transforms: Number of flow transforms. + hidden_features: Number of hidden features per layer. + z_score_x: Method for z-scoring input. One of "independent", "structured", + or "none". Use "structured" for data with correlations (e.g., images). + + Returns: + TransformedZukoFlow: The constructed flow wrapped with link_transform. + + Raises: + ValueError: If flow_type is not supported. + """ + if flow_type in ("gaussian", "gaussian_diag"): + raise ValueError( + f"Flow type '{flow_type}' uses LearnableGaussian, not Zuko flows. " + f"This is handled automatically in set_q()." + ) + + if flow_type not in _ZUKO_FLOW_TYPES: + raise ValueError( + f"Unknown flow type '{flow_type}'. " + f"Supported types: {sorted(_ZUKO_FLOW_TYPES)} + " + f"['gaussian', 'gaussian_diag']." + ) + zuko_flow_type = flow_type.upper() + + # Get prior dimensionality + prior_dim = self._prior.event_shape[0] if self._prior.event_shape else 1 + + # Warn about 1D limitation + if prior_dim == 1: + warnings.warn( + f"Using {flow_type.upper()} flow for 1D parameter space. " + f"Normalizing flows may be unstable for 1D VI optimization. " + f"Consider using q='gaussian' for better results in 1D.", + UserWarning, + stacklevel=3, + ) + + # Sample from prior to get batch for dimensionality inference and z-scoring + # We apply link_transform.inv to map constrained prior samples to unconstrained + # space (link_transform.forward maps unconstrained -> constrained) + with torch.no_grad(): + prior_samples = self._prior.sample((1000,)) + batch_theta = self.link_transform.inv(prior_samples) + assert isinstance(batch_theta, Tensor) # Type narrowing for pyright + + flow = build_zuko_unconditional_flow( + which_nf=zuko_flow_type, + batch_x=batch_theta, + z_score_x=z_score_x, + hidden_features=hidden_features, + num_transforms=num_transforms, + ) + + # Wrap flow with link_transform to ensure samples are in constrained space + # The flow operates in unconstrained space, but we want samples/log_probs + # in constrained space (matching the prior's support) + transformed_flow = TransformedZukoFlow( + flow=flow.to(self._device), + link_transform=self.link_transform, + ) + + return transformed_flow.to(self._device) + + def _build_conditional_flow( + self, + theta: Tensor, + x: Tensor, + flow_type: Union[ZukoFlowType, str] = ZukoFlowType.NSF, + num_transforms: int = 2, + hidden_features: int = 32, + z_score_theta: Literal["none", "independent", "structured"] = "independent", + z_score_x: Literal["none", "independent", "structured"] = "independent", + ) -> ConditionalDensityEstimator: + """Build a conditional Zuko flow for amortized variational inference. Args: - q: Variational distribution, either string, distribution, or a VIPosterior - object. This specifies a parametric class of distribution over which - the best possible posterior approximation is searched. For string input, - we currently support [nsf, scf, maf, mcf, gaussian, gaussian_diag]. Of - course, you can also specify your own variational family by passing a - `parameterized` distribution object i.e. a torch.distributions - Distribution with methods `parameters` returning an iterable of all - parameters (you can pass them within the paramters/modules attribute). - Additionally, we allow a `Callable`, which allows you the pass a - `builder` function, which if called returns an distribution. This may be - useful for setting the hyperparameters e.g. `num_transfroms:int` by - using the `get_flow_builder` method specifying the hyperparameters. If q - is already a `VIPosterior`, then the arguments will be copied from it - (relevant for multi-round training). + theta: Sample of θ values for z-scoring (batch_size, θ_dim). + x: Sample of x values for z-scoring (batch_size, x_dim). + flow_type: Type of flow. Can be a ZukoFlowType enum or string. + num_transforms: Number of flow transforms. + hidden_features: Number of hidden features per layer. + z_score_theta: Method for z-scoring θ (the parameters being modeled). + One of "none", "independent", "structured". + z_score_x: Method for z-scoring x (the conditioning variable). + One of "none", "independent", "structured". Use "structured" for + structured data like images. + Returns: + ConditionalDensityEstimator: The constructed conditional flow q(θ|x). + Raises: + ValueError: If flow_type is not supported. + """ + # Convert string to ZukoFlowType if needed + if isinstance(flow_type, str): + try: + flow_type = ZukoFlowType[flow_type.upper()] + except KeyError as e: + raise ValueError( + f"Unknown flow type '{flow_type}'. " + f"Supported types: {[t.name for t in ZukoFlowType]}." + ) from e + + return build_zuko_flow( + flow_type.value.upper(), + batch_x=theta, # θ is what we model + batch_y=x, # x is the condition + z_score_x=z_score_theta, # z-score for θ (naming mismatch) + z_score_y=z_score_x, # z-score for x condition + num_transforms=num_transforms, + hidden_features=hidden_features, + ).to(self._device) + + @property + def q( + self, + ) -> Union[ + Distribution, ZukoUnconditionalFlow, TransformedZukoFlow, LearnableGaussian + ]: + """Returns the variational posterior.""" + return self._q + + @q.setter + def q(self, q: QType) -> None: + """Sets the variational distribution. + + If the distribution does not admit access through `parameters` and `modules` + function, please use `set_q` to explicitly specify the parameters and modules. """ self.set_q(q) def set_q( self, - q: Union[str, PyroTransformedDistribution, "VIPosterior", Callable], + q: QType, parameters: Optional[Iterable] = None, modules: Optional[Iterable] = None, ) -> None: @@ -242,17 +390,20 @@ def set_q( q: Variational distribution, either string, distribution, or a VIPosterior object. This specifies a parametric class of distribution over which the best possible posterior approximation is searched. For string input, - we currently support [nsf, scf, maf, mcf, gaussian, gaussian_diag]. Of - course, you can also specify your own variational family by passing a + we support normalizing flows [maf, nsf, naf, unaf, nice, sospf] via + Zuko, and simple Gaussian families [gaussian, gaussian_diag] via pure + PyTorch. You can also specify your own variational family by passing a `parameterized` distribution object i.e. a torch.distributions Distribution with methods `parameters` returning an iterable of all - parameters (you can pass them within the paramters/modules attribute). - Additionally, we allow a `Callable`, which allows you the pass a - `builder` function, which if called returns an distribution. This may be - useful for setting the hyperparameters e.g. `num_transfroms:int` by - using the `get_flow_builder` method specifying the hyperparameters. If q - is already a `VIPosterior`, then the arguments will be copied from it - (relevant for multi-round training). + parameters (you can pass them within the parameters/modules attribute). + Additionally, we allow a `Callable` with signature + `(event_shape: torch.Size, link_transform: TorchTransform, device: str) + -> Distribution`, which builds a custom distribution. If q is already + a `VIPosterior`, then the arguments will be copied from it (relevant + for multi-round training). + + Note: For 1D parameter spaces, normalizing flows may be unstable. + Consider using `q='gaussian'` for 1D problems. parameters: List of parameters associated with the distribution object. modules: List of modules associated with the distribution object. @@ -262,7 +413,12 @@ def set_q( if modules is None: modules = [] self._q_arg = (q, parameters, modules) - if isinstance(q, Distribution): + _flow_types = (ZukoUnconditionalFlow, TransformedZukoFlow, LearnableGaussian) + if isinstance(q, _flow_types): + # Flow/Gaussian passed directly (e.g., from _q_build_fn during retrain) + make_object_deepcopy_compatible(q) + self._trained_on = None + elif isinstance(q, Distribution): q = adapt_variational_distribution( q, self._prior, @@ -274,22 +430,57 @@ def set_q( self_custom_q_init_cache = deepcopy(q) self._q_build_fn = lambda *args, **kwargs: self_custom_q_init_cache self._trained_on = None + self._zuko_flow_type = None elif isinstance(q, (str, Callable)): if isinstance(q, str): - self._q_build_fn = get_flow_builder(q) + if q in _ZUKO_FLOW_TYPES: + q_flow = self._build_zuko_flow(q) + self._zuko_flow_type = q + self._q_build_fn = lambda *args, ft=q, **kwargs: ( + self._build_zuko_flow(ft) + ) + q = q_flow + elif q in ("gaussian", "gaussian_diag"): + self._zuko_flow_type = None + full_cov = q == "gaussian" + dim = self._prior.event_shape[0] + q_dist = LearnableGaussian( + dim=dim, + full_covariance=full_cov, + link_transform=self.link_transform, + device=self._device, + ) + self._q_build_fn = lambda *args, fc=full_cov, d=dim, **kwargs: ( + LearnableGaussian( + dim=d, + full_covariance=fc, + link_transform=self.link_transform, + device=self._device, + ) + ) + q = q_dist + else: + supported = sorted(_ZUKO_FLOW_TYPES) + ["gaussian", "gaussian_diag"] + raise ValueError( + f"Unknown variational family '{q}'. " + f"Supported options: {supported}" + ) else: + # Callable provided - use as-is + self._zuko_flow_type = None self._q_build_fn = q - - q = self._q_build_fn( - self._prior.event_shape, - self.link_transform, - device=self._device, - ) + q = self._q_build_fn( + self._prior.event_shape, + self.link_transform, + device=self._device, + ) make_object_deepcopy_compatible(q) self._trained_on = None elif isinstance(q, VIPosterior): self._q_build_fn = q._q_build_fn self._trained_on = q._trained_on + self._mode = getattr(q, "_mode", None) # Copy mode from source + self._zuko_flow_type = getattr(q, "_zuko_flow_type", None) self.vi_method = q.vi_method # type: ignore self._device = q._device self._prior = q._prior @@ -298,11 +489,16 @@ def set_q( make_object_deepcopy_compatible(q.q) q = deepcopy(q.q) move_all_tensor_to_device(q, self._device) - assert isinstance( - q, Distribution - ), """Something went wrong when initializing the variational distribution. - Please create an issue on github https://github.com/mackelab/sbi/issues""" - check_variational_distribution(q, self._prior) + # Validate the variational distribution + if isinstance(q, _flow_types): + pass # These are validated during construction + elif isinstance(q, Distribution): + check_variational_distribution(q, self._prior) + else: + raise ValueError( + f"Variational distribution must be a Distribution, got {type(q)}. " + "Please create an issue on github https://github.com/mackelab/sbi/issues" + ) self._q = q @property @@ -336,24 +532,52 @@ def sample( ) -> Tensor: r"""Draw samples from the variational posterior distribution $p(\theta|x)$. + For single-x mode (trained via `train()`): samples from q(θ) trained on x_o. + For amortized mode (trained via `train_amortized()`): samples from q(θ|x). + Args: sample_shape: Desired shape of samples that are drawn from the posterior. - x: Conditioning observation $x_o$. If not provided, uses the default `x` - set via `.set_default_x()`. + x: Conditioning observation. In single-x mode, must match trained x_o + (or be None to use default). In amortized mode, required and can be + any observation. For batched observations, shape should be + (batch_size, x_dim). show_progress_bars: Unused for `VIPosterior` since sampling from the variational distribution is fast. Included for API consistency. Returns: - Samples from posterior. + Samples from posterior with shape (*sample_shape, θ_dim) for single x, + or (*sample_shape, batch_size, θ_dim) for batched observations in + amortized mode. + + Raises: + ValueError: If mode requirements are not met. """ - x = self._x_else_default_x(x) - if self._trained_on is None or (x != self._trained_on).all(): - raise AttributeError( - f"The variational posterior was not fit on the specified `default_x` " - f"{x}. Please train using `posterior.train()`." - ) - samples = self.q.sample(torch.Size(sample_shape)) - return samples.reshape((*sample_shape, samples.shape[-1])) + if self._mode == "amortized": + # Amortized mode: sample from conditional flow q(θ|x) + x = self._x_else_default_x(x) + if x is None: + raise ValueError( + "x is required for amortized mode. Provide an observation or " + "set a default x with set_default_x()." + ) + x = atleast_2d_float32_tensor(x).to(self._device) + assert self._amortized_q is not None + # samples shape from flow: (*sample_shape, batch_size, θ_dim) + samples = self._amortized_q.sample(torch.Size(sample_shape), condition=x) + # Match base posterior behavior: drop singleton x batch dimension + if x.shape[0] == 1: + samples = samples.squeeze(-2) + return samples + else: + # Single-x mode: sample from unconditional flow q(θ) + x = self._x_else_default_x(x) + if self._trained_on is None or (x != self._trained_on).any(): + raise ValueError( + f"The variational posterior was not fit on the specified " + f"observation {x}. Please train using posterior.train()." + ) + samples = self.q.sample(torch.Size(sample_shape)) + return samples.reshape((*sample_shape, samples.shape[-1])) def sample_batched( self, @@ -362,11 +586,35 @@ def sample_batched( max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: - raise NotImplementedError( - "Batched sampling is not implemented for VIPosterior. " - "Alternatively you can use `sample` in a loop " - "[posterior.sample(theta, x_o) for x_o in x]." - ) + """Sample from posterior for a batch of observations. + + In amortized mode, this is efficient as all x values are processed in + parallel through the conditional flow. + + In single-x mode, this raises NotImplementedError since the unconditional + flow is trained for a specific x_o. + + Args: + sample_shape: Number of samples per observation. + x: Batch of observations (num_obs, x_dim). + max_sampling_batch_size: Unused for amortized mode (no batching needed). + show_progress_bars: Unused for amortized mode. + + Returns: + Samples of shape (*sample_shape, num_obs, θ_dim). + + Raises: + NotImplementedError: If called in single-x mode. + """ + if self._mode == "amortized": + # In amortized mode, sample() handles batched x directly + return self.sample(sample_shape, x=x, show_progress_bars=show_progress_bars) + else: + raise NotImplementedError( + "Batched sampling is not implemented for single-x VI mode. " + "Use train_amortized() to train an amortized posterior, or " + "call sample() in a loop: [posterior.sample(shape, x_o) for x_o in x]." + ) def log_prob( self, @@ -376,24 +624,75 @@ def log_prob( ) -> Tensor: r"""Returns the log-probability of theta under the variational posterior. + For single-x mode: returns log q(θ). + For amortized mode: returns log q(θ|x). + Args: - theta: Parameters + theta: Parameters to evaluate, shape (batch_theta, θ_dim). + x: Observation. In single-x mode, must match trained x_o (or be None). + In amortized mode, required and can be any observation. + For single x, shape (1, x_dim) or (x_dim,). + For batched x, shape (batch_x, x_dim). track_gradients: Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis but increases memory consumption. Returns: - `len($\theta$)`-shaped log-probability. + Log-probability of shape (batch,) where batch is: + - batch_theta if x has batch size 1 (broadcast x) + - batch_x if theta has batch size 1 (broadcast theta) + - batch_theta if batch_theta == batch_x (paired evaluation) + + Raises: + ValueError: If mode requirements are not met or batch sizes incompatible. """ - x = self._x_else_default_x(x) - if self._trained_on is None or (x != self._trained_on).all(): - raise AttributeError( - f"The variational posterior was not fit using observation {x}.\ - Please train." - ) with torch.set_grad_enabled(track_gradients): - theta = ensure_theta_batched(torch.as_tensor(theta)) - return self.q.log_prob(theta) + theta = ensure_theta_batched(torch.as_tensor(theta)).to(self._device) + + if self._mode == "amortized": + # Amortized mode: evaluate log q(θ|x) + x = self._x_else_default_x(x) + if x is None: + raise ValueError( + "x is required for amortized mode. Provide an observation or " + "set a default x with set_default_x()." + ) + x = atleast_2d_float32_tensor(x).to(self._device) + assert self._amortized_q is not None + + # Handle broadcasting between theta and x + batch_theta = theta.shape[0] + batch_x = x.shape[0] + + if batch_theta != batch_x: + if batch_x == 1: + # Broadcast x to match theta + x = x.expand(batch_theta, -1) + elif batch_theta == 1: + # Broadcast theta to match x + theta = theta.expand(batch_x, -1) + else: + raise ValueError( + f"Batch sizes of theta ({batch_theta}) and x ({batch_x}) " + f"are incompatible. They must be equal, or one must be 1." + ) + + # ZukoFlow expects input shape (sample_dim, batch_dim, *event_shape) + # Add sample dimension, compute log_prob, then squeeze back + theta_with_sample_dim = theta.unsqueeze(0) + log_probs = self._amortized_q.log_prob( + theta_with_sample_dim, condition=x + ) + return log_probs.squeeze(0) + else: + # Single-x mode: evaluate log q(θ) + x = self._x_else_default_x(x) + if self._trained_on is None or (x != self._trained_on).any(): + raise ValueError( + f"The variational posterior was not fit on the specified " + f"observation {x}. Please train using posterior.train()." + ) + return self.q.log_prob(theta) def train( self, @@ -413,10 +712,10 @@ def train( quality_control_metric: str = "psis", **kwargs, ) -> "VIPosterior": - """This method trains the variational posterior. + """This method trains the variational posterior for a single observation. Args: - x: The observation. + x: The observation, optional, defaults to self._x. n_particles: Number of samples to approximate expectations within the variational bounds. The larger the more accurate are gradient estimates, but the computational cost per iteration increases. @@ -450,7 +749,24 @@ def train( weight_transform: Callable applied to importance weights (only for fKL) Returns: VIPosterior: `VIPosterior` (can be used to chain calls). + + Raises: + ValueError: If hyperparameters are invalid. """ + # Validate hyperparameters + if n_particles <= 0: + raise ValueError(f"n_particles must be positive, got {n_particles}") + if learning_rate <= 0: + raise ValueError(f"learning_rate must be positive, got {learning_rate}") + if not 0 < gamma <= 1: + raise ValueError(f"gamma must be in (0, 1], got {gamma}") + if max_num_iters <= 0: + raise ValueError(f"max_num_iters must be positive, got {max_num_iters}") + if min_num_iters < 0: + raise ValueError(f"min_num_iters must be non-negative, got {min_num_iters}") + if clip_value <= 0: + raise ValueError(f"clip_value must be positive, got {clip_value}") + # Update optimizer with current arguments. if self._optimizer is not None: self._optimizer.update({**locals(), **kwargs}) @@ -489,6 +805,8 @@ def train( x = atleast_2d_float32_tensor(self._x_else_default_x(x)).to( # type: ignore self._device ) + if not torch.isfinite(x).all(): + raise ValueError("x contains NaN or Inf values.") already_trained = self._trained_on is not None and (x == self._trained_on).all() @@ -517,7 +835,7 @@ def train( if show_progress_bar: assert isinstance(iters, tqdm) iters.set_description( # type: ignore - f"Loss: {np.round(float(mean_loss), 2)}" + f"Loss: {np.round(float(mean_loss), 2)}, " f"Std: {np.round(float(std_loss), 2)}" ) # Check for convergence @@ -527,6 +845,15 @@ def train( break # Training finished: self._trained_on = x + if self._mode == "amortized": + warnings.warn( + "Switching from amortized to single-x mode. " + "The previously trained amortized model will be discarded.", + UserWarning, + stacklevel=2, + ) + self._amortized_q = None + self._mode = "single_x" # Evaluate quality if quality_control: @@ -549,6 +876,311 @@ def train( return self + def train_amortized( + self, + theta: Tensor, + x: Tensor, + n_particles: int = 128, + learning_rate: float = 1e-3, + gamma: float = 0.999, + max_num_iters: int = 500, + clip_value: float = 5.0, + batch_size: int = 64, + validation_fraction: float = 0.1, + validation_batch_size: Optional[int] = None, + validation_n_particles: Optional[int] = None, + stop_after_iters: int = 20, + show_progress_bar: bool = True, + retrain_from_scratch: bool = False, + flow_type: Union[ZukoFlowType, str] = ZukoFlowType.NSF, + num_transforms: int = 2, + hidden_features: int = 32, + z_score_theta: Literal["none", "independent", "structured"] = "independent", + z_score_x: Literal["none", "independent", "structured"] = "independent", + params: Optional["VIPosteriorParameters"] = None, + ) -> "VIPosterior": + """Train a conditional flow q(θ|x) for amortized variational inference. + + This allows sampling from q(θ|x) for any observation x without retraining. + Uses the ELBO (Evidence Lower Bound) objective with early stopping based on + validation loss. + + Args: + theta: Training θ values from simulations (num_sims, θ_dim). + x: Training x values from simulations (num_sims, x_dim). + n_particles: Number of samples to estimate ELBO per x. + learning_rate: Learning rate for Adam optimizer. + gamma: Learning rate decay per iteration. + max_num_iters: Maximum training iterations. + clip_value: Gradient clipping threshold. + batch_size: Number of x values per training batch. + validation_fraction: Fraction of data to use for validation. + validation_batch_size: Batch size for validation loss. Defaults to + `batch_size`. + validation_n_particles: Number of particles for validation loss. + Defaults to `n_particles`. + stop_after_iters: Stop training after this many iterations without + improvement in validation loss. + show_progress_bar: Whether to show progress. + retrain_from_scratch: If True, rebuild the flow from scratch. + flow_type: Flow architecture for the variational distribution. + Use ZukoFlowType.NSF, ZukoFlowType.MAF, etc., or a string. + num_transforms: Number of transforms in the flow. + hidden_features: Hidden layer size in the flow. + z_score_theta: Method for z-scoring θ (the parameters being modeled). + One of "none", "independent", "structured". + z_score_x: Method for z-scoring x (the conditioning variable). + One of "none", "independent", "structured". Use "structured" for + structured data like images with spatial correlations. + params: Optional VIPosteriorParameters dataclass. If provided, its values + for q (as flow_type), num_transforms, hidden_features, z_score_theta, + and z_score_x override the individual arguments. + + Returns: + self for method chaining. + """ + # Extract parameters from dataclass if provided + if params is not None: + # Amortized VI only supports string flow types (not VIPosterior or Callable) + if not isinstance(params.q, str): + raise ValueError( + "train_amortized() only supports string flow types " + f"(e.g., 'nsf', 'maf'), not {type(params.q).__name__}. " + "Use set_q() to pass custom distributions for single-x VI." + ) + flow_type = params.q + num_transforms = params.num_transforms + hidden_features = params.hidden_features + z_score_theta = params.z_score_theta + z_score_x = params.z_score_x + + theta = atleast_2d_float32_tensor(theta).to(self._device) + x = atleast_2d_float32_tensor(x).to(self._device) + + # Validate inputs + if theta.shape[0] != x.shape[0]: + raise ValueError( + f"Batch size mismatch: theta has {theta.shape[0]} samples, " + f"x has {x.shape[0]} samples. They must match." + ) + if len(theta) == 0: + raise ValueError("Training data cannot be empty.") + if not torch.isfinite(theta).all(): + raise ValueError("theta contains NaN or Inf values.") + if not torch.isfinite(x).all(): + raise ValueError("x contains NaN or Inf values.") + + # Validate theta dimension matches prior + prior_event_shape = self._prior.event_shape + if len(prior_event_shape) > 0: + expected_theta_dim = prior_event_shape[0] + if theta.shape[1] != expected_theta_dim: + raise ValueError( + f"theta dimension {theta.shape[1]} does not match prior " + f"event shape {expected_theta_dim}." + ) + + # Validate hyperparameters + if not 0 < validation_fraction < 1: + raise ValueError( + f"validation_fraction must be in (0, 1), got {validation_fraction}" + ) + if n_particles <= 0: + raise ValueError(f"n_particles must be positive, got {n_particles}") + if batch_size <= 0: + raise ValueError(f"batch_size must be positive, got {batch_size}") + + # Validate flow_type early to fail fast + if isinstance(flow_type, str): + try: + flow_type = ZukoFlowType[flow_type.upper()] + except KeyError: + raise ValueError( + f"Unknown flow type '{flow_type}'. " + f"Supported types: {[t.name for t in ZukoFlowType]}." + ) from None + + if validation_batch_size is None: + validation_batch_size = batch_size + if validation_n_particles is None: + validation_n_particles = n_particles + + if validation_batch_size <= 0: + raise ValueError( + f"validation_batch_size must be positive, got {validation_batch_size}" + ) + if validation_n_particles <= 0: + raise ValueError( + f"validation_n_particles must be positive, got {validation_n_particles}" + ) + + # Split into training and validation sets + num_examples = len(theta) + num_val = int(validation_fraction * num_examples) + num_train = num_examples - num_val + + if num_val == 0: + raise ValueError( + "Validation set is empty. Increase validation_fraction or provide more " + "training data." + ) + if num_train < batch_size: + raise ValueError( + f"Training set size ({num_train}) is smaller than batch_size " + f"({batch_size}). Reduce validation_fraction or batch_size." + ) + + permuted_indices = torch.randperm(num_examples, device=self._device) + train_indices = permuted_indices[:num_train] + val_indices = permuted_indices[num_train:] + + theta_train, x_train = theta[train_indices], x[train_indices] + x_val = x[val_indices] # Only x needed for validation (θ sampled from q) + + if validation_batch_size < x_val.shape[0]: + val_batch_indices = torch.randperm(x_val.shape[0], device=self._device)[ + :validation_batch_size + ] + else: + val_batch_indices = None + + # Build or rebuild the conditional flow (z-score on training data only) + if self._amortized_q is None or retrain_from_scratch: + self._amortized_q = self._build_conditional_flow( + theta_train, + x_train, + flow_type=flow_type, + num_transforms=num_transforms, + hidden_features=hidden_features, + z_score_theta=z_score_theta, + z_score_x=z_score_x, + ) + + # Setup optimizer + optimizer = Adam(self._amortized_q.parameters(), lr=learning_rate) + scheduler = ExponentialLR(optimizer, gamma=gamma) + + # Training loop with validation-based early stopping + best_val_loss = float("inf") + iters_since_improvement = 0 + best_state_dict = deepcopy(self._amortized_q.state_dict()) + + if show_progress_bar: + iters = tqdm(range(max_num_iters), desc="Amortized VI (ELBO)") + else: + iters = range(max_num_iters) + + for iteration in iters: + # Training step + self._amortized_q.train() + optimizer.zero_grad() + + # Sample batch from training set + idx = torch.randint(0, num_train, (batch_size,), device=self._device) + x_batch = x_train[idx] + + train_loss = self._compute_amortized_elbo_loss(x_batch, n_particles) + + if not torch.isfinite(train_loss): + raise RuntimeError( + f"Training loss became non-finite at iteration {iteration}: " + f"{train_loss.item()}. This indicates numerical instability. Try:\n" + f" - Reducing learning_rate (currently {learning_rate})\n" + f" - Reducing n_particles (currently {n_particles})\n" + f" - Checking your potential_fn for numerical issues" + ) + + train_loss.backward() + nn.utils.clip_grad_norm_(self._amortized_q.parameters(), clip_value) + optimizer.step() + scheduler.step() + + # Compute validation loss + self._amortized_q.eval() + with torch.no_grad(): + if val_batch_indices is None: + x_val_batch = x_val + else: + x_val_batch = x_val[val_batch_indices] + val_loss = self._compute_amortized_elbo_loss( + x_val_batch, validation_n_particles + ).item() + + # Check for improvement + if val_loss < best_val_loss: + best_val_loss = val_loss + iters_since_improvement = 0 + best_state_dict = deepcopy(self._amortized_q.state_dict()) + else: + iters_since_improvement += 1 + + if show_progress_bar: + assert isinstance(iters, tqdm) + iters.set_postfix({ + "train": f"{train_loss.item():.3f}", + "val": f"{val_loss:.3f}", + }) + + # Early stopping + if iters_since_improvement >= stop_after_iters: + if show_progress_bar: + print(f"\nConverged at iteration {iteration}") + break + + # Restore best model + self._amortized_q.load_state_dict(best_state_dict) + self._amortized_q.eval() + if self._mode == "single_x": + warnings.warn( + "Switching from single-x to amortized mode. " + "The previously trained single-x model will not be usable.", + UserWarning, + stacklevel=2, + ) + self._mode = "amortized" + + return self + + def _compute_amortized_elbo_loss(self, x_batch: Tensor, n_particles: int) -> Tensor: + """Compute negative ELBO loss for a batch of x values. + + Args: + x_batch: Batch of observations (batch_size, x_dim). + n_particles: Number of θ samples per x. + + Returns: + Negative ELBO (scalar tensor). + """ + assert self._amortized_q is not None, "q must be built before computing ELBO" + batch_size = x_batch.shape[0] + + # Reparameterized samples from q(θ|x) with their log probabilities + # theta_samples shape: (n_particles, batch_size, θ_dim) + # log_q shape: (n_particles, batch_size) + theta_samples, log_q = self._amortized_q.sample_and_log_prob( + torch.Size((n_particles,)), condition=x_batch + ) + + # Vectorized evaluation of potential log p(θ|x) for all (θ, x) pairs + # Flatten: (n_particles, batch_size, θ_dim) -> (n_particles * batch_size, θ_dim) + theta_dim = theta_samples.shape[-1] + theta_flat = theta_samples.reshape(n_particles * batch_size, theta_dim) + + # Repeat x to match: (batch_size, x_dim) -> (n_particles * batch_size, x_dim) + # Each x[j] is repeated n_particles times to pair with theta[:, j, :] + x_expanded = x_batch.repeat(n_particles, 1) + + # Set x_o for batched evaluation (x_is_iid=False: each θ paired with its x) + self.potential_fn.set_x(x_expanded, x_is_iid=False) + log_potential_flat = self.potential_fn(theta_flat) + + # Reshape: (n_particles * batch_size,) -> (n_particles, batch_size) + log_potential = log_potential_flat.reshape(n_particles, batch_size) + + # ELBO = E_q[log p(θ|x) - log q(θ|x)] + elbo = (log_potential - log_q).mean() + return -elbo + def evaluate(self, quality_control_metric: str = "psis", N: int = int(5e4)) -> None: """This function will evaluate the quality of the variational posterior distribution. We currently support two different metrics of type `psis`, which @@ -654,6 +1286,7 @@ def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior": """ if memo is None: memo = {} + # Create a new instance of the class cls = self.__class__ result = cls.__new__(cls) @@ -668,7 +1301,7 @@ def __getstate__(self) -> Dict: """This method is called when pickling the object. It defines what is pickled. We need to overwrite this method, since some parts - due not support pickle protocols (e.g. due to local functions, etc.). + do not support pickle protocols (e.g. due to local functions). Returns: Dict: All attributes of the VIPosterior. @@ -677,7 +1310,8 @@ def __getstate__(self) -> Dict: self.__deepcopy__ = None # type: ignore self._q_build_fn = None self._q.__deepcopy__ = None # type: ignore - return self.__dict__ + state = self.__dict__.copy() + return state def __setstate__(self, state_dict: Dict): """This method is called when unpickling the object. @@ -695,3 +1329,6 @@ def __setstate__(self, state_dict: Dict): self._q = q make_object_deepcopy_compatible(self) make_object_deepcopy_compatible(self.q) + # Handle amortized mode + if self._mode == "amortized" and self._amortized_q is not None: + make_object_deepcopy_compatible(self._amortized_q) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index dd60624bc..85a9bea5e 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -62,9 +62,11 @@ class ZukoFlowType(Enum): """Enumeration of Zuko flow types.""" BPF = "bpf" + GF = "gf" MAF = "maf" NAF = "naf" NCSF = "ncsf" + NICE = "nice" NSF = "nsf" SOSPF = "sospf" UNAF = "unaf" diff --git a/sbi/neural_nets/net_builders/flow.py b/sbi/neural_nets/net_builders/flow.py index 9014a266b..9c6457656 100644 --- a/sbi/neural_nets/net_builders/flow.py +++ b/sbi/neural_nets/net_builders/flow.py @@ -1272,6 +1272,8 @@ def build_zuko_unconditional_flow( *base_transforms, standardizing_transform_zuko(batch_x, structured_x), ) + else: + transforms = base_transforms # Combine transforms. neural_net = zuko.flows.Flow(transforms, base) @@ -1378,3 +1380,225 @@ def get_base_dist( base = distributions_.StandardNormal((num_dims,)) base._log_z = base._log_z.to(dtype) return base + + +def build_zuko_vi_flow( + event_shape: torch.Size, + link_transform: Optional["zuko.transforms.Transform"] = None, + flow_type: str = "nsf", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + **kwargs, +) -> Flow: + """Build an unconditional Zuko normalizing flow for variational inference. + + This function creates a Zuko flow suitable for use with VI training. The flow + maps from a simple base distribution (standard normal) to a more complex + distribution that approximates the posterior. + + Args: + event_shape: Shape of the events generated by the distribution. For 1D + parameters, this is typically torch.Size([dim]). + link_transform: Optional bijective transform that constrains samples to + a specific support (e.g., transforms to positive reals or bounded + intervals). Applied as the final transform in the flow. + flow_type: The type of normalizing flow to build. Supported options: + - "nsf": Neural Spline Flow (default, flexible and expressive) + - "maf": Masked Autoregressive Flow (fast density evaluation) + - "gaussian": Full covariance Gaussian (single affine transform) + - "gaussian_diag": Diagonal covariance Gaussian + hidden_features: The number of hidden features in each transform layer. + Can be an int (same for all layers) or sequence. Defaults to 50. + num_transforms: The number of transform layers. Defaults to 5. Ignored + for gaussian and gaussian_diag flow types. + **kwargs: Additional keyword arguments passed to the Zuko flow constructor. + Common options include `randperm` (bool) for permutation between layers. + + Returns: + A Zuko Flow object that can be used for VI training. The flow has + `log_prob()` and `rsample()` methods through its distribution interface. + + Raises: + ValueError: If an unsupported flow_type is specified. + + Example: + >>> import torch + >>> from sbi.neural_nets.net_builders.flow import build_zuko_vi_flow + >>> flow = build_zuko_vi_flow( + ... event_shape=torch.Size([2]), + ... flow_type="nsf", + ... num_transforms=3, + ... ) + >>> dist = flow() + >>> samples = dist.rsample((100,)) # Shape: (100, 2) + """ + # Convert event_shape to number of features + if len(event_shape) != 1: + raise ValueError( + f"event_shape must be 1D, got {event_shape}. " + "Multi-dimensional event shapes are not yet supported for VI flows." + ) + features = event_shape[0] + + # Handle hidden_features as sequence + if isinstance(hidden_features, int): + hidden_features_list = [hidden_features] * num_transforms + else: + hidden_features_list = list(hidden_features) + + # Keep only zuko-compatible kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in nflow_specific_kwargs} + + # Build the base flow based on flow_type + flow_type_lower = flow_type.lower() + + if flow_type_lower == "nsf": + base_flow = zuko.flows.NSF( + features=features, + context=0, # Unconditional + hidden_features=hidden_features_list, + transforms=num_transforms, + **kwargs, + ) + elif flow_type_lower == "maf": + base_flow = zuko.flows.MAF( + features=features, + context=0, + hidden_features=hidden_features_list, + transforms=num_transforms, + **kwargs, + ) + elif flow_type_lower in ("gaussian", "gaussian_diag"): + # For Gaussian distributions, we create a simple affine transform + # The base is a standard normal, and we learn location and scale + base_flow = _build_zuko_gaussian_flow( + features=features, + diagonal=(flow_type_lower == "gaussian_diag"), + ) + else: + supported = ["nsf", "maf", "gaussian", "gaussian_diag"] + raise ValueError( + f"Unsupported flow_type '{flow_type}'. Supported types: {supported}" + ) + + # Apply link transform if provided + if link_transform is not None: + # Compose the base flow transforms with the link transform + # The link transform maps from unconstrained space to the target support + transforms = (*base_flow.transform.transforms, link_transform) + neural_net = zuko.flows.Flow(transforms, base_flow.base) + return neural_net + + return base_flow + + +def _build_zuko_gaussian_flow( + features: int, + diagonal: bool = False, +) -> Flow: + """Build a simple Gaussian flow (affine transform on standard normal). + + Args: + features: Number of features/dimensions. + diagonal: If True, use diagonal covariance (scale). If False, use + full covariance (lower triangular). + + Returns: + A Zuko Flow representing a Gaussian distribution. + """ + # Create learnable location and scale parameters + loc = torch.nn.Parameter(torch.zeros(features)) + + if diagonal: + # Diagonal scale (log-space for positivity) + log_scale = torch.nn.Parameter(torch.zeros(features)) + + class DiagAffineTransform(zuko.transforms.Transform): + """Diagonal affine transform: y = loc + exp(log_scale) * x.""" + + domain = zuko.transforms.constraints.real_vector + codomain = zuko.transforms.constraints.real_vector + bijective = True + + def __init__(self, loc: torch.nn.Parameter, log_scale: torch.nn.Parameter): + super().__init__() + self.loc = loc + self.log_scale = log_scale + + def _call(self, x: torch.Tensor) -> torch.Tensor: + return self.loc + torch.exp(self.log_scale) * x + + def _inverse(self, y: torch.Tensor) -> torch.Tensor: + return (y - self.loc) / torch.exp(self.log_scale) + + def log_abs_det_jacobian( + self, x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + return self.log_scale.sum() + + transform = DiagAffineTransform(loc, log_scale) + + else: + # Lower triangular scale matrix (with positive diagonal via exp) + # We parameterize: L = tril(raw_L) with diag(L) = exp(diag(raw_L)) + tril_indices = torch.tril_indices(features, features) + num_tril = len(tril_indices[0]) + raw_tril = torch.nn.Parameter(torch.zeros(num_tril)) + + class TriLAffineTransform(zuko.transforms.Transform): + """Lower triangular affine transform: y = loc + L @ x.""" + + domain = zuko.transforms.constraints.real_vector + codomain = zuko.transforms.constraints.real_vector + bijective = True + + def __init__( + self, + loc: torch.nn.Parameter, + raw_tril: torch.nn.Parameter, + tril_indices: torch.Tensor, + features: int, + ): + super().__init__() + self.loc = loc + self.raw_tril = raw_tril + self.register_buffer("_tril_indices", tril_indices) + self._features = features + + def _get_scale_tril(self) -> torch.Tensor: + """Construct the lower triangular matrix from parameters.""" + L = torch.zeros( + self._features, + self._features, + device=self.raw_tril.device, + dtype=self.raw_tril.dtype, + ) + L[self._tril_indices[0], self._tril_indices[1]] = self.raw_tril + # Make diagonal positive via exp + diag_indices = torch.arange(self._features, device=L.device) + L[diag_indices, diag_indices] = torch.exp(L[diag_indices, diag_indices]) + return L + + def _call(self, x: torch.Tensor) -> torch.Tensor: + L = self._get_scale_tril() + return self.loc + x @ L.T + + def _inverse(self, y: torch.Tensor) -> torch.Tensor: + L = self._get_scale_tril() + return torch.linalg.solve_triangular( + L, (y - self.loc).unsqueeze(-1), upper=False + ).squeeze(-1) + + def log_abs_det_jacobian( + self, x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + L = self._get_scale_tril() + # Log det of lower triangular = sum of log of diagonal + return torch.log(torch.diag(L)).sum() + + transform = TriLAffineTransform(loc, raw_tril, tril_indices, features) + + # Standard normal base distribution + base = zuko.distributions.DiagNormal(torch.zeros(features), torch.ones(features)) + + return zuko.flows.Flow((transform,), base) diff --git a/sbi/samplers/vi/__init__.py b/sbi/samplers/vi/__init__.py index 0da81986f..1451683c3 100644 --- a/sbi/samplers/vi/__init__.py +++ b/sbi/samplers/vi/__init__.py @@ -5,5 +5,5 @@ get_VI_method, get_default_VI_method, ) -from sbi.samplers.vi.vi_pyro_flows import get_default_flows, get_flow_builder from sbi.samplers.vi.vi_quality_control import get_quality_metric +from sbi.samplers.vi.vi_utils import LearnableGaussian, TransformedZukoFlow diff --git a/sbi/samplers/vi/vi_divergence_optimizers.py b/sbi/samplers/vi/vi_divergence_optimizers.py index a8d1127d6..99047be5f 100644 --- a/sbi/samplers/vi/vi_divergence_optimizers.py +++ b/sbi/samplers/vi/vi_divergence_optimizers.py @@ -28,14 +28,24 @@ from torch.optim.rmsprop import RMSprop from torch.optim.sgd import SGD +from sbi.neural_nets.estimators import ZukoUnconditionalFlow from sbi.samplers.vi.vi_utils import ( + LearnableGaussian, + TransformedZukoFlow, filter_kwrags_for_func, make_object_deepcopy_compatible, move_all_tensor_to_device, ) -from sbi.sbi_types import Array, PyroTransformedDistribution +from sbi.sbi_types import Array from sbi.utils.user_input_checks import check_prior +# Type alias for variational distributions (Zuko-based flows and LearnableGaussian) +VariationalDistribution = Union[ + ZukoUnconditionalFlow, + TransformedZukoFlow, + LearnableGaussian, +] + _VI_method = {} @@ -50,7 +60,7 @@ class DivergenceOptimizer(ABC): def __init__( self, potential_fn: 'BasePotential', # noqa: F821 # type: ignore - q: PyroTransformedDistribution, + q: VariationalDistribution, prior: Optional[Distribution] = None, n_particles: int = 256, clip_value: float = 5.0, @@ -76,7 +86,8 @@ def __init__( Args: potential_fn: Potential function of the target i.e. the posterior density up to normalization constant. - q: Variational distribution + q: Variational distribution. Must be a Zuko-based flow + (ZukoUnconditionalFlow, TransformedZukoFlow) or LearnableGaussian. prior: Prior distribution, which will be used within the warmup, if given. Note that this will not affect the potential_fn, so make sure to have the same prior within it. @@ -109,16 +120,11 @@ def __init__( self.retain_graph = kwargs.get("retain_graph", False) self._kwargs = kwargs - # This prevents error that would stop optimization. - self.q.set_default_validate_args(False) if prior is not None: self.prior.set_default_validate_args(False) # type: ignore - # Manage modules if present. - if hasattr(self.q, "modules"): - self.modules = nn.ModuleList(self.q.modules()) - else: - self.modules = nn.ModuleList() + # Manage modules - all supported distributions are nn.Module-based + self.modules = nn.ModuleList([self.q]) self.modules.train() # Ensure that distribution has parameters and that these are on the right device @@ -167,37 +173,36 @@ def to(self, device: str) -> None: move_all_tensor_to_device(self.prior, self.device) def warm_up(self, num_steps: int, method: str = "prior") -> None: - """This initializes q, either to follow the prior or the base distribution - of the flow. This can increase training stability. + """This initializes q to follow the prior. This can increase training stability. Args: num_steps: Number of steps to train. - method: Method for warmup. + method: Method for warmup. Only "prior" is supported. + + Raises: + NotImplementedError: If an unsupported method is specified. """ - if method == "prior" and self.prior is not None: - inital_target = self.prior - elif method == "identity": - inital_target = torch.distributions.TransformedDistribution( - self.q.base_dist, self.q.transforms[-1] - ) - else: - NotImplementedError( - "The only implemented methods are `prior` and `identity`." + if method != "prior": + raise NotImplementedError( + f"Only 'prior' warmup method is supported. Got: {method}" ) + if self.prior is None: + raise ValueError("Prior must be provided for warmup.") + + initial_target = self.prior for _ in range(num_steps): self._optimizer.zero_grad() - if self.q.has_rsample: - samples = self.q.rsample((32,)) - logq = self.q.log_prob(samples) - logp = inital_target.log_prob(samples) # type: ignore - loss = -torch.mean(logp - logq) - else: - samples = inital_target.sample((256,)) # type: ignore - loss = -torch.mean(self.q.log_prob(samples)) + + # Use sample_and_log_prob for efficient reparameterized sampling + samples, logq = self.q.sample_and_log_prob(torch.Size((32,))) + logp = initial_target.log_prob(samples) + loss = -torch.mean(logp - logq) + loss.backward(retain_graph=self.retain_graph) self._optimizer.step() + # Denote that warmup was already done self.warm_up_was_done = True @@ -432,14 +437,9 @@ def __init__(self, *args, stick_the_landing: bool = False, **kwargs): self.HYPER_PARAMETERS += ["stick_the_landing"] def _loss(self, xo: Tensor) -> Tuple[Tensor, Tensor]: - if self.q.has_rsample: - return self.loss_rsample(xo) - else: - raise NotImplementedError( - "Currently only reparameterizable distributions are supported." - ) + return self._compute_elbo_loss(xo) - def loss_rsample(self, x_o: Tensor) -> Tuple[Tensor, Tensor]: + def _compute_elbo_loss(self, x_o: Tensor) -> Tuple[Tensor, Tensor]: """Computes the ELBO""" elbo_particles = self.generate_elbo_particles(x_o) loss = -elbo_particles.mean() @@ -451,12 +451,12 @@ def generate_elbo_particles( """Generates individual ELBO particles i.e. logp(theta, x_o) - logq(theta).""" if num_samples is None: num_samples = self.n_particles - samples = self.q.rsample((num_samples,)) + + samples, log_q = self.q.sample_and_log_prob(torch.Size((num_samples,))) if self.stick_the_landing: self.update_surrogate_q() log_q = self._surrogate_q.log_prob(samples) - else: - log_q = self.q.log_prob(samples) + self.potential_fn.x_o = x_o log_potential = self.potential_fn(samples) elbo = log_potential - log_q @@ -599,10 +599,8 @@ def __init__( self.eps = 1e-5 def _loss(self, xo: Tensor) -> Tuple[Tensor, Tensor]: - if self.q.has_rsample: - return self._loss_q_proposal(xo) - else: - raise NotImplementedError("Unknown loss.") + # Zuko flows always support reparameterized sampling via sample_and_log_prob + return self._loss_q_proposal(xo) def _loss_q_proposal(self, x_o: Tensor) -> Tuple[Tensor, Tensor]: """This gives an importance sampling estimate of the forward KL divergence. @@ -687,16 +685,11 @@ def __init__( self.stick_the_landing = True def _loss(self, xo: Tensor) -> Tuple[Tensor, Tensor]: - assert isinstance(self.alpha, float) - if self.q.has_rsample: - if not self.unbiased: - return self.loss_alpha(xo) - else: - return self.loss_alpha_unbiased(xo) + # Zuko flows always support reparameterized sampling via sample_and_log_prob + if not self.unbiased: + return self.loss_alpha(xo) else: - raise NotImplementedError( - "Currently we only support reparameterizable distributions" - ) + return self.loss_alpha_unbiased(xo) def loss_alpha_unbiased(self, x_o: Tensor) -> Tuple[Tensor, Tensor]: """Unbiased estimate of a surrogate RVB""" diff --git a/sbi/samplers/vi/vi_pyro_flows.py b/sbi/samplers/vi/vi_pyro_flows.py deleted file mode 100644 index 8b711feec..000000000 --- a/sbi/samplers/vi/vi_pyro_flows.py +++ /dev/null @@ -1,663 +0,0 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Apache License Version 2.0, see - -from __future__ import annotations - -from typing import Callable, Iterable, List, Optional, Type - -import torch -from pyro.distributions import transforms -from pyro.nn import AutoRegressiveNN, DenseNN -from torch import nn -from torch.distributions import Distribution, Independent, Normal - -from sbi.samplers.vi.vi_utils import filter_kwrags_for_func, get_modules, get_parameters -from sbi.sbi_types import TorchTransform - -# Supported transforms and flows are registered here i.e. associated with a name - -_TRANSFORMS = {} -_TRANSFORMS_INITS = {} -_FLOW_BUILDERS = {} - - -def register_transform( - cls: Optional[Type[TorchTransform]] = None, - name: Optional[str] = None, - inits: Callable = lambda *args, **kwargs: (args, kwargs), -) -> Callable: - """Decorator to register a learnable transformation. - - - Args: - cls: Class to register - name: Name of the transform. - inits: Function that provides initial args and kwargs. - - - """ - - def _register(cls): - cls_name = cls.__name__ if name is None else name - if cls_name in _TRANSFORMS: - raise ValueError(f"The transform {cls_name} is already registered") - else: - _TRANSFORMS[cls_name] = cls - _TRANSFORMS_INITS[cls_name] = inits - return cls - - if cls is None: - return _register - else: - return _register(cls) - - -def get_all_transforms() -> List[str]: - """Returns all registered transforms. - - Returns: - List[str]: List of names of all transforms. - - """ - return list(_TRANSFORMS.keys()) - - -def get_transform(name: str, dim: int, device: str = "cpu", **kwargs) -> TorchTransform: - """Returns an initialized transformation - - - - Args: - name: Name of the transform, must be one of [affine_diag, - affine_tril, affine_coupling, affine_autoregressive, spline_coupling, - spline_autoregressive]. - dim: Input dimension. - device: Device on which everythink is initialized. - kwargs: All associated parameters which will be passed through. - - Returns: - Transform: Invertible transformation. - - """ - name = name.lower() - transform = _TRANSFORMS[name] - overwritable_kwargs = filter_kwrags_for_func(transform.__init__, kwargs) - args, default_kwargs = _TRANSFORMS_INITS[name](dim, device=device, **kwargs) - kwargs = {**default_kwargs, **overwritable_kwargs} - return _TRANSFORMS[name](*args, **kwargs) - - -def register_flow_builder( - cls: Optional[Callable] = None, name: Optional[str] = None -) -> Callable: - """Registers a function that builds a normalizing flow. - - Args: - cls: Builder that is registered. - name: Name of the builder. - - - """ - - def _register(cls): - cls_name = cls.__name__ if name is None else name - if cls_name in _FLOW_BUILDERS: - raise ValueError(f"The flow {cls_name} is not registered as default.") - else: - _FLOW_BUILDERS[cls_name] = cls - return cls - - if cls is None: - return _register - else: - return _register(cls) - - -def get_default_flows() -> List[str]: - """Returns names of all registered flow builders. - - Returns: - List[str]: List of names. - - """ - return list(_FLOW_BUILDERS.keys()) - - -def get_flow_builder( - name: str, - **kwargs, -) -> Callable: - """Returns an normalizing flow, by instantiating the flow build with all arguments. - For details within the keyword arguments we refer to the actual builder class. Some - common arguments are listed here. - - Args: - name: Name of the flow. - kwargs: Hyperparameters for the flow. - num_transforms: Number of normalizing flows that are concatenated. - permute: Permute dimension after each layer. This may helpfull for - autoregressive or coupling nets. - batch_norm: Perform batch normalization. - base_dist: Base distribution. If `None` then a standard Gaussian is used. - hidden_dims: The dimensionality of the hidden units per layer. Given as a - list of integers. - - Returns: - Callable: A function that if called returns a initialized flow. - - """ - - def build_fn( - event_shape: torch.Size, link_flow: TorchTransform, device: str = "cpu" - ): - return _FLOW_BUILDERS[name](event_shape, link_flow, device=device, **kwargs) - - build_fn.__doc__ = _FLOW_BUILDERS[name].__doc__ - - return build_fn - - -# Initialization functions. - - -def init_affine_autoregressive(dim: int, device: str = "cpu", **kwargs): - """Provides the default initial arguments for an affine autoregressive transform.""" - hidden_dims: List[int] = kwargs.pop("hidden_dims", [3 * dim + 5, 3 * dim + 5]) - skip_connections: bool = kwargs.pop("skip_connections", False) - nonlinearity = kwargs.pop("nonlinearity", nn.ReLU()) - arn = AutoRegressiveNN( - dim, hidden_dims, nonlinearity=nonlinearity, skip_connections=skip_connections - ).to(device) - return [arn], {"log_scale_min_clip": -3.0} - - -def init_spline_autoregressive(dim: int, device: str = "cpu", **kwargs): - """Provides the default initial arguments for an spline autoregressive transform.""" - hidden_dims: List[int] = kwargs.pop("hidden_dims", [3 * dim + 5, 3 * dim + 5]) - skip_connections: bool = kwargs.pop("skip_connections", False) - nonlinearity = kwargs.pop("nonlinearity", nn.ReLU()) - count_bins: int = kwargs.get("count_bins", 10) - order: str = kwargs.get("order", "linear") - bound: int = kwargs.get("bound", 10) - if order == "linear": - param_dims = [count_bins, count_bins, (count_bins - 1), count_bins] - else: - param_dims = [count_bins, count_bins, (count_bins - 1)] - neural_net = AutoRegressiveNN( - dim, - hidden_dims, - param_dims=param_dims, - skip_connections=skip_connections, - nonlinearity=nonlinearity, - ).to(device) - return [dim, neural_net], {"count_bins": count_bins, "bound": bound, "order": order} - - -def init_affine_coupling(dim: int, device: str = "cpu", **kwargs): - """Provides the default initial arguments for an affine autoregressive transform.""" - assert dim > 1, "In 1d this would be equivalent to affine flows, use them." - nonlinearity = kwargs.pop("nonlinearity", nn.ReLU()) - split_dim: int = int(kwargs.get("split_dim", dim // 2)) - hidden_dims: List[int] = kwargs.pop("hidden_dims", [5 * dim + 20, 5 * dim + 20]) - params_dims: List[int] = [dim - split_dim, dim - split_dim] - arn = DenseNN( - split_dim, - hidden_dims, - params_dims, - nonlinearity=nonlinearity, - ).to(device) - return [split_dim, arn], {"log_scale_min_clip": -3.0} - - -def init_spline_coupling(dim: int, device: str = "cpu", **kwargs): - """Intitialize a spline coupling transform, by providing necessary args and - kwargs.""" - assert dim > 1, "In 1d this would be equivalent to affine flows, use them." - split_dim: int = kwargs.get("split_dim", dim // 2) - hidden_dims: List[int] = kwargs.pop("hidden_dims", [5 * dim + 30, 5 * dim + 30]) - nonlinearity = kwargs.pop("nonlinearity", nn.ReLU()) - count_bins: int = kwargs.get("count_bins", 15) - order: str = kwargs.get("order", "linear") - bound: int = kwargs.get("bound", 10) - if order == "linear": - param_dims = [ - (dim - split_dim) * count_bins, - (dim - split_dim) * count_bins, - (dim - split_dim) * (count_bins - 1), - (dim - split_dim) * count_bins, - ] - else: - param_dims = [ - (dim - split_dim) * count_bins, - (dim - split_dim) * count_bins, - (dim - split_dim) * (count_bins - 1), - ] - neural_net = DenseNN( - split_dim, hidden_dims, param_dims, nonlinearity=nonlinearity - ).to(device) - return [dim, split_dim, neural_net], { - "count_bins": count_bins, - "bound": bound, - "order": order, - } - - -# Register these directly from pyro - -register_transform( - transforms.AffineAutoregressive, - "affine_autoregressive", - inits=init_affine_autoregressive, -) -register_transform( - transforms.SplineAutoregressive, - "spline_autoregressive", - inits=init_spline_autoregressive, -) - -register_transform( - transforms.AffineCoupling, "affine_coupling", inits=init_affine_coupling -) - -register_transform( - transforms.SplineCoupling, "spline_coupling", inits=init_spline_coupling -) - - -# Register these very simple transforms. - - -@register_transform( - name="affine_diag", - inits=lambda dim, device="cpu", **kwargs: ( - [], - { - "loc": torch.zeros(dim, device=device), - "scale": torch.ones(dim, device=device), - }, - ), -) -class AffineTransform(transforms.AffineTransform): - """Trainable version of an Affine transform. This can be used to get diagonal - Gaussian approximations.""" - - __doc__ = transforms.AffineTransform.__doc__ - - def parameters(self): - self.loc.requires_grad_(True) - self.scale.requires_grad_(True) - yield self.loc - yield self.scale - - def with_cache(self, cache_size=1): - if self._cache_size == cache_size: - return self - return AffineTransform(self.loc, self.scale, cache_size=cache_size) - - -@register_transform( - name="affine_tril", - inits=lambda dim, device="cpu", **kwargs: ( - [], - { - "loc": torch.zeros(dim, device=device), - "scale_tril": torch.eye(dim, device=device), - }, - ), -) -class LowerCholeskyAffine(transforms.LowerCholeskyAffine): - """Trainable version of a Lower Cholesky Affine transform. This can be used to get - full Gaussian approximations.""" - - __doc__ = transforms.LowerCholeskyAffine.__doc__ - - def parameters(self): - self.loc.requires_grad_(True) - self.scale_tril.requires_grad_(True) - yield self.loc - yield self.scale_tril - - def with_cache(self, cache_size=1): - if self._cache_size == cache_size: - return self - return LowerCholeskyAffine(self.loc, self.scale_tril, cache_size=cache_size) - - def log_abs_det_jacobian(self, x, y): - """This modification allows batched scale_tril matrices.""" - return self.log_abs_jacobian_diag(x, y).sum(-1) - - def log_abs_jacobian_diag(self, x, y): - """This returns the full diagonal which is necessary to compute conditionals""" - dim = self.scale_tril.dim() - return torch.diagonal(self.scale_tril, dim1=dim - 2, dim2=dim - 1).log() - - -# Now build Normalizing flows - - -class TransformedDistribution(torch.distributions.TransformedDistribution): - """This is TransformedDistribution with the capability to return parameters.""" - - assert __doc__ is not None - assert torch.distributions.TransformedDistribution.__doc__ is not None - __doc__ = torch.distributions.TransformedDistribution.__doc__ - - def parameters(self): - if hasattr(self.base_dist, "parameters"): - yield from self.base_dist.parameters() # type: ignore - for t in self.transforms: - yield from get_parameters(t) - - def modules(self): - if hasattr(self.base_dist, "modules"): - yield from self.base_dist.modules() # type: ignore - for t in self.transforms: - yield from get_modules(t) - - -def build_flow( - event_shape: torch.Size, - link_flow: TorchTransform, - num_transforms: int = 5, - transform: str = "affine_autoregressive", - permute: bool = True, - batch_norm: bool = False, - base_dist: Optional[Distribution] = None, - device: str = "cpu", - **kwargs, -) -> TransformedDistribution: - """Generates a Transformed Distribution where the base_dist is transformed by - num_transforms bijective transforms of specified type. - - Args: - event_shape: Shape of the events generated by the distribution. - link_flow: Links to a specific support . - num_transforms: Number of normalizing flows that are concatenated. - transform: The type of normalizing flow. Should be one of [affine_diag, - affine_tril, affine_coupling, affine_autoregressive, spline_coupling, - spline_autoregressive]. - permute: Permute dimension after each layer. This may helpfull for - autoregressive or coupling nets. - batch_norm: Perform batch normalization. - base_dist: Base distribution. If `None` then a standard Gaussian is used. - device: Device on which we build everythink. - kwargs: Hyperparameters are added here. - Returns: - TransformedDistribution - - """ - # Some transforms increase dimension by decreasing the degrees of freedom e.g. - # SoftMax. - # `unsqueeze(0)` because the `link_flow` requires a batch dimension if the prior is - # a `MultipleIndependent`. - additional_dim = ( - len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) # type: ignore # Since link flow should never be None - - torch.tensor(event_shape, device=device).item() - ) - event_shape = torch.Size( - (torch.tensor(event_shape, device=device) - additional_dim).tolist() - ) - # Base distribution is standard normal if not specified - if base_dist is None: - base_dist = Independent( - Normal( - torch.zeros(event_shape, device=device), - torch.ones(event_shape, device=device), - ), - 1, - ) - # Generate normalizing flow - if isinstance(event_shape, int): - dim = event_shape - elif isinstance(event_shape, Iterable): - dim = event_shape[-1] - else: - raise ValueError("The eventshape must either be an Integer or a Iterable.") - - flows = [] - for i in range(num_transforms): - flows.append( - get_transform(transform, dim, device=device, **kwargs).with_cache() - ) - if permute and i < num_transforms - 1: - permutation = torch.randperm(dim, device=device) - flows.append(transforms.Permute(permutation)) - if batch_norm and i < num_transforms - 1: - bn = transforms.BatchNorm(dim).to(device) - flows.append(bn) - flows.append(link_flow.with_cache()) - dist = TransformedDistribution(base_dist, flows) - return dist - - -@register_flow_builder(name="gaussian_diag") -def gaussian_diag_flow_builder( - event_shape: torch.Size, link_flow: TorchTransform, device: str = "cpu", **kwargs -): - """Generates a Gaussian distribution with diagonal covariance. - - Args: - event_shape: Shape of the events generated by the distribution. - link_flow: Links to a specific support . - kwargs: Hyperparameters are added here. - loc: Initial location. - scale: Initial triangular matrix. - - Returns: - TransformedDistribution - - """ - if "transform" in kwargs: - kwargs.pop("transform") - if "base_dist" in kwargs: - kwargs.pop("base_dist") - if "num_transforms" in kwargs: - kwargs.pop("num_transforms") - return build_flow( - event_shape, - link_flow, - device=device, - transform="affine_diag", - num_transforms=1, - shuffle=False, - **kwargs, - ) - - -@register_flow_builder(name="gaussian") -def gaussian_flow_builder( - event_shape: torch.Size, link_flow: TorchTransform, device: str = "cpu", **kwargs -) -> TransformedDistribution: - """Generates a Gaussian distribution. - - Args: - event_shape: Shape of the events generated by the distribution. - link_flow: Links to a specific support . - device: Device on which to build. - kwargs: Hyperparameters are added here. - loc: Initial location. - scale_tril: Initial triangular matrix. - - Returns: - TransformedDistribution - - """ - if "transform" in kwargs: - kwargs.pop("transform") - if "base_dist" in kwargs: - kwargs.pop("base_dist") - if "num_transforms" in kwargs: - kwargs.pop("num_transforms") - return build_flow( - event_shape, - link_flow, - device=device, - transform="affine_tril", - shuffle=False, - num_transforms=1, - **kwargs, - ) - - -@register_flow_builder(name="maf") -def masked_autoregressive_flow_builder( - event_shape: torch.Size, link_flow: TorchTransform, device: str = "cpu", **kwargs -) -> TransformedDistribution: - """Generates a masked autoregressive flow - - Args: - event_shape: Shape of the events generated by the distribution. - link_flow: Links to a specific support. - device: Device on which to build. - num_transforms: Number of normalizing flows that are concatenated. - permute: Permute dimension after each layer. This may helpfull for - autoregressive or coupling nets. - batch_norm: Perform batch normalization. - base_dist: Base distribution. If `None` then a standard Gaussian is used. - kwargs: Hyperparameters are added here. - hidden_dims: The dimensionality of the hidden units per layer. - skip_connections: Whether to add skip connections from the input to the - output. - nonlinearity: The nonlinearity to use in the feedforward network such as - torch.nn.ReLU(). - log_scale_min_clip: The minimum value for clipping the log(scale) from - the autoregressive NN - log_scale_max_clip: The maximum value for clipping the log(scale) from - the autoregressive NN - sigmoid_bias: A term to add the logit of the input when using the stable - tranform. - stable: When true, uses the alternative "stable" version of the transform. - Yet this version is also less expressive. - - Returns: - TransformedDistribution - - """ - if "transform" in kwargs: - kwargs.pop("transform") - return build_flow( - event_shape, - link_flow, - transform="affine_autoregressive", - device=device, - **kwargs, - ) - - -@register_flow_builder(name="nsf") -def spline_autoregressive_flow_builder( - event_shape: torch.Size, link_flow: TorchTransform, device: str = "cpu", **kwargs -) -> TransformedDistribution: - """Generates an autoregressive neural spline flow. - - Args: - event_shape: Shape of the events generated by the distribution. - link_flow: Links to a specific support . - num_transforms: Number of normalizing flows that are concatenated. - permute: Permute dimension after each layer. This may helpfull for - autoregressive or coupling nets. - batch_norm: Perform batch normalization. - base_dist: Base distribution. If `None` then a standard Gaussian is used. - kwargs: Hyperparameters are added here. - hidden_dims: The dimensionality of the hidden units per layer. - skip_connections: Whether to add skip connections from the input to the - output. - nonlinearity: The nonlinearity to use in the feedforward network such as - torch.nn.ReLU(). - count_bins: The number of segments comprising the spline. - bound: The quantity `K` determining the bounding box. - order: One of [`linear`, `quadratic`] specifying the order of the spline. - - Returns: - TransformedDistribution - - """ - if "transform" in kwargs: - kwargs.pop("transform") - return build_flow( - event_shape, - link_flow, - transform="spline_autoregressive", - device=device, - **kwargs, - ) - - -@register_flow_builder(name="mcf") -def coupling_flow_builder( - event_shape: torch.Size, link_flow: TorchTransform, device: str = "cpu", **kwargs -) -> TransformedDistribution: - """Generates a affine coupling flow. - - Args: - event_shape: Shape of the events generated by the distribution. - link_flow: Links to a specific support. - num_transforms: Number of normalizing flows that are concatenated. - permute: Permute dimension after each layer. This may helpfull for - autoregressive or coupling nets. - batch_norm: Perform batch normalization. - base_dist: Base distribution. If `None` then a standard Gaussian is used. - kwargs: Hyperparameters are added here. - hidden_dims: The dimensionality of the hidden units per layer. - skip_connections: Whether to add skip connections from the input to the - output. - nonlinearity: The nonlinearity to use in the feedforward network such as - torch.nn.ReLU(). - log_scale_min_clip: The minimum value for clipping the log(scale) from - the autoregressive NN - log_scale_max_clip: The maximum value for clipping the log(scale) from - the autoregressive NN - split_dim : The dimension to split the input on for the coupling transform. - - Returns: - TransformedDistribution - - """ - if "transform" in kwargs: - kwargs.pop("transform") - return build_flow( - event_shape, link_flow, device=device, transform="affine_coupling", **kwargs - ) - - -@register_flow_builder(name="scf") -def spline_coupling_flow_builder( - event_shape: torch.Size, link_flow: TorchTransform, device: str = "cpu", **kwargs -) -> TransformedDistribution: - """Generates an spline coupling flow. Implementation is based on [1], we do not - implement affine transformations using LU decomposition as proposed in [2]. - - Args: - event_shape: Shape of the events generated by the distribution. - link_flow: Links to a specific support . - num_transforms: Number of normalizing flows that are concatenated. - permute: Permute dimension after each layer. This may helpfull for - autoregressive or coupling nets. - batch_norm: Perform batch normalization. - base_dist: Base distribution. If `None` then a standard Gaussian is used. - kwargs: Hyperparameters are added here. - hidden_dims: The dimensionality of the hidden units per layer. - nonlinearity: The nonlinearity to use in the feedforward network such as - torch.nn.ReLU(). - count_bins: The number of segments comprising the spline. - bound: The quantity `K` determining the bounding box. - order: One of [`linear`, `quadratic`] specifying the order of the spline. - split_dim : The dimension to split the input on for the coupling transform. - - Returns: - TransformedDistribution - - References: - [1] Invertible Generative Modeling using Linear Rational Splines, Hadi M. - Dolatabadi, Sarah Erfani, Christopher Leckie, 2020, - https://arxiv.org/pdf/2001.05168.pdf. - [2] Neural Spline Flows, Conor Durkan, Artur Bekasov, Iain Murray, George - Papamakarios, 2019, https://arxiv.org/pdf/1906.04032.pdf. - - - """ - if "transform" in kwargs: - kwargs.pop("transform") - return build_flow( - event_shape, link_flow, device=device, transform="spline_coupling", **kwargs - ) diff --git a/sbi/samplers/vi/vi_utils.py b/sbi/samplers/vi/vi_utils.py index 70731504e..8b65641ff 100644 --- a/sbi/samplers/vi/vi_utils.py +++ b/sbi/samplers/vi/vi_utils.py @@ -15,13 +15,231 @@ ) import torch -from pyro.distributions.torch_transform import TransformModule -from torch import Tensor -from torch.distributions import Distribution, TransformedDistribution +from torch import Tensor, nn +from torch.distributions import ( + Distribution, + Independent, + MultivariateNormal, + Normal, + TransformedDistribution, +) from torch.distributions.transforms import ComposeTransform, IndependentTransform from torch.nn import Module -from sbi.sbi_types import PyroTransformedDistribution, TorchTransform +from sbi.neural_nets.estimators.zuko_flow import ZukoUnconditionalFlow +from sbi.sbi_types import Shape, TorchTransform, VariationalDistribution + + +class TransformedZukoFlow(nn.Module): + """Wrapper for Zuko flows that applies a link transform to samples. + + This wrapper ensures that: + 1. Samples from the flow (in unconstrained space) are transformed to constrained + space via link_transform + 2. log_prob accounts for the Jacobian of the transformation + + The underlying Zuko flow operates in unconstrained space, but this wrapper + provides an interface where samples and log_probs are in constrained space + (matching the prior's support). + """ + + def __init__( + self, + flow: ZukoUnconditionalFlow, + link_transform: TorchTransform, + ): + """Initialize the transformed flow wrapper. + + Args: + flow: The underlying Zuko unconditional flow (operates in unconstrained + space). + link_transform: Transform from unconstrained to constrained space. + link_transform.forward maps unconstrained -> constrained. + link_transform.inv maps constrained -> unconstrained. + """ + super().__init__() + self._flow = flow + self._link_transform = link_transform + + @property + def net(self): + """Access the underlying flow's network (for compatibility).""" + return self._flow.net + + def parameters(self): + """Return the parameters of the underlying flow.""" + return self._flow.parameters() + + def sample(self, sample_shape: Shape) -> Tensor: + """Sample from the flow and transform to constrained space. + + Args: + sample_shape: Shape of samples to generate. + + Returns: + Samples in constrained space with shape (*sample_shape, event_dim). + """ + # Sample in unconstrained space + unconstrained_samples = self._flow.sample(sample_shape) + # Transform to constrained space + constrained_samples = self._link_transform(unconstrained_samples) + assert isinstance(constrained_samples, Tensor) # Type narrowing for pyright + return constrained_samples + + def log_prob(self, theta: Tensor) -> Tensor: + """Compute log probability of samples in constrained space. + + Uses change of variables: log p(θ) = log q(z) + log|det(dz/dθ)| + where z = link_transform.inv(θ) and q is the flow's distribution. + + Args: + theta: Samples in constrained space. + + Returns: + Log probabilities with shape (*batch_shape,). + """ + # Transform to unconstrained space + z = self._link_transform.inv(theta) + assert isinstance(z, Tensor) # Type narrowing for pyright + # Get flow log prob in unconstrained space + log_prob_z = self._flow.log_prob(z) + # Add Jacobian correction for the inverse transform + # log_abs_det_jacobian gives log|det(dz/dθ)| + log_det_jacobian = self._link_transform.inv.log_abs_det_jacobian(theta, z) + # Some transforms (e.g. identity) return per-dimension Jacobians, + # while IndependentTransform returns summed Jacobians. Sum if needed. + if log_det_jacobian.dim() > log_prob_z.dim(): + log_det_jacobian = log_det_jacobian.sum(dim=-1) + return log_prob_z + log_det_jacobian + + def sample_and_log_prob(self, sample_shape: Shape) -> tuple[Tensor, Tensor]: + """Sample from the flow and compute log probabilities efficiently. + + Args: + sample_shape: Shape of samples to generate. + + Returns: + Tuple of (samples, log_probs) where samples are in constrained space. + """ + # Sample in unconstrained space and get log prob + z, log_prob_z = self._flow.sample_and_log_prob(torch.Size(sample_shape)) + # Transform to constrained space + theta = self._link_transform(z) + assert isinstance(theta, Tensor) # Type narrowing for pyright + # Subtract Jacobian for forward transform (we want log p(θ) not log q(z)) + # log p(θ) = log q(z) - log|det(dθ/dz)| = log q(z) + log|det(dz/dθ)| + log_det_jacobian = self._link_transform.log_abs_det_jacobian(z, theta) + # Some transforms (e.g. identity) return per-dimension Jacobians, + # while IndependentTransform returns summed Jacobians. Sum if needed. + if log_det_jacobian.dim() > log_prob_z.dim(): + log_det_jacobian = log_det_jacobian.sum(dim=-1) + log_prob_theta = log_prob_z - log_det_jacobian + return theta, log_prob_theta + + +class LearnableGaussian(nn.Module): + """Learnable Gaussian distribution for variational inference. + + A simple parametric variational family with learnable mean and covariance. + Supports both full covariance (gaussian) and diagonal covariance (gaussian_diag). + """ + + def __init__( + self, + dim: int, + full_covariance: bool = True, + link_transform: Optional[TorchTransform] = None, + device: Union[str, torch.device] = "cpu", + ): + """Initialize the learnable Gaussian. + + Args: + dim: Dimensionality of the distribution. + full_covariance: If True, use full covariance matrix. If False, use + diagonal covariance (faster, fewer parameters). + link_transform: Optional transform to apply to samples. Maps from + unconstrained to constrained space (matching prior support). + device: Device to create parameters on. + """ + super().__init__() + self._dim = dim + self._full_cov = full_covariance + self._link_transform = link_transform + + # Learnable parameters - create on correct device from the start + self.loc = nn.Parameter(torch.zeros(dim, device=device)) + if full_covariance: + # Lower triangular matrix for Cholesky parameterization + self.scale_tril = nn.Parameter(torch.eye(dim, device=device)) + else: + # Log scale for numerical stability + self.log_scale = nn.Parameter(torch.zeros(dim, device=device)) + + def _base_dist(self) -> Distribution: + """Get the base Gaussian distribution with current parameters.""" + if self._full_cov: + return MultivariateNormal(self.loc, scale_tril=self.scale_tril) + return Independent(Normal(self.loc, self.log_scale.exp()), 1) + + def sample(self, sample_shape: Shape) -> Tensor: + """Sample from the distribution. + + Args: + sample_shape: Shape of samples to generate. + + Returns: + Samples with shape (*sample_shape, dim). + """ + # Use sample() not rsample() - this is for inference, not training + samples = self._base_dist().sample(sample_shape) + if self._link_transform is not None: + samples = self._link_transform(samples) + assert isinstance(samples, Tensor) # Type narrowing for pyright + return samples + + def log_prob(self, theta: Tensor) -> Tensor: + """Compute log probability. + + Args: + theta: Points at which to evaluate log probability. + + Returns: + Log probabilities with shape (*batch_shape,). + """ + if self._link_transform is not None: + # Transform to unconstrained space + z = self._link_transform.inv(theta) + assert isinstance(z, Tensor) # Type narrowing for pyright + log_prob_z = self._base_dist().log_prob(z) + # Add Jacobian correction + log_det = self._link_transform.inv.log_abs_det_jacobian(theta, z) + if log_det.dim() > log_prob_z.dim(): + log_det = log_det.sum(dim=-1) + return log_prob_z + log_det + return self._base_dist().log_prob(theta) + + def sample_and_log_prob(self, sample_shape: Shape) -> tuple[Tensor, Tensor]: + """Sample and compute log probability efficiently. + + Args: + sample_shape: Shape of samples to generate. + + Returns: + Tuple of (samples, log_probs). + """ + dist = self._base_dist() + z = dist.rsample(sample_shape) + log_prob_z = dist.log_prob(z) + + if self._link_transform is not None: + theta = self._link_transform(z) + assert isinstance(theta, Tensor) # Type narrowing for pyright + # Adjust log_prob for the transformation + log_det = self._link_transform.log_abs_det_jacobian(z, theta) + if log_det.dim() > log_prob_z.dim(): + log_det = log_det.sum(dim=-1) + return theta, log_prob_z - log_det + return z, log_prob_z def filter_kwrags_for_func(f: Callable, kwargs: Dict) -> Dict: @@ -41,7 +259,7 @@ def filter_kwrags_for_func(f: Callable, kwargs: Dict) -> Dict: return new_kwargs -def get_parameters(t: Union[TorchTransform, TransformModule]) -> Iterable: +def get_parameters(t: Union[TorchTransform, Module]) -> Iterable: """Recursive helper function which can be used to extract parameters from TransformedDistributions. @@ -62,7 +280,7 @@ def get_parameters(t: Union[TorchTransform, TransformModule]) -> Iterable: pass -def get_modules(t: Union[TorchTransform, TransformModule]) -> Iterable: +def get_modules(t: Union[TorchTransform, Module]) -> Iterable: """Recursive helper function which can be used to extract modules from TransformedDistributions. @@ -70,7 +288,7 @@ def get_modules(t: Union[TorchTransform, TransformModule]) -> Iterable: t: A TorchTransform object, which is scanned for the "modules" attribute. Yields: - Iterator[Iterable]: Generator of TransformModules + Iterator[Iterable]: Generator of Modules """ if isinstance(t, Module): yield t @@ -83,7 +301,7 @@ def get_modules(t: Union[TorchTransform, TransformModule]) -> Iterable: pass -def check_parameters_modules_attribute(q: PyroTransformedDistribution) -> None: +def check_parameters_modules_attribute(q: VariationalDistribution) -> None: """Checks a parameterized distribution object for valid `parameters` and `modules`. Args: @@ -196,7 +414,7 @@ def add_parameters_module_attributes( def add_parameter_attributes_to_transformed_distribution( - q: PyroTransformedDistribution, + q: VariationalDistribution, ) -> None: """A function that will add `parameters` and `modules` to q automatically, if q is a TransformedDistribution. @@ -225,7 +443,7 @@ def modules(): def adapt_variational_distribution( - q: PyroTransformedDistribution, + q: VariationalDistribution, prior: Distribution, link_transform: Callable, parameters: Optional[Iterable] = None, diff --git a/sbi/sbi_types.py b/sbi/sbi_types.py index 1d62e8eb4..f4b566085 100644 --- a/sbi/sbi_types.py +++ b/sbi/sbi_types.py @@ -32,7 +32,7 @@ TensorBoardSummaryWriter: TypeAlias = SummaryWriter TorchDistribution: TypeAlias = Distribution TorchTransform: TypeAlias = Transform -PyroTransformedDistribution: TypeAlias = TransformedDistribution +VariationalDistribution: TypeAlias = TransformedDistribution TorchTensor: TypeAlias = Tensor @@ -67,6 +67,6 @@ def __call__(self, theta: Tensor) -> Tensor: ... "TorchTransform", "transform_types", "TorchDistribution", - "PyroTransformedDistribution", + "VariationalDistribution", "TorchTensor", ] diff --git a/tests/vi_test.py b/tests/vi_test.py index 84be2a0a3..9825ee398 100644 --- a/tests/vi_test.py +++ b/tests/vi_test.py @@ -1,6 +1,14 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +""" +Tests for Variational Inference (VI) using VIPosterior. + +This module tests both: +- Single-x VI mode: train() - trains q(θ) for a specific observation x_o +- Amortized VI mode: train_amortized() - trains q(θ|x) across observations +""" + from __future__ import annotations import tempfile @@ -13,39 +21,153 @@ from torch import eye, ones, zeros from torch.distributions import Beta, Binomial, Gamma, MultivariateNormal -from sbi.inference import NLE, likelihood_estimator_based_potential +from sbi.inference import NLE, NRE, likelihood_estimator_based_potential from sbi.inference.posteriors import VIPosterior from sbi.inference.potentials.base_potential import BasePotential -from sbi.samplers.vi.vi_pyro_flows import get_default_flows, get_flow_builder -from sbi.simulators.linear_gaussian import true_posterior_linear_gaussian_mvn_prior +from sbi.inference.potentials.ratio_based_potential import ( + ratio_estimator_based_potential, +) +from sbi.neural_nets.factory import ZukoFlowType +from sbi.simulators.linear_gaussian import ( + linear_gaussian, + true_posterior_linear_gaussian_mvn_prior, +) from sbi.utils import MultipleIndependent -from sbi.utils.metrics import check_c2st +from sbi.utils.metrics import c2st, check_c2st + +# Supported variational families for VI +FLOWS = ["maf", "nsf", "naf", "unaf", "nice", "sospf", "gaussian", "gaussian_diag"] + -# Tests should be run for all default flows -FLOWS = get_default_flows() +# ============================================================================= +# Shared Test Utilities +# ============================================================================= class FakePotential(BasePotential): + """A potential that returns the prior log probability. + + This makes the posterior equal to the prior, which is a trivial but + well-defined posterior that allows proper testing of VI machinery. + """ + def __call__(self, theta, **kwargs): - return torch.ones(theta.shape[0], dtype=torch.float32) + return self.prior.log_prob(theta) def allow_iid_x(self) -> bool: return True -@pytest.mark.slow -@pytest.mark.parametrize("num_dim", (1, 2)) -@pytest.mark.parametrize("vi_method", ("rKL", "fKL", "IW", "alpha")) -def test_c2st_vi_on_Gaussian(num_dim: int, vi_method: str): - """Test VI on Gaussian, comparing to ground truth target via c2st. +def make_tractable_potential(target_distribution, prior): + """Create a potential function from a known target distribution.""" + + class TractablePotential(BasePotential): + def __call__(self, theta, **kwargs): + return target_distribution.log_prob( + torch.as_tensor(theta, dtype=torch.float32) + ) + + def allow_iid_x(self) -> bool: + return True + + return TractablePotential(prior=prior) + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +def _build_linear_gaussian_setup(trainer_type: str = "nle"): + """Helper to build linear Gaussian setup with specified trainer type. Args: - num_dim: parameter dimension of the gaussian model - vi_method: different vi methods + trainer_type: Either "nle" or "nre". + + Returns a dict with: + - prior: MultivariateNormal prior + - potential_fn: Trained potential function (NLE or NRE based) + - theta, x: Simulation data + - likelihood_shift, likelihood_cov: Likelihood parameters + - num_dim: Dimensionality + - trainer_type: The trainer type used """ + torch.manual_seed(42) + + num_dim = 2 + num_simulations = 5000 + prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.25 * eye(num_dim) + + def simulator(theta): + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + + # Generate simulation data + theta = prior.sample((num_simulations,)) + x = simulator(theta) + + # Train estimator and create potential based on trainer type + if trainer_type == "nle": + trainer = NLE(prior=prior, show_progress_bars=False, density_estimator="nsf") + trainer.append_simulations(theta, x) + estimator = trainer.train(max_num_epochs=200) + potential_fn, _ = likelihood_estimator_based_potential( + likelihood_estimator=estimator, + prior=prior, + x_o=None, + ) + elif trainer_type == "nre": + trainer = NRE(prior=prior, show_progress_bars=False, classifier="mlp") + trainer.append_simulations(theta, x) + estimator = trainer.train(max_num_epochs=200) + potential_fn, _ = ratio_estimator_based_potential( + ratio_estimator=estimator, + prior=prior, + x_o=None, + ) + else: + raise ValueError(f"Unknown trainer_type: {trainer_type}") + + return { + "prior": prior, + "potential_fn": potential_fn, + "theta": theta, + "x": x, + "likelihood_shift": likelihood_shift, + "likelihood_cov": likelihood_cov, + "num_dim": num_dim, + "trainer_type": trainer_type, + } + + +@pytest.fixture +def linear_gaussian_setup(): + """Setup for linear Gaussian test problem with trained NLE.""" + return _build_linear_gaussian_setup("nle") + + +@pytest.fixture(params=["nle", "nre"]) +def linear_gaussian_setup_trainers(request): + """Parametrized setup for linear Gaussian with NLE or NRE.""" + return _build_linear_gaussian_setup(request.param) + + +# ============================================================================= +# Single-x VI Tests: train() method +# ============================================================================= - num_samples = 2000 +@pytest.mark.slow +@pytest.mark.parametrize("num_dim", (1, 2)) +@pytest.mark.parametrize("vi_method", ("rKL", "fKL", "IW", "alpha")) +@pytest.mark.parametrize("sampling_method", ("naive", "sir")) +def test_c2st_vi_on_Gaussian(num_dim: int, vi_method: str, sampling_method: str): + """Test single-x VI on Gaussian, comparing to ground truth via c2st.""" + if sampling_method == "naive" and vi_method == "IW": + return # This combination is not meant to perform well + + num_samples = 2000 likelihood_shift = -1.0 * ones(num_dim) likelihood_cov = 0.3 * eye(num_dim) prior_mean = zeros(num_dim) @@ -57,20 +179,13 @@ def test_c2st_vi_on_Gaussian(num_dim: int, vi_method: str): ) target_samples = target_distribution.sample((num_samples,)) - class TractablePotential(BasePotential): - def __call__(self, theta, **kwargs): - return target_distribution.log_prob( - torch.as_tensor(theta, dtype=torch.float32) - ) - - def allow_iid_x(self) -> bool: - return True - prior = MultivariateNormal(prior_mean, prior_cov) - potential_fn = TractablePotential(prior=prior) + potential_fn = make_tractable_potential(target_distribution, prior) theta_transform = torch_tf.identity_transform - posterior = VIPosterior(potential_fn, prior, theta_transform=theta_transform) + # Use 'gaussian' for 1D (normalizing flows are unstable in 1D with Zuko) + q = "gaussian" if num_dim == 1 else "nsf" + posterior = VIPosterior(potential_fn, prior, theta_transform=theta_transform, q=q) posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32))) posterior.vi_method = vi_method posterior.train() @@ -84,20 +199,12 @@ def allow_iid_x(self) -> bool: @pytest.mark.parametrize("num_dim", (1, 2)) @pytest.mark.parametrize("q", FLOWS) def test_c2st_vi_flows_on_Gaussian(num_dim: int, q: str): - """Test VI on Gaussian, comparing to ground truth target via c2st. - - Args: - num_dim: parameter dimension of the gaussian model - vi_method: different vi methods - sampling_method: Different sampling methods - - """ - # Coupling flows undefined at 1d - if num_dim == 1 and q in ["mcf", "scf"]: + """Test different flow types on Gaussian via c2st.""" + # Normalizing flows (except gaussian families) are unstable in 1D with Zuko + if num_dim == 1 and q not in ["gaussian", "gaussian_diag"]: return num_samples = 2000 - likelihood_shift = -1.0 * ones(num_dim) likelihood_cov = 0.3 * eye(num_dim) prior_mean = zeros(num_dim) @@ -109,17 +216,8 @@ def test_c2st_vi_flows_on_Gaussian(num_dim: int, q: str): ) target_samples = target_distribution.sample((num_samples,)) - class TractablePotential(BasePotential): - def __call__(self, theta, **kwargs): - return target_distribution.log_prob( - torch.as_tensor(theta, dtype=torch.float32) - ) - - def allow_iid_x(self) -> bool: - return True - prior = MultivariateNormal(prior_mean, prior_cov) - potential_fn = TractablePotential(prior=prior) + potential_fn = make_tractable_potential(target_distribution, prior) theta_transform = torch_tf.identity_transform posterior = VIPosterior(potential_fn, prior, theta_transform=theta_transform, q=q) @@ -134,16 +232,8 @@ def allow_iid_x(self) -> bool: @pytest.mark.slow @pytest.mark.parametrize("num_dim", (1, 2)) def test_c2st_vi_external_distributions_on_Gaussian(num_dim: int): - """Test VI on Gaussian, comparing to ground truth target via c2st. - - Args: - num_dim: parameter dimension of the gaussian model - vi_method: different vi methods - sampling_method: Different sampling methods - - """ + """Test VI with user-provided external distribution.""" num_samples = 2000 - likelihood_shift = -1.0 * ones(num_dim) likelihood_cov = 0.3 * eye(num_dim) prior_mean = zeros(num_dim) @@ -155,17 +245,8 @@ def test_c2st_vi_external_distributions_on_Gaussian(num_dim: int): ) target_samples = target_distribution.sample((num_samples,)) - class TractablePotential(BasePotential): - def __call__(self, theta, **kwargs): - return target_distribution.log_prob( - torch.as_tensor(theta, dtype=torch.float32) - ) - - def allow_iid_x(self) -> bool: - return True - prior = MultivariateNormal(prior_mean, prior_cov) - potential_fn = TractablePotential(prior=prior) + potential_fn = make_tractable_potential(target_distribution, prior) theta_transform = torch_tf.identity_transform mu = zeros(num_dim, requires_grad=True) @@ -189,153 +270,120 @@ def allow_iid_x(self) -> bool: @pytest.mark.parametrize("q", FLOWS) def test_deepcopy_support(q: str): - """Tests if the variational does support deepcopy. - - Args: - q: Different variational posteriors. - """ - + """Test that VIPosterior supports deepcopy for all flow types.""" num_dim = 2 - prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) potential_fn = FakePotential(prior=prior) theta_transform = torch_tf.identity_transform - posterior = VIPosterior( - potential_fn, - prior, - theta_transform=theta_transform, - q=q, - ) + posterior = VIPosterior(potential_fn, prior, theta_transform=theta_transform, q=q) posterior_copy = deepcopy(posterior) posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32))) - assert posterior._x != posterior_copy._x, ( - "Default x attributed of original and copied but modified VIPosterior must be\ - the different, on change (otherwise it is not a deep copy)." - ) + assert posterior._x != posterior_copy._x, "Deepcopy should create independent copy" + posterior_copy = deepcopy(posterior) - assert (posterior._x == posterior_copy._x).all(), ( - "Default x attributed of original and copied VIPosterior must be the same." - ) + assert (posterior._x == posterior_copy._x).all(), "Deepcopy should preserve values" - # Try if they are the same + # Verify samples are reproducible torch.manual_seed(0) - s1 = posterior._q.rsample() + if hasattr(posterior._q, "rsample"): + s1 = posterior._q.rsample() + else: + s1 = posterior._q.sample((1,)).squeeze(0) torch.manual_seed(0) - s2 = posterior_copy._q.rsample() - assert torch.allclose(s1, s2), ( - "Samples from original and unpickled VIPosterior must be close." - ) + if hasattr(posterior_copy._q, "rsample"): + s2 = posterior_copy._q.rsample() + else: + s2 = posterior_copy._q.sample((1,)).squeeze(0) + assert torch.allclose(s1, s2), "Samples should match after deepcopy" - # Produces nonleaf tensors in the cache... -> Can lead to failure of deepcopy. - posterior.q.rsample() - posterior_copy = deepcopy(posterior) + # Test deepcopy after sampling (can produce nonleaf tensors in cache) + if hasattr(posterior.q, "rsample"): + posterior.q.rsample() + else: + posterior.q.sample((1,)) + deepcopy(posterior) # Should not raise @pytest.mark.parametrize("q", FLOWS) def test_pickle_support(q: str): - """Tests if the VIPosterior can be saved and loaded via pickle. - - Args: - q: Different variational posteriors. - """ + """Test that VIPosterior can be saved and loaded via pickle.""" num_dim = 2 - prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) potential_fn = FakePotential(prior=prior) theta_transform = torch_tf.identity_transform - posterior = VIPosterior( - potential_fn, - prior, - theta_transform=theta_transform, - q=q, - ) + posterior = VIPosterior(potential_fn, prior, theta_transform=theta_transform, q=q) posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32))) with tempfile.NamedTemporaryFile(suffix=".pt") as f: torch.save(posterior, f.name) posterior_loaded = torch.load(f.name, weights_only=False) - assert (posterior._x == posterior_loaded._x).all(), ( - "Mhh, something with the pickled is strange" - ) + assert (posterior._x == posterior_loaded._x).all() - # Try if they are the same + # Verify samples are reproducible torch.manual_seed(0) - s1 = posterior._q.rsample() + if hasattr(posterior._q, "rsample"): + s1 = posterior._q.rsample() + else: + s1 = posterior._q.sample((1,)).squeeze(0) torch.manual_seed(0) - s2 = posterior_loaded._q.rsample() + if hasattr(posterior_loaded._q, "rsample"): + s2 = posterior_loaded._q.rsample() + else: + s2 = posterior_loaded._q.sample((1,)).squeeze(0) - assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange" + assert torch.allclose(s1, s2), "Samples should match after unpickling" -def test_vi_posterior_inferface(): +def test_vi_posterior_interface(): + """Test VIPosterior interface: hyperparameters, training, evaluation.""" num_dim = 2 - prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) potential_fn = FakePotential(prior=prior) theta_transform = torch_tf.identity_transform - posterior = VIPosterior( - potential_fn, - theta_transform=theta_transform, - ) + posterior = VIPosterior(potential_fn, theta_transform=theta_transform) posterior.set_default_x(torch.zeros((1, num_dim))) posterior2 = VIPosterior(potential_fn) - # Raising errors if untrained - assert isinstance(posterior.q.support, type(posterior2.q.support)), ( - "The support indicated by 'theta_transform' is different than that of 'prior'." - ) + # Check support compatibility (if available) + if hasattr(posterior.q, "support") and hasattr(posterior2.q, "support"): + assert isinstance(posterior.q.support, type(posterior2.q.support)) + # Should raise if not trained with pytest.raises(Exception) as execinfo: posterior.sample() - - assert "The variational posterior was not fit" in execinfo.value.args[0], ( - "An expected error was raised but the error message is different than expected." - ) + assert "The variational posterior was not fit" in execinfo.value.args[0] with pytest.raises(Exception) as execinfo: posterior.log_prob(prior.sample()) + assert "The variational posterior was not fit" in execinfo.value.args[0] - assert "The variational posterior was not fit" in execinfo.value.args[0], ( - "An expected error was raised but the error message is different than expected." - ) - - # Passing Hyperparameters in train + # Test training hyperparameters posterior.train(max_num_iters=20) posterior.train(max_num_iters=20, optimizer=torch.optim.SGD) - assert isinstance(posterior._optimizer._optimizer, torch.optim.SGD), ( - "Assert chaning the optimizer base class did not work" - ) - posterior.train(max_num_iters=20, stick_the_landing=True) + assert isinstance(posterior._optimizer._optimizer, torch.optim.SGD) - assert posterior._optimizer.stick_the_landing, ( - "The sticking_the_landing argument is not correctly passed." - ) + posterior.train(max_num_iters=20, stick_the_landing=True) + assert posterior._optimizer.stick_the_landing posterior.vi_method = "alpha" posterior.train(max_num_iters=20) posterior.train(max_num_iters=20, alpha=0.9) - - assert posterior._optimizer.alpha == 0.9, ( - "The Hyperparameter alpha is not passed to the corresponding optmizer" - ) + assert posterior._optimizer.alpha == 0.9 posterior.vi_method = "IW" posterior.train(max_num_iters=20) posterior.train(max_num_iters=20, K=32) - - assert posterior._optimizer.K == 32, ( - "The Hyperparameter K is not passed to the corresponding optmizer" - ) + assert posterior._optimizer.K == 32 # Test sampling from trained posterior posterior.sample() - # Testing evaluate + # Test evaluation posterior.evaluate() posterior.evaluate("prop") posterior.evaluate("prop_prior") @@ -346,6 +394,7 @@ def test_vi_posterior_inferface(): def test_vi_with_multiple_independent_prior(): + """Test VI with MultipleIndependent prior (mixed distributions).""" prior = MultipleIndependent( [ Gamma(torch.tensor([1.0]), torch.tensor([0.5])), @@ -366,41 +415,286 @@ def simulator(theta): potential, transform = likelihood_estimator_based_potential(nle, prior, x[0]) posterior = VIPosterior( potential, - prior=prior, # type: ignore - theta_transform=transform, + prior=prior, + theta_transform=transform, # type: ignore ) posterior.set_default_x(x[0]) posterior.train() - posterior.sample( - sample_shape=(10,), - show_progress_bars=False, - ) + posterior.sample(sample_shape=(10,), show_progress_bars=False) @pytest.mark.parametrize("num_dim", (1, 2, 3, 4, 5, 10, 25, 33)) -@pytest.mark.parametrize("q", FLOWS) -def test_vi_flow_builders(num_dim: int, q: str): - """Test if the flow builder build the flows correctly, such that at least sampling - and log_prob works.""" - - try: - q = get_flow_builder(q)( - (num_dim,), torch.distributions.transforms.identity_transform - ) - except AssertionError: - # If the flow is not defined for the dimensionality, we pass the test +@pytest.mark.parametrize("q_type", FLOWS) +def test_vi_flow_builders(num_dim: int, q_type: str): + """Test variational families are built correctly with sampling and log_prob.""" + # Normalizing flows (except gaussian families) need >= 2 dimensions for Zuko + if num_dim == 1 and q_type not in ("gaussian", "gaussian_diag"): return - # Without sample_shape + prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) + potential_fn = FakePotential(prior=prior) + theta_transform = torch_tf.identity_transform + + posterior = VIPosterior( + potential_fn, prior, theta_transform=theta_transform, q=q_type + ) - sample = q.sample() - assert sample.shape == (num_dim,), "The sample shape is not as expected" - log_prob = q.log_prob(sample) - assert log_prob.shape == (), "The log_prob shape is not as expected" + q = posterior.q - # With sample_shape + # Test sampling without sample_shape + sample = q.sample(()) + assert sample.shape == (num_dim,), f"Shape mismatch: {sample.shape}" + log_prob = q.log_prob(sample.unsqueeze(0)) + assert log_prob.shape == (1,), f"Log_prob shape mismatch: {log_prob.shape}" + + # Test sampling with sample_shape sample_batch = q.sample((10,)) - assert sample_batch.shape == (10, num_dim), "The sample shape is not as expected" + expected_shape = (10, num_dim) + assert sample_batch.shape == expected_shape, f"Shape mismatch: {sample_batch.shape}" log_prob_batch = q.log_prob(sample_batch) - assert log_prob_batch.shape == (10,), "The log_prob shape is not as expected" + assert log_prob_batch.shape == (10,), f"Shape mismatch: {log_prob_batch.shape}" + + +# ============================================================================= +# Amortized VI Tests: train_amortized() method +# ============================================================================= + + +@pytest.mark.slow +def test_amortized_vi_accuracy(linear_gaussian_setup_trainers): + """Test that amortized VI produces accurate posteriors (NLE and NRE).""" + setup = linear_gaussian_setup_trainers + + posterior = VIPosterior( + potential_fn=setup["potential_fn"], + prior=setup["prior"], + ) + + posterior.train_amortized( + theta=setup["theta"], + x=setup["x"], + max_num_iters=500, + show_progress_bar=False, + flow_type=ZukoFlowType.NSF, + num_transforms=2, + hidden_features=32, + ) + + # Verify training completed successfully + assert posterior._mode == "amortized" + + # Test on multiple observations + test_x_os = [ + zeros(1, setup["num_dim"]), + ones(1, setup["num_dim"]), + -ones(1, setup["num_dim"]), + ] + + for x_o in test_x_os: + true_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o.squeeze(0), + setup["likelihood_shift"], + setup["likelihood_cov"], + zeros(setup["num_dim"]), + eye(setup["num_dim"]), + ) + true_samples = true_posterior.sample((1000,)) + vi_samples = posterior.sample((1000,), x=x_o) + + c2st_score = c2st(true_samples, vi_samples).item() + assert c2st_score < 0.65, ( + f"C2ST too high for {setup['trainer_type']}, " + f"x_o={x_o.squeeze().tolist()}: {c2st_score:.3f}" + ) + + +@pytest.mark.slow +def test_amortized_vi_batched_sampling(linear_gaussian_setup): + """Test batched sampling from amortized VIPosterior.""" + setup = linear_gaussian_setup + + posterior = VIPosterior( + potential_fn=setup["potential_fn"], + prior=setup["prior"], + ) + + posterior.train_amortized( + theta=setup["theta"], + x=setup["x"], + max_num_iters=500, + show_progress_bar=False, + flow_type=ZukoFlowType.NSF, + num_transforms=2, + hidden_features=32, + ) + + num_obs = 10 + num_samples = 100 + x_batch = torch.randn(num_obs, setup["num_dim"]) + samples = posterior.sample_batched((num_samples,), x=x_batch) + + assert samples.shape == (num_samples, num_obs, setup["num_dim"]) + + +@pytest.mark.slow +def test_amortized_vi_log_prob(linear_gaussian_setup): + """Test log_prob evaluation in amortized mode.""" + setup = linear_gaussian_setup + + posterior = VIPosterior( + potential_fn=setup["potential_fn"], + prior=setup["prior"], + ) + + posterior.train_amortized( + theta=setup["theta"], + x=setup["x"], + max_num_iters=500, + show_progress_bar=False, + flow_type=ZukoFlowType.NSF, + num_transforms=2, + hidden_features=32, + ) + + x_o = zeros(1, setup["num_dim"]) + theta_test = torch.randn(10, setup["num_dim"]) + + log_probs = posterior.log_prob(theta_test, x=x_o) + + assert log_probs.shape == (10,) + assert torch.isfinite(log_probs).all() + + +@pytest.mark.slow +def test_amortized_vi_default_x(linear_gaussian_setup): + """Test that amortized mode uses default_x when x not provided.""" + setup = linear_gaussian_setup + + posterior = VIPosterior( + potential_fn=setup["potential_fn"], + prior=setup["prior"], + ) + + posterior.train_amortized( + theta=setup["theta"], + x=setup["x"], + max_num_iters=100, + show_progress_bar=False, + flow_type=ZukoFlowType.NSF, + ) + + posterior.set_default_x(zeros(1, setup["num_dim"])) + samples = posterior.sample((100,)) + assert samples.shape == (100, setup["num_dim"]) + + +@pytest.mark.slow +def test_amortized_vi_requires_training(linear_gaussian_setup): + """Test that sampling before training raises an error.""" + setup = linear_gaussian_setup + + posterior = VIPosterior( + potential_fn=setup["potential_fn"], + prior=setup["prior"], + ) + + posterior.set_default_x(zeros(1, setup["num_dim"])) + with pytest.raises(ValueError): + posterior.sample((100,)) + + +@pytest.mark.slow +def test_amortized_vi_map(linear_gaussian_setup): + """Test that MAP estimation returns high-density region.""" + setup = linear_gaussian_setup + + posterior = VIPosterior( + potential_fn=setup["potential_fn"], + prior=setup["prior"], + ) + + posterior.train_amortized( + theta=setup["theta"], + x=setup["x"], + max_num_iters=500, + show_progress_bar=False, + flow_type=ZukoFlowType.NSF, + num_transforms=2, + hidden_features=32, + ) + + x_o = zeros(1, setup["num_dim"]) + posterior.set_default_x(x_o) + map_estimate = posterior.map(num_iter=500, num_to_optimize=50) + + # For linear Gaussian, MAP equals posterior mean + true_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o.squeeze(0), + setup["likelihood_shift"], + setup["likelihood_cov"], + zeros(setup["num_dim"]), + eye(setup["num_dim"]), + ) + true_mean = true_posterior.mean + map_estimate_flat = map_estimate.squeeze(0) + + assert torch.allclose(map_estimate_flat, true_mean, atol=0.3), ( + f"MAP {map_estimate_flat.tolist()} not close to true mean {true_mean.tolist()}" + ) + + # MAP should have higher potential than random samples + map_log_prob = posterior.potential(map_estimate) + random_samples = posterior.sample((100,), x=x_o) + random_log_probs = posterior.potential(random_samples) + + assert map_log_prob > random_log_probs.median(), ( + f"MAP log_prob {map_log_prob.item():.3f} not better than " + f"median random {random_log_probs.median().item():.3f}" + ) + + +def test_amortized_vi_with_fake_potential(): + """Fast test for amortized VI using FakePotential (no NLE training required). + + This test runs in CI (not marked slow) to ensure amortized VI coverage. + Uses FakePotential where the posterior equals the prior. + """ + torch.manual_seed(42) + + num_dim = 2 + prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) + potential_fn = FakePotential(prior=prior) + + # Generate mock simulation data (not actually used for training potential) + theta = prior.sample((500,)) + x = theta + 0.1 * torch.randn_like(theta) # Noisy observations + + posterior = VIPosterior( + potential_fn=potential_fn, + prior=prior, + ) + + # Train amortized VI + posterior.train_amortized( + theta=theta, + x=x, + max_num_iters=100, # Fewer iterations for speed + show_progress_bar=False, + flow_type=ZukoFlowType.NSF, + num_transforms=2, + hidden_features=16, # Smaller network for speed + ) + + # Verify training completed + assert posterior._mode == "amortized" + + # Test sampling works + x_test = torch.randn(1, num_dim) + samples = posterior.sample((100,), x=x_test) + assert samples.shape == (100, num_dim) + + # Test log_prob works + log_probs = posterior.log_prob(samples, x=x_test) + assert log_probs.shape == (100,) + assert torch.isfinite(log_probs).all()