Skip to content

Commit 5e8ccf8

Browse files
authored
Merge pull request #707 from CUQI-DTU/gibbs_online_thinning
Add online thinning to HybridGibbs
2 parents e8fd90f + b634cab commit 5e8ccf8

File tree

2 files changed

+69
-7
lines changed

2 files changed

+69
-7
lines changed

cuqi/sampler/_gibbs.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,33 +180,39 @@ def validate_targets(self):
180180
for sampler in self.samplers.values():
181181
sampler.validate_target()
182182

183-
def sample(self, Ns) -> 'HybridGibbs':
183+
def sample(self, Ns, Nt=1) -> 'HybridGibbs':
184184
""" Sample from the joint distribution using Gibbs sampling
185185
186186
Parameters
187187
----------
188188
Ns : int
189189
The number of samples to draw.
190-
190+
Nt : int, optional, default=1
191+
The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
192+
191193
"""
192194
for idx in tqdm(range(Ns), "Sample: "):
193195

194196
self.step()
195197

196-
self._store_samples()
198+
if (Nt > 0) and ((idx + 1) % Nt == 0):
199+
self._store_samples()
197200

198201
# Call callback function if specified
199202
self._call_callback(idx, Ns)
200203

201204
return self
202205

203-
def warmup(self, Nb, tune_freq=0.1) -> 'HybridGibbs':
206+
def warmup(self, Nb, Nt=1, tune_freq=0.1) -> 'HybridGibbs':
204207
""" Warmup (tune) the samplers in the Gibbs sampling scheme
205208
206209
Parameters
207210
----------
208211
Nb : int
209212
The number of samples to draw during warmup.
213+
214+
Nt : int, optional, default=1
215+
The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
210216
211217
tune_freq : float, optional
212218
Frequency of tuning the samplers. Tuning is performed every tune_freq*Nb steps.
@@ -221,9 +227,10 @@ def warmup(self, Nb, tune_freq=0.1) -> 'HybridGibbs':
221227

222228
# Tune the sampler at tuning intervals (matching behavior of Sampler class)
223229
if (idx + 1) % tune_interval == 0:
224-
self.tune(tune_interval, idx // tune_interval)
225-
226-
self._store_samples()
230+
self.tune(tune_interval, idx // tune_interval)
231+
232+
if (Nt > 0) and ((idx + 1) % Nt == 0):
233+
self._store_samples()
227234

228235
# Call callback function if specified
229236
self._call_callback(idx, Nb)

tests/zexperimental/test_mcmc.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,6 +1656,61 @@ def test_online_thinning_with_mala_and_rto():
16561656
assert np.allclose(samples_rto_1.samples[:,1], samples_rto_2.samples[:,9], rtol=1e-8)
16571657
assert np.allclose(samples_rto_1.samples[:,2], samples_rto_2.samples[:,14], rtol=1e-8)
16581658

1659+
def test_online_thinning_with_hybrid_gibbs():
1660+
1661+
# example adapted from https://cuqi-dtu.github.io/CUQI-Book/chapter04/gibbs.html
1662+
1663+
# Model and data
1664+
A, y_data, _ = cuqi.testproblem.Deconvolution1D(phantom='sinc', noise_std=0.005, PSF_param=6).get_components()
1665+
1666+
# Get dimension of signal
1667+
n = A.domain_dim
1668+
1669+
d = cuqi.distribution.Gamma(1, 1e-4)
1670+
s = cuqi.distribution.Gamma(1, 1e-4)
1671+
x = cuqi.distribution.GMRF(np.zeros(n), lambda d: d)
1672+
y = cuqi.distribution.Gaussian(A@x, cov=lambda s: 1/s)
1673+
1674+
# Create joint distribution
1675+
joint = cuqi.distribution.JointDistribution(y, x, d, s)
1676+
1677+
# Define posterior by conditioning on the data
1678+
posterior = joint(y=y_data)
1679+
1680+
# Define sampling strategies
1681+
sampling_strategy_1 = {
1682+
'x': cuqi.sampler.LinearRTO(),
1683+
'd': cuqi.sampler.Conjugate(),
1684+
's': cuqi.sampler.Conjugate()
1685+
}
1686+
sampling_strategy_2 = {
1687+
'x': cuqi.sampler.LinearRTO(),
1688+
'd': cuqi.sampler.Conjugate(),
1689+
's': cuqi.sampler.Conjugate()
1690+
}
1691+
1692+
# Define Gibbs samplers
1693+
sampler_1 = cuqi.sampler.HybridGibbs(posterior, sampling_strategy_1)
1694+
sampler_2 = cuqi.sampler.HybridGibbs(posterior, sampling_strategy_2)
1695+
1696+
# Run sampler with different online thinnning Nt
1697+
np.random.seed(0)
1698+
samples_1 = sampler_1.sample(20, Nt=5).get_samples()
1699+
np.random.seed(0)
1700+
samples_2 = sampler_2.sample(20).get_samples() # by default Nt=1
1701+
1702+
# Compare samples
1703+
assert np.allclose(samples_1['d'].samples[:, 0], samples_2['d'].samples[:, 4], rtol=1e-5)
1704+
assert np.allclose(samples_1['d'].samples[:, 1], samples_2['d'].samples[:, 9], rtol=1e-5)
1705+
assert np.allclose(samples_1['d'].samples[:, 2], samples_2['d'].samples[:, 14], rtol=1e-5)
1706+
assert np.allclose(samples_1['s'].samples[:, 0], samples_2['s'].samples[:, 4], rtol=1e-5)
1707+
assert np.allclose(samples_1['s'].samples[:, 1], samples_2['s'].samples[:, 9], rtol=1e-5)
1708+
assert np.allclose(samples_1['s'].samples[:, 2], samples_2['s'].samples[:, 14], rtol=1e-5)
1709+
assert np.allclose(samples_1['x'].samples[:, 0], samples_2['x'].samples[:, 4], rtol=1e-5)
1710+
assert np.allclose(samples_1['x'].samples[:, 1], samples_2['x'].samples[:, 9], rtol=1e-5)
1711+
assert np.allclose(samples_1['x'].samples[:, 2], samples_2['x'].samples[:, 14], rtol=1e-5)
1712+
1713+
16591714
@pytest.mark.parametrize("step_size", [None, 0.1])
16601715
@pytest.mark.parametrize("num_sampling_steps_x", [1, 5])
16611716
@pytest.mark.parametrize("nb", [5, 20])

0 commit comments

Comments
 (0)