2121from torch import eye , ones , zeros
2222from torch .distributions import Beta , Binomial , Gamma , MultivariateNormal
2323
24- from sbi .inference import NLE , likelihood_estimator_based_potential
24+ from sbi .inference import NLE , NRE , likelihood_estimator_based_potential
2525from sbi .inference .posteriors import VIPosterior
2626from sbi .inference .potentials .base_potential import BasePotential
27+ from sbi .inference .potentials .ratio_based_potential import (
28+ ratio_estimator_based_potential ,
29+ )
2730from sbi .neural_nets .factory import ZukoFlowType
2831from sbi .simulators .linear_gaussian import (
2932 linear_gaussian ,
@@ -75,16 +78,19 @@ def allow_iid_x(self) -> bool:
7578# =============================================================================
7679
7780
78- @pytest .fixture
79- def linear_gaussian_setup ():
80- """Setup for linear Gaussian test problem with trained NLE.
81+ def _build_linear_gaussian_setup (trainer_type : str = "nle" ):
82+ """Helper to build linear Gaussian setup with specified trainer type.
83+
84+ Args:
85+ trainer_type: Either "nle" or "nre".
8186
8287 Returns a dict with:
8388 - prior: MultivariateNormal prior
84- - potential_fn: Trained NLE-based potential function
89+ - potential_fn: Trained potential function (NLE or NRE based)
8590 - theta, x: Simulation data
8691 - likelihood_shift, likelihood_cov: Likelihood parameters
8792 - num_dim: Dimensionality
93+ - trainer_type: The trainer type used
8894 """
8995 torch .manual_seed (42 )
9096
@@ -101,17 +107,27 @@ def simulator(theta):
101107 theta = prior .sample ((num_simulations ,))
102108 x = simulator (theta )
103109
104- # Train NLE
105- trainer = NLE (prior = prior , show_progress_bars = False , density_estimator = "nsf" )
106- trainer .append_simulations (theta , x )
107- likelihood_estimator = trainer .train (max_num_epochs = 200 )
108-
109- # Create potential function
110- potential_fn , _ = likelihood_estimator_based_potential (
111- likelihood_estimator = likelihood_estimator ,
112- prior = prior ,
113- x_o = None ,
114- )
110+ # Train estimator and create potential based on trainer type
111+ if trainer_type == "nle" :
112+ trainer = NLE (prior = prior , show_progress_bars = False , density_estimator = "nsf" )
113+ trainer .append_simulations (theta , x )
114+ estimator = trainer .train (max_num_epochs = 200 )
115+ potential_fn , _ = likelihood_estimator_based_potential (
116+ likelihood_estimator = estimator ,
117+ prior = prior ,
118+ x_o = None ,
119+ )
120+ elif trainer_type == "nre" :
121+ trainer = NRE (prior = prior , show_progress_bars = False , classifier = "mlp" )
122+ trainer .append_simulations (theta , x )
123+ estimator = trainer .train (max_num_epochs = 200 )
124+ potential_fn , _ = ratio_estimator_based_potential (
125+ ratio_estimator = estimator ,
126+ prior = prior ,
127+ x_o = None ,
128+ )
129+ else :
130+ raise ValueError (f"Unknown trainer_type: { trainer_type } " )
115131
116132 return {
117133 "prior" : prior ,
@@ -121,9 +137,22 @@ def simulator(theta):
121137 "likelihood_shift" : likelihood_shift ,
122138 "likelihood_cov" : likelihood_cov ,
123139 "num_dim" : num_dim ,
140+ "trainer_type" : trainer_type ,
124141 }
125142
126143
144+ @pytest .fixture
145+ def linear_gaussian_setup ():
146+ """Setup for linear Gaussian test problem with trained NLE."""
147+ return _build_linear_gaussian_setup ("nle" )
148+
149+
150+ @pytest .fixture (params = ["nle" , "nre" ])
151+ def linear_gaussian_setup_trainers (request ):
152+ """Parametrized setup for linear Gaussian with NLE or NRE."""
153+ return _build_linear_gaussian_setup (request .param )
154+
155+
127156# =============================================================================
128157# Single-x VI Tests: train() method
129158# =============================================================================
@@ -435,9 +464,9 @@ def test_vi_flow_builders(num_dim: int, q_type: str):
435464
436465
437466@pytest .mark .slow
438- def test_amortized_vi_training ( linear_gaussian_setup ):
439- """Test that VIPosterior.train_amortized() trains successfully ."""
440- setup = linear_gaussian_setup
467+ def test_amortized_vi_accuracy ( linear_gaussian_setup_trainers ):
468+ """Test that amortized VI produces accurate posteriors (NLE and NRE) ."""
469+ setup = linear_gaussian_setup_trainers
441470
442471 posterior = VIPosterior (
443472 potential_fn = setup ["potential_fn" ],
@@ -454,29 +483,9 @@ def test_amortized_vi_training(linear_gaussian_setup):
454483 hidden_features = 32 ,
455484 )
456485
486+ # Verify training completed successfully
457487 assert posterior ._mode == "amortized"
458488
459-
460- @pytest .mark .slow
461- def test_amortized_vi_accuracy (linear_gaussian_setup ):
462- """Test that amortized VI produces accurate posteriors across observations."""
463- setup = linear_gaussian_setup
464-
465- posterior = VIPosterior (
466- potential_fn = setup ["potential_fn" ],
467- prior = setup ["prior" ],
468- )
469-
470- posterior .train_amortized (
471- theta = setup ["theta" ],
472- x = setup ["x" ],
473- max_num_iters = 500 ,
474- show_progress_bar = False ,
475- flow_type = ZukoFlowType .NSF ,
476- num_transforms = 2 ,
477- hidden_features = 32 ,
478- )
479-
480489 # Test on multiple observations
481490 test_x_os = [
482491 zeros (1 , setup ["num_dim" ]),
@@ -496,7 +505,10 @@ def test_amortized_vi_accuracy(linear_gaussian_setup):
496505 vi_samples = posterior .sample ((1000 ,), x = x_o )
497506
498507 c2st_score = c2st (true_samples , vi_samples ).item ()
499- assert c2st_score < 0.6 , f"C2ST too high for x_o={ x_o .squeeze ().tolist ()} "
508+ assert c2st_score < 0.65 , (
509+ f"C2ST too high for { setup ['trainer_type' ]} , "
510+ f"x_o={ x_o .squeeze ().tolist ()} : { c2st_score :.3f} "
511+ )
500512
501513
502514@pytest .mark .slow
@@ -594,119 +606,6 @@ def test_amortized_vi_requires_training(linear_gaussian_setup):
594606 posterior .sample ((100 ,))
595607
596608
597- @pytest .mark .slow
598- def test_amortized_vs_single_x_vi (linear_gaussian_setup ):
599- """Compare amortized VI against single-x VI for the same observation."""
600- setup = linear_gaussian_setup
601-
602- # Train amortized VI
603- amortized_posterior = VIPosterior (
604- potential_fn = setup ["potential_fn" ],
605- prior = setup ["prior" ],
606- )
607- amortized_posterior .train_amortized (
608- theta = setup ["theta" ],
609- x = setup ["x" ],
610- max_num_iters = 500 ,
611- show_progress_bar = False ,
612- flow_type = ZukoFlowType .NSF ,
613- num_transforms = 2 ,
614- hidden_features = 32 ,
615- )
616-
617- # Train single-x VI for a specific observation
618- x_o = torch .tensor ([[0.5 , 0.5 ]])
619- potential_fn_single , _ = likelihood_estimator_based_potential (
620- likelihood_estimator = setup ["potential_fn" ].likelihood_estimator ,
621- prior = setup ["prior" ],
622- x_o = x_o ,
623- )
624-
625- single_x_posterior = VIPosterior (
626- potential_fn = potential_fn_single ,
627- prior = setup ["prior" ],
628- q = "maf" ,
629- )
630- single_x_posterior .set_default_x (x_o )
631- single_x_posterior .train (max_num_iters = 500 , show_progress_bar = False )
632-
633- # Get ground truth
634- true_posterior = true_posterior_linear_gaussian_mvn_prior (
635- x_o .squeeze (0 ),
636- setup ["likelihood_shift" ],
637- setup ["likelihood_cov" ],
638- zeros (setup ["num_dim" ]),
639- eye (setup ["num_dim" ]),
640- )
641- true_samples = true_posterior .sample ((1000 ,))
642-
643- # Compare
644- amortized_samples = amortized_posterior .sample ((1000 ,), x = x_o )
645- single_x_samples = single_x_posterior .sample ((1000 ,))
646-
647- c2st_amortized = c2st (true_samples , amortized_samples ).item ()
648- c2st_single_x = c2st (true_samples , single_x_samples ).item ()
649-
650- assert c2st_amortized < 0.6 , f"Amortized VI C2ST too high: { c2st_amortized :.3f} "
651- assert abs (c2st_amortized - c2st_single_x ) < 0.15 , (
652- f"Amortized ({ c2st_amortized :.3f} ) much worse than "
653- f"single-x ({ c2st_single_x :.3f} )"
654- )
655-
656-
657- @pytest .mark .slow
658- def test_amortized_vi_gradient_flow (linear_gaussian_setup ):
659- """Verify gradients flow through ELBO to flow parameters."""
660- from torch .optim import Adam
661-
662- setup = linear_gaussian_setup
663-
664- posterior = VIPosterior (
665- potential_fn = setup ["potential_fn" ],
666- prior = setup ["prior" ],
667- )
668-
669- # Build the flow manually
670- theta = setup ["theta" ][:500 ].to ("cpu" )
671- x = setup ["x" ][:500 ].to ("cpu" )
672- posterior ._amortized_q = posterior ._build_conditional_flow (
673- theta [:100 ],
674- x [:100 ],
675- flow_type = ZukoFlowType .NSF ,
676- num_transforms = 2 ,
677- hidden_features = 32 ,
678- )
679- posterior ._amortized_q .to ("cpu" )
680- posterior ._mode = "amortized"
681-
682- initial_params = [p .clone () for p in posterior ._amortized_q .parameters ()]
683-
684- # Compute ELBO loss and backprop
685- optimizer = Adam (posterior ._amortized_q .parameters (), lr = 1e-3 )
686- optimizer .zero_grad ()
687-
688- x_batch = x [:8 ]
689- loss = posterior ._compute_amortized_elbo_loss (x_batch , n_particles = 16 )
690- loss .backward ()
691-
692- # Verify gradients exist and are non-zero
693- for i , p in enumerate (posterior ._amortized_q .parameters ()):
694- assert p .grad is not None , f"Gradient is None for parameter { i } "
695- assert torch .isfinite (p .grad ).all (), f"Gradient has NaN/Inf for param { i } "
696- assert p .grad .abs ().max () > 1e-10 , f"Gradient is zero for parameter { i } "
697-
698- # Verify parameters change after optimization step
699- optimizer .step ()
700- changed_count = sum (
701- 1
702- for p_init , p_new in zip (
703- initial_params , posterior ._amortized_q .parameters (), strict = True
704- )
705- if not torch .allclose (p_init , p_new .detach (), atol = 1e-8 )
706- )
707- assert changed_count > 0 , "No parameters changed after optimization step"
708-
709-
710609@pytest .mark .slow
711610def test_amortized_vi_map (linear_gaussian_setup ):
712611 """Test that MAP estimation returns high-density region."""
@@ -755,3 +654,49 @@ def test_amortized_vi_map(linear_gaussian_setup):
755654 f"MAP log_prob { map_log_prob .item ():.3f} not better than "
756655 f"median random { random_log_probs .median ().item ():.3f} "
757656 )
657+
658+
659+ def test_amortized_vi_with_fake_potential ():
660+ """Fast test for amortized VI using FakePotential (no NLE training required).
661+
662+ This test runs in CI (not marked slow) to ensure amortized VI coverage.
663+ Uses FakePotential where the posterior equals the prior.
664+ """
665+ torch .manual_seed (42 )
666+
667+ num_dim = 2
668+ prior = MultivariateNormal (zeros (num_dim ), eye (num_dim ))
669+ potential_fn = FakePotential (prior = prior )
670+
671+ # Generate mock simulation data (not actually used for training potential)
672+ theta = prior .sample ((500 ,))
673+ x = theta + 0.1 * torch .randn_like (theta ) # Noisy observations
674+
675+ posterior = VIPosterior (
676+ potential_fn = potential_fn ,
677+ prior = prior ,
678+ )
679+
680+ # Train amortized VI
681+ posterior .train_amortized (
682+ theta = theta ,
683+ x = x ,
684+ max_num_iters = 100 , # Fewer iterations for speed
685+ show_progress_bar = False ,
686+ flow_type = ZukoFlowType .NSF ,
687+ num_transforms = 2 ,
688+ hidden_features = 16 , # Smaller network for speed
689+ )
690+
691+ # Verify training completed
692+ assert posterior ._mode == "amortized"
693+
694+ # Test sampling works
695+ x_test = torch .randn (1 , num_dim )
696+ samples = posterior .sample ((100 ,), x = x_test )
697+ assert samples .shape == (100 , num_dim )
698+
699+ # Test log_prob works
700+ log_probs = posterior .log_prob (samples , x = x_test )
701+ assert log_probs .shape == (100 ,)
702+ assert torch .isfinite (log_probs ).all ()
0 commit comments