Skip to content

Commit f08eb2a

Browse files
authored
Merge pull request #690 from CUQI-DTU/feature-online-thinning
add online thinning option to samplers
2 parents 2cf72ec + e515974 commit f08eb2a

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

cuqi/experimental/mcmc/_sampler.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,16 @@ def load_checkpoint(self, path):
203203

204204
self.set_state(state)
205205

206-
def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
206+
def sample(self, Ns, Nt=1, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
207207
""" Sample Ns samples from the target density.
208208
209209
Parameters
210210
----------
211211
Ns : int
212212
The number of samples to draw.
213+
214+
Nt : int, optional, default=1
215+
The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
213216
214217
batch_size : int, optional
215218
The batch size for saving samples to disk. If 0, no batching is used. If positive, samples are saved to disk in batches of the specified size.
@@ -233,7 +236,8 @@ def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
233236

234237
# Store samples
235238
self._acc.append(acc)
236-
self._samples.append(self.current_point)
239+
if (Nt > 0) and ((idx + 1) % Nt == 0):
240+
self._samples.append(self.current_point)
237241

238242
# display acc rate at progress bar
239243
pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")
@@ -248,14 +252,17 @@ def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
248252
return self
249253

250254

251-
def warmup(self, Nb, tune_freq=0.1) -> 'Sampler':
255+
def warmup(self, Nb, Nt=1, tune_freq=0.1) -> 'Sampler':
252256
""" Warmup the sampler by drawing Nb samples.
253257
254258
Parameters
255259
----------
256260
Nb : int
257261
The number of samples to draw during warmup.
258262
263+
Nt : int, optional, default=1
264+
The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
265+
259266
tune_freq : float, optional
260267
The frequency of tuning. Tuning is performed every tune_freq*Nb samples.
261268
@@ -278,7 +285,8 @@ def warmup(self, Nb, tune_freq=0.1) -> 'Sampler':
278285

279286
# Store samples
280287
self._acc.append(acc)
281-
self._samples.append(self.current_point)
288+
if (Nt > 0) and ((idx + 1) % Nt == 0):
289+
self._samples.append(self.current_point)
282290

283291
# display acc rate at progress bar
284292
pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")

tests/zexperimental/test_mcmc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,41 @@ def test_gibbs_scan_order():
16211621
sampler = cuqi.experimental.mcmc.HybridGibbs(target, sampling_strategy, scan_order=['x', 's'])
16221622
assert sampler.scan_order == ['x', 's']
16231623

1624+
def test_online_thinning_with_mala_and_rto():
1625+
1626+
# Define LinearModel and data
1627+
A, y_obs, _ = cuqi.testproblem.Deconvolution1D().get_components()
1628+
1629+
# Define Bayesian Problem
1630+
x = cuqi.distribution.GMRF(np.zeros(A.domain_dim), 100)
1631+
y = cuqi.distribution.Gaussian(A@x, 0.01**2)
1632+
posterior = cuqi.distribution.JointDistribution(x, y)(y=y_obs)
1633+
1634+
# Set up MALA and RTO samplers
1635+
sampler_mala_1 = cuqi.experimental.mcmc.MALA(posterior, scale=0.01)
1636+
sampler_mala_2 = cuqi.experimental.mcmc.MALA(posterior, scale=0.01)
1637+
sampler_rto_1 = cuqi.experimental.mcmc.LinearRTO(posterior, maxit=1000, tol=1e-8)
1638+
sampler_rto_2 = cuqi.experimental.mcmc.LinearRTO(posterior, maxit=1000, tol=1e-8)
1639+
1640+
# Sample MALA and RTO with fixed seed, but different online thinning Nt
1641+
np.random.seed(0)
1642+
samples_mala_1 = sampler_mala_1.sample(100,Nt=5).get_samples()
1643+
np.random.seed(0)
1644+
samples_mala_2 = sampler_mala_2.sample(100,Nt=1).get_samples()
1645+
np.random.seed(0)
1646+
samples_rto_1 = sampler_rto_1.sample(100,Nt=5).get_samples()
1647+
np.random.seed(0)
1648+
samples_rto_2 = sampler_rto_2.sample(100,Nt=1).get_samples()
1649+
1650+
# Check that the samples are the same for MALA
1651+
assert np.allclose(samples_mala_1.samples[:,0], samples_mala_2.samples[:,4], rtol=1e-8)
1652+
assert np.allclose(samples_mala_1.samples[:,1], samples_mala_2.samples[:,9], rtol=1e-8)
1653+
assert np.allclose(samples_mala_1.samples[:,2], samples_mala_2.samples[:,14], rtol=1e-8)
1654+
# Check that the samples are the same for RTO
1655+
assert np.allclose(samples_rto_1.samples[:,0], samples_rto_2.samples[:,4], rtol=1e-8)
1656+
assert np.allclose(samples_rto_1.samples[:,1], samples_rto_2.samples[:,9], rtol=1e-8)
1657+
assert np.allclose(samples_rto_1.samples[:,2], samples_rto_2.samples[:,14], rtol=1e-8)
1658+
16241659
@pytest.mark.parametrize("step_size", [None, 0.1])
16251660
@pytest.mark.parametrize("num_sampling_steps_x", [1, 5])
16261661
@pytest.mark.parametrize("nb", [5, 20])

0 commit comments

Comments
 (0)