Skip to content

Commit 0bcbb40

Browse files
authored
tests: refactor "not slow" tests to be not so slow (#1495)
* wip: speed batched sampling test * tests: refactor density estimator tests for speed * speed up mnle api tests
1 parent 312e9ef commit 0bcbb40

File tree

3 files changed

+46
-18
lines changed

3 files changed

+46
-18
lines changed

tests/density_estimator_test.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
build_mnpe,
2424
build_nsf,
2525
build_resnet_flowmatcher,
26+
build_score_estimator,
2627
build_zuko_bpf,
2728
build_zuko_gf,
2829
build_zuko_maf,
@@ -52,9 +53,10 @@
5253
build_zuko_unaf,
5354
]
5455

55-
flowmatching_build_functions = [
56+
diffusion_builders = [
5657
build_mlp_flowmatcher,
5758
build_resnet_flowmatcher,
59+
build_score_estimator,
5860
]
5961

6062

@@ -136,7 +138,13 @@ def test_shape_handling_utility_for_density_estimator(
136138

137139

138140
@pytest.mark.parametrize(
139-
"density_estimator_build_fn", model_builders + flowmatching_build_functions
141+
"density_estimator_build_fn",
142+
[
143+
build_nsf,
144+
build_zuko_nsf,
145+
build_mlp_flowmatcher,
146+
build_score_estimator,
147+
], # just test nflows, zuko and flowmatching
140148
)
141149
@pytest.mark.parametrize("input_sample_dim", (1, 2))
142150
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
@@ -241,10 +249,17 @@ def test_correctness_of_density_estimator_log_prob(
241249

242250

243251
@pytest.mark.parametrize(
244-
"density_estimator_build_fn", model_builders + flowmatching_build_functions
252+
"density_estimator_build_fn",
253+
[
254+
build_nsf,
255+
build_zuko_nsf,
256+
build_mlp_flowmatcher,
257+
], # just test nflows, zuko and flowmatching
245258
)
246-
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
247-
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
259+
@pytest.mark.parametrize(
260+
"input_event_shape", ((1,), pytest.param((2,), marks=pytest.mark.slow))
261+
)
262+
@pytest.mark.parametrize("condition_event_shape", ((1,), (2,)))
248263
@pytest.mark.parametrize("sample_shape", ((1000,), (500, 2)))
249264
def test_correctness_of_batched_vs_seperate_sample_and_log_prob(
250265
density_estimator_build_fn: Callable,
@@ -267,11 +282,17 @@ def test_correctness_of_batched_vs_seperate_sample_and_log_prob(
267282
samples = density_estimator.sample(sample_shape, condition=condition)
268283
samples = samples.reshape(-1, batch_dim, *input_event_shape) # Flat for comp.
269284

285+
# Flatten sample_shape to (B*E,) if it is (B, E)
286+
if len(sample_shape) > 1:
287+
flat_sample_shape = (torch.prod(torch.tensor(sample_shape)).item(),)
288+
else:
289+
flat_sample_shape = sample_shape
290+
270291
samples_separate1 = density_estimator.sample(
271-
(1000,), condition=condition[0][None, ...]
292+
flat_sample_shape, condition=condition[0][None, ...]
272293
)
273294
samples_separate2 = density_estimator.sample(
274-
(1000,), condition=condition[1][None, ...]
295+
flat_sample_shape, condition=condition[1][None, ...]
275296
)
276297

277298
# Check if means are approx. same
@@ -310,12 +331,14 @@ def _build_density_estimator_and_tensors(
310331
"""Helper function for all tests that deal with shapes of density
311332
estimators."""
312333

334+
batch_size = 1000
313335
# Use positive random values for continuous dims (log transform)
314-
batch_input = torch.rand((1000, *input_event_shape), dtype=torch.float32) * 10.0
336+
batch_input = (
337+
torch.rand((batch_size, *input_event_shape), dtype=torch.float32) * 10.0
338+
)
315339
# make last dim discrete for mixed density estimators
316-
batch_input[:, -1] = torch.randint(0, 4, (1000,))
317-
batch_condition = torch.randn((1000, *condition_event_shape))
318-
340+
batch_input[:, -1] = torch.randint(0, 4, (batch_size,))
341+
batch_condition = torch.randn((batch_size, *condition_event_shape))
319342
if len(condition_event_shape) > 1:
320343
embedding_net = CNNEmbedding(condition_event_shape, kernel_size=1)
321344
z_score_y = "structured"
@@ -335,11 +358,16 @@ def _build_density_estimator_and_tensors(
335358
z_score_y=z_score_y,
336359
)
337360
else:
361+
embedding_net_kwarg = (
362+
dict(embedding_net_y=embedding_net)
363+
if "score" in density_estimator_build_fn.__name__
364+
else dict(embedding_net=embedding_net)
365+
)
338366
density_estimator = density_estimator_build_fn(
339367
torch.randn_like(batch_input),
340368
torch.randn_like(batch_condition),
341-
embedding_net=embedding_net,
342369
z_score_y=z_score_y,
370+
**embedding_net_kwarg,
343371
)
344372

345373
inputs = batch_input[:batch_dim]

tests/mnle_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ def test_mnle_on_device(
9292
@pytest.mark.parametrize(
9393
"sampler", (pytest.param("mcmc", marks=pytest.mark.mcmc), "rejection", "vi")
9494
)
95-
@pytest.mark.parametrize("flow_model", ("mdn", "maf", "nsf", "zuko_nsf", "zuko_bpf"))
95+
@pytest.mark.parametrize("flow_model", ("mdn", "nsf", "zuko_nsf"))
9696
@pytest.mark.parametrize("z_score_theta", ("independent", "none"))
9797
def test_mnle_api(flow_model: str, sampler, mcmc_params_fast: dict, z_score_theta: str):
9898
"""Test MNLE API."""
9999
# Generate mixed data.
100-
num_simulations = 100
100+
num_simulations = 10
101101
theta = torch.rand(num_simulations, 2)
102102
x = torch.cat(
103103
(
@@ -119,7 +119,7 @@ def test_mnle_api(flow_model: str, sampler, mcmc_params_fast: dict, z_score_thet
119119
embedding_net=theta_embedding,
120120
)
121121
trainer = MNLE(density_estimator=density_estimator)
122-
trainer.append_simulations(theta, x).train(max_num_epochs=5)
122+
trainer.append_simulations(theta, x).train(max_num_epochs=1)
123123

124124
# Test different samplers.
125125
posterior = trainer.build_posterior(prior=prior, sample_with=sampler)
@@ -132,7 +132,7 @@ def test_mnle_api(flow_model: str, sampler, mcmc_params_fast: dict, z_score_thet
132132
posterior.sample(
133133
(1,),
134134
init_strategy="proposal",
135-
method="slice_np_vectorized",
135+
method="hmc_pyro",
136136
**mcmc_params_fast,
137137
)
138138

tests/posterior_nn_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def test_batched_sample_log_prob_with_different_x(
151151

152152

153153
@pytest.mark.mcmc
154-
@pytest.mark.parametrize("snlre_method", [NLE_A, NRE_A, NRE_B, NRE_C, NPE_C])
154+
@pytest.mark.parametrize("snlre_method", [NRE_C]) # it's independent of the method
155155
@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
156156
@pytest.mark.parametrize("init_strategy", ["proposal", "resample"])
157157
@pytest.mark.parametrize(
@@ -179,7 +179,7 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
179179
inference = snlre_method(prior=prior)
180180
theta = prior.sample((num_simulations,))
181181
x = simulator(theta)
182-
inference.append_simulations(theta, x).train(max_num_epochs=2)
182+
inference.append_simulations(theta, x).train(max_num_epochs=1)
183183

184184
x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim)
185185

0 commit comments

Comments
 (0)