Skip to content

Commit 4f277f3

Browse files
committed
add unit test
1 parent 0df6e45 commit 4f277f3

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

tests/zexperimental/test_mcmc.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import inspect
55
from numbers import Number
66

7+
from demos.dev import JointDistribution
8+
79
def assert_true_if_sampling_is_equivalent(
810
sampler_old: cuqi.legacy.sampler.Sampler,
911
sampler_new: cuqi.sampler.Sampler,
@@ -1656,6 +1658,61 @@ def test_online_thinning_with_mala_and_rto():
16561658
assert np.allclose(samples_rto_1.samples[:,1], samples_rto_2.samples[:,9], rtol=1e-8)
16571659
assert np.allclose(samples_rto_1.samples[:,2], samples_rto_2.samples[:,14], rtol=1e-8)
16581660

1661+
def test_online_thinning_with_hybrid_gibbs():
1662+
1663+
# example adapted from https://cuqi-dtu.github.io/CUQI-Book/chapter04/gibbs.html
1664+
1665+
# Model and data
1666+
A, y_data, _ = cuqi.testproblem.Deconvolution1D(phantom='sinc', noise_std=0.005, PSF_param=6).get_components()
1667+
1668+
# Get dimension of signal
1669+
n = A.domain_dim
1670+
1671+
d = cuqi.distribution.Gamma(1, 1e-4)
1672+
s = cuqi.distribution.Gamma(1, 1e-4)
1673+
x = cuqi.distribution.GMRF(np.zeros(n), lambda d: d)
1674+
y = cuqi.distribution.Gaussian(A@x, cov=lambda s: 1/s)
1675+
1676+
# Create joint distribution
1677+
joint = cuqi.distribution.JointDistribution(y, x, d, s)
1678+
1679+
# Define posterior by conditioning on the data
1680+
posterior = joint(y=y_data)
1681+
1682+
# Define sampling strategies
1683+
sampling_strategy_1 = {
1684+
'x': cuqi.sampler.LinearRTO(),
1685+
'd': cuqi.sampler.Conjugate(),
1686+
's': cuqi.sampler.Conjugate()
1687+
}
1688+
sampling_strategy_2 = {
1689+
'x': cuqi.sampler.LinearRTO(),
1690+
'd': cuqi.sampler.Conjugate(),
1691+
's': cuqi.sampler.Conjugate()
1692+
}
1693+
1694+
# Define Gibbs samplers
1695+
sampler_1 = cuqi.sampler.HybridGibbs(posterior, sampling_strategy_1)
1696+
sampler_2 = cuqi.sampler.HybridGibbs(posterior, sampling_strategy_2)
1697+
1698+
# Run sampler with different online thinnning Nt
1699+
np.random.seed(0)
1700+
samples_1 = sampler_1.sample(20, Nt=5).get_samples()
1701+
np.random.seed(0)
1702+
samples_2 = sampler_2.sample(20).get_samples() # by default Nt=1
1703+
1704+
# Compare samples
1705+
assert np.allclose(samples_1['d'].samples[:, 0], samples_2['d'].samples[:, 4], rtol=1e-5)
1706+
assert np.allclose(samples_1['d'].samples[:, 1], samples_2['d'].samples[:, 9], rtol=1e-5)
1707+
assert np.allclose(samples_1['d'].samples[:, 2], samples_2['d'].samples[:, 14], rtol=1e-5)
1708+
assert np.allclose(samples_1['s'].samples[:, 0], samples_2['s'].samples[:, 4], rtol=1e-5)
1709+
assert np.allclose(samples_1['s'].samples[:, 1], samples_2['s'].samples[:, 9], rtol=1e-5)
1710+
assert np.allclose(samples_1['s'].samples[:, 2], samples_2['s'].samples[:, 14], rtol=1e-5)
1711+
assert np.allclose(samples_1['x'].samples[:, 0], samples_2['x'].samples[:, 4], rtol=1e-5)
1712+
assert np.allclose(samples_1['x'].samples[:, 1], samples_2['x'].samples[:, 9], rtol=1e-5)
1713+
assert np.allclose(samples_1['x'].samples[:, 2], samples_2['x'].samples[:, 14], rtol=1e-5)
1714+
1715+
16591716
@pytest.mark.parametrize("step_size", [None, 0.1])
16601717
@pytest.mark.parametrize("num_sampling_steps_x", [1, 5])
16611718
@pytest.mark.parametrize("nb", [5, 20])

0 commit comments

Comments
 (0)