11# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33
4+ import warnings
45from abc import ABC , abstractmethod
56from copy import deepcopy
7+ from dataclasses import asdict
68from datetime import datetime
79from pathlib import Path
810from typing import Any , Callable , Dict , Literal , Optional , Tuple , Union
2022from sbi .inference .posteriors .direct_posterior import DirectPosterior
2123from sbi .inference .posteriors .importance_posterior import ImportanceSamplingPosterior
2224from sbi .inference .posteriors .mcmc_posterior import MCMCPosterior
25+ from sbi .inference .posteriors .posterior_parameters import (
26+ DirectPosteriorParameters ,
27+ ImportanceSamplingPosteriorParameters ,
28+ MCMCPosteriorParameters ,
29+ PosteriorParameters ,
30+ RejectionPosteriorParameters ,
31+ VIPosteriorParameters ,
32+ VectorFieldPosteriorParameters ,
33+ )
2334from sbi .inference .posteriors .rejection_posterior import RejectionPosterior
2435from sbi .inference .posteriors .vector_field_posterior import VectorFieldPosterior
2536from sbi .inference .posteriors .vi_posterior import VIPosterior
@@ -401,6 +412,7 @@ def build_posterior(
401412 sample_with : Literal [
402413 "mcmc" , "rejection" , "vi" , "importance" , "direct" , "sde" , "ode"
403414 ],
415+ posterior_parameters : Optional [PosteriorParameters ],
404416 ** kwargs ,
405417 ) -> NeuralPosterior :
406418 r"""Method for building posteriors.
@@ -422,6 +434,8 @@ def build_posterior(
422434 - "direct"
423435 - "sde"
424436 - "ode"
437+ posterior_parameters: Configuration passed to the init method for the
438+ posterior. Must be of type PosteriorParameters.
425439 **kwargs: Additional method-specific parameters.
426440
427441 Returns:
@@ -431,12 +445,16 @@ def build_posterior(
431445 prior = self ._resolve_prior (prior )
432446 estimator , device = self ._resolve_estimator (estimator )
433447
448+ posterior_parameters = self ._resolve_posterior_parameters (
449+ sample_with , posterior_parameters , ** kwargs
450+ )
451+
434452 self ._posterior = self ._create_posterior (
435453 estimator ,
436454 prior ,
437455 sample_with ,
438456 device ,
439- ** kwargs ,
457+ posterior_parameters ,
440458 )
441459
442460 # Store models at end of each round.
@@ -508,6 +526,144 @@ def _resolve_estimator(
508526
509527 return estimator , device
510528
529+ def _resolve_posterior_parameters (
530+ self ,
531+ sample_with : Literal [
532+ "mcmc" , "rejection" , "vi" , "importance" , "direct" , "sde" , "ode"
533+ ],
534+ posterior_parameters : Optional [PosteriorParameters ],
535+ ** kwargs ,
536+ ) -> PosteriorParameters :
537+ """
538+ Resolve posterior parameters based on the sampling strategy.
539+
540+ If `posterior_parameters` is provided, it is returned directly.
541+
542+ If `posterior_parameters` is not provided, this method extracts
543+ sampling-specific parameters from `kwargs` using predefined keys
544+ to instantiate the appropriate posterior parameters dataclass.
545+
546+ Raises:
547+ NotImplementedError: If an unsupported `sample_with` method is provided.
548+ ValueError: If posterior_parameter and a configuration dictionary are passed
549+ together.
550+
551+ Args:
552+ sample_with: The posterior sampling method to use.
553+ posterior_parameters: Optional preconstructed posterior parameter object.
554+ **kwargs: Additional parameters to construct the posterior parameters.
555+
556+ Returns:
557+ A dataclass instance containing the resolved posterior
558+ parameters.
559+ """
560+
561+ if posterior_parameters is not None :
562+ self ._validate_no_duplicate_parameters (** kwargs )
563+ self ._validate_posterior_parameters_consistency (
564+ posterior_parameters , ** kwargs
565+ )
566+ else :
567+ # Resolve parameters passed through kwargs and convert
568+ # into a subclass of PosteriorParameters
569+ if sample_with == "direct" :
570+ params = kwargs .get ("direct_sampling_parameters" , {}) or {}
571+ posterior_parameters = DirectPosteriorParameters (** params )
572+ elif sample_with == "mcmc" :
573+ params = kwargs .get ("mcmc_parameters" , {}) or {}
574+ posterior_parameters = MCMCPosteriorParameters (
575+ method = kwargs .get ("mcmc_method" , "slice_np_vectorized" ), ** params
576+ )
577+ elif sample_with in ("ode" , "sde" ):
578+ params = kwargs .get ("vectorfield_sampling_parameters" , {}) or {}
579+ posterior_parameters = VectorFieldPosteriorParameters (** params )
580+ elif sample_with == "rejection" :
581+ params = kwargs .get ("rejection_sampling_parameters" , {}) or {}
582+ posterior_parameters = RejectionPosteriorParameters (** params )
583+ elif sample_with == "vi" :
584+ params = kwargs .get ("vi_parameters" , {}) or {}
585+ posterior_parameters = VIPosteriorParameters (
586+ vi_method = kwargs .get ("vi_method" , "rKL" ), ** params
587+ )
588+ elif sample_with == "importance" :
589+ params = kwargs .get ("importance_sampling_parameters" , {}) or {}
590+ posterior_parameters = ImportanceSamplingPosteriorParameters (** params )
591+ else :
592+ raise NotImplementedError (
593+ "Posterior parameter construction not implemented for" ,
594+ f"'{ sample_with } '" ,
595+ )
596+
597+ return posterior_parameters
598+
599+ def _validate_no_duplicate_parameters (self , ** kwargs ) -> None :
600+ """
601+ Ensure parameters aren't specified in both posterior_parameters and the
602+ posterior parameter dictionaries in the build_posterior method.
603+
604+ Args:
605+ **kwargs: Additional parameters to construct the posterior parameters
606+ """
607+
608+ old_style_params = {
609+ "direct_sampling_parameters" ,
610+ "mcmc_parameters" ,
611+ "vectorfield_sampling_parameters" ,
612+ "rejection_sampling_parameters" ,
613+ "vi_parameters" ,
614+ "importance_sampling_parameters" ,
615+ }
616+
617+ # Check if any old-style parameters were provided
618+ provided_old_params = [
619+ param for param in old_style_params if kwargs .get (param ) is not None
620+ ]
621+
622+ if provided_old_params :
623+ raise ValueError (
624+ f"Cannot use both old-style parameters { provided_old_params } "
625+ f"and new-style posterior_parameters. Please use only one approach."
626+ )
627+
628+ def _validate_posterior_parameters_consistency (
629+ self , posterior_parameters : PosteriorParameters , ** kwargs
630+ ) -> None :
631+ """
632+ This method raises a warning for mismatches between values passed in
633+ mcmc_method and MCMCPosteriorParameters.method, or vi_method and
634+ VIPosteriorParameters.vi_method.
635+
636+ Args:
637+ posterior_parameters: Configuration passed to the init method for the
638+ posterior.
639+ kwargs: keyword arguments passed from build_posterior method.
640+ """
641+
642+ if not isinstance (posterior_parameters , PosteriorParameters ):
643+ raise TypeError (
644+ "posterior_parameters must be PosteriorParameters,"
645+ f" got { type (posterior_parameters ).__name__ } " ,
646+ )
647+ elif isinstance (posterior_parameters , MCMCPosteriorParameters ):
648+ mcmc_method = kwargs .get ("mcmc_method" )
649+ if (
650+ mcmc_method != "slice_np_vectorized"
651+ and posterior_parameters .method != mcmc_method
652+ ):
653+ warnings .warn (
654+ f"Conflicting mcmc_method='{ mcmc_method } ' ignored in favor of "
655+ f"posterior_parameters.method='{ posterior_parameters .method } '" ,
656+ stacklevel = 2 ,
657+ )
658+ elif isinstance (posterior_parameters , VIPosteriorParameters ):
659+ vi_method = kwargs .get ("vi_method" )
660+ if vi_method != "rKL" and posterior_parameters .vi_method != vi_method :
661+ warnings .warn (
662+ f"Conflicting vi_method='{ vi_method } ' ignored in favor of "
663+ f"posterior_parameters.vi_method='{ posterior_parameters .vi_method } '" ,
664+ stacklevel = 2 ,
665+ )
666+
511667 def _create_posterior (
512668 self ,
513669 estimator : Union [RatioEstimator , ConditionalEstimator ],
@@ -516,7 +672,7 @@ def _create_posterior(
516672 "mcmc" , "rejection" , "vi" , "importance" , "direct" , "sde" , "ode"
517673 ],
518674 device : Union [str , torch .device ],
519- ** kwargs ,
675+ posterior_parameters : PosteriorParameters ,
520676 ) -> NeuralPosterior :
521677 """
522678 Create a posterior object using the specified inference method.
@@ -539,83 +695,89 @@ def _create_posterior(
539695 - "ode"
540696 device: torch device on which to train the neural net and on which to
541697 perform all posterior operations, e.g. gpu or cpu.
542- **kwargs: Additional method-specific parameters.
698+ posterior_parameters: Configuration passed to the init method for the
699+ posterior. Must be of type PosteriorParameters.
543700
544701 Returns:
545702 NeuralPosterior object.
546703 """
547704
548- if sample_with == "direct" :
705+ if isinstance ( posterior_parameters , DirectPosteriorParameters ) :
549706 posterior_estimator = estimator
550- assert isinstance (posterior_estimator , ConditionalDensityEstimator ), (
551- f"Expected posterior_estimator to be an instance of "
552- " ConditionalDensityEstimator, "
553- f"but got { type (posterior_estimator ).__name__ } instead."
554- )
707+ if not isinstance (posterior_estimator , ConditionalDensityEstimator ):
708+ raise TypeError (
709+ f"Expected posterior_estimator to be an instance of "
710+ " ConditionalDensityEstimator, "
711+ f"but got { type (posterior_estimator ).__name__ } instead."
712+ )
555713 posterior = DirectPosterior (
556714 posterior_estimator = posterior_estimator ,
557715 prior = prior ,
558716 device = device ,
559- ** ( kwargs . get ( "direct_sampling_parameters" ) or {} ),
717+ ** asdict ( posterior_parameters ),
560718 )
561- elif sample_with in ( "sde" , "ode" ):
719+ elif isinstance ( posterior_parameters , VectorFieldPosteriorParameters ):
562720 vector_field_estimator = estimator
563- assert isinstance (
564- vector_field_estimator , ConditionalVectorFieldEstimator
565- ), (
566- f"Expected vector_field_estimator to be an instance of "
567- " ConditionalVectorFieldEstimator, "
568- f"but got { type (vector_field_estimator ).__name__ } instead."
569- )
721+ if not isinstance (vector_field_estimator , ConditionalVectorFieldEstimator ):
722+ raise TypeError (
723+ f"Expected vector_field_estimator to be an instance of "
724+ " ConditionalVectorFieldEstimator, "
725+ f"but got { type (vector_field_estimator ).__name__ } instead."
726+ )
727+ if sample_with not in ("ode" , "sde" ):
728+ raise ValueError (
729+ "`sample_with` must be either" ,
730+ f" 'ode' or 'sde', got '{ sample_with } '" ,
731+ )
570732 posterior = VectorFieldPosterior (
571- vector_field_estimator ,
572- prior ,
733+ vector_field_estimator = vector_field_estimator ,
734+ prior = prior ,
573735 device = device ,
574736 sample_with = sample_with ,
575- ** ( kwargs . get ( "vectorfield_sampling_parameters" ) or {} ),
737+ ** asdict ( posterior_parameters ),
576738 )
577739 else :
578740 # Posteriors requiring potential_fn and theta_transform
579741 potential_fn , theta_transform = self ._get_potential_function (
580742 prior , estimator
581743 )
582- if sample_with == "mcmc" :
744+ if isinstance ( posterior_parameters , MCMCPosteriorParameters ) :
583745 posterior = MCMCPosterior (
584746 potential_fn = potential_fn ,
585747 theta_transform = theta_transform ,
586748 proposal = prior ,
587- method = kwargs .get ("mcmc_method" , "slice_np_vectorized" ),
588749 device = device ,
589- ** ( kwargs . get ( "mcmc_parameters" ) or {} ),
750+ ** asdict ( posterior_parameters ),
590751 )
591- elif sample_with == "rejection" :
752+ elif isinstance ( posterior_parameters , RejectionPosteriorParameters ) :
592753 posterior = RejectionPosterior (
593754 potential_fn = potential_fn ,
594755 proposal = prior ,
595756 device = device ,
596- ** ( kwargs . get ( "rejection_sampling_parameters" ) or {} ),
757+ ** asdict ( posterior_parameters ),
597758 )
598- elif sample_with == "vi" :
759+ elif isinstance ( posterior_parameters , VIPosteriorParameters ) :
599760 posterior = VIPosterior (
600761 potential_fn = potential_fn ,
601762 theta_transform = theta_transform ,
602763 prior = prior ,
603- vi_method = kwargs .get ("vi_method" , "rKL" ),
604764 device = device ,
605- ** ( kwargs . get ( "vi_parameters" ) or {} ),
765+ ** asdict ( posterior_parameters ),
606766 )
607- elif sample_with == "importance" :
767+ elif isinstance (
768+ posterior_parameters , ImportanceSamplingPosteriorParameters
769+ ):
608770 posterior = ImportanceSamplingPosterior (
609771 potential_fn = potential_fn ,
610772 proposal = prior ,
611773 device = device ,
612- ** ( kwargs . get ( "importance_sampling_parameters" ) or {} ),
774+ ** asdict ( posterior_parameters ),
613775 )
614776 else :
615777 raise NotImplementedError (
616- f"Sampling method '{ sample_with } ' is not supported."
778+ "Sampling method not implemented for" ,
779+ f"'{ posterior_parameters } '" ,
617780 )
618-
619781 return posterior
620782
621783 def _converged (self , epoch : int , stop_after_epochs : int ) -> bool :
0 commit comments