1515from torch .distributions import MultivariateNormal
1616
1717from sbi import utils as utils
18- from sbi .inference import (
19- ABC ,
20- FMPE ,
21- NLE ,
22- NPE ,
23- NPE_A ,
24- NPE_C ,
25- NRE_A ,
26- NRE_B ,
27- NRE_C ,
28- VIPosterior ,
29- likelihood_estimator_based_potential ,
30- ratio_estimator_based_potential ,
31- )
18+ from sbi .inference .abc import MCABC as ABC
3219from sbi .inference .posteriors .ensemble_posterior import (
3320 EnsemblePotential ,
3421)
3724from sbi .inference .posteriors .posterior_parameters import (
3825 MCMCPosteriorParameters ,
3926)
27+ from sbi .inference .posteriors .vi_posterior import VIPosterior
4028from sbi .inference .potentials .base_potential import BasePotential
41- from sbi .inference .potentials .likelihood_based_potential import LikelihoodBasedPotential
29+ from sbi .inference .potentials .likelihood_based_potential import (
30+ LikelihoodBasedPotential ,
31+ likelihood_estimator_based_potential ,
32+ )
4233from sbi .inference .potentials .posterior_based_potential import PosteriorBasedPotential
43- from sbi .inference .potentials .ratio_based_potential import RatioBasedPotential
34+ from sbi .inference .potentials .ratio_based_potential import (
35+ RatioBasedPotential ,
36+ ratio_estimator_based_potential ,
37+ )
38+ from sbi .inference .trainers .nle import NLE
39+ from sbi .inference .trainers .npe import NPE , NPE_A , NPE_C
40+ from sbi .inference .trainers .nre import NRE_A , NRE_B , NRE_C
41+ from sbi .inference .trainers .vfpe import FMPE , NPSE
4442from sbi .neural_nets .embedding_nets import FCEmbedding
4543from sbi .neural_nets .factory import (
4644 classifier_nn ,
4745 embedding_net_warn_msg ,
4846 likelihood_nn ,
4947 posterior_nn ,
5048)
51- from sbi .simulators import diagonal_linear_gaussian , linear_gaussian
52- from sbi .utils .torchutils import (
53- BoxUniform ,
54- gpu_available ,
55- process_device ,
56- )
57- from sbi .utils .user_input_checks import (
58- validate_theta_and_x ,
59- )
49+ from sbi .simulators .linear_gaussian import diagonal_linear_gaussian , linear_gaussian
50+ from sbi .utils import BoxUniform
51+ from sbi .utils .torchutils import gpu_available , process_device
52+ from sbi .utils .user_input_checks import validate_theta_and_x
6053
61- # tests in this file are skipped if there is GPU device available
6254pytestmark = pytest .mark .skipif (
6355 not gpu_available (), reason = "No CUDA or MPS device available."
6456)
@@ -718,27 +710,43 @@ def test_to_method_on_posteriors(device: str, sampling_method: str):
718710@pytest .mark .gpu
719711@pytest .mark .parametrize ("device" , ["cpu" , "gpu" ])
720712@pytest .mark .parametrize ("device_inference" , ["cpu" , "gpu" ])
721- @pytest .mark .parametrize (
722- "iid_method" , ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" , None ]
723- )
724- def test_VectorFieldPosterior_device_handling (
725- device : str , device_inference : str , iid_method : str
713+ @pytest .mark .parametrize ("num_trials" , [1 , 2 ])
714+ @pytest .mark .parametrize ("vf_trainer" , [FMPE , NPSE ])
715+ def test_vector_field_methods_device_handling (
716+ vf_trainer , device : str , device_inference : str , num_trials : int
726717):
727718 """Test VectorFieldPosterior on different devices training and inference devices.
728719
720+ Tests both ode and sde sampling for both FMPE and NPSE.
721+
722+ Tests iid methods for num_trials = 2.
723+
729724 Args:
725+ vf_trainer: vector field trainer class to use.
730726 device: device to train the model on.
731727 device_inference: device to run the inference on.
732728 iid_method: method to sample from the posterior.
733729 """
730+
731+ num_dims = 2
732+ num_simulations = 1000
733+ if vf_trainer == NPSE :
734+ iid_methods = ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" ]
735+ else :
736+ iid_methods = ["fnpe" ]
737+
734738 device = process_device (device )
735739 device_inference = process_device (device_inference )
736- prior = BoxUniform (torch .zeros (3 ), torch .ones (3 ), device = device )
737- inference = FMPE (score_estimator = "mlp" , prior = prior , device = device )
738- density_estimator = inference .append_simulations (
739- torch .randn ((100 , 3 )), torch .randn ((100 , 2 ))
740- ).train (max_num_epochs = 1 )
741- posterior = inference .build_posterior (density_estimator , prior )
740+
741+ prior = BoxUniform (torch .zeros (num_dims ), torch .ones (num_dims ), device = device )
742+ theta = prior .sample ((num_simulations ,))
743+ x = theta + 0.1 * torch .randn_like (theta )
744+
745+ inference = vf_trainer (prior = prior , device = device )
746+ _ = inference .append_simulations (theta , x ).train (max_num_epochs = 10 )
747+ posterior = inference .build_posterior (
748+ sample_with = "sde" if num_trials > 1 else "ode"
749+ )
742750
743751 # faster but inaccurate log_prob computation
744752 posterior .potential_fn .neural_ode .update_params (exact = False , atol = 1e-4 , rtol = 1e-4 )
@@ -748,13 +756,23 @@ def test_VectorFieldPosterior_device_handling(
748756 f"VectorFieldPosterior is not in device { device_inference } ."
749757 )
750758
751- x_o = torch .ones (2 ).to (device_inference )
752- samples = posterior .sample ((2 ,), x = x_o , iid_method = iid_method )
753- assert samples .device .type == device_inference .split (":" )[0 ], (
754- f"Samples are not on device { device_inference } ."
755- )
759+ x_o = torch .ones (num_trials , num_dims ).to (device_inference )
760+ if num_trials > 1 :
761+ for iid_method in iid_methods :
762+ samples = posterior .sample ((2 ,), x = x_o , iid_method = iid_method )
763+ assert samples .device .type == device_inference .split (":" )[0 ], (
764+ f"Samples are not on device { device_inference } . "
765+ f"{ vf_trainer .__name__ } with { iid_method } "
766+ )
767+ else :
768+ samples = posterior .sample ((2 ,), x = x_o )
769+ assert samples .device .type == device_inference .split (":" )[0 ], (
770+ f"Samples are not on device { device_inference } . "
771+ f"{ vf_trainer .__name__ } with { iid_method } "
772+ )
756773
757- log_probs = posterior .log_prob (samples , x = x_o )
758- assert log_probs .device .type == device_inference .split (":" )[0 ], (
759- f"log_prob was not correctly moved to { device_inference } ."
760- )
774+ log_probs = posterior .log_prob (samples , x = x_o )
775+ assert log_probs .device .type == device_inference .split (":" )[0 ], (
776+ f"log_prob was not correctly moved to { device_inference } . "
777+ f"{ vf_trainer .__name__ } with { iid_method } "
778+ )
0 commit comments