Skip to content

Commit 9c59a89

Browse files
#1405 add util to generate mcmc samples from user defined potential (#1483)
* add util to generate mcmc samples directly from potential * refactor to leverage MCMCPosterior class and return sampler * move from utils to posteriors and return full posterior * add tests for mcmc posterior from user defined potential * fix formatting
1 parent 570ab18 commit 9c59a89

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3+
import inspect
34
import warnings
45
from copy import deepcopy
56
from functools import partial
@@ -32,6 +33,7 @@
3233
sir_init,
3334
)
3435
from sbi.sbi_types import Shape, TorchTransform
36+
from sbi.utils import mcmc_transform
3537
from sbi.utils.potentialutils import pyro_potential_wrapper, transformed_potential
3638
from sbi.utils.torchutils import ensure_theta_batched, tensor2numpy
3739

@@ -1110,3 +1112,80 @@ def _maybe_use_dict_entry(default: Any, key: str, dict_to_check: Dict) -> Any:
11101112
"""
11111113
attribute = dict_to_check.get(key, default)
11121114
return attribute
1115+
1116+
1117+
def _num_required_args(func):
1118+
"""
1119+
Utility for counting the number of positional args in a function.
1120+
1121+
This function counts each parameter in the signature that are positional -- ie.
1122+
(1) cannot only be passed in as keyword arguments
1123+
(2) do not have a default value
1124+
1125+
Args:
1126+
func: A callable function.
1127+
1128+
Returns:
1129+
Number of required positional arguments.
1130+
"""
1131+
sig = inspect.signature(func)
1132+
return sum(
1133+
1
1134+
for param in sig.parameters.values()
1135+
if param.kind
1136+
in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
1137+
and param.default is inspect._empty
1138+
)
1139+
1140+
1141+
def build_from_potential(
1142+
potential_fn: Callable, prior: Any, x: Optional[Tensor] = None, **kwargs
1143+
) -> MCMCPosterior:
1144+
"""
1145+
Returns a sampler from a MCMCPosterior object, given user-defined potential
1146+
function and prior.
1147+
1148+
The user-defined potential can be conditional (accepts theta and x as positional
1149+
arguments) or unconditional (accepting only theta).
1150+
1151+
Args:
1152+
potential_fn: User defined potential function. Must be of type Callable.
1153+
prior: Prior distribution for parameter transformation and initialization.
1154+
x: Conditional x value. Provided if using a conditional potential function.
1155+
1156+
Returns:
1157+
Callable sampling function from MCMCPosterior object.
1158+
"""
1159+
# build transformation to unrestricted space for sampling
1160+
transform = mcmc_transform(prior)
1161+
1162+
# potential_fn must take 1 or 2 required arguments: (theta) or (theta, x)
1163+
num_args = _num_required_args(potential_fn)
1164+
assert num_args > 0 and num_args < 3, (
1165+
"potential_fn must take 1-2 required arguments"
1166+
)
1167+
is_conditional = num_args == 2
1168+
1169+
if is_conditional:
1170+
# you could remove this and require use to set x before calling sample
1171+
assert x is not None, "x must be provided if potential_fn is conditional"
1172+
posterior = MCMCPosterior(potential_fn, prior, theta_transform=transform)
1173+
posterior.set_default_x(x)
1174+
1175+
else:
1176+
warn(
1177+
"x has not been provided. Using unconditional potential function.",
1178+
UserWarning,
1179+
stacklevel=2,
1180+
)
1181+
1182+
# define an unconditional potential function (ignores x)
1183+
def unconditional_potential_fn(theta, x):
1184+
return potential_fn(theta)
1185+
1186+
posterior = MCMCPosterior(
1187+
unconditional_potential_fn, prior, theta_transform=transform, **kwargs
1188+
)
1189+
posterior.set_default_x(torch.zeros(1)) # set default_x to dummy value
1190+
1191+
return posterior

tests/mcmc_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
MCMCPosterior,
1818
likelihood_estimator_based_potential,
1919
)
20+
from sbi.inference.posteriors.mcmc_posterior import build_from_potential
2021
from sbi.neural_nets import likelihood_nn
2122
from sbi.samplers.mcmc.pymc_wrapper import PyMCSampler
2223
from sbi.samplers.mcmc.slice_numpy import (
@@ -28,6 +29,7 @@
2829
diagonal_linear_gaussian,
2930
true_posterior_linear_gaussian_mvn_prior,
3031
)
32+
from sbi.utils import BoxUniform
3133
from sbi.utils.user_input_checks import process_prior
3234
from tests.test_utils import check_c2st
3335

@@ -252,3 +254,93 @@ def test_getting_inference_diagnostics(method, mcmc_params_fast: dict):
252254
f"MCMC samples for method {method} have incorrect shape (n_samples, n_dims). "
253255
f"Expected {(num_samples, num_dim)}, got {samples.shape}"
254256
)
257+
258+
259+
@pytest.mark.mcmc
260+
def test_direct_mcmc_unconditional():
261+
"Test MCMCPosterior from user defined potential (unconditional)"
262+
num_samples = 100
263+
theta_dim = 2
264+
265+
prior = BoxUniform(low=-2 * torch.ones(theta_dim), high=2 * torch.ones(theta_dim))
266+
267+
def potential_fn(theta: np.ndarray) -> np.ndarray:
268+
# Example: a 2D Gaussian with mean=[0,0], identity covariance
269+
return -0.5 * (theta**2).sum(axis=-1)
270+
271+
mcmc_posterior = build_from_potential(potential_fn, prior)
272+
273+
# test sampling
274+
samples = mcmc_posterior.sample(
275+
(num_samples,), num_chains=10, warmup_steps=50, thin=10
276+
)
277+
278+
assert samples.shape == (num_samples, theta_dim), (
279+
f"MCMC samples have incorrect shape (n_samples, n_dims). "
280+
f"Expected {(num_samples, theta_dim)}, got {samples.shape}"
281+
)
282+
283+
# test potential evaluation
284+
dist = torch.distributions.MultivariateNormal(
285+
torch.zeros(theta_dim), torch.eye(theta_dim)
286+
)
287+
samples = dist.sample((num_samples,))
288+
log_p = mcmc_posterior.potential(samples)
289+
290+
assert log_p.shape == (num_samples,), (
291+
f"Potential evals have incorrect shape. "
292+
f"Expected ({num_samples}), got {log_p.shape}"
293+
)
294+
295+
296+
@pytest.mark.mcmc
297+
def test_direct_mcmc_conditional():
298+
"Test MCMCPosterior from user defined potential (conditional)"
299+
theta_dim = 2
300+
num_samples = 100
301+
num_batches = 5
302+
num_samples_batch = num_samples // num_batches
303+
304+
prior = BoxUniform(low=-2 * torch.ones(theta_dim), high=2 * torch.ones(theta_dim))
305+
306+
def potential_fn(theta: np.ndarray, x: np.ndarray) -> np.ndarray:
307+
# Example: a 2D Gaussian with mean=[0,0], variance conditioned on x
308+
return -x * (theta**2).sum(axis=-1)
309+
310+
# test sampling
311+
x = torch.tensor([0.5])
312+
mcmc_posterior = build_from_potential(potential_fn, prior, x=x)
313+
samples = mcmc_posterior.sample(
314+
(num_samples,), num_chains=10, warmup_steps=50, thin=10
315+
)
316+
317+
assert samples.shape == (num_samples, theta_dim), (
318+
f"MCMC samples have incorrect shape (n_chains, n_samples, n_dims). "
319+
f"Expected {(num_samples, theta_dim)}, got {samples.shape}"
320+
)
321+
322+
# test batched sampling
323+
x_batch = torch.linspace(0.1, 0.9, num_batches).unsqueeze(1)
324+
samples_batched = mcmc_posterior.sample_batched(
325+
(num_samples_batch,), x=x_batch, num_chains=10, warmup_steps=50, thin=10
326+
)
327+
assert samples_batched.shape == (num_samples_batch, num_batches, theta_dim), (
328+
f"MCMC samples have incorrect shape (n_samples, n_batches, n_dims). "
329+
f"Expected {(num_samples, num_batches, theta_dim)}, got {samples.shape}"
330+
)
331+
332+
# test potential evaluation
333+
dist = torch.distributions.MultivariateNormal(
334+
torch.zeros(theta_dim), torch.eye(theta_dim)
335+
)
336+
theta_samples = dist.sample((num_samples,))
337+
x_samples = torch.rand((num_samples,))
338+
log_p = mcmc_posterior.potential(theta_samples, x_samples)
339+
340+
assert log_p.shape == (
341+
1,
342+
num_samples,
343+
), (
344+
f"Potential evals have incorrect shape. "
345+
f"Expected (1, {num_samples}), got {log_p.shape}"
346+
)

0 commit comments

Comments
 (0)