Skip to content

Commit a406410

Browse files
authored
refactor: build_posterior Posterior Configuration Using Dataclasses (#1619)
* refactor: add posterior parameters dataclass * refactor: add posterior_parameters parameter to build_posterior method * test: add unit test to check signature consistency * refactor: update posterior_parameters initialization * refactor: add posterior_parameters resolution method * refactor(build_posterior): update posterior_parameter resolution method * test: add VectorFieldBasedPotential to test argument * refactor: add fields to VectorFieldPosteriorParameters and update _resolve_posterior_parameters method * refactor: add warning and value error for conflicting posterior_parameter resolution * test: add test for passing conflicting parameters to build_posterior * doc(posterior_parameters): add fields to every dataclass docstring * refactor: remove x_shape field * refactor(mnle): add importance_sampling_parameters and ImportanceSamplingPosteriorParameters * refactor: update str to Literal for neural_ode_backend * refactor: modularize functions for checking posterior parameter value mismatches * refactor: add PosteriorParameters base class and add field validations * refactor(PosteriorParameters): update validations * test(PosteriorParameters): add new tests * add license header * chore: update spacing between docstring and code implementation * refactor: update posterior_parameters resolution flow * test(posterior_parameters): update tests and add test fixture * refactor(npe): validate build_posterior prior argument for rejection_sampling * refactor(NPE): update error message for build_posterior rejection sampling * test(NPE): check build_posterior with rejection sampling fails for empty prior
1 parent 5cec1ea commit a406410

File tree

14 files changed

+1180
-43
lines changed

14 files changed

+1180
-43
lines changed

sbi/inference/posteriors/posterior_parameters.py

Lines changed: 394 additions & 0 deletions
Large diffs are not rendered by default.

sbi/inference/potentials/vector_field_potential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
device: Union[str, torch.device] = "cpu",
6767
iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
6868
iid_params: Optional[Dict[str, Any]] = None,
69-
neural_ode_backend: str = "zuko",
69+
neural_ode_backend: Literal["zuko"] = "zuko",
7070
neural_ode_kwargs: Optional[Dict[str, Any]] = None,
7171
):
7272
r"""

sbi/inference/trainers/base.py

Lines changed: 195 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import warnings
45
from abc import ABC, abstractmethod
56
from copy import deepcopy
7+
from dataclasses import asdict
68
from datetime import datetime
79
from pathlib import Path
810
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
@@ -20,6 +22,15 @@
2022
from sbi.inference.posteriors.direct_posterior import DirectPosterior
2123
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
2224
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
25+
from sbi.inference.posteriors.posterior_parameters import (
26+
DirectPosteriorParameters,
27+
ImportanceSamplingPosteriorParameters,
28+
MCMCPosteriorParameters,
29+
PosteriorParameters,
30+
RejectionPosteriorParameters,
31+
VIPosteriorParameters,
32+
VectorFieldPosteriorParameters,
33+
)
2334
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
2435
from sbi.inference.posteriors.vector_field_posterior import VectorFieldPosterior
2536
from sbi.inference.posteriors.vi_posterior import VIPosterior
@@ -401,6 +412,7 @@ def build_posterior(
401412
sample_with: Literal[
402413
"mcmc", "rejection", "vi", "importance", "direct", "sde", "ode"
403414
],
415+
posterior_parameters: Optional[PosteriorParameters],
404416
**kwargs,
405417
) -> NeuralPosterior:
406418
r"""Method for building posteriors.
@@ -422,6 +434,8 @@ def build_posterior(
422434
- "direct"
423435
- "sde"
424436
- "ode"
437+
posterior_parameters: Configuration passed to the init method for the
438+
posterior. Must be of type PosteriorParameters.
425439
**kwargs: Additional method-specific parameters.
426440
427441
Returns:
@@ -431,12 +445,16 @@ def build_posterior(
431445
prior = self._resolve_prior(prior)
432446
estimator, device = self._resolve_estimator(estimator)
433447

448+
posterior_parameters = self._resolve_posterior_parameters(
449+
sample_with, posterior_parameters, **kwargs
450+
)
451+
434452
self._posterior = self._create_posterior(
435453
estimator,
436454
prior,
437455
sample_with,
438456
device,
439-
**kwargs,
457+
posterior_parameters,
440458
)
441459

442460
# Store models at end of each round.
@@ -508,6 +526,144 @@ def _resolve_estimator(
508526

509527
return estimator, device
510528

529+
def _resolve_posterior_parameters(
530+
self,
531+
sample_with: Literal[
532+
"mcmc", "rejection", "vi", "importance", "direct", "sde", "ode"
533+
],
534+
posterior_parameters: Optional[PosteriorParameters],
535+
**kwargs,
536+
) -> PosteriorParameters:
537+
"""
538+
Resolve posterior parameters based on the sampling strategy.
539+
540+
If `posterior_parameters` is provided, it is returned directly.
541+
542+
If `posterior_parameters` is not provided, this method extracts
543+
sampling-specific parameters from `kwargs` using predefined keys
544+
to instantiate the appropriate posterior parameters dataclass.
545+
546+
Raises:
547+
NotImplementedError: If an unsupported `sample_with` method is provided.
548+
ValueError: If posterior_parameter and a configuration dictionary are passed
549+
together.
550+
551+
Args:
552+
sample_with: The posterior sampling method to use.
553+
posterior_parameters: Optional preconstructed posterior parameter object.
554+
**kwargs: Additional parameters to construct the posterior parameters.
555+
556+
Returns:
557+
A dataclass instance containing the resolved posterior
558+
parameters.
559+
"""
560+
561+
if posterior_parameters is not None:
562+
self._validate_no_duplicate_parameters(**kwargs)
563+
self._validate_posterior_parameters_consistency(
564+
posterior_parameters, **kwargs
565+
)
566+
else:
567+
# Resolve parameters passed through kwargs and convert
568+
# into a subclass of PosteriorParameters
569+
if sample_with == "direct":
570+
params = kwargs.get("direct_sampling_parameters", {}) or {}
571+
posterior_parameters = DirectPosteriorParameters(**params)
572+
elif sample_with == "mcmc":
573+
params = kwargs.get("mcmc_parameters", {}) or {}
574+
posterior_parameters = MCMCPosteriorParameters(
575+
method=kwargs.get("mcmc_method", "slice_np_vectorized"), **params
576+
)
577+
elif sample_with in ("ode", "sde"):
578+
params = kwargs.get("vectorfield_sampling_parameters", {}) or {}
579+
posterior_parameters = VectorFieldPosteriorParameters(**params)
580+
elif sample_with == "rejection":
581+
params = kwargs.get("rejection_sampling_parameters", {}) or {}
582+
posterior_parameters = RejectionPosteriorParameters(**params)
583+
elif sample_with == "vi":
584+
params = kwargs.get("vi_parameters", {}) or {}
585+
posterior_parameters = VIPosteriorParameters(
586+
vi_method=kwargs.get("vi_method", "rKL"), **params
587+
)
588+
elif sample_with == "importance":
589+
params = kwargs.get("importance_sampling_parameters", {}) or {}
590+
posterior_parameters = ImportanceSamplingPosteriorParameters(**params)
591+
else:
592+
raise NotImplementedError(
593+
"Posterior parameter construction not implemented for",
594+
f"'{sample_with}'",
595+
)
596+
597+
return posterior_parameters
598+
599+
def _validate_no_duplicate_parameters(self, **kwargs) -> None:
600+
"""
601+
Ensure parameters aren't specified in both posterior_parameters and the
602+
posterior parameter dictionaries in the build_posterior method.
603+
604+
Args:
605+
**kwargs: Additional parameters to construct the posterior parameters
606+
"""
607+
608+
old_style_params = {
609+
"direct_sampling_parameters",
610+
"mcmc_parameters",
611+
"vectorfield_sampling_parameters",
612+
"rejection_sampling_parameters",
613+
"vi_parameters",
614+
"importance_sampling_parameters",
615+
}
616+
617+
# Check if any old-style parameters were provided
618+
provided_old_params = [
619+
param for param in old_style_params if kwargs.get(param) is not None
620+
]
621+
622+
if provided_old_params:
623+
raise ValueError(
624+
f"Cannot use both old-style parameters {provided_old_params} "
625+
f"and new-style posterior_parameters. Please use only one approach."
626+
)
627+
628+
def _validate_posterior_parameters_consistency(
629+
self, posterior_parameters: PosteriorParameters, **kwargs
630+
) -> None:
631+
"""
632+
This method raises a warning for mismatches between values passed in
633+
mcmc_method and MCMCPosteriorParameters.method, or vi_method and
634+
VIPosteriorParameters.vi_method.
635+
636+
Args:
637+
posterior_parameters: Configuration passed to the init method for the
638+
posterior.
639+
kwargs: keyword arguments passed from build_posterior method.
640+
"""
641+
642+
if not isinstance(posterior_parameters, PosteriorParameters):
643+
raise TypeError(
644+
"posterior_parameters must be PosteriorParameters,"
645+
f" got {type(posterior_parameters).__name__}",
646+
)
647+
elif isinstance(posterior_parameters, MCMCPosteriorParameters):
648+
mcmc_method = kwargs.get("mcmc_method")
649+
if (
650+
mcmc_method != "slice_np_vectorized"
651+
and posterior_parameters.method != mcmc_method
652+
):
653+
warnings.warn(
654+
f"Conflicting mcmc_method='{mcmc_method}' ignored in favor of "
655+
f"posterior_parameters.method='{posterior_parameters.method}'",
656+
stacklevel=2,
657+
)
658+
elif isinstance(posterior_parameters, VIPosteriorParameters):
659+
vi_method = kwargs.get("vi_method")
660+
if vi_method != "rKL" and posterior_parameters.vi_method != vi_method:
661+
warnings.warn(
662+
f"Conflicting vi_method='{vi_method}' ignored in favor of "
663+
f"posterior_parameters.vi_method='{posterior_parameters.vi_method}'",
664+
stacklevel=2,
665+
)
666+
511667
def _create_posterior(
512668
self,
513669
estimator: Union[RatioEstimator, ConditionalEstimator],
@@ -516,7 +672,7 @@ def _create_posterior(
516672
"mcmc", "rejection", "vi", "importance", "direct", "sde", "ode"
517673
],
518674
device: Union[str, torch.device],
519-
**kwargs,
675+
posterior_parameters: PosteriorParameters,
520676
) -> NeuralPosterior:
521677
"""
522678
Create a posterior object using the specified inference method.
@@ -539,83 +695,89 @@ def _create_posterior(
539695
- "ode"
540696
device: torch device on which to train the neural net and on which to
541697
perform all posterior operations, e.g. gpu or cpu.
542-
**kwargs: Additional method-specific parameters.
698+
posterior_parameters: Configuration passed to the init method for the
699+
posterior. Must be of type PosteriorParameters.
543700
544701
Returns:
545702
NeuralPosterior object.
546703
"""
547704

548-
if sample_with == "direct":
705+
if isinstance(posterior_parameters, DirectPosteriorParameters):
549706
posterior_estimator = estimator
550-
assert isinstance(posterior_estimator, ConditionalDensityEstimator), (
551-
f"Expected posterior_estimator to be an instance of "
552-
" ConditionalDensityEstimator, "
553-
f"but got {type(posterior_estimator).__name__} instead."
554-
)
707+
if not isinstance(posterior_estimator, ConditionalDensityEstimator):
708+
raise TypeError(
709+
f"Expected posterior_estimator to be an instance of "
710+
" ConditionalDensityEstimator, "
711+
f"but got {type(posterior_estimator).__name__} instead."
712+
)
555713
posterior = DirectPosterior(
556714
posterior_estimator=posterior_estimator,
557715
prior=prior,
558716
device=device,
559-
**(kwargs.get("direct_sampling_parameters") or {}),
717+
**asdict(posterior_parameters),
560718
)
561-
elif sample_with in ("sde", "ode"):
719+
elif isinstance(posterior_parameters, VectorFieldPosteriorParameters):
562720
vector_field_estimator = estimator
563-
assert isinstance(
564-
vector_field_estimator, ConditionalVectorFieldEstimator
565-
), (
566-
f"Expected vector_field_estimator to be an instance of "
567-
" ConditionalVectorFieldEstimator, "
568-
f"but got {type(vector_field_estimator).__name__} instead."
569-
)
721+
if not isinstance(vector_field_estimator, ConditionalVectorFieldEstimator):
722+
raise TypeError(
723+
f"Expected vector_field_estimator to be an instance of "
724+
" ConditionalVectorFieldEstimator, "
725+
f"but got {type(vector_field_estimator).__name__} instead."
726+
)
727+
if sample_with not in ("ode", "sde"):
728+
raise ValueError(
729+
"`sample_with` must be either",
730+
f" 'ode' or 'sde', got '{sample_with}'",
731+
)
570732
posterior = VectorFieldPosterior(
571-
vector_field_estimator,
572-
prior,
733+
vector_field_estimator=vector_field_estimator,
734+
prior=prior,
573735
device=device,
574736
sample_with=sample_with,
575-
**(kwargs.get("vectorfield_sampling_parameters") or {}),
737+
**asdict(posterior_parameters),
576738
)
577739
else:
578740
# Posteriors requiring potential_fn and theta_transform
579741
potential_fn, theta_transform = self._get_potential_function(
580742
prior, estimator
581743
)
582-
if sample_with == "mcmc":
744+
if isinstance(posterior_parameters, MCMCPosteriorParameters):
583745
posterior = MCMCPosterior(
584746
potential_fn=potential_fn,
585747
theta_transform=theta_transform,
586748
proposal=prior,
587-
method=kwargs.get("mcmc_method", "slice_np_vectorized"),
588749
device=device,
589-
**(kwargs.get("mcmc_parameters") or {}),
750+
**asdict(posterior_parameters),
590751
)
591-
elif sample_with == "rejection":
752+
elif isinstance(posterior_parameters, RejectionPosteriorParameters):
592753
posterior = RejectionPosterior(
593754
potential_fn=potential_fn,
594755
proposal=prior,
595756
device=device,
596-
**(kwargs.get("rejection_sampling_parameters") or {}),
757+
**asdict(posterior_parameters),
597758
)
598-
elif sample_with == "vi":
759+
elif isinstance(posterior_parameters, VIPosteriorParameters):
599760
posterior = VIPosterior(
600761
potential_fn=potential_fn,
601762
theta_transform=theta_transform,
602763
prior=prior,
603-
vi_method=kwargs.get("vi_method", "rKL"),
604764
device=device,
605-
**(kwargs.get("vi_parameters") or {}),
765+
**asdict(posterior_parameters),
606766
)
607-
elif sample_with == "importance":
767+
elif isinstance(
768+
posterior_parameters, ImportanceSamplingPosteriorParameters
769+
):
608770
posterior = ImportanceSamplingPosterior(
609771
potential_fn=potential_fn,
610772
proposal=prior,
611773
device=device,
612-
**(kwargs.get("importance_sampling_parameters") or {}),
774+
**asdict(posterior_parameters),
613775
)
614776
else:
615777
raise NotImplementedError(
616-
f"Sampling method '{sample_with}' is not supported."
778+
"Sampling method not implemented for",
779+
f"'{posterior_parameters}'",
617780
)
618-
619781
return posterior
620782

621783
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:

sbi/inference/trainers/fmpe/fmpe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from sbi import utils as utils
1111
from sbi.inference.posteriors.base_posterior import NeuralPosterior
12+
from sbi.inference.posteriors.posterior_parameters import VectorFieldPosteriorParameters
1213
from sbi.inference.trainers.npse.vector_field_inference import (
1314
VectorFieldEstimatorBuilder,
1415
VectorFieldTrainer,
@@ -65,6 +66,7 @@ def build_posterior(
6566
prior: Optional[Distribution] = None,
6667
sample_with: Literal["ode", "sde"] = "ode",
6768
vectorfield_sampling_parameters: Optional[Dict[str, Any]] = None,
69+
posterior_parameters: Optional[VectorFieldPosteriorParameters] = None,
6870
) -> NeuralPosterior:
6971
r"""Build posterior from the flow matching estimator.
7072
@@ -89,16 +91,19 @@ def build_posterior(
8991
a numerical ODE solver.
9092
vectorfield_sampling_parameters: Additional keyword arguments passed to
9193
`VectorFieldPosterior`.
92-
94+
posterior_parameters: Configuration passed to the init method for
95+
VectorFieldPosterior.
9396
9497
Returns:
9598
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods.
9699
"""
100+
97101
return super().build_posterior(
98102
estimator=vector_field_estimator,
99103
prior=prior,
100104
sample_with=sample_with,
101105
vectorfield_sampling_parameters=vectorfield_sampling_parameters,
106+
posterior_parameters=posterior_parameters,
102107
)
103108

104109
def _build_default_nn_fn(self, **kwargs) -> VectorFieldEstimatorBuilder:

0 commit comments

Comments
 (0)