|
17 | 17 | MCMCPosterior, |
18 | 18 | likelihood_estimator_based_potential, |
19 | 19 | ) |
| 20 | +from sbi.inference.posteriors.mcmc_posterior import build_from_potential |
20 | 21 | from sbi.neural_nets import likelihood_nn |
21 | 22 | from sbi.samplers.mcmc.pymc_wrapper import PyMCSampler |
22 | 23 | from sbi.samplers.mcmc.slice_numpy import ( |
|
28 | 29 | diagonal_linear_gaussian, |
29 | 30 | true_posterior_linear_gaussian_mvn_prior, |
30 | 31 | ) |
| 32 | +from sbi.utils import BoxUniform |
31 | 33 | from sbi.utils.user_input_checks import process_prior |
32 | 34 | from tests.test_utils import check_c2st |
33 | 35 |
|
@@ -252,3 +254,93 @@ def test_getting_inference_diagnostics(method, mcmc_params_fast: dict): |
252 | 254 | f"MCMC samples for method {method} have incorrect shape (n_samples, n_dims). " |
253 | 255 | f"Expected {(num_samples, num_dim)}, got {samples.shape}" |
254 | 256 | ) |
| 257 | + |
| 258 | + |
| 259 | +@pytest.mark.mcmc |
| 260 | +def test_direct_mcmc_unconditional(): |
| 261 | + "Test MCMCPosterior from user defined potential (unconditional)" |
| 262 | + num_samples = 100 |
| 263 | + theta_dim = 2 |
| 264 | + |
| 265 | + prior = BoxUniform(low=-2 * torch.ones(theta_dim), high=2 * torch.ones(theta_dim)) |
| 266 | + |
| 267 | + def potential_fn(theta: np.ndarray) -> np.ndarray: |
| 268 | + # Example: a 2D Gaussian with mean=[0,0], identity covariance |
| 269 | + return -0.5 * (theta**2).sum(axis=-1) |
| 270 | + |
| 271 | + mcmc_posterior = build_from_potential(potential_fn, prior) |
| 272 | + |
| 273 | + # test sampling |
| 274 | + samples = mcmc_posterior.sample( |
| 275 | + (num_samples,), num_chains=10, warmup_steps=50, thin=10 |
| 276 | + ) |
| 277 | + |
| 278 | + assert samples.shape == (num_samples, theta_dim), ( |
| 279 | + f"MCMC samples have incorrect shape (n_samples, n_dims). " |
| 280 | + f"Expected {(num_samples, theta_dim)}, got {samples.shape}" |
| 281 | + ) |
| 282 | + |
| 283 | + # test potential evaluation |
| 284 | + dist = torch.distributions.MultivariateNormal( |
| 285 | + torch.zeros(theta_dim), torch.eye(theta_dim) |
| 286 | + ) |
| 287 | + samples = dist.sample((num_samples,)) |
| 288 | + log_p = mcmc_posterior.potential(samples) |
| 289 | + |
| 290 | + assert log_p.shape == (num_samples,), ( |
| 291 | + f"Potential evals have incorrect shape. " |
| 292 | + f"Expected ({num_samples}), got {log_p.shape}" |
| 293 | + ) |
| 294 | + |
| 295 | + |
| 296 | +@pytest.mark.mcmc |
| 297 | +def test_direct_mcmc_conditional(): |
| 298 | + "Test MCMCPosterior from user defined potential (conditional)" |
| 299 | + theta_dim = 2 |
| 300 | + num_samples = 100 |
| 301 | + num_batches = 5 |
| 302 | + num_samples_batch = num_samples // num_batches |
| 303 | + |
| 304 | + prior = BoxUniform(low=-2 * torch.ones(theta_dim), high=2 * torch.ones(theta_dim)) |
| 305 | + |
| 306 | + def potential_fn(theta: np.ndarray, x: np.ndarray) -> np.ndarray: |
| 307 | + # Example: a 2D Gaussian with mean=[0,0], variance conditioned on x |
| 308 | + return -x * (theta**2).sum(axis=-1) |
| 309 | + |
| 310 | + # test sampling |
| 311 | + x = torch.tensor([0.5]) |
| 312 | + mcmc_posterior = build_from_potential(potential_fn, prior, x=x) |
| 313 | + samples = mcmc_posterior.sample( |
| 314 | + (num_samples,), num_chains=10, warmup_steps=50, thin=10 |
| 315 | + ) |
| 316 | + |
| 317 | + assert samples.shape == (num_samples, theta_dim), ( |
| 318 | + f"MCMC samples have incorrect shape (n_chains, n_samples, n_dims). " |
| 319 | + f"Expected {(num_samples, theta_dim)}, got {samples.shape}" |
| 320 | + ) |
| 321 | + |
| 322 | + # test batched sampling |
| 323 | + x_batch = torch.linspace(0.1, 0.9, num_batches).unsqueeze(1) |
| 324 | + samples_batched = mcmc_posterior.sample_batched( |
| 325 | + (num_samples_batch,), x=x_batch, num_chains=10, warmup_steps=50, thin=10 |
| 326 | + ) |
| 327 | + assert samples_batched.shape == (num_samples_batch, num_batches, theta_dim), ( |
| 328 | + f"MCMC samples have incorrect shape (n_samples, n_batches, n_dims). " |
| 329 | + f"Expected {(num_samples, num_batches, theta_dim)}, got {samples.shape}" |
| 330 | + ) |
| 331 | + |
| 332 | + # test potential evaluation |
| 333 | + dist = torch.distributions.MultivariateNormal( |
| 334 | + torch.zeros(theta_dim), torch.eye(theta_dim) |
| 335 | + ) |
| 336 | + theta_samples = dist.sample((num_samples,)) |
| 337 | + x_samples = torch.rand((num_samples,)) |
| 338 | + log_p = mcmc_posterior.potential(theta_samples, x_samples) |
| 339 | + |
| 340 | + assert log_p.shape == ( |
| 341 | + 1, |
| 342 | + num_samples, |
| 343 | + ), ( |
| 344 | + f"Potential evals have incorrect shape. " |
| 345 | + f"Expected (1, {num_samples}), got {log_p.shape}" |
| 346 | + ) |
0 commit comments