Skip to content

Commit a1f0343

Browse files
authored
tests: refactor z-scoring tests (#1711)
1 parent 1d69d28 commit a1f0343

File tree

1 file changed

+59
-49
lines changed

1 file changed

+59
-49
lines changed

tests/sbiutils_test.py

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -481,58 +481,68 @@ def test_z_score_parser(z_x, z_theta):
481481
"transform_to_unconstrained",
482482
],
483483
)
484-
@pytest.mark.parametrize("builder", [likelihood_nn, posterior_nn, classifier_nn])
485-
def test_z_scoring_structured(z_x, z_theta, builder):
486-
"""
487-
Test that z-scoring string args don't break API.
488-
"""
489-
# Generate some signals for test.
490-
t = torch.arange(0, 1, 0.1)
491-
x_sin = torch.sin(t * 2 * torch.pi * 5)
492-
t_batch = torch.stack([(x_sin * (i + 1)) + (i * 2) for i in range(10)])
493-
494-
num_dim = t_batch.shape[1]
495-
x_dist = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
496-
497-
# API tests
498-
# TODO: Test breaks at "mnle"
499-
if builder in [likelihood_nn, posterior_nn]:
500-
for model in [
501-
"mdn",
502-
"made",
503-
"maf",
504-
"nsf",
505-
"zuko_nice",
506-
"zuko_nsf",
507-
"zuko_maf",
508-
"zuko_ncsf",
509-
"zuko_bpf",
510-
"maf_rqs",
511-
"zuko_sospf",
512-
"zuko_naf",
513-
"zuko_unaf",
514-
"zuko_gf",
515-
]:
516-
net = builder(
517-
model,
518-
z_score_theta=z_theta,
519-
z_score_x=z_x,
520-
hidden_features=2,
521-
num_transforms=1,
522-
x_dist=x_dist,
484+
@pytest.mark.parametrize("build_fn", [likelihood_nn, posterior_nn, classifier_nn])
485+
def test_z_scoring_structured(z_x, z_theta, build_fn):
486+
"""Test z-scoring args across architectures and ensure correct input shapes."""
487+
batch_dim, num_dim = 10, 3
488+
dist = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
489+
theta = dist.sample((batch_dim,))
490+
x = dist.sample((batch_dim,))
491+
492+
models = [
493+
"mdn",
494+
"made",
495+
"maf",
496+
"nsf",
497+
"zuko_nice",
498+
"zuko_nsf",
499+
"zuko_maf",
500+
"zuko_ncsf",
501+
"zuko_bpf",
502+
"maf_rqs",
503+
"zuko_sospf",
504+
"zuko_naf",
505+
"zuko_unaf",
506+
"zuko_gf",
507+
]
508+
509+
if build_fn == likelihood_nn:
510+
models.append("mnle")
511+
elif build_fn == classifier_nn:
512+
models = ["linear", "mlp", "resnet"]
513+
514+
for model in models:
515+
if model == "mnle":
516+
x_cont, x_disc = x[:, :-1], torch.randint(0, 2, (batch_dim, 1)).float()
517+
model_x = torch.cat([x_cont, x_disc], dim=1)
518+
else:
519+
model_x = x
520+
521+
kwargs = {
522+
"model": model,
523+
"z_score_theta": z_theta,
524+
"z_score_x": z_x,
525+
"hidden_features": 8,
526+
}
527+
if build_fn in [likelihood_nn, posterior_nn]:
528+
kwargs.update({"x_dist": dist, "num_transforms": 1})
529+
530+
build_fun = build_fn(**kwargs)
531+
estimator = build_fun(theta, model_x)
532+
533+
if build_fn == posterior_nn:
534+
assert estimator.log_prob(theta.unsqueeze(0), model_x).shape == (
535+
1,
536+
batch_dim,
523537
)
524-
assert net(t_batch, t_batch)
525-
else:
526-
for model in ["linear", "mlp", "resnet"]:
527-
net = builder(
528-
model,
529-
z_score_theta=z_theta,
530-
z_score_x=z_x,
531-
hidden_features=2,
538+
elif build_fn == likelihood_nn:
539+
assert estimator.log_prob(model_x.unsqueeze(0), theta).shape == (
540+
1,
541+
batch_dim,
532542
)
533-
assert net(t_batch, t_batch)
543+
else:
544+
assert estimator(theta, model_x).shape[0] == batch_dim
534545

535-
# Test that it doesn't break what doesn't use structured z-scoring.
536546
assert sensitivity_analysis.Destandardize(0, 1)
537547

538548
# # Uncomment to plot the generated signal.

0 commit comments

Comments
 (0)