@@ -1656,6 +1656,61 @@ def test_online_thinning_with_mala_and_rto():
16561656 assert np .allclose (samples_rto_1 .samples [:,1 ], samples_rto_2 .samples [:,9 ], rtol = 1e-8 )
16571657 assert np .allclose (samples_rto_1 .samples [:,2 ], samples_rto_2 .samples [:,14 ], rtol = 1e-8 )
16581658
1659+ def test_online_thinning_with_hybrid_gibbs ():
1660+
1661+ # example adapted from https://cuqi-dtu.github.io/CUQI-Book/chapter04/gibbs.html
1662+
1663+ # Model and data
1664+ A , y_data , _ = cuqi .testproblem .Deconvolution1D (phantom = 'sinc' , noise_std = 0.005 , PSF_param = 6 ).get_components ()
1665+
1666+ # Get dimension of signal
1667+ n = A .domain_dim
1668+
1669+ d = cuqi .distribution .Gamma (1 , 1e-4 )
1670+ s = cuqi .distribution .Gamma (1 , 1e-4 )
1671+ x = cuqi .distribution .GMRF (np .zeros (n ), lambda d : d )
1672+ y = cuqi .distribution .Gaussian (A @x , cov = lambda s : 1 / s )
1673+
1674+ # Create joint distribution
1675+ joint = cuqi .distribution .JointDistribution (y , x , d , s )
1676+
1677+ # Define posterior by conditioning on the data
1678+ posterior = joint (y = y_data )
1679+
1680+ # Define sampling strategies
1681+ sampling_strategy_1 = {
1682+ 'x' : cuqi .sampler .LinearRTO (),
1683+ 'd' : cuqi .sampler .Conjugate (),
1684+ 's' : cuqi .sampler .Conjugate ()
1685+ }
1686+ sampling_strategy_2 = {
1687+ 'x' : cuqi .sampler .LinearRTO (),
1688+ 'd' : cuqi .sampler .Conjugate (),
1689+ 's' : cuqi .sampler .Conjugate ()
1690+ }
1691+
1692+ # Define Gibbs samplers
1693+ sampler_1 = cuqi .sampler .HybridGibbs (posterior , sampling_strategy_1 )
1694+ sampler_2 = cuqi .sampler .HybridGibbs (posterior , sampling_strategy_2 )
1695+
1696+ # Run sampler with different online thinnning Nt
1697+ np .random .seed (0 )
1698+ samples_1 = sampler_1 .sample (20 , Nt = 5 ).get_samples ()
1699+ np .random .seed (0 )
1700+ samples_2 = sampler_2 .sample (20 ).get_samples () # by default Nt=1
1701+
1702+ # Compare samples
1703+ assert np .allclose (samples_1 ['d' ].samples [:, 0 ], samples_2 ['d' ].samples [:, 4 ], rtol = 1e-5 )
1704+ assert np .allclose (samples_1 ['d' ].samples [:, 1 ], samples_2 ['d' ].samples [:, 9 ], rtol = 1e-5 )
1705+ assert np .allclose (samples_1 ['d' ].samples [:, 2 ], samples_2 ['d' ].samples [:, 14 ], rtol = 1e-5 )
1706+ assert np .allclose (samples_1 ['s' ].samples [:, 0 ], samples_2 ['s' ].samples [:, 4 ], rtol = 1e-5 )
1707+ assert np .allclose (samples_1 ['s' ].samples [:, 1 ], samples_2 ['s' ].samples [:, 9 ], rtol = 1e-5 )
1708+ assert np .allclose (samples_1 ['s' ].samples [:, 2 ], samples_2 ['s' ].samples [:, 14 ], rtol = 1e-5 )
1709+ assert np .allclose (samples_1 ['x' ].samples [:, 0 ], samples_2 ['x' ].samples [:, 4 ], rtol = 1e-5 )
1710+ assert np .allclose (samples_1 ['x' ].samples [:, 1 ], samples_2 ['x' ].samples [:, 9 ], rtol = 1e-5 )
1711+ assert np .allclose (samples_1 ['x' ].samples [:, 2 ], samples_2 ['x' ].samples [:, 14 ], rtol = 1e-5 )
1712+
1713+
16591714@pytest .mark .parametrize ("step_size" , [None , 0.1 ])
16601715@pytest .mark .parametrize ("num_sampling_steps_x" , [1 , 5 ])
16611716@pytest .mark .parametrize ("nb" , [5 , 20 ])
0 commit comments