Skip to content

Commit dc8dc86

Browse files
committed
add integration test for FMPE time-dependent z-scoring
1 parent 3883de2 commit dc8dc86

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/linearGaussian_vector_field_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,33 @@ def test_vfinference_with_different_models(vector_field_type, model):
248248
check_c2st(samples, target_samples, alg=f"fmpe_{model}")
249249

250250

251+
@pytest.mark.parametrize("vector_field_type", ["fmpe"])
252+
def test_fmpe_time_dependent_z_scoring_integration(vector_field_type):
253+
num_dim = 2
254+
prior = BoxUniform(9.0 * ones(num_dim), 11.0 * ones(num_dim))
255+
256+
def simulator(theta):
257+
return theta + torch.randn_like(theta) * 0.1
258+
259+
inference = FMPE(prior, z_score_x='structured', show_progress_bars=False)
260+
theta = prior.sample((200,))
261+
x = simulator(theta)
262+
density_estimator = inference.append_simulations(theta, x).train(max_num_epochs=1)
263+
264+
assert hasattr(density_estimator, "mean_1")
265+
assert hasattr(density_estimator, "std_1")
266+
assert torch.all(density_estimator.mean_1 > 8.0)
267+
268+
batch_size = 10
269+
t = torch.rand(batch_size)
270+
theta_test = torch.randn(batch_size, num_dim)
271+
cond_test = zeros(batch_size, num_dim)
272+
v_pred = density_estimator.ode_fn(theta_test, cond_test, t)
273+
274+
assert v_pred.shape == (batch_size, num_dim)
275+
assert not torch.isnan(v_pred).any()
276+
277+
251278
# ------------------------------------------------------------------------------
252279
# -------------------------------- SLOW TESTS ----------------------------------
253280
# ------------------------------------------------------------------------------

0 commit comments

Comments
 (0)