11# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33
4- from typing import List
4+ from dataclasses import asdict
5+ from typing import List , Literal
56
67import numpy as np
78import pytest
2122 simulate_for_sbi ,
2223 vector_field_estimator_based_potential ,
2324)
25+ from sbi .inference .posteriors import MCMCPosteriorParameters
2426from sbi .inference .posteriors .posterior_parameters import VectorFieldPosteriorParameters
2527from sbi .neural_nets .factory import posterior_flow_nn
2628from sbi .simulators import linear_gaussian
5658 ],
5759)
5860def test_c2st_vector_field_on_linearGaussian (
59- vector_field_type , num_dim : int , prior_str : str , sample_with : List [str ]
61+ vector_field_type ,
62+ num_dim : int ,
63+ prior_str : str ,
64+ sample_with : List [Literal ["sde" , "ode" ]],
6065):
6166 """
6267 Test whether NPSE and FMPE infer well a simple example with available ground truth.
@@ -118,8 +123,12 @@ def test_c2st_vector_field_on_linearGaussian(
118123 # For the Gaussian prior, we compute the KLd between ground truth and
119124 # posterior.
120125
126+ # For type checking below.
127+ assert isinstance (posterior , VectorFieldPosterior )
128+
121129 # Disable exact integration for the ODE solver to speed up the computation.
122130 # But this gives stochastic results -> increase max_dkl a bit
131+
123132 posterior .potential_fn .neural_ode .update_params (
124133 exact = False ,
125134 atol = 1e-4 ,
@@ -244,23 +253,23 @@ def test_vfinference_with_different_models(vector_field_type, model):
244253# ------------------------------------------------------------------------------
245254
246255
247- @ pytest . fixture ( scope = "module" , params = [ "vp" , "ve" , "subvp" , "fmpe" ])
248- def vector_field_type ( request ):
249- """Module-scoped fixture for vector field type."""
250- return request . param
256+ # NOTE: Using a function with explicit caching instead of a parametrized fixture here to
257+ # make the test cases below more readable and maintainable.
258+
259+ _trained_models_cache = {}
251260
252261
253- @pytest .fixture (scope = "module" , params = ["gaussian" , "uniform" ])
254- def prior_type (request ):
255- """Module-scoped fixture for prior type."""
256- return request .param
262+ def train_vector_field_model (vector_field_type , prior_type ):
263+ """Factory function that trains a score estimator for NPSE tests with caching."""
264+ cache_key = (vector_field_type , prior_type )
257265
266+ # Return cached model if available
267+ if cache_key in _trained_models_cache :
268+ return _trained_models_cache [cache_key ]
258269
259- @pytest .fixture (scope = "module" )
260- def vector_field_trained_model (vector_field_type , prior_type ):
261- """Module-scoped fixture that trains a score estimator for NPSE tests."""
270+ # Train the model
262271 num_dim = 2
263- num_simulations = 5000
272+ num_simulations = 6000
264273
265274 # likelihood_mean will be likelihood_shift+theta
266275 likelihood_shift = - 1.0 * ones (num_dim )
@@ -290,13 +299,9 @@ def vector_field_trained_model(vector_field_type, prior_type):
290299 theta = prior .sample ((num_simulations ,))
291300 x = linear_gaussian (theta , likelihood_shift , likelihood_cov )
292301
293- estimator = inference .append_simulations (theta , x ).train (
294- # stop_after_epochs=200,
295- # training_batch_size=100,
296- # max_num_epochs=50,
297- )
302+ estimator = inference .append_simulations (theta , x ).train ()
298303
299- return {
304+ result = {
300305 "estimator" : estimator ,
301306 "inference" : inference ,
302307 "prior" : prior ,
@@ -308,13 +313,22 @@ def vector_field_trained_model(vector_field_type, prior_type):
308313 "vector_field_type" : vector_field_type ,
309314 }
310315
316+ # Cache the result
317+ _trained_models_cache [cache_key ] = result
318+ return result
319+
311320
312321@pytest .mark .slow
313- def test_vector_field_sde_ode_sampling_equivalence (vector_field_trained_model ):
322+ @pytest .mark .parametrize (
323+ "vector_field_type, prior_type" , [("ve" , "gaussian" ), ("fmpe" , "gaussian" )]
324+ )
325+ def test_vector_field_sde_ode_sampling_equivalence (vector_field_type , prior_type ):
314326 """
315327 Test whether SDE and ODE sampling are equivalent
316328 for FMPE and NPSE.
317329 """
330+ vector_field_trained_model = train_vector_field_model (vector_field_type , prior_type )
331+
318332 num_samples = 1000
319333 x_o = zeros (1 , vector_field_trained_model ["num_dim" ])
320334
@@ -334,49 +348,42 @@ def test_vector_field_sde_ode_sampling_equivalence(vector_field_trained_model):
334348 )
335349
336350
337- # ------------------------------------------------------------------------------
338- # ------------------------------- SKIPPED TESTS --------------------------------
339- # ------------------------------------------------------------------------------
340-
341-
342- # TODO: Currently, c2st is too high for FMPE (e.g., > 3 number of observations),
343- # so some tests are skipped so far. This seems to be an issue with the
344- # neural network architecture and can be addressed in PR #1501
345351@pytest .mark .slow
352+ @pytest .mark .parametrize ("vector_field_type" , ["ve" , "fmpe" , "subvp" , "vp" ])
353+ @pytest .mark .parametrize ("prior_type" , ["gaussian" , "uniform" ])
346354@pytest .mark .parametrize (
347- "iid_method, num_trial " ,
355+ "iid_method, num_trials " ,
348356 [
349- pytest .param (
350- "fnpe" ,
351- 5 ,
352- id = "fnpe-5trials" ,
353- # marks=pytest.mark.skip(reason="fails randomly, see #1646"),
354- ),
355- # pytest.param("gauss", 5, id="gauss-5trials"),
356- # pytest.param("auto_gauss", 5, id="auto_gauss-5trials"),
357- # pytest.param("jac_gauss", 5, id="jac_gauss-5trials"),
357+ pytest .param ("fnpe" , 5 , id = "fnpe-5trials" ),
358+ pytest .param ("gauss" , 5 , id = "gauss-5trials" ),
359+ pytest .param ("auto_gauss" , 5 , id = "auto_gauss-5trials" ),
360+ pytest .param ("jac_gauss" , 5 , id = "jac_gauss-5trials" ),
358361 ],
359362)
360363def test_vector_field_iid_inference (
361- vector_field_trained_model , iid_method , num_trial , vector_field_type , prior_type
364+ vector_field_type , prior_type , iid_method , num_trials
362365):
363366 """
364367 Test whether NPSE and FMPE infers well a simple example with available ground truth.
365368
366369 Args:
367- vector_field_trained_model: The trained vector field model.
370+ vector_field_type: The type of vector field ("ve", "fmpe", etc.).
371+ prior_type: The type of prior distribution ("gaussian" or "uniform").
368372 iid_method: The IID method to use for sampling.
369- num_trial: The number of trials to run.
370- vector_field_type: fixture for vector_field_type (e.g., "fmpe", "vp", "ve").
371- prior_type: The type of prior distribution (e.g., "gaussian" or "uniform").
373+ num_trials: The number of trials to run.
372374 """
373- # if vector_field_type == "fmpe":
374- # # TODO: Remove on merge
375- # pytest.xfail(reason="c2st to high, fixed in PR #1501/1544", strict=True)
376375
377- num_samples = 1000
376+ if (
377+ vector_field_type == "fmpe"
378+ and prior_type == "uniform"
379+ and iid_method in ["gauss" , "auto_gauss" , "jac_gauss" ]
380+ ):
381+ # TODO: Predictor produces NaNs for these cases, see #1656
382+ pytest .skip ("Known issue with FMPE and IID methods with uniform priors" )
378383
379- # Extract data from fixture
384+ vector_field_trained_model = train_vector_field_model (vector_field_type , prior_type )
385+
386+ # Extract data from the trained model
380387 estimator = vector_field_trained_model ["estimator" ]
381388 inference = vector_field_trained_model ["inference" ]
382389 prior = vector_field_trained_model ["prior" ]
@@ -386,11 +393,13 @@ def test_vector_field_iid_inference(
386393 prior_cov = vector_field_trained_model ["prior_cov" ]
387394 num_dim = vector_field_trained_model ["num_dim" ]
388395
389- x_o = zeros (num_trial , num_dim )
396+ num_samples = 1000
397+
398+ x_o = zeros (num_trials , num_dim )
390399
391400 posterior = inference .build_posterior (
392401 estimator ,
393- sample_with = "sde" ,
402+ sample_with = "sde" , # iid works only with score-based SDEs.
394403 posterior_parameters = VectorFieldPosteriorParameters (iid_method = iid_method ),
395404 )
396405 posterior .set_default_x (x_o )
@@ -406,7 +415,7 @@ def test_vector_field_iid_inference(
406415 x_o ,
407416 likelihood_shift ,
408417 likelihood_cov ,
409- prior , # type: ignore
418+ prior ,
410419 )
411420 else :
412421 raise ValueError (f"Invalid prior type: { prior_type } " )
@@ -419,9 +428,9 @@ def test_vector_field_iid_inference(
419428 target_samples ,
420429 alg = (
421430 f"{ vector_field_type } -{ prior_type } -"
422- f"{ num_dim } -{ iid_method } -{ num_trial } iid-trials"
431+ f"{ num_dim } -{ iid_method } -{ num_trials } iid-trials"
423432 ),
424- tol = 0.05 * min ( num_trial , 8 ),
433+ tol = 0.07 * max ( num_trials , 2 ),
425434 )
426435
427436
@@ -465,7 +474,7 @@ def test_vector_field_map(vector_field_type):
465474# this will only work after implementing additional methods for vector fields,
466475# so it is skipped for now.
467476@pytest .mark .slow
468- # @pytest.mark.skip(reason="Potential evaluation is not implemented for iid yet.")
477+ @pytest .mark .skip (reason = "Potential evaluation is not implemented for iid yet." )
469478def test_sample_conditional ():
470479 """
471480 Test whether sampling from the conditional gives the same results as evaluating.
@@ -483,7 +492,7 @@ def test_sample_conditional():
483492 num_simulations = 6000
484493 num_conditional_samples = 500
485494
486- mcmc_parameters = dict (
495+ mcmc_parameters = MCMCPosteriorParameters (
487496 method = "slice_np_vectorized" , num_chains = 20 , warmup_steps = 50 , thin = 5
488497 )
489498
@@ -511,7 +520,9 @@ def simulator(theta):
511520 )
512521
513522 # Test whether fmpe works properly with structured z-scoring.
514- net = posterior_flow_nn ("mlp" , z_score_x = "structured" , hidden_features = [65 ] * 5 )
523+ net = posterior_flow_nn (
524+ "mlp" , z_score_x = "structured" , hidden_features = 65 , num_layers = 5
525+ )
515526
516527 inference = FMPE (prior , density_estimator = net , show_progress_bars = False )
517528 posterior_estimator = inference .append_simulations (theta , x ).train (
@@ -544,9 +555,9 @@ def simulator(theta):
544555 potential_fn = conditioned_potential_fn ,
545556 theta_transform = restricted_tf ,
546557 proposal = restricted_prior ,
547- ** mcmc_parameters ,
558+ ** asdict ( mcmc_parameters ) ,
548559 )
549- mcmc_posterior .set_default_x (x_o ) # TODO: This test has a bug? Needed to add this
560+ mcmc_posterior .set_default_x (x_o )
550561 cond_samples = mcmc_posterior .sample ((num_conditional_samples ,))
551562
552563 _ = analysis .pairplot (
0 commit comments