Skip to content

Commit d1e2f68

Browse files
committed
remove bm_test changes
1 parent d2e9d53 commit d1e2f68

File tree

1 file changed

+5
-86
lines changed

1 file changed

+5
-86
lines changed

tests/bm_test.py

Lines changed: 5 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from sbi.inference.posteriors.base_posterior import NeuralPosterior
1212
from sbi.inference.trainers.npe import NPE_C
1313
from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C
14-
from sbi.neural_nets.factory import posterior_score_nn
1514
from sbi.utils.metrics import c2st
1615

1716
from .mini_sbibm import get_task
@@ -34,78 +33,6 @@
3433
"transformer",
3534
]
3635

37-
VP_SCHEDULES = [
38-
{"name": "vp_default", "beta_min": 0.01, "beta_max": 10.0},
39-
{"name": "vp_wide", "beta_min": 0.1, "beta_max": 20.0},
40-
]
41-
SUBVP_SCHEDULES = [
42-
{"name": "subvp_default", "beta_min": 0.01, "beta_max": 10.0},
43-
{"name": "subvp_wide", "beta_min": 0.1, "beta_max": 20.0},
44-
]
45-
VE_SCHEDULES = [
46-
{"name": "ve_default", "sigma_min": 1e-4, "sigma_max": 10.0},
47-
{"name": "ve_wide", "sigma_min": 1e-3, "sigma_max": 50.0},
48-
]
49-
50-
51-
def _make_npse_builder(
52-
vf_estimator: Literal["mlp", "ada_mlp", "transformer"],
53-
sde_type: Literal["vp", "subvp", "ve"],
54-
schedule_kwargs: dict[str, float | str],
55-
):
56-
builder = posterior_score_nn(model=vf_estimator, sde_type=sde_type)
57-
58-
def build_fn(batch_theta, batch_x):
59-
estimator = builder(batch_theta, batch_x)
60-
if sde_type in {"vp", "subvp"}:
61-
if "beta_min" in schedule_kwargs:
62-
estimator.beta_min = schedule_kwargs["beta_min"]
63-
if "beta_max" in schedule_kwargs:
64-
estimator.beta_max = schedule_kwargs["beta_max"]
65-
elif sde_type == "ve":
66-
if "sigma_min" in schedule_kwargs:
67-
estimator.sigma_min = schedule_kwargs["sigma_min"]
68-
if "sigma_max" in schedule_kwargs:
69-
estimator.sigma_max = schedule_kwargs["sigma_max"]
70-
71-
t_tensor = torch.as_tensor([estimator.t_max], device=estimator.mean_0.device)
72-
mean_t = estimator.approx_marginal_mean(t_tensor)
73-
std_t = estimator.approx_marginal_std(t_tensor)
74-
mean_t = torch.broadcast_to(mean_t, (1, *estimator.input_shape))
75-
std_t = torch.broadcast_to(std_t, (1, *estimator.input_shape))
76-
estimator._mean_base = mean_t
77-
estimator._std_base = std_t
78-
return estimator
79-
80-
return build_fn
81-
82-
83-
def _npse_variants():
84-
variants = []
85-
for vf_estimator in VF_ESTIMATORS:
86-
for schedule in VP_SCHEDULES:
87-
label = f"{vf_estimator}/vp/{schedule['name']}"
88-
variants.append({
89-
"vf_estimator": _make_npse_builder(vf_estimator, "vp", schedule),
90-
"sde_type": "vp",
91-
"_bm_label": label,
92-
})
93-
for schedule in SUBVP_SCHEDULES:
94-
label = f"{vf_estimator}/subvp/{schedule['name']}"
95-
variants.append({
96-
"vf_estimator": _make_npse_builder(vf_estimator, "subvp", schedule),
97-
"sde_type": "subvp",
98-
"_bm_label": label,
99-
})
100-
for schedule in VE_SCHEDULES:
101-
label = f"{vf_estimator}/ve/{schedule['name']}"
102-
variants.append({
103-
"vf_estimator": _make_npse_builder(vf_estimator, "ve", schedule),
104-
"sde_type": "ve",
105-
"_bm_label": label,
106-
})
107-
return variants
108-
10936

11037
# Benchmarking method groups i.e. what to run for different --bm-mode
11138
METHOD_GROUPS = {
@@ -126,7 +53,11 @@ def _npse_variants():
12653
"nle": [{"density_estimator": de} for de in ["maf", "nsf"]],
12754
"nre": [{"classifier": cl} for cl in CLASSIFIERS],
12855
"fmpe": [{"vf_estimator": nn} for nn in VF_ESTIMATORS],
129-
"npse": _npse_variants(),
56+
"npse": [
57+
{"vf_estimator": nn, "sde_type": sde}
58+
for nn in VF_ESTIMATORS
59+
for sde in ["ve", "vp"]
60+
],
13061
"vfpe": [{"vf_estimator": nn} for nn in VF_ESTIMATORS],
13162
"snpe": [{}],
13263
"snle": [{}],
@@ -260,8 +191,6 @@ def train_and_eval_amortized_inference(
260191
thetas, xs = task.get_data(benchmark_num_simulations)
261192
prior = task.get_prior()
262193

263-
extra_kwargs = dict(extra_kwargs)
264-
bm_label = extra_kwargs.pop("_bm_label", None)
265194
inference = inference_class(prior, **extra_kwargs)
266195
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)
267196

@@ -273,10 +202,6 @@ def train_and_eval_amortized_inference(
273202
results_bag.metric = mean_c2st
274203
results_bag.num_simulations = benchmark_num_simulations
275204
results_bag.task_name = task_name
276-
if bm_label is None:
277-
results_bag.method = inference_class.__name__ + str(extra_kwargs)
278-
else:
279-
results_bag.method = f"{inference_class.__name__}({bm_label})"
280205

281206

282207
def train_and_eval_sequential_inference(
@@ -305,8 +230,6 @@ def train_and_eval_sequential_inference(
305230
simulator = task.get_simulator()
306231

307232
# Round 1
308-
extra_kwargs = dict(extra_kwargs)
309-
bm_label = extra_kwargs.pop("_bm_label", None)
310233
inference = inference_class(prior, **extra_kwargs)
311234
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)
312235

@@ -330,10 +253,6 @@ def train_and_eval_sequential_inference(
330253
results_bag.metric = c2st_val
331254
results_bag.num_simulations = benchmark_num_simulations
332255
results_bag.task_name = task_name
333-
if bm_label is None:
334-
results_bag.method = inference_class.__name__ + str(extra_kwargs)
335-
else:
336-
results_bag.method = f"{inference_class.__name__}({bm_label})"
337256

338257

339258
@pytest.mark.benchmark

0 commit comments

Comments
 (0)