2323 build_mnpe ,
2424 build_nsf ,
2525 build_resnet_flowmatcher ,
26+ build_score_estimator ,
2627 build_zuko_bpf ,
2728 build_zuko_gf ,
2829 build_zuko_maf ,
5253 build_zuko_unaf ,
5354]
5455
55- flowmatching_build_functions = [
56+ diffusion_builders = [
5657 build_mlp_flowmatcher ,
5758 build_resnet_flowmatcher ,
59+ build_score_estimator ,
5860]
5961
6062
@@ -136,7 +138,13 @@ def test_shape_handling_utility_for_density_estimator(
136138
137139
138140@pytest .mark .parametrize (
139- "density_estimator_build_fn" , model_builders + flowmatching_build_functions
141+ "density_estimator_build_fn" ,
142+ [
143+ build_nsf ,
144+ build_zuko_nsf ,
145+ build_mlp_flowmatcher ,
146+ build_score_estimator ,
147+ ], # just test nflows, zuko and flowmatching
140148)
141149@pytest .mark .parametrize ("input_sample_dim" , (1 , 2 ))
142150@pytest .mark .parametrize ("input_event_shape" , ((1 ,), (4 ,)))
@@ -241,10 +249,17 @@ def test_correctness_of_density_estimator_log_prob(
241249
242250
243251@pytest .mark .parametrize (
244- "density_estimator_build_fn" , model_builders + flowmatching_build_functions
252+ "density_estimator_build_fn" ,
253+ [
254+ build_nsf ,
255+ build_zuko_nsf ,
256+ build_mlp_flowmatcher ,
257+ ], # just test nflows, zuko and flowmatching
245258)
246- @pytest .mark .parametrize ("input_event_shape" , ((1 ,), (4 ,)))
247- @pytest .mark .parametrize ("condition_event_shape" , ((1 ,), (7 ,)))
259+ @pytest .mark .parametrize (
260+ "input_event_shape" , ((1 ,), pytest .param ((2 ,), marks = pytest .mark .slow ))
261+ )
262+ @pytest .mark .parametrize ("condition_event_shape" , ((1 ,), (2 ,)))
248263@pytest .mark .parametrize ("sample_shape" , ((1000 ,), (500 , 2 )))
249264def test_correctness_of_batched_vs_seperate_sample_and_log_prob (
250265 density_estimator_build_fn : Callable ,
@@ -267,11 +282,17 @@ def test_correctness_of_batched_vs_seperate_sample_and_log_prob(
267282 samples = density_estimator .sample (sample_shape , condition = condition )
268283 samples = samples .reshape (- 1 , batch_dim , * input_event_shape ) # Flat for comp.
269284
285+ # Flatten sample_shape to (B*E,) if it is (B, E)
286+ if len (sample_shape ) > 1 :
287+ flat_sample_shape = (torch .prod (torch .tensor (sample_shape )).item (),)
288+ else :
289+ flat_sample_shape = sample_shape
290+
270291 samples_separate1 = density_estimator .sample (
271- ( 1000 ,) , condition = condition [0 ][None , ...]
292+ flat_sample_shape , condition = condition [0 ][None , ...]
272293 )
273294 samples_separate2 = density_estimator .sample (
274- ( 1000 ,) , condition = condition [1 ][None , ...]
295+ flat_sample_shape , condition = condition [1 ][None , ...]
275296 )
276297
277298 # Check if means are approx. same
@@ -310,12 +331,14 @@ def _build_density_estimator_and_tensors(
310331 """Helper function for all tests that deal with shapes of density
311332 estimators."""
312333
334+ batch_size = 1000
313335 # Use positive random values for continuous dims (log transform)
314- batch_input = torch .rand ((1000 , * input_event_shape ), dtype = torch .float32 ) * 10.0
336+ batch_input = (
337+ torch .rand ((batch_size , * input_event_shape ), dtype = torch .float32 ) * 10.0
338+ )
315339 # make last dim discrete for mixed density estimators
316- batch_input [:, - 1 ] = torch .randint (0 , 4 , (1000 ,))
317- batch_condition = torch .randn ((1000 , * condition_event_shape ))
318-
340+ batch_input [:, - 1 ] = torch .randint (0 , 4 , (batch_size ,))
341+ batch_condition = torch .randn ((batch_size , * condition_event_shape ))
319342 if len (condition_event_shape ) > 1 :
320343 embedding_net = CNNEmbedding (condition_event_shape , kernel_size = 1 )
321344 z_score_y = "structured"
@@ -335,11 +358,16 @@ def _build_density_estimator_and_tensors(
335358 z_score_y = z_score_y ,
336359 )
337360 else :
361+ embedding_net_kwarg = (
362+ dict (embedding_net_y = embedding_net )
363+ if "score" in density_estimator_build_fn .__name__
364+ else dict (embedding_net = embedding_net )
365+ )
338366 density_estimator = density_estimator_build_fn (
339367 torch .randn_like (batch_input ),
340368 torch .randn_like (batch_condition ),
341- embedding_net = embedding_net ,
342369 z_score_y = z_score_y ,
370+ ** embedding_net_kwarg ,
343371 )
344372
345373 inputs = batch_input [:batch_dim ]
0 commit comments