Skip to content

Commit 2cf72ec

Browse files
authored
Merge pull request #683 from CUQI-DTU/improve_NUTS_statefulness_new
Ensure NUTS statefulness within Gibbs
2 parents 37e1d44 + 227b664 commit 2cf72ec

File tree

5 files changed

+105
-27
lines changed

5 files changed

+105
-27
lines changed

cuqi/distribution/_posterior.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from cuqi.geometry import _DefaultGeometry, _get_identity_geometries
22
from cuqi.distribution import Distribution
3+
from cuqi.density import Density
34

45
# ========================================================================
56
class Posterior(Distribution):
@@ -25,6 +26,14 @@ def __init__(self, likelihood, prior, **kwargs):
2526
self.prior = prior
2627
super().__init__(**kwargs)
2728

29+
def get_density(self, name) -> Density:
30+
""" Return a density with the given name. """
31+
if name == self.likelihood.name:
32+
return self.likelihood
33+
if name == self.prior.name:
34+
return self.prior
35+
raise ValueError(f"No density with name {name}.")
36+
2837
@property
2938
def data(self):
3039
return self.likelihood.data

cuqi/experimental/mcmc/_gibbs.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from cuqi.distribution import JointDistribution
1+
from cuqi.distribution import JointDistribution, Posterior
22
from cuqi.experimental.mcmc import Sampler
33
from cuqi.samples import Samples, JointSamples
4-
from cuqi.experimental.mcmc import NUTS
54
from typing import Dict
65
import numpy as np
76
import warnings
@@ -36,11 +35,10 @@ class HybridGibbs:
3635
Gelman et al. "Bayesian Data Analysis" (2014), Third Edition
3736
for more details.
3837
39-
In each Gibbs step, the corresponding sampler has the initial_point
40-
and initial_scale (if applicable) set to the value of the previous step
41-
and the sampler is reinitialized. This means that the sampling is not
42-
fully stateful at this point. This means samplers like NUTS will lose
43-
their internal state between Gibbs steps.
38+
In each Gibbs step, the corresponding sampler state and history are stored,
39+
then the sampler is reinitialized. After reinitialization, the sampler state
40+
and history are set back to the stored values. This ensures preserving the
41+
statefulness of the samplers.
4442
4543
The order in which the conditionals are sampled is the order of the
4644
variables in the sampling strategy, unless a different sampling order
@@ -177,8 +175,8 @@ def scan_order(self):
177175
# ------------ Public methods ------------
178176
def validate_targets(self):
179177
""" Validate each of the conditional targets used in the Gibbs steps """
180-
if not isinstance(self.target, JointDistribution):
181-
raise ValueError('Target distribution must be a JointDistribution.')
178+
if not isinstance(self.target, (JointDistribution, Posterior)):
179+
raise ValueError('Target distribution must be a JointDistribution or Posterior.')
182180
for sampler in self.samplers.values():
183181
sampler.validate_target()
184182

@@ -257,19 +255,15 @@ def step(self):
257255
# before reinitializing the sampler and then set the state and history back to the sampler
258256

259257
# Extract state and history from sampler
260-
if isinstance(sampler, NUTS): # Special case for NUTS as it is not playing nice with get_state and get_history
261-
sampler.initial_point = sampler.current_point
262-
else:
263-
sampler_state = sampler.get_state()
264-
sampler_history = sampler.get_history()
258+
sampler_state = sampler.get_state()
259+
sampler_history = sampler.get_history()
265260

266261
# Reinitialize sampler
267262
sampler.reinitialize()
268263

269264
# Set state and history back to sampler
270-
if not isinstance(sampler, NUTS): # Again, special case for NUTS.
271-
sampler.set_state(sampler_state)
272-
sampler.set_history(sampler_history)
265+
sampler.set_state(sampler_state)
266+
sampler.set_history(sampler_history)
273267

274268
# Allow for multiple sampling steps in each Gibbs step
275269
for _ in range(self.num_sampling_steps[par_name]):
@@ -309,8 +303,6 @@ def _call_callback(self, sample_index, num_of_samples):
309303
def _initialize_samplers(self):
310304
""" Initialize samplers """
311305
for sampler in self.samplers.values():
312-
if isinstance(sampler, NUTS):
313-
print(f'Warning: NUTS sampler is not fully stateful in HybridGibbs. Sampler will be reinitialized in each Gibbs step.')
314306
sampler.initialize()
315307

316308
def _initialize_num_sampling_steps(self):

cuqi/experimental/mcmc/_hmc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,18 @@ def _initialize(self):
118118
# to epsilon_bar for the remaining sampling steps.
119119
if self.step_size is None:
120120
self._epsilon = self._FindGoodEpsilon()
121+
self.step_size = self._epsilon
121122
else:
122123
self._epsilon = self.step_size
124+
123125
self._epsilon_bar = "unset"
124126

125127
# Parameter mu, does not change during the run
126128
self._mu = np.log(10*self._epsilon)
127129

128130
self._H_bar = 0
129131

130-
# NUTS run diagnostic:
132+
# NUTS run diagnostics
131133
# number of tree nodes created each NUTS iteration
132134
self._num_tree_node = 0
133135

0 Bytes
Binary file not shown.

tests/zexperimental/test_mcmc.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ def create_lmrf_prior_target(dim=16):
252252
return cuqi.distribution.JointDistribution(x, y)(y=y_data)
253253

254254

255-
256255
@pytest.mark.parametrize("target_dim", [16, 128])
257256
def test_UGLA_regression_sample(target_dim):
258257
"""Test the UGLA sampler regression."""
@@ -324,7 +323,7 @@ def test_NUTS_regression_warmup(target: cuqi.density.Density):
324323
Ns=Ns,
325324
Nb=Nb,
326325
strategy="NUTS")
327-
326+
328327
# ============= MYULA ==============
329328
def create_myula_target(dim=16):
330329
"""Create a target for MYULA."""
@@ -419,7 +418,7 @@ def create_conjugate_target(type:str):
419418
cuqi.experimental.mcmc.ConjugateApprox(create_conjugate_target("LMRF-Gamma")),
420419
cuqi.experimental.mcmc.NUTS(cuqi.testproblem.Deconvolution1D(dim=10).posterior, max_depth=4)
421420
]
422-
421+
423422
# List of samplers from cuqi.experimental.mcmc that should be skipped for checkpoint testing
424423
skip_checkpoint = [
425424
cuqi.experimental.mcmc.Sampler,
@@ -967,8 +966,6 @@ def HybridGibbs_target_1():
967966
def test_NUTS_within_HybridGibbs_regression_sample_and_warmup(copy_reference):
968967
""" Test that using NUTS sampler within HybridGibbs sampler works as
969968
expected."""
970-
#TODO: This test might break in the future if the NUTS within HybridGibbs
971-
# is changed to be fully stateful.
972969

973970
Nb=10
974971
Ns=10
@@ -982,7 +979,7 @@ def test_NUTS_within_HybridGibbs_regression_sample_and_warmup(copy_reference):
982979

983980
# Here we do 1 internal steps with NUTS for each Gibbs step
984981
num_sampling_steps = {
985-
"x" : 1,
982+
"x" : 2,
986983
"s" : 1
987984
}
988985

@@ -1080,7 +1077,7 @@ def test_nuts_acceptance_rate(sampler: cuqi.experimental.mcmc.Sampler):
10801077
acc_rate_sum = sum(sampler._acc[2:])
10811078

10821079
assert np.isclose(counter, acc_rate_sum), "NUTS sampler does not update acceptance rate correctly: "+str(counter)+" != "+str(acc_rate_sum)
1083-
1080+
10841081
# ============ Testing of AffineModel with RTO-type samplers ============
10851082

10861083
def test_LinearRTO_with_AffineModel_is_equivalent_to_LinearModel_and_shifted_data():
@@ -1623,3 +1620,81 @@ def test_gibbs_scan_order():
16231620

16241621
sampler = cuqi.experimental.mcmc.HybridGibbs(target, sampling_strategy, scan_order=['x', 's'])
16251622
assert sampler.scan_order == ['x', 's']
1623+
1624+
@pytest.mark.parametrize("step_size", [None, 0.1])
1625+
@pytest.mark.parametrize("num_sampling_steps_x", [1, 5])
1626+
@pytest.mark.parametrize("nb", [5, 20])
1627+
def test_NUTS_within_Gibbs_consistant_with_NUTS(step_size, num_sampling_steps_x, nb):
1628+
""" Test that using NUTS sampler within HybridGibbs sampler is consistant
1629+
with using NUTS sampler alone for sampling and tuning. This test ensures
1630+
NUTS within HybridGibbs statefulness.
1631+
"""
1632+
1633+
ns = 15 # number of sampling steps
1634+
tune_freq = 0.1
1635+
1636+
np.random.seed(0)
1637+
# Forward problem
1638+
A, y_data, info = cuqi.testproblem.Deconvolution1D(
1639+
dim=5, phantom='sinc', noise_std=0.001).get_components()
1640+
1641+
# Bayesian Inverse Problem
1642+
x = cuqi.distribution.GMRF(np.zeros(A.domain_dim), 50)
1643+
y = cuqi.distribution.Gaussian(A@x, 0.001**2)
1644+
1645+
# Posterior
1646+
target = cuqi.distribution.JointDistribution(y, x)(y=y_data)
1647+
1648+
# Sample with NUTS within HybridGibbs
1649+
np.random.seed(0)
1650+
sampling_strategy = {
1651+
"x" : cuqi.experimental.mcmc.NUTS(max_depth=4, step_size=step_size)
1652+
}
1653+
1654+
num_sampling_steps = {
1655+
"x" : num_sampling_steps_x
1656+
}
1657+
1658+
sampler_gibbs = cuqi.experimental.mcmc.HybridGibbs(target,
1659+
sampling_strategy,
1660+
num_sampling_steps)
1661+
sampler_gibbs.warmup(nb, tune_freq=tune_freq)
1662+
sampler_gibbs.sample(ns)
1663+
samples_gibbs = sampler_gibbs.get_samples()["x"].samples
1664+
1665+
# Sample with NUTS alone
1666+
np.random.seed(0)
1667+
sampler_nuts = cuqi.experimental.mcmc.NUTS(target,
1668+
max_depth=4,
1669+
step_size=step_size)
1670+
# Warm up (when num_sampling_steps_x>0, we do not using built-in warmup
1671+
# in order to control number of steps between tuning steps to
1672+
# match Gibbs sampling behavior)
1673+
if num_sampling_steps_x == 1:
1674+
sampler_nuts.warmup(nb, tune_freq=tune_freq)
1675+
else:
1676+
tune_interval = max(int(tune_freq * nb), 1)
1677+
for count in range(nb):
1678+
for _ in range(num_sampling_steps_x):
1679+
sampler_nuts.sample(1)
1680+
if (count+1) % tune_interval == 0:
1681+
sampler_nuts.tune(None, count//tune_interval)
1682+
# Sample
1683+
sampler_nuts.sample(ns * num_sampling_steps_x)
1684+
samples_nuts = sampler_nuts.get_samples().samples
1685+
# skip every num_sampling_steps_x samples to match Gibbs samples
1686+
samples_nuts_skip = samples_nuts[:, num_sampling_steps_x - 1::num_sampling_steps_x]
1687+
1688+
# assert warmup samples are correct:
1689+
assert np.allclose(
1690+
samples_gibbs[:, :nb],
1691+
samples_nuts_skip[:, :nb],
1692+
rtol=1e-5,
1693+
)
1694+
1695+
# assert samples are correct:
1696+
assert np.allclose(
1697+
samples_gibbs[:, nb:],
1698+
samples_nuts_skip[:, nb:],
1699+
rtol=1e-5,
1700+
)

0 commit comments

Comments
 (0)