Skip to content

Commit 4fe9de3

Browse files
committed
update NUTS acc rate list update
1 parent 6051fec commit 4fe9de3

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

cuqi/experimental/mcmc/_hmc.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,6 @@ def __init__(self, target=None, initial_point=None, max_depth=15,
119119

120120
def _initialize(self):
121121

122-
# Arrays to store acceptance rate
123-
self._acc = [None] # Overwrites acc from Sampler. TODO. Check if this is necessary
124-
125122
self._alpha = 0 # check if meaningful value
126123
self._n_alpha = 0 # check if meaningful value
127124

@@ -263,9 +260,9 @@ def step(self):
263260
self.current_point = point_prime
264261
self.current_target_logd = logd_prime
265262
self.current_target_grad = np.copy(grad_prime)
266-
self._acc.append(1)
263+
acc = 1
267264
else:
268-
self._acc.append(0)
265+
acc = 0
269266

270267
# update number of particles, tree level, and stopping criterion
271268
n += n_prime
@@ -284,6 +281,8 @@ def step(self):
284281
if np.isnan(self.current_target_logd):
285282
raise NameError('NaN potential func')
286283

284+
return acc
285+
287286
def tune(self, skip_len, update_count):
288287
""" adapt epsilon during burn-in using dual averaging"""
289288
k = update_count+1

0 commit comments

Comments
 (0)