@@ -1621,6 +1621,41 @@ def test_gibbs_scan_order():
16211621 sampler = cuqi .experimental .mcmc .HybridGibbs (target , sampling_strategy , scan_order = ['x' , 's' ])
16221622 assert sampler .scan_order == ['x' , 's' ]
16231623
1624+ def test_online_thinning_with_mala_and_rto ():
1625+
1626+ # Define LinearModel and data
1627+ A , y_obs , _ = cuqi .testproblem .Deconvolution1D ().get_components ()
1628+
1629+ # Define Bayesian Problem
1630+ x = cuqi .distribution .GMRF (np .zeros (A .domain_dim ), 100 )
1631+ y = cuqi .distribution .Gaussian (A @x , 0.01 ** 2 )
1632+ posterior = cuqi .distribution .JointDistribution (x , y )(y = y_obs )
1633+
1634+ # Set up MALA and RTO samplers
1635+ sampler_mala_1 = cuqi .experimental .mcmc .MALA (posterior , scale = 0.01 )
1636+ sampler_mala_2 = cuqi .experimental .mcmc .MALA (posterior , scale = 0.01 )
1637+ sampler_rto_1 = cuqi .experimental .mcmc .LinearRTO (posterior , maxit = 1000 , tol = 1e-8 )
1638+ sampler_rto_2 = cuqi .experimental .mcmc .LinearRTO (posterior , maxit = 1000 , tol = 1e-8 )
1639+
1640+ # Sample MALA and RTO with fixed seed, but different online thinning Nt
1641+ np .random .seed (0 )
1642+ samples_mala_1 = sampler_mala_1 .sample (100 ,Nt = 5 ).get_samples ()
1643+ np .random .seed (0 )
1644+ samples_mala_2 = sampler_mala_2 .sample (100 ,Nt = 1 ).get_samples ()
1645+ np .random .seed (0 )
1646+ samples_rto_1 = sampler_rto_1 .sample (100 ,Nt = 5 ).get_samples ()
1647+ np .random .seed (0 )
1648+ samples_rto_2 = sampler_rto_2 .sample (100 ,Nt = 1 ).get_samples ()
1649+
1650+ # Check that the samples are the same for MALA
1651+ assert np .allclose (samples_mala_1 .samples [:,0 ], samples_mala_2 .samples [:,4 ], rtol = 1e-8 )
1652+ assert np .allclose (samples_mala_1 .samples [:,1 ], samples_mala_2 .samples [:,9 ], rtol = 1e-8 )
1653+ assert np .allclose (samples_mala_1 .samples [:,2 ], samples_mala_2 .samples [:,14 ], rtol = 1e-8 )
1654+ # Check that the samples are the same for RTO
1655+ assert np .allclose (samples_rto_1 .samples [:,0 ], samples_rto_2 .samples [:,4 ], rtol = 1e-8 )
1656+ assert np .allclose (samples_rto_1 .samples [:,1 ], samples_rto_2 .samples [:,9 ], rtol = 1e-8 )
1657+ assert np .allclose (samples_rto_1 .samples [:,2 ], samples_rto_2 .samples [:,14 ], rtol = 1e-8 )
1658+
16241659@pytest .mark .parametrize ("step_size" , [None , 0.1 ])
16251660@pytest .mark .parametrize ("num_sampling_steps_x" , [1 , 5 ])
16261661@pytest .mark .parametrize ("nb" , [5 , 20 ])
0 commit comments