@@ -248,6 +248,33 @@ def test_vfinference_with_different_models(vector_field_type, model):
248248 check_c2st (samples , target_samples , alg = f"fmpe_{ model } " )
249249
250250
251+ @pytest .mark .parametrize ("vector_field_type" , ["fmpe" ])
252+ def test_fmpe_time_dependent_z_scoring_integration (vector_field_type ):
253+ num_dim = 2
254+ prior = BoxUniform (9.0 * ones (num_dim ), 11.0 * ones (num_dim ))
255+
256+ def simulator (theta ):
257+ return theta + torch .randn_like (theta ) * 0.1
258+
259+ inference = FMPE (prior , z_score_x = 'structured' , show_progress_bars = False )
260+ theta = prior .sample ((200 ,))
261+ x = simulator (theta )
262+ density_estimator = inference .append_simulations (theta , x ).train (max_num_epochs = 1 )
263+
264+ assert hasattr (density_estimator , "mean_1" )
265+ assert hasattr (density_estimator , "std_1" )
266+ assert torch .all (density_estimator .mean_1 > 8.0 )
267+
268+ batch_size = 10
269+ t = torch .rand (batch_size )
270+ theta_test = torch .randn (batch_size , num_dim )
271+ cond_test = zeros (batch_size , num_dim )
272+ v_pred = density_estimator .ode_fn (theta_test , cond_test , t )
273+
274+ assert v_pred .shape == (batch_size , num_dim )
275+ assert not torch .isnan (v_pred ).any ()
276+
277+
251278# ------------------------------------------------------------------------------
252279# -------------------------------- SLOW TESTS ----------------------------------
253280# ------------------------------------------------------------------------------
0 commit comments