@@ -252,7 +252,6 @@ def create_lmrf_prior_target(dim=16):
252252 return cuqi .distribution .JointDistribution (x , y )(y = y_data )
253253
254254
255-
256255@pytest .mark .parametrize ("target_dim" , [16 , 128 ])
257256def test_UGLA_regression_sample (target_dim ):
258257 """Test the UGLA sampler regression."""
@@ -324,7 +323,7 @@ def test_NUTS_regression_warmup(target: cuqi.density.Density):
324323 Ns = Ns ,
325324 Nb = Nb ,
326325 strategy = "NUTS" )
327-
326+
328327# ============= MYULA ==============
329328def create_myula_target (dim = 16 ):
330329 """Create a target for MYULA."""
@@ -419,7 +418,7 @@ def create_conjugate_target(type:str):
419418 cuqi .experimental .mcmc .ConjugateApprox (create_conjugate_target ("LMRF-Gamma" )),
420419 cuqi .experimental .mcmc .NUTS (cuqi .testproblem .Deconvolution1D (dim = 10 ).posterior , max_depth = 4 )
421420]
422-
421+
423422# List of samplers from cuqi.experimental.mcmc that should be skipped for checkpoint testing
424423skip_checkpoint = [
425424 cuqi .experimental .mcmc .Sampler ,
@@ -967,8 +966,6 @@ def HybridGibbs_target_1():
967966def test_NUTS_within_HybridGibbs_regression_sample_and_warmup (copy_reference ):
968967 """ Test that using NUTS sampler within HybridGibbs sampler works as
969968 expected."""
970- #TODO: This test might break in the future if the NUTS within HybridGibbs
971- # is changed to be fully stateful.
972969
973970 Nb = 10
974971 Ns = 10
@@ -982,7 +979,7 @@ def test_NUTS_within_HybridGibbs_regression_sample_and_warmup(copy_reference):
982979
983980 # Here we do 1 internal steps with NUTS for each Gibbs step
984981 num_sampling_steps = {
985- "x" : 1 ,
982+ "x" : 2 ,
986983 "s" : 1
987984 }
988985
@@ -1080,7 +1077,7 @@ def test_nuts_acceptance_rate(sampler: cuqi.experimental.mcmc.Sampler):
10801077 acc_rate_sum = sum (sampler ._acc [2 :])
10811078
10821079 assert np .isclose (counter , acc_rate_sum ), "NUTS sampler does not update acceptance rate correctly: " + str (counter )+ " != " + str (acc_rate_sum )
1083-
1080+
10841081# ============ Testing of AffineModel with RTO-type samplers ============
10851082
10861083def test_LinearRTO_with_AffineModel_is_equivalent_to_LinearModel_and_shifted_data ():
@@ -1623,3 +1620,81 @@ def test_gibbs_scan_order():
16231620
16241621 sampler = cuqi .experimental .mcmc .HybridGibbs (target , sampling_strategy , scan_order = ['x' , 's' ])
16251622 assert sampler .scan_order == ['x' , 's' ]
1623+
1624+ @pytest .mark .parametrize ("step_size" , [None , 0.1 ])
1625+ @pytest .mark .parametrize ("num_sampling_steps_x" , [1 , 5 ])
1626+ @pytest .mark .parametrize ("nb" , [5 , 20 ])
1627+ def test_NUTS_within_Gibbs_consistant_with_NUTS (step_size , num_sampling_steps_x , nb ):
1628+ """ Test that using NUTS sampler within HybridGibbs sampler is consistant
1629+ with using NUTS sampler alone for sampling and tuning. This test ensures
1630+ NUTS within HybridGibbs statefulness.
1631+ """
1632+
1633+ ns = 15 # number of sampling steps
1634+ tune_freq = 0.1
1635+
1636+ np .random .seed (0 )
1637+ # Forward problem
1638+ A , y_data , info = cuqi .testproblem .Deconvolution1D (
1639+ dim = 5 , phantom = 'sinc' , noise_std = 0.001 ).get_components ()
1640+
1641+ # Bayesian Inverse Problem
1642+ x = cuqi .distribution .GMRF (np .zeros (A .domain_dim ), 50 )
1643+ y = cuqi .distribution .Gaussian (A @x , 0.001 ** 2 )
1644+
1645+ # Posterior
1646+ target = cuqi .distribution .JointDistribution (y , x )(y = y_data )
1647+
1648+ # Sample with NUTS within HybridGibbs
1649+ np .random .seed (0 )
1650+ sampling_strategy = {
1651+ "x" : cuqi .experimental .mcmc .NUTS (max_depth = 4 , step_size = step_size )
1652+ }
1653+
1654+ num_sampling_steps = {
1655+ "x" : num_sampling_steps_x
1656+ }
1657+
1658+ sampler_gibbs = cuqi .experimental .mcmc .HybridGibbs (target ,
1659+ sampling_strategy ,
1660+ num_sampling_steps )
1661+ sampler_gibbs .warmup (nb , tune_freq = tune_freq )
1662+ sampler_gibbs .sample (ns )
1663+ samples_gibbs = sampler_gibbs .get_samples ()["x" ].samples
1664+
1665+ # Sample with NUTS alone
1666+ np .random .seed (0 )
1667+ sampler_nuts = cuqi .experimental .mcmc .NUTS (target ,
1668+ max_depth = 4 ,
1669+ step_size = step_size )
1670+ # Warm up (when num_sampling_steps_x>0, we do not using built-in warmup
1671+ # in order to control number of steps between tuning steps to
1672+ # match Gibbs sampling behavior)
1673+ if num_sampling_steps_x == 1 :
1674+ sampler_nuts .warmup (nb , tune_freq = tune_freq )
1675+ else :
1676+ tune_interval = max (int (tune_freq * nb ), 1 )
1677+ for count in range (nb ):
1678+ for _ in range (num_sampling_steps_x ):
1679+ sampler_nuts .sample (1 )
1680+ if (count + 1 ) % tune_interval == 0 :
1681+ sampler_nuts .tune (None , count // tune_interval )
1682+ # Sample
1683+ sampler_nuts .sample (ns * num_sampling_steps_x )
1684+ samples_nuts = sampler_nuts .get_samples ().samples
1685+ # skip every num_sampling_steps_x samples to match Gibbs samples
1686+ samples_nuts_skip = samples_nuts [:, num_sampling_steps_x - 1 ::num_sampling_steps_x ]
1687+
1688+ # assert warmup samples are correct:
1689+ assert np .allclose (
1690+ samples_gibbs [:, :nb ],
1691+ samples_nuts_skip [:, :nb ],
1692+ rtol = 1e-5 ,
1693+ )
1694+
1695+ # assert samples are correct:
1696+ assert np .allclose (
1697+ samples_gibbs [:, nb :],
1698+ samples_nuts_skip [:, nb :],
1699+ rtol = 1e-5 ,
1700+ )
0 commit comments