Skip to content

Commit c53a83a

Browse files
committed
refactor tests: add fast amortized test; remove redundant.
1 parent 14abaf2 commit c53a83a

File tree

1 file changed

+99
-154
lines changed

1 file changed

+99
-154
lines changed

tests/vi_test.py

Lines changed: 99 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
from torch import eye, ones, zeros
2222
from torch.distributions import Beta, Binomial, Gamma, MultivariateNormal
2323

24-
from sbi.inference import NLE, likelihood_estimator_based_potential
24+
from sbi.inference import NLE, NRE, likelihood_estimator_based_potential
2525
from sbi.inference.posteriors import VIPosterior
2626
from sbi.inference.potentials.base_potential import BasePotential
27+
from sbi.inference.potentials.ratio_based_potential import (
28+
ratio_estimator_based_potential,
29+
)
2730
from sbi.neural_nets.factory import ZukoFlowType
2831
from sbi.simulators.linear_gaussian import (
2932
linear_gaussian,
@@ -75,16 +78,19 @@ def allow_iid_x(self) -> bool:
7578
# =============================================================================
7679

7780

78-
@pytest.fixture
79-
def linear_gaussian_setup():
80-
"""Setup for linear Gaussian test problem with trained NLE.
81+
def _build_linear_gaussian_setup(trainer_type: str = "nle"):
82+
"""Helper to build linear Gaussian setup with specified trainer type.
83+
84+
Args:
85+
trainer_type: Either "nle" or "nre".
8186
8287
Returns a dict with:
8388
- prior: MultivariateNormal prior
84-
- potential_fn: Trained NLE-based potential function
89+
- potential_fn: Trained potential function (NLE or NRE based)
8590
- theta, x: Simulation data
8691
- likelihood_shift, likelihood_cov: Likelihood parameters
8792
- num_dim: Dimensionality
93+
- trainer_type: The trainer type used
8894
"""
8995
torch.manual_seed(42)
9096

@@ -101,17 +107,27 @@ def simulator(theta):
101107
theta = prior.sample((num_simulations,))
102108
x = simulator(theta)
103109

104-
# Train NLE
105-
trainer = NLE(prior=prior, show_progress_bars=False, density_estimator="nsf")
106-
trainer.append_simulations(theta, x)
107-
likelihood_estimator = trainer.train(max_num_epochs=200)
108-
109-
# Create potential function
110-
potential_fn, _ = likelihood_estimator_based_potential(
111-
likelihood_estimator=likelihood_estimator,
112-
prior=prior,
113-
x_o=None,
114-
)
110+
# Train estimator and create potential based on trainer type
111+
if trainer_type == "nle":
112+
trainer = NLE(prior=prior, show_progress_bars=False, density_estimator="nsf")
113+
trainer.append_simulations(theta, x)
114+
estimator = trainer.train(max_num_epochs=200)
115+
potential_fn, _ = likelihood_estimator_based_potential(
116+
likelihood_estimator=estimator,
117+
prior=prior,
118+
x_o=None,
119+
)
120+
elif trainer_type == "nre":
121+
trainer = NRE(prior=prior, show_progress_bars=False, classifier="mlp")
122+
trainer.append_simulations(theta, x)
123+
estimator = trainer.train(max_num_epochs=200)
124+
potential_fn, _ = ratio_estimator_based_potential(
125+
ratio_estimator=estimator,
126+
prior=prior,
127+
x_o=None,
128+
)
129+
else:
130+
raise ValueError(f"Unknown trainer_type: {trainer_type}")
115131

116132
return {
117133
"prior": prior,
@@ -121,9 +137,22 @@ def simulator(theta):
121137
"likelihood_shift": likelihood_shift,
122138
"likelihood_cov": likelihood_cov,
123139
"num_dim": num_dim,
140+
"trainer_type": trainer_type,
124141
}
125142

126143

144+
@pytest.fixture
145+
def linear_gaussian_setup():
146+
"""Setup for linear Gaussian test problem with trained NLE."""
147+
return _build_linear_gaussian_setup("nle")
148+
149+
150+
@pytest.fixture(params=["nle", "nre"])
151+
def linear_gaussian_setup_trainers(request):
152+
"""Parametrized setup for linear Gaussian with NLE or NRE."""
153+
return _build_linear_gaussian_setup(request.param)
154+
155+
127156
# =============================================================================
128157
# Single-x VI Tests: train() method
129158
# =============================================================================
@@ -435,9 +464,9 @@ def test_vi_flow_builders(num_dim: int, q_type: str):
435464

436465

437466
@pytest.mark.slow
438-
def test_amortized_vi_training(linear_gaussian_setup):
439-
"""Test that VIPosterior.train_amortized() trains successfully."""
440-
setup = linear_gaussian_setup
467+
def test_amortized_vi_accuracy(linear_gaussian_setup_trainers):
468+
"""Test that amortized VI produces accurate posteriors (NLE and NRE)."""
469+
setup = linear_gaussian_setup_trainers
441470

442471
posterior = VIPosterior(
443472
potential_fn=setup["potential_fn"],
@@ -454,29 +483,9 @@ def test_amortized_vi_training(linear_gaussian_setup):
454483
hidden_features=32,
455484
)
456485

486+
# Verify training completed successfully
457487
assert posterior._mode == "amortized"
458488

459-
460-
@pytest.mark.slow
461-
def test_amortized_vi_accuracy(linear_gaussian_setup):
462-
"""Test that amortized VI produces accurate posteriors across observations."""
463-
setup = linear_gaussian_setup
464-
465-
posterior = VIPosterior(
466-
potential_fn=setup["potential_fn"],
467-
prior=setup["prior"],
468-
)
469-
470-
posterior.train_amortized(
471-
theta=setup["theta"],
472-
x=setup["x"],
473-
max_num_iters=500,
474-
show_progress_bar=False,
475-
flow_type=ZukoFlowType.NSF,
476-
num_transforms=2,
477-
hidden_features=32,
478-
)
479-
480489
# Test on multiple observations
481490
test_x_os = [
482491
zeros(1, setup["num_dim"]),
@@ -496,7 +505,10 @@ def test_amortized_vi_accuracy(linear_gaussian_setup):
496505
vi_samples = posterior.sample((1000,), x=x_o)
497506

498507
c2st_score = c2st(true_samples, vi_samples).item()
499-
assert c2st_score < 0.6, f"C2ST too high for x_o={x_o.squeeze().tolist()}"
508+
assert c2st_score < 0.65, (
509+
f"C2ST too high for {setup['trainer_type']}, "
510+
f"x_o={x_o.squeeze().tolist()}: {c2st_score:.3f}"
511+
)
500512

501513

502514
@pytest.mark.slow
@@ -594,119 +606,6 @@ def test_amortized_vi_requires_training(linear_gaussian_setup):
594606
posterior.sample((100,))
595607

596608

597-
@pytest.mark.slow
598-
def test_amortized_vs_single_x_vi(linear_gaussian_setup):
599-
"""Compare amortized VI against single-x VI for the same observation."""
600-
setup = linear_gaussian_setup
601-
602-
# Train amortized VI
603-
amortized_posterior = VIPosterior(
604-
potential_fn=setup["potential_fn"],
605-
prior=setup["prior"],
606-
)
607-
amortized_posterior.train_amortized(
608-
theta=setup["theta"],
609-
x=setup["x"],
610-
max_num_iters=500,
611-
show_progress_bar=False,
612-
flow_type=ZukoFlowType.NSF,
613-
num_transforms=2,
614-
hidden_features=32,
615-
)
616-
617-
# Train single-x VI for a specific observation
618-
x_o = torch.tensor([[0.5, 0.5]])
619-
potential_fn_single, _ = likelihood_estimator_based_potential(
620-
likelihood_estimator=setup["potential_fn"].likelihood_estimator,
621-
prior=setup["prior"],
622-
x_o=x_o,
623-
)
624-
625-
single_x_posterior = VIPosterior(
626-
potential_fn=potential_fn_single,
627-
prior=setup["prior"],
628-
q="maf",
629-
)
630-
single_x_posterior.set_default_x(x_o)
631-
single_x_posterior.train(max_num_iters=500, show_progress_bar=False)
632-
633-
# Get ground truth
634-
true_posterior = true_posterior_linear_gaussian_mvn_prior(
635-
x_o.squeeze(0),
636-
setup["likelihood_shift"],
637-
setup["likelihood_cov"],
638-
zeros(setup["num_dim"]),
639-
eye(setup["num_dim"]),
640-
)
641-
true_samples = true_posterior.sample((1000,))
642-
643-
# Compare
644-
amortized_samples = amortized_posterior.sample((1000,), x=x_o)
645-
single_x_samples = single_x_posterior.sample((1000,))
646-
647-
c2st_amortized = c2st(true_samples, amortized_samples).item()
648-
c2st_single_x = c2st(true_samples, single_x_samples).item()
649-
650-
assert c2st_amortized < 0.6, f"Amortized VI C2ST too high: {c2st_amortized:.3f}"
651-
assert abs(c2st_amortized - c2st_single_x) < 0.15, (
652-
f"Amortized ({c2st_amortized:.3f}) much worse than "
653-
f"single-x ({c2st_single_x:.3f})"
654-
)
655-
656-
657-
@pytest.mark.slow
658-
def test_amortized_vi_gradient_flow(linear_gaussian_setup):
659-
"""Verify gradients flow through ELBO to flow parameters."""
660-
from torch.optim import Adam
661-
662-
setup = linear_gaussian_setup
663-
664-
posterior = VIPosterior(
665-
potential_fn=setup["potential_fn"],
666-
prior=setup["prior"],
667-
)
668-
669-
# Build the flow manually
670-
theta = setup["theta"][:500].to("cpu")
671-
x = setup["x"][:500].to("cpu")
672-
posterior._amortized_q = posterior._build_conditional_flow(
673-
theta[:100],
674-
x[:100],
675-
flow_type=ZukoFlowType.NSF,
676-
num_transforms=2,
677-
hidden_features=32,
678-
)
679-
posterior._amortized_q.to("cpu")
680-
posterior._mode = "amortized"
681-
682-
initial_params = [p.clone() for p in posterior._amortized_q.parameters()]
683-
684-
# Compute ELBO loss and backprop
685-
optimizer = Adam(posterior._amortized_q.parameters(), lr=1e-3)
686-
optimizer.zero_grad()
687-
688-
x_batch = x[:8]
689-
loss = posterior._compute_amortized_elbo_loss(x_batch, n_particles=16)
690-
loss.backward()
691-
692-
# Verify gradients exist and are non-zero
693-
for i, p in enumerate(posterior._amortized_q.parameters()):
694-
assert p.grad is not None, f"Gradient is None for parameter {i}"
695-
assert torch.isfinite(p.grad).all(), f"Gradient has NaN/Inf for param {i}"
696-
assert p.grad.abs().max() > 1e-10, f"Gradient is zero for parameter {i}"
697-
698-
# Verify parameters change after optimization step
699-
optimizer.step()
700-
changed_count = sum(
701-
1
702-
for p_init, p_new in zip(
703-
initial_params, posterior._amortized_q.parameters(), strict=True
704-
)
705-
if not torch.allclose(p_init, p_new.detach(), atol=1e-8)
706-
)
707-
assert changed_count > 0, "No parameters changed after optimization step"
708-
709-
710609
@pytest.mark.slow
711610
def test_amortized_vi_map(linear_gaussian_setup):
712611
"""Test that MAP estimation returns high-density region."""
@@ -755,3 +654,49 @@ def test_amortized_vi_map(linear_gaussian_setup):
755654
f"MAP log_prob {map_log_prob.item():.3f} not better than "
756655
f"median random {random_log_probs.median().item():.3f}"
757656
)
657+
658+
659+
def test_amortized_vi_with_fake_potential():
660+
"""Fast test for amortized VI using FakePotential (no NLE training required).
661+
662+
This test runs in CI (not marked slow) to ensure amortized VI coverage.
663+
Uses FakePotential where the posterior equals the prior.
664+
"""
665+
torch.manual_seed(42)
666+
667+
num_dim = 2
668+
prior = MultivariateNormal(zeros(num_dim), eye(num_dim))
669+
potential_fn = FakePotential(prior=prior)
670+
671+
# Generate mock simulation data (not actually used for training potential)
672+
theta = prior.sample((500,))
673+
x = theta + 0.1 * torch.randn_like(theta) # Noisy observations
674+
675+
posterior = VIPosterior(
676+
potential_fn=potential_fn,
677+
prior=prior,
678+
)
679+
680+
# Train amortized VI
681+
posterior.train_amortized(
682+
theta=theta,
683+
x=x,
684+
max_num_iters=100, # Fewer iterations for speed
685+
show_progress_bar=False,
686+
flow_type=ZukoFlowType.NSF,
687+
num_transforms=2,
688+
hidden_features=16, # Smaller network for speed
689+
)
690+
691+
# Verify training completed
692+
assert posterior._mode == "amortized"
693+
694+
# Test sampling works
695+
x_test = torch.randn(1, num_dim)
696+
samples = posterior.sample((100,), x=x_test)
697+
assert samples.shape == (100, num_dim)
698+
699+
# Test log_prob works
700+
log_probs = posterior.log_prob(samples, x=x_test)
701+
assert log_probs.shape == (100,)
702+
assert torch.isfinite(log_probs).all()

0 commit comments

Comments
 (0)