Skip to content

Commit 227b664

Browse files
committed
update test to use NUTS warm up directly when possible
1 parent 9c9287a commit 227b664

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

tests/zexperimental/test_mcmc.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,14 +1641,18 @@ def test_NUTS_within_Gibbs_consistant_with_NUTS(step_size, num_sampling_steps_x,
16411641
sampler_nuts = cuqi.experimental.mcmc.NUTS(target,
16421642
max_depth=4,
16431643
step_size=step_size)
1644-
# Warm up (not using built-in warmup to control number of steps
1645-
# between tuning steps to match Gibbs sampling)
1646-
tune_interval = max(int(tune_freq * nb), 1)
1647-
for count in range(nb):
1648-
for _ in range(num_sampling_steps_x):
1649-
sampler_nuts.sample(1)
1650-
if (count+1) % tune_interval == 0:
1651-
sampler_nuts.tune(None, count//tune_interval)
1644+
# Warm up (when num_sampling_steps_x>0, we do not using built-in warmup
1645+
# in order to control number of steps between tuning steps to
1646+
# match Gibbs sampling behavior)
1647+
if num_sampling_steps_x == 1:
1648+
sampler_nuts.warmup(nb, tune_freq=tune_freq)
1649+
else:
1650+
tune_interval = max(int(tune_freq * nb), 1)
1651+
for count in range(nb):
1652+
for _ in range(num_sampling_steps_x):
1653+
sampler_nuts.sample(1)
1654+
if (count+1) % tune_interval == 0:
1655+
sampler_nuts.tune(None, count//tune_interval)
16521656
# Sample
16531657
sampler_nuts.sample(ns * num_sampling_steps_x)
16541658
samples_nuts = sampler_nuts.get_samples().samples

0 commit comments

Comments
 (0)