@@ -1255,6 +1255,142 @@ def test_UGLA_with_AffineModel_is_equivalent_to_LinearModel_and_shifted_data():
12551255 # Check that the samples are the same
12561256 assert np .allclose (samples_linear .samples , samples_affine .samples )
12571257
1258+ # ============ Test for sampling with RandomVariable prior against Distribution prior ============
1259+ samplers_for_rv_against_dist = [cuqi .experimental .mcmc .MALA ,
1260+ cuqi .experimental .mcmc .ULA ,
1261+ cuqi .experimental .mcmc .MH ,
1262+ cuqi .experimental .mcmc .PCN ,
1263+ cuqi .experimental .mcmc .CWMH ,
1264+ cuqi .experimental .mcmc .NUTS ,
1265+ cuqi .experimental .mcmc .LinearRTO ]
1266+
1267+ @pytest .mark .parametrize ("sampler" , samplers_for_rv_against_dist )
1268+ def test_RandomVariable_prior_against_Distribution_prior (sampler : cuqi .experimental .mcmc .Sampler ):
1269+ """ Test RandomVariable prior is equivalent to Distribution prior for
1270+ MALA, ULA, MH, PCN, CWMH, NUTS and LinearRTO.
1271+ """
1272+
1273+ # Set dim
1274+ dim = 32
1275+
1276+ # Extract model and data
1277+ A , y_data , info = cuqi .testproblem .Deconvolution1D (dim = 32 , phantom = 'square' ).get_components ()
1278+
1279+ # Set up RandomVariable prior and do posterior sampling
1280+ np .random .seed (0 )
1281+ x_rv = cuqi .distribution .Gaussian (0.5 * np .ones (dim ), 0.1 ).rv
1282+ y_rv = cuqi .distribution .Gaussian (A @x_rv , 0.001 ).rv
1283+ joint_rv = cuqi .distribution .JointDistribution (x_rv , y_rv )(y_rv = y_data )
1284+ sampler_rv = sampler (joint_rv )
1285+ sampler_rv .sample (10 )
1286+ samples_rv = sampler_rv .get_samples ()
1287+
1288+ # Set up Distribution prior and do posterior sampling
1289+ np .random .seed (0 )
1290+ x_dist = cuqi .distribution .Gaussian (0.5 * np .ones (dim ), 0.1 )
1291+ y_dist = cuqi .distribution .Gaussian (A @x_dist , 0.001 )
1292+ joint_dist = cuqi .distribution .JointDistribution (x_dist , y_dist )(y_dist = y_data )
1293+ sampler_dist = sampler (joint_dist )
1294+ sampler_dist .sample (10 )
1295+ samples_dist = sampler_dist .get_samples ()
1296+
1297+ assert np .allclose (samples_rv .samples , samples_dist .samples )
1298+
1299+ def test_RandomVariable_prior_against_Distribution_prior_regularized_RTO ():
1300+ """ Test RandomVariable prior is equivalent to Distribution prior for
1301+ RegularizedLinearRTO.
1302+ """
1303+
1304+ # Set dim
1305+ dim = 32
1306+
1307+ # Extract model and data
1308+ A , y_data , info = cuqi .testproblem .Deconvolution1D (dim = 32 , phantom = 'square' ).get_components ()
1309+
1310+ # Set up RandomVariable prior and do posterior sampling
1311+ np .random .seed (0 )
1312+ x_rv = cuqi .implicitprior .RegularizedGaussian (0.5 * np .ones (dim ), 0.1 , constraint = "nonnegativity" ).rv
1313+ y_rv = cuqi .distribution .Gaussian (A @x_rv , 0.001 ).rv
1314+ joint_rv = cuqi .distribution .JointDistribution (x_rv , y_rv )(y_rv = y_data )
1315+ sampler_rv = cuqi .experimental .mcmc .RegularizedLinearRTO (joint_rv )
1316+ sampler_rv .sample (10 )
1317+ samples_rv = sampler_rv .get_samples ()
1318+
1319+ # Set up Distribution prior and do posterior sampling
1320+ np .random .seed (0 )
1321+ x_dist = cuqi .implicitprior .RegularizedGaussian (0.5 * np .ones (dim ), 0.1 , constraint = "nonnegativity" )
1322+ y_dist = cuqi .distribution .Gaussian (A @x_dist , 0.001 )
1323+ joint_dist = cuqi .distribution .JointDistribution (x_dist , y_dist )(y_dist = y_data )
1324+ sampler_dist = cuqi .experimental .mcmc .RegularizedLinearRTO (joint_dist )
1325+ sampler_dist .sample (10 )
1326+ samples_dist = sampler_dist .get_samples ()
1327+
1328+ assert np .allclose (samples_rv .samples , samples_dist .samples )
1329+
1330+ def test_RandomVariable_prior_against_Distribution_prior_UGLA_Conjugate_ConjugateApprox_HybridGibbs ():
1331+ """ Test RandomVariable prior is equivalent to Distribution prior for
1332+ UGLA, Conjugate, ConjugateApprox and HybridGibbs samplers.
1333+ """
1334+
1335+ # Forward problem
1336+ A , y_data , info = cuqi .testproblem .Deconvolution1D (dim = 28 , phantom = 'square' , noise_std = 0.001 ).get_components ()
1337+
1338+ # Random seed
1339+ np .random .seed (0 )
1340+
1341+ # Bayesian Inverse Problem
1342+ d = cuqi .distribution .Gamma (1 , 1e-4 )
1343+ s = cuqi .distribution .Gamma (1 , 1e-4 )
1344+ x = cuqi .distribution .LMRF (0 , lambda d : 1 / d , geometry = A .domain_geometry )
1345+ y = cuqi .distribution .Gaussian (A @x , lambda s : 1 / s )
1346+
1347+ # Posterior
1348+ target = cuqi .distribution .JointDistribution (y , x , s , d )(y = y_data )
1349+
1350+ # Sampling strategy
1351+ sampling_strategy = {
1352+ "x" : cuqi .experimental .mcmc .UGLA (),
1353+ "s" : cuqi .experimental .mcmc .Conjugate (),
1354+ "d" : cuqi .experimental .mcmc .ConjugateApprox ()
1355+ }
1356+
1357+ # Gibbs sampler
1358+ sampler = cuqi .experimental .mcmc .HybridGibbs (target , sampling_strategy )
1359+
1360+ # Run sampler
1361+ sampler .warmup (50 )
1362+ sampler .sample (200 )
1363+ samples = sampler .get_samples ()
1364+
1365+ # Random seed
1366+ np .random .seed (0 )
1367+
1368+ # Bayesian Inverse Problem
1369+ d_rv = cuqi .distribution .Gamma (1 , 1e-4 ).rv
1370+ s_rv = cuqi .distribution .Gamma (1 , 1e-4 ).rv
1371+ x_rv = cuqi .distribution .LMRF (0 , lambda d_rv : 1 / d_rv , geometry = A .domain_geometry ).rv
1372+ y_rv = cuqi .distribution .Gaussian (A @x_rv , lambda s_rv : 1 / s_rv ).rv
1373+
1374+ # Posterior
1375+ target_rv = cuqi .distribution .JointDistribution (y_rv , x_rv , s_rv , d_rv )(y_rv = y_data )
1376+
1377+ # Sampling strategy
1378+ sampling_strategy_rv = {
1379+ "x_rv" : cuqi .experimental .mcmc .UGLA (),
1380+ "s_rv" : cuqi .experimental .mcmc .Conjugate (),
1381+ "d_rv" : cuqi .experimental .mcmc .ConjugateApprox ()
1382+ }
1383+
1384+ # Gibbs sampler
1385+ sampler_rv = cuqi .experimental .mcmc .HybridGibbs (target_rv , sampling_strategy_rv )
1386+
1387+ # Run sampler
1388+ sampler_rv .warmup (50 )
1389+ sampler_rv .sample (200 )
1390+ samples_rv = sampler_rv .get_samples ()
1391+
1392+ assert np .allclose (samples ['x' ].samples , samples_rv ['x_rv' ].samples )
1393+
12581394def Conjugate_GaussianGammaPair ():
12591395 """ Unit test whether Conjugacy Pair (Gaussian, Gamma) constructs the right distribution """
12601396 x = cuqi .distribution .Gamma (1.0 , 2.0 )
@@ -1373,4 +1509,4 @@ def test_RegularizedLinearRTO_ScipyLinearLSQ_option_invalid():
13731509 posterior = joint (y = y_data )
13741510
13751511 with pytest .raises (ValueError , match = "ScipyLinearLSQ" ):
1376- sampler = cuqi .experimental .mcmc .RegularizedLinearRTO (posterior , solver = "ScipyLinearLSQ" )
1512+ sampler = cuqi .experimental .mcmc .RegularizedLinearRTO (posterior , solver = "ScipyLinearLSQ" )
0 commit comments