|
4 | 4 | import inspect |
5 | 5 | from numbers import Number |
6 | 6 |
|
| 7 | +from demos.dev import JointDistribution |
| 8 | + |
7 | 9 | def assert_true_if_sampling_is_equivalent( |
8 | 10 | sampler_old: cuqi.legacy.sampler.Sampler, |
9 | 11 | sampler_new: cuqi.sampler.Sampler, |
@@ -1656,6 +1658,61 @@ def test_online_thinning_with_mala_and_rto(): |
1656 | 1658 | assert np.allclose(samples_rto_1.samples[:,1], samples_rto_2.samples[:,9], rtol=1e-8) |
1657 | 1659 | assert np.allclose(samples_rto_1.samples[:,2], samples_rto_2.samples[:,14], rtol=1e-8) |
1658 | 1660 |
|
| 1661 | +def test_online_thinning_with_hybrid_gibbs(): |
| 1662 | + |
| 1663 | + # example adapted from https://cuqi-dtu.github.io/CUQI-Book/chapter04/gibbs.html |
| 1664 | + |
| 1665 | + # Model and data |
| 1666 | + A, y_data, _ = cuqi.testproblem.Deconvolution1D(phantom='sinc', noise_std=0.005, PSF_param=6).get_components() |
| 1667 | + |
| 1668 | + # Get dimension of signal |
| 1669 | + n = A.domain_dim |
| 1670 | + |
| 1671 | + d = cuqi.distribution.Gamma(1, 1e-4) |
| 1672 | + s = cuqi.distribution.Gamma(1, 1e-4) |
| 1673 | + x = cuqi.distribution.GMRF(np.zeros(n), lambda d: d) |
| 1674 | + y = cuqi.distribution.Gaussian(A@x, cov=lambda s: 1/s) |
| 1675 | + |
| 1676 | + # Create joint distribution |
| 1677 | + joint = cuqi.distribution.JointDistribution(y, x, d, s) |
| 1678 | + |
| 1679 | + # Define posterior by conditioning on the data |
| 1680 | + posterior = joint(y=y_data) |
| 1681 | + |
| 1682 | + # Define sampling strategies |
| 1683 | + sampling_strategy_1 = { |
| 1684 | + 'x': cuqi.sampler.LinearRTO(), |
| 1685 | + 'd': cuqi.sampler.Conjugate(), |
| 1686 | + 's': cuqi.sampler.Conjugate() |
| 1687 | + } |
| 1688 | + sampling_strategy_2 = { |
| 1689 | + 'x': cuqi.sampler.LinearRTO(), |
| 1690 | + 'd': cuqi.sampler.Conjugate(), |
| 1691 | + 's': cuqi.sampler.Conjugate() |
| 1692 | + } |
| 1693 | + |
| 1694 | + # Define Gibbs samplers |
| 1695 | + sampler_1 = cuqi.sampler.HybridGibbs(posterior, sampling_strategy_1) |
| 1696 | + sampler_2 = cuqi.sampler.HybridGibbs(posterior, sampling_strategy_2) |
| 1697 | + |
| 1698 | + # Run sampler with different online thinnning Nt |
| 1699 | + np.random.seed(0) |
| 1700 | + samples_1 = sampler_1.sample(20, Nt=5).get_samples() |
| 1701 | + np.random.seed(0) |
| 1702 | + samples_2 = sampler_2.sample(20).get_samples() # by default Nt=1 |
| 1703 | + |
| 1704 | + # Compare samples |
| 1705 | + assert np.allclose(samples_1['d'].samples[:, 0], samples_2['d'].samples[:, 4], rtol=1e-5) |
| 1706 | + assert np.allclose(samples_1['d'].samples[:, 1], samples_2['d'].samples[:, 9], rtol=1e-5) |
| 1707 | + assert np.allclose(samples_1['d'].samples[:, 2], samples_2['d'].samples[:, 14], rtol=1e-5) |
| 1708 | + assert np.allclose(samples_1['s'].samples[:, 0], samples_2['s'].samples[:, 4], rtol=1e-5) |
| 1709 | + assert np.allclose(samples_1['s'].samples[:, 1], samples_2['s'].samples[:, 9], rtol=1e-5) |
| 1710 | + assert np.allclose(samples_1['s'].samples[:, 2], samples_2['s'].samples[:, 14], rtol=1e-5) |
| 1711 | + assert np.allclose(samples_1['x'].samples[:, 0], samples_2['x'].samples[:, 4], rtol=1e-5) |
| 1712 | + assert np.allclose(samples_1['x'].samples[:, 1], samples_2['x'].samples[:, 9], rtol=1e-5) |
| 1713 | + assert np.allclose(samples_1['x'].samples[:, 2], samples_2['x'].samples[:, 14], rtol=1e-5) |
| 1714 | + |
| 1715 | + |
1659 | 1716 | @pytest.mark.parametrize("step_size", [None, 0.1]) |
1660 | 1717 | @pytest.mark.parametrize("num_sampling_steps_x", [1, 5]) |
1661 | 1718 | @pytest.mark.parametrize("nb", [5, 20]) |
|
0 commit comments