1111from sbi .inference .posteriors .base_posterior import NeuralPosterior
1212from sbi .inference .trainers .npe import NPE_C
1313from sbi .inference .trainers .nre import BNRE , NRE_A , NRE_B , NRE_C
14- from sbi .neural_nets .factory import posterior_score_nn
1514from sbi .utils .metrics import c2st
1615
1716from .mini_sbibm import get_task
3433 "transformer" ,
3534]
3635
37- VP_SCHEDULES = [
38- {"name" : "vp_default" , "beta_min" : 0.01 , "beta_max" : 10.0 },
39- {"name" : "vp_wide" , "beta_min" : 0.1 , "beta_max" : 20.0 },
40- ]
41- SUBVP_SCHEDULES = [
42- {"name" : "subvp_default" , "beta_min" : 0.01 , "beta_max" : 10.0 },
43- {"name" : "subvp_wide" , "beta_min" : 0.1 , "beta_max" : 20.0 },
44- ]
45- VE_SCHEDULES = [
46- {"name" : "ve_default" , "sigma_min" : 1e-4 , "sigma_max" : 10.0 },
47- {"name" : "ve_wide" , "sigma_min" : 1e-3 , "sigma_max" : 50.0 },
48- ]
49-
50-
51- def _make_npse_builder (
52- vf_estimator : Literal ["mlp" , "ada_mlp" , "transformer" ],
53- sde_type : Literal ["vp" , "subvp" , "ve" ],
54- schedule_kwargs : dict [str , float | str ],
55- ):
56- builder = posterior_score_nn (model = vf_estimator , sde_type = sde_type )
57-
58- def build_fn (batch_theta , batch_x ):
59- estimator = builder (batch_theta , batch_x )
60- if sde_type in {"vp" , "subvp" }:
61- if "beta_min" in schedule_kwargs :
62- estimator .beta_min = schedule_kwargs ["beta_min" ]
63- if "beta_max" in schedule_kwargs :
64- estimator .beta_max = schedule_kwargs ["beta_max" ]
65- elif sde_type == "ve" :
66- if "sigma_min" in schedule_kwargs :
67- estimator .sigma_min = schedule_kwargs ["sigma_min" ]
68- if "sigma_max" in schedule_kwargs :
69- estimator .sigma_max = schedule_kwargs ["sigma_max" ]
70-
71- t_tensor = torch .as_tensor ([estimator .t_max ], device = estimator .mean_0 .device )
72- mean_t = estimator .approx_marginal_mean (t_tensor )
73- std_t = estimator .approx_marginal_std (t_tensor )
74- mean_t = torch .broadcast_to (mean_t , (1 , * estimator .input_shape ))
75- std_t = torch .broadcast_to (std_t , (1 , * estimator .input_shape ))
76- estimator ._mean_base = mean_t
77- estimator ._std_base = std_t
78- return estimator
79-
80- return build_fn
81-
82-
83- def _npse_variants ():
84- variants = []
85- for vf_estimator in VF_ESTIMATORS :
86- for schedule in VP_SCHEDULES :
87- label = f"{ vf_estimator } /vp/{ schedule ['name' ]} "
88- variants .append ({
89- "vf_estimator" : _make_npse_builder (vf_estimator , "vp" , schedule ),
90- "sde_type" : "vp" ,
91- "_bm_label" : label ,
92- })
93- for schedule in SUBVP_SCHEDULES :
94- label = f"{ vf_estimator } /subvp/{ schedule ['name' ]} "
95- variants .append ({
96- "vf_estimator" : _make_npse_builder (vf_estimator , "subvp" , schedule ),
97- "sde_type" : "subvp" ,
98- "_bm_label" : label ,
99- })
100- for schedule in VE_SCHEDULES :
101- label = f"{ vf_estimator } /ve/{ schedule ['name' ]} "
102- variants .append ({
103- "vf_estimator" : _make_npse_builder (vf_estimator , "ve" , schedule ),
104- "sde_type" : "ve" ,
105- "_bm_label" : label ,
106- })
107- return variants
108-
10936
11037# Benchmarking method groups i.e. what to run for different --bm-mode
11138METHOD_GROUPS = {
@@ -126,7 +53,11 @@ def _npse_variants():
12653 "nle" : [{"density_estimator" : de } for de in ["maf" , "nsf" ]],
12754 "nre" : [{"classifier" : cl } for cl in CLASSIFIERS ],
12855 "fmpe" : [{"vf_estimator" : nn } for nn in VF_ESTIMATORS ],
129- "npse" : _npse_variants (),
56+ "npse" : [
57+ {"vf_estimator" : nn , "sde_type" : sde }
58+ for nn in VF_ESTIMATORS
59+ for sde in ["ve" , "vp" ]
60+ ],
13061 "vfpe" : [{"vf_estimator" : nn } for nn in VF_ESTIMATORS ],
13162 "snpe" : [{}],
13263 "snle" : [{}],
@@ -260,8 +191,6 @@ def train_and_eval_amortized_inference(
260191 thetas , xs = task .get_data (benchmark_num_simulations )
261192 prior = task .get_prior ()
262193
263- extra_kwargs = dict (extra_kwargs )
264- bm_label = extra_kwargs .pop ("_bm_label" , None )
265194 inference = inference_class (prior , ** extra_kwargs )
266195 _ = inference .append_simulations (thetas , xs ).train (** TRAIN_KWARGS )
267196
@@ -273,10 +202,6 @@ def train_and_eval_amortized_inference(
273202 results_bag .metric = mean_c2st
274203 results_bag .num_simulations = benchmark_num_simulations
275204 results_bag .task_name = task_name
276- if bm_label is None :
277- results_bag .method = inference_class .__name__ + str (extra_kwargs )
278- else :
279- results_bag .method = f"{ inference_class .__name__ } ({ bm_label } )"
280205
281206
282207def train_and_eval_sequential_inference (
@@ -305,8 +230,6 @@ def train_and_eval_sequential_inference(
305230 simulator = task .get_simulator ()
306231
307232 # Round 1
308- extra_kwargs = dict (extra_kwargs )
309- bm_label = extra_kwargs .pop ("_bm_label" , None )
310233 inference = inference_class (prior , ** extra_kwargs )
311234 _ = inference .append_simulations (thetas , xs ).train (** TRAIN_KWARGS )
312235
@@ -330,10 +253,6 @@ def train_and_eval_sequential_inference(
330253 results_bag .metric = c2st_val
331254 results_bag .num_simulations = benchmark_num_simulations
332255 results_bag .task_name = task_name
333- if bm_label is None :
334- results_bag .method = inference_class .__name__ + str (extra_kwargs )
335- else :
336- results_bag .method = f"{ inference_class .__name__ } ({ bm_label } )"
337256
338257
339258@pytest .mark .benchmark
0 commit comments