Skip to content

Commit f5a7447

Browse files
authored
Merge pull request #533 from CUQI-DTU/add_info_to_progress_bar
Display acc rate on progress bar
2 parents 59a6730 + 55d3878 commit f5a7447

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

cuqi/experimental/mcmc/_hmc.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,7 @@ def __init__(self, target=None, initial_point=None, max_depth=None,
109109

110110
def _initialize(self):
111111

112-
# Arrays to store acceptance rate
113-
self._acc = [None] # Overwrites acc from Sampler. TODO. Check if this is necessary
114-
115-
self._current_alpha_ratio = np.nan # Current alpha ratio is set to some
112+
self._current_alpha_ratio = np.nan # Current alpha ratio will be set to some
116113
# value (other than np.nan) before
117114
# being used
118115

@@ -233,6 +230,7 @@ def step(self):
233230
r_minus, r_plus = np.copy(r_k), np.copy(r_k)
234231

235232
# run NUTS
233+
acc = 0
236234
while (s == 1) and (j <= self.max_depth):
237235
# sample a direction
238236
v = int(2*(np.random.rand() < 0.5)-1)
@@ -260,9 +258,8 @@ def step(self):
260258
self.current_point = point_prime
261259
self.current_target_logd = logd_prime
262260
self.current_target_grad = np.copy(grad_prime)
263-
self._acc.append(1)
264-
else:
265-
self._acc.append(0)
261+
acc = 1
262+
266263

267264
# update number of particles, tree level, and stopping criterion
268265
n += n_prime
@@ -280,6 +277,8 @@ def step(self):
280277
if np.isnan(self.current_target_logd):
281278
raise NameError('NaN potential func')
282279

280+
return acc
281+
283282
def tune(self, skip_len, update_count):
284283
""" adapt epsilon during burn-in using dual averaging"""
285284
k = update_count+1

cuqi/experimental/mcmc/_sampler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
220220
if hasattr(self, "_pre_sample"): self._pre_sample()
221221

222222
# Draw samples
223-
for _ in tqdm( range(Ns), "Sample: "):
223+
pbar = tqdm(range(Ns), "Sample: ")
224+
for idx in pbar:
224225

225226
# Perform one step of the sampler
226227
acc = self.step()
@@ -229,6 +230,9 @@ def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
229230
self._acc.append(acc)
230231
self._samples.append(self.current_point)
231232

233+
# display acc rate at progress bar
234+
pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")
235+
232236
# Add sample to batch
233237
if batch_size > 0:
234238
batch_handler.add_sample(self.current_point)
@@ -260,7 +264,8 @@ def warmup(self, Nb, tune_freq=0.1) -> 'Sampler':
260264
if hasattr(self, "_pre_warmup"): self._pre_warmup()
261265

262266
# Draw warmup samples with tuning
263-
for idx in tqdm(range(Nb), "Warmup: "):
267+
pbar = tqdm(range(Nb), "Warmup: ")
268+
for idx in pbar:
264269

265270
# Perform one step of the sampler
266271
acc = self.step()
@@ -273,6 +278,9 @@ def warmup(self, Nb, tune_freq=0.1) -> 'Sampler':
273278
self._acc.append(acc)
274279
self._samples.append(self.current_point)
275280

281+
# display acc rate at progress bar
282+
pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")
283+
276284
# Call callback function if specified
277285
self._call_callback(self.current_point, len(self._samples)-1)
278286

tests/zexperimental/test_mcmc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,3 +962,32 @@ def test_if_invalid_sample_accepted(sampler: cuqi.experimental.mcmc.Sampler):
962962
assert (
963963
samples.min() > 0.0 - tol and samples.max() < 1.0 + tol
964964
), f"Invalid samples accepted for sampler {sampler.__class__.__name__}."
965+
966+
967+
# Test NUTS acceptance rate
968+
@pytest.mark.parametrize(
969+
"sampler",
970+
[
971+
cuqi.experimental.mcmc.NUTS(cuqi.distribution.Gaussian(0, 1)),
972+
cuqi.experimental.mcmc.NUTS(cuqi.distribution.DistributionGallery('donut'))
973+
],
974+
)
975+
def test_nuts_acceptance_rate(sampler: cuqi.experimental.mcmc.Sampler):
976+
""" Test that the NUTS sampler correctly updates the acceptance rate. """
977+
# Fix random seed for reproducibility, but the test should be robust to seed
978+
np.random.seed(0)
979+
980+
# Sample:
981+
sampler.warmup(100).sample(100)
982+
983+
# Compute number of times samples were updated:
984+
samples = sampler.get_samples().samples
985+
counter = 0
986+
for i in range(1, samples.shape[1]):
987+
if np.any(samples[:, i] != samples[:, i - 1]):
988+
counter += 1
989+
990+
# Compute the number of accepted samples according to the sampler
991+
acc_rate_sum = sum(sampler._acc[2:])
992+
993+
assert np.isclose(counter, acc_rate_sum), "NUTS sampler does not update acceptance rate correctly: "+str(counter)+" != "+str(acc_rate_sum)

0 commit comments

Comments
 (0)