Skip to content

Commit 90850f2

Browse files
committed
refactor vf slow tests.
1 parent 6a3a80c commit 90850f2

File tree

1 file changed

+70
-59
lines changed

1 file changed

+70
-59
lines changed

tests/linearGaussian_vector_field_test.py

Lines changed: 70 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import List
4+
from dataclasses import asdict
5+
from typing import List, Literal
56

67
import numpy as np
78
import pytest
@@ -21,6 +22,7 @@
2122
simulate_for_sbi,
2223
vector_field_estimator_based_potential,
2324
)
25+
from sbi.inference.posteriors import MCMCPosteriorParameters
2426
from sbi.inference.posteriors.posterior_parameters import VectorFieldPosteriorParameters
2527
from sbi.neural_nets.factory import posterior_flow_nn
2628
from sbi.simulators import linear_gaussian
@@ -56,7 +58,10 @@
5658
],
5759
)
5860
def test_c2st_vector_field_on_linearGaussian(
59-
vector_field_type, num_dim: int, prior_str: str, sample_with: List[str]
61+
vector_field_type,
62+
num_dim: int,
63+
prior_str: str,
64+
sample_with: List[Literal["sde", "ode"]],
6065
):
6166
"""
6267
Test whether NPSE and FMPE infer well a simple example with available ground truth.
@@ -118,8 +123,12 @@ def test_c2st_vector_field_on_linearGaussian(
118123
# For the Gaussian prior, we compute the KLd between ground truth and
119124
# posterior.
120125

126+
# For type checking below.
127+
assert isinstance(posterior, VectorFieldPosterior)
128+
121129
# Disable exact integration for the ODE solver to speed up the computation.
122130
# But this gives stochastic results -> increase max_dkl a bit
131+
123132
posterior.potential_fn.neural_ode.update_params(
124133
exact=False,
125134
atol=1e-4,
@@ -244,23 +253,23 @@ def test_vfinference_with_different_models(vector_field_type, model):
244253
# ------------------------------------------------------------------------------
245254

246255

247-
@pytest.fixture(scope="module", params=["vp", "ve", "subvp", "fmpe"])
248-
def vector_field_type(request):
249-
"""Module-scoped fixture for vector field type."""
250-
return request.param
256+
# NOTE: Using a function with explicit caching instead of a parametrized fixture here to
257+
# make the test cases below more readable and maintainable.
258+
259+
_trained_models_cache = {}
251260

252261

253-
@pytest.fixture(scope="module", params=["gaussian", "uniform"])
254-
def prior_type(request):
255-
"""Module-scoped fixture for prior type."""
256-
return request.param
262+
def train_vector_field_model(vector_field_type, prior_type):
263+
"""Factory function that trains a score estimator for NPSE tests with caching."""
264+
cache_key = (vector_field_type, prior_type)
257265

266+
# Return cached model if available
267+
if cache_key in _trained_models_cache:
268+
return _trained_models_cache[cache_key]
258269

259-
@pytest.fixture(scope="module")
260-
def vector_field_trained_model(vector_field_type, prior_type):
261-
"""Module-scoped fixture that trains a score estimator for NPSE tests."""
270+
# Train the model
262271
num_dim = 2
263-
num_simulations = 5000
272+
num_simulations = 6000
264273

265274
# likelihood_mean will be likelihood_shift+theta
266275
likelihood_shift = -1.0 * ones(num_dim)
@@ -290,13 +299,9 @@ def vector_field_trained_model(vector_field_type, prior_type):
290299
theta = prior.sample((num_simulations,))
291300
x = linear_gaussian(theta, likelihood_shift, likelihood_cov)
292301

293-
estimator = inference.append_simulations(theta, x).train(
294-
# stop_after_epochs=200,
295-
# training_batch_size=100,
296-
# max_num_epochs=50,
297-
)
302+
estimator = inference.append_simulations(theta, x).train()
298303

299-
return {
304+
result = {
300305
"estimator": estimator,
301306
"inference": inference,
302307
"prior": prior,
@@ -308,13 +313,22 @@ def vector_field_trained_model(vector_field_type, prior_type):
308313
"vector_field_type": vector_field_type,
309314
}
310315

316+
# Cache the result
317+
_trained_models_cache[cache_key] = result
318+
return result
319+
311320

312321
@pytest.mark.slow
313-
def test_vector_field_sde_ode_sampling_equivalence(vector_field_trained_model):
322+
@pytest.mark.parametrize(
323+
"vector_field_type, prior_type", [("ve", "gaussian"), ("fmpe", "gaussian")]
324+
)
325+
def test_vector_field_sde_ode_sampling_equivalence(vector_field_type, prior_type):
314326
"""
315327
Test whether SDE and ODE sampling are equivalent
316328
for FMPE and NPSE.
317329
"""
330+
vector_field_trained_model = train_vector_field_model(vector_field_type, prior_type)
331+
318332
num_samples = 1000
319333
x_o = zeros(1, vector_field_trained_model["num_dim"])
320334

@@ -334,49 +348,42 @@ def test_vector_field_sde_ode_sampling_equivalence(vector_field_trained_model):
334348
)
335349

336350

337-
# ------------------------------------------------------------------------------
338-
# ------------------------------- SKIPPED TESTS --------------------------------
339-
# ------------------------------------------------------------------------------
340-
341-
342-
# TODO: Currently, c2st is too high for FMPE (e.g., > 3 number of observations),
343-
# so some tests are skipped so far. This seems to be an issue with the
344-
# neural network architecture and can be addressed in PR #1501
345351
@pytest.mark.slow
352+
@pytest.mark.parametrize("vector_field_type", ["ve", "fmpe", "subvp", "vp"])
353+
@pytest.mark.parametrize("prior_type", ["gaussian", "uniform"])
346354
@pytest.mark.parametrize(
347-
"iid_method, num_trial",
355+
"iid_method, num_trials",
348356
[
349-
pytest.param(
350-
"fnpe",
351-
5,
352-
id="fnpe-5trials",
353-
# marks=pytest.mark.skip(reason="fails randomly, see #1646"),
354-
),
355-
# pytest.param("gauss", 5, id="gauss-5trials"),
356-
# pytest.param("auto_gauss", 5, id="auto_gauss-5trials"),
357-
# pytest.param("jac_gauss", 5, id="jac_gauss-5trials"),
357+
pytest.param("fnpe", 5, id="fnpe-5trials"),
358+
pytest.param("gauss", 5, id="gauss-5trials"),
359+
pytest.param("auto_gauss", 5, id="auto_gauss-5trials"),
360+
pytest.param("jac_gauss", 5, id="jac_gauss-5trials"),
358361
],
359362
)
360363
def test_vector_field_iid_inference(
361-
vector_field_trained_model, iid_method, num_trial, vector_field_type, prior_type
364+
vector_field_type, prior_type, iid_method, num_trials
362365
):
363366
"""
364367
Test whether NPSE and FMPE infers well a simple example with available ground truth.
365368
366369
Args:
367-
vector_field_trained_model: The trained vector field model.
370+
vector_field_type: The type of vector field ("ve", "fmpe", etc.).
371+
prior_type: The type of prior distribution ("gaussian" or "uniform").
368372
iid_method: The IID method to use for sampling.
369-
num_trial: The number of trials to run.
370-
vector_field_type: fixture for vector_field_type (e.g., "fmpe", "vp", "ve").
371-
prior_type: The type of prior distribution (e.g., "gaussian" or "uniform").
373+
num_trials: The number of trials to run.
372374
"""
373-
# if vector_field_type == "fmpe":
374-
# # TODO: Remove on merge
375-
# pytest.xfail(reason="c2st to high, fixed in PR #1501/1544", strict=True)
376375

377-
num_samples = 1000
376+
if (
377+
vector_field_type == "fmpe"
378+
and prior_type == "uniform"
379+
and iid_method in ["gauss", "auto_gauss", "jac_gauss"]
380+
):
381+
# TODO: Predictor produces NaNs for these cases, see #1656
382+
pytest.skip("Known issue with FMPE and IID methods with uniform priors")
378383

379-
# Extract data from fixture
384+
vector_field_trained_model = train_vector_field_model(vector_field_type, prior_type)
385+
386+
# Extract data from the trained model
380387
estimator = vector_field_trained_model["estimator"]
381388
inference = vector_field_trained_model["inference"]
382389
prior = vector_field_trained_model["prior"]
@@ -386,11 +393,13 @@ def test_vector_field_iid_inference(
386393
prior_cov = vector_field_trained_model["prior_cov"]
387394
num_dim = vector_field_trained_model["num_dim"]
388395

389-
x_o = zeros(num_trial, num_dim)
396+
num_samples = 1000
397+
398+
x_o = zeros(num_trials, num_dim)
390399

391400
posterior = inference.build_posterior(
392401
estimator,
393-
sample_with="sde",
402+
sample_with="sde", # iid works only with score-based SDEs.
394403
posterior_parameters=VectorFieldPosteriorParameters(iid_method=iid_method),
395404
)
396405
posterior.set_default_x(x_o)
@@ -406,7 +415,7 @@ def test_vector_field_iid_inference(
406415
x_o,
407416
likelihood_shift,
408417
likelihood_cov,
409-
prior, # type: ignore
418+
prior,
410419
)
411420
else:
412421
raise ValueError(f"Invalid prior type: {prior_type}")
@@ -419,9 +428,9 @@ def test_vector_field_iid_inference(
419428
target_samples,
420429
alg=(
421430
f"{vector_field_type}-{prior_type}-"
422-
f"{num_dim}-{iid_method}-{num_trial}iid-trials"
431+
f"{num_dim}-{iid_method}-{num_trials}iid-trials"
423432
),
424-
tol=0.05 * min(num_trial, 8),
433+
tol=0.07 * max(num_trials, 2),
425434
)
426435

427436

@@ -465,7 +474,7 @@ def test_vector_field_map(vector_field_type):
465474
# this will only work after implementing additional methods for vector fields,
466475
# so it is skipped for now.
467476
@pytest.mark.slow
468-
# @pytest.mark.skip(reason="Potential evaluation is not implemented for iid yet.")
477+
@pytest.mark.skip(reason="Potential evaluation is not implemented for iid yet.")
469478
def test_sample_conditional():
470479
"""
471480
Test whether sampling from the conditional gives the same results as evaluating.
@@ -483,7 +492,7 @@ def test_sample_conditional():
483492
num_simulations = 6000
484493
num_conditional_samples = 500
485494

486-
mcmc_parameters = dict(
495+
mcmc_parameters = MCMCPosteriorParameters(
487496
method="slice_np_vectorized", num_chains=20, warmup_steps=50, thin=5
488497
)
489498

@@ -511,7 +520,9 @@ def simulator(theta):
511520
)
512521

513522
# Test whether fmpe works properly with structured z-scoring.
514-
net = posterior_flow_nn("mlp", z_score_x="structured", hidden_features=[65] * 5)
523+
net = posterior_flow_nn(
524+
"mlp", z_score_x="structured", hidden_features=65, num_layers=5
525+
)
515526

516527
inference = FMPE(prior, density_estimator=net, show_progress_bars=False)
517528
posterior_estimator = inference.append_simulations(theta, x).train(
@@ -544,9 +555,9 @@ def simulator(theta):
544555
potential_fn=conditioned_potential_fn,
545556
theta_transform=restricted_tf,
546557
proposal=restricted_prior,
547-
**mcmc_parameters,
558+
**asdict(mcmc_parameters),
548559
)
549-
mcmc_posterior.set_default_x(x_o) # TODO: This test has a bug? Needed to add this
560+
mcmc_posterior.set_default_x(x_o)
550561
cond_samples = mcmc_posterior.sample((num_conditional_samples,))
551562

552563
_ = analysis.pairplot(

0 commit comments

Comments
 (0)