@@ -252,7 +252,6 @@ def create_lmrf_prior_target(dim=16):
252252 return cuqi .distribution .JointDistribution (x , y )(y = y_data )
253253
254254
255-
256255@pytest .mark .parametrize ("target_dim" , [16 , 128 ])
257256def test_UGLA_regression_sample (target_dim ):
258257 """Test the UGLA sampler regression."""
@@ -324,7 +323,7 @@ def test_NUTS_regression_warmup(target: cuqi.density.Density):
324323 Ns = Ns ,
325324 Nb = Nb ,
326325 strategy = "NUTS" )
327-
326+
328327# ============= MYULA ==============
329328def create_myula_target (dim = 16 ):
330329 """Create a target for MYULA."""
@@ -419,7 +418,7 @@ def create_conjugate_target(type:str):
419418 cuqi .experimental .mcmc .ConjugateApprox (create_conjugate_target ("LMRF-Gamma" )),
420419 cuqi .experimental .mcmc .NUTS (cuqi .testproblem .Deconvolution1D (dim = 10 ).posterior , max_depth = 4 )
421420]
422-
421+
423422# List of samplers from cuqi.experimental.mcmc that should be skipped for checkpoint testing
424423skip_checkpoint = [
425424 cuqi .experimental .mcmc .Sampler ,
@@ -967,8 +966,6 @@ def HybridGibbs_target_1():
967966def test_NUTS_within_HybridGibbs_regression_sample_and_warmup (copy_reference ):
968967 """ Test that using NUTS sampler within HybridGibbs sampler works as
969968 expected."""
970- #TODO: This test might break in the future if the NUTS within HybridGibbs
971- # is changed to be fully stateful.
972969
973970 Nb = 10
974971 Ns = 10
@@ -982,7 +979,7 @@ def test_NUTS_within_HybridGibbs_regression_sample_and_warmup(copy_reference):
982979
983980 # Here we do 1 internal steps with NUTS for each Gibbs step
984981 num_sampling_steps = {
985- "x" : 1 ,
982+ "x" : 2 ,
986983 "s" : 1
987984 }
988985
@@ -1080,7 +1077,7 @@ def test_nuts_acceptance_rate(sampler: cuqi.experimental.mcmc.Sampler):
10801077 acc_rate_sum = sum (sampler ._acc [2 :])
10811078
10821079 assert np .isclose (counter , acc_rate_sum ), "NUTS sampler does not update acceptance rate correctly: " + str (counter )+ " != " + str (acc_rate_sum )
1083-
1080+
10841081# ============ Testing of AffineModel with RTO-type samplers ============
10851082
10861083def test_LinearRTO_with_AffineModel_is_equivalent_to_LinearModel_and_shifted_data ():
@@ -1657,4 +1654,82 @@ def test_online_thinning_with_mala_and_rto():
16571654 # Check that the samples are the same for RTO
16581655 assert np .allclose (samples_rto_1 .samples [:,0 ], samples_rto_2 .samples [:,4 ], rtol = 1e-8 )
16591656 assert np .allclose (samples_rto_1 .samples [:,1 ], samples_rto_2 .samples [:,9 ], rtol = 1e-8 )
1660- assert np .allclose (samples_rto_1 .samples [:,2 ], samples_rto_2 .samples [:,14 ], rtol = 1e-8 )
1657+ assert np .allclose (samples_rto_1 .samples [:,2 ], samples_rto_2 .samples [:,14 ], rtol = 1e-8 )
1658+
1659+ @pytest .mark .parametrize ("step_size" , [None , 0.1 ])
1660+ @pytest .mark .parametrize ("num_sampling_steps_x" , [1 , 5 ])
1661+ @pytest .mark .parametrize ("nb" , [5 , 20 ])
1662+ def test_NUTS_within_Gibbs_consistant_with_NUTS (step_size , num_sampling_steps_x , nb ):
1663+ """ Test that using NUTS sampler within HybridGibbs sampler is consistant
1664+ with using NUTS sampler alone for sampling and tuning. This test ensures
1665+ NUTS within HybridGibbs statefulness.
1666+ """
1667+
1668+ ns = 15 # number of sampling steps
1669+ tune_freq = 0.1
1670+
1671+ np .random .seed (0 )
1672+ # Forward problem
1673+ A , y_data , info = cuqi .testproblem .Deconvolution1D (
1674+ dim = 5 , phantom = 'sinc' , noise_std = 0.001 ).get_components ()
1675+
1676+ # Bayesian Inverse Problem
1677+ x = cuqi .distribution .GMRF (np .zeros (A .domain_dim ), 50 )
1678+ y = cuqi .distribution .Gaussian (A @x , 0.001 ** 2 )
1679+
1680+ # Posterior
1681+ target = cuqi .distribution .JointDistribution (y , x )(y = y_data )
1682+
1683+ # Sample with NUTS within HybridGibbs
1684+ np .random .seed (0 )
1685+ sampling_strategy = {
1686+ "x" : cuqi .experimental .mcmc .NUTS (max_depth = 4 , step_size = step_size )
1687+ }
1688+
1689+ num_sampling_steps = {
1690+ "x" : num_sampling_steps_x
1691+ }
1692+
1693+ sampler_gibbs = cuqi .experimental .mcmc .HybridGibbs (target ,
1694+ sampling_strategy ,
1695+ num_sampling_steps )
1696+ sampler_gibbs .warmup (nb , tune_freq = tune_freq )
1697+ sampler_gibbs .sample (ns )
1698+ samples_gibbs = sampler_gibbs .get_samples ()["x" ].samples
1699+
1700+ # Sample with NUTS alone
1701+ np .random .seed (0 )
1702+ sampler_nuts = cuqi .experimental .mcmc .NUTS (target ,
1703+ max_depth = 4 ,
1704+ step_size = step_size )
1705+ # Warm up (when num_sampling_steps_x>0, we do not using built-in warmup
1706+ # in order to control number of steps between tuning steps to
1707+ # match Gibbs sampling behavior)
1708+ if num_sampling_steps_x == 1 :
1709+ sampler_nuts .warmup (nb , tune_freq = tune_freq )
1710+ else :
1711+ tune_interval = max (int (tune_freq * nb ), 1 )
1712+ for count in range (nb ):
1713+ for _ in range (num_sampling_steps_x ):
1714+ sampler_nuts .sample (1 )
1715+ if (count + 1 ) % tune_interval == 0 :
1716+ sampler_nuts .tune (None , count // tune_interval )
1717+ # Sample
1718+ sampler_nuts .sample (ns * num_sampling_steps_x )
1719+ samples_nuts = sampler_nuts .get_samples ().samples
1720+ # skip every num_sampling_steps_x samples to match Gibbs samples
1721+ samples_nuts_skip = samples_nuts [:, num_sampling_steps_x - 1 ::num_sampling_steps_x ]
1722+
1723+ # assert warmup samples are correct:
1724+ assert np .allclose (
1725+ samples_gibbs [:, :nb ],
1726+ samples_nuts_skip [:, :nb ],
1727+ rtol = 1e-5 ,
1728+ )
1729+
1730+ # assert samples are correct:
1731+ assert np .allclose (
1732+ samples_gibbs [:, nb :],
1733+ samples_nuts_skip [:, nb :],
1734+ rtol = 1e-5 ,
1735+ )
0 commit comments