@@ -481,58 +481,68 @@ def test_z_score_parser(z_x, z_theta):
481481 "transform_to_unconstrained" ,
482482 ],
483483)
484- @pytest .mark .parametrize ("builder" , [likelihood_nn , posterior_nn , classifier_nn ])
485- def test_z_scoring_structured (z_x , z_theta , builder ):
486- """
487- Test that z-scoring string args don't break API.
488- """
489- # Generate some signals for test.
490- t = torch .arange (0 , 1 , 0.1 )
491- x_sin = torch .sin (t * 2 * torch .pi * 5 )
492- t_batch = torch .stack ([(x_sin * (i + 1 )) + (i * 2 ) for i in range (10 )])
493-
494- num_dim = t_batch .shape [1 ]
495- x_dist = BoxUniform (low = - 2 * torch .ones (num_dim ), high = 2 * torch .ones (num_dim ))
496-
497- # API tests
498- # TODO: Test breaks at "mnle"
499- if builder in [likelihood_nn , posterior_nn ]:
500- for model in [
501- "mdn" ,
502- "made" ,
503- "maf" ,
504- "nsf" ,
505- "zuko_nice" ,
506- "zuko_nsf" ,
507- "zuko_maf" ,
508- "zuko_ncsf" ,
509- "zuko_bpf" ,
510- "maf_rqs" ,
511- "zuko_sospf" ,
512- "zuko_naf" ,
513- "zuko_unaf" ,
514- "zuko_gf" ,
515- ]:
516- net = builder (
517- model ,
518- z_score_theta = z_theta ,
519- z_score_x = z_x ,
520- hidden_features = 2 ,
521- num_transforms = 1 ,
522- x_dist = x_dist ,
484+ @pytest .mark .parametrize ("build_fn" , [likelihood_nn , posterior_nn , classifier_nn ])
485+ def test_z_scoring_structured (z_x , z_theta , build_fn ):
486+ """Test z-scoring args across architectures and ensure correct input shapes."""
487+ batch_dim , num_dim = 10 , 3
488+ dist = BoxUniform (low = - 2 * torch .ones (num_dim ), high = 2 * torch .ones (num_dim ))
489+ theta = dist .sample ((batch_dim ,))
490+ x = dist .sample ((batch_dim ,))
491+
492+ models = [
493+ "mdn" ,
494+ "made" ,
495+ "maf" ,
496+ "nsf" ,
497+ "zuko_nice" ,
498+ "zuko_nsf" ,
499+ "zuko_maf" ,
500+ "zuko_ncsf" ,
501+ "zuko_bpf" ,
502+ "maf_rqs" ,
503+ "zuko_sospf" ,
504+ "zuko_naf" ,
505+ "zuko_unaf" ,
506+ "zuko_gf" ,
507+ ]
508+
509+ if build_fn == likelihood_nn :
510+ models .append ("mnle" )
511+ elif build_fn == classifier_nn :
512+ models = ["linear" , "mlp" , "resnet" ]
513+
514+ for model in models :
515+ if model == "mnle" :
516+ x_cont , x_disc = x [:, :- 1 ], torch .randint (0 , 2 , (batch_dim , 1 )).float ()
517+ model_x = torch .cat ([x_cont , x_disc ], dim = 1 )
518+ else :
519+ model_x = x
520+
521+ kwargs = {
522+ "model" : model ,
523+ "z_score_theta" : z_theta ,
524+ "z_score_x" : z_x ,
525+ "hidden_features" : 8 ,
526+ }
527+ if build_fn in [likelihood_nn , posterior_nn ]:
528+ kwargs .update ({"x_dist" : dist , "num_transforms" : 1 })
529+
530+ build_fun = build_fn (** kwargs )
531+ estimator = build_fun (theta , model_x )
532+
533+ if build_fn == posterior_nn :
534+ assert estimator .log_prob (theta .unsqueeze (0 ), model_x ).shape == (
535+ 1 ,
536+ batch_dim ,
523537 )
524- assert net (t_batch , t_batch )
525- else :
526- for model in ["linear" , "mlp" , "resnet" ]:
527- net = builder (
528- model ,
529- z_score_theta = z_theta ,
530- z_score_x = z_x ,
531- hidden_features = 2 ,
538+ elif build_fn == likelihood_nn :
539+ assert estimator .log_prob (model_x .unsqueeze (0 ), theta ).shape == (
540+ 1 ,
541+ batch_dim ,
532542 )
533- assert net (t_batch , t_batch )
543+ else :
544+ assert estimator (theta , model_x ).shape [0 ] == batch_dim
534545
535- # Test that it doesn't break what doesn't use structured z-scoring.
536546 assert sensitivity_analysis .Destandardize (0 , 1 )
537547
538548 # # Uncomment to plot the generated signal.
0 commit comments