diff --git a/sbi/inference/posteriors/vector_field_posterior.py b/sbi/inference/posteriors/vector_field_posterior.py index 8d81fc2b8..cac80df73 100644 --- a/sbi/inference/posteriors/vector_field_posterior.py +++ b/sbi/inference/posteriors/vector_field_posterior.py @@ -291,6 +291,7 @@ def _sample_via_diffusion( max_sampling_batch_size: int = 10_000, show_progress_bars: bool = True, save_intermediate: bool = False, + **kwargs, ) -> Tensor: r"""Return samples from posterior distribution $p(\theta|x)$. diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 892334050..13ebcd929 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -11,6 +11,7 @@ from torch import Tensor, as_tensor from tqdm.auto import tqdm +from sbi.sbi_types import AcceptRejectFn, SampleProposal from sbi.utils.sbiutils import gradient_ascent @@ -214,8 +215,8 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: @torch.no_grad() def accept_reject_sample( - proposal: Callable, - accept_reject_fn: Callable, + proposal: SampleProposal, + accept_reject_fn: AcceptRejectFn, num_samples: int, num_xos: int = 1, show_progress_bars: bool = False, @@ -242,12 +243,12 @@ def accept_reject_sample( density during evaluation of the posterior. Args: - proposal: A callable that takes `sample_shape` as arguments (and kwargs as - needed). Returns samples from the proposal distribution with shape - (*sample_shape, event_dim). - accept_reject_fn: Function that evaluates which samples are accepted or - rejected. Must take a batch of parameters and return a boolean tensor which - indicates which parameters get accepted. + proposal: A callable following the `SampleProposal` protocol, i.e., takes + `sample_shape` as first argument (and kwargs as needed). Returns samples + from the proposal distribution with shape (*sample_shape, event_dim). + accept_reject_fn: A callable following the `AcceptRejectFn` protocol that + evaluates which samples are accepted or rejected. Takes a batch of + parameters and returns a boolean tensor indicating which are accepted. num_samples: Desired number of samples. num_xos: Number of conditions for batched_sampling (currently only accepting one batch dimension for the condition). @@ -329,7 +330,7 @@ def accept_reject_sample( # Sample and reject. candidates = proposal( - (sampling_batch_size,), # type: ignore + torch.Size((sampling_batch_size,)), **proposal_sampling_kwargs, ) # SNPE-style rejection-sampling when the proposal is the neural net. diff --git a/sbi/sbi_types.py b/sbi/sbi_types.py index 425a900ad..1d62e8eb4 100644 --- a/sbi/sbi_types.py +++ b/sbi/sbi_types.py @@ -1,7 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from typing import Optional, Sequence, Tuple, TypeVar, Union +from typing import Optional, Protocol, Sequence, Tuple, TypeVar, Union import numpy as np import torch @@ -35,10 +35,33 @@ PyroTransformedDistribution: TypeAlias = TransformedDistribution TorchTensor: TypeAlias = Tensor + +class SampleProposal(Protocol): + """Protocol for sample proposal callables used in rejection sampling. + + Any callable that takes a sample shape and optional keyword arguments + and returns a Tensor of samples satisfies this protocol. + """ + + def __call__(self, sample_shape: torch.Size, **kwargs) -> Tensor: ... + + +class AcceptRejectFn(Protocol): + """Protocol for accept/reject functions used in rejection sampling. + + Any callable that takes a batch of parameters (theta) and returns a boolean + Tensor indicating which samples are accepted satisfies this protocol. + """ + + def __call__(self, theta: Tensor) -> Tensor: ... + + __all__ = [ + "AcceptRejectFn", "Array", "Shape", "OneOrMore", + "SampleProposal", "ScalarFloat", "TensorBoardSummaryWriter", "TorchTransform", diff --git a/sbi/utils/restriction_estimator.py b/sbi/utils/restriction_estimator.py index ffbaf9665..1a0de4182 100644 --- a/sbi/utils/restriction_estimator.py +++ b/sbi/utils/restriction_estimator.py @@ -688,7 +688,9 @@ def sample( if sample_with == "rejection": samples, acceptance_rate = rejection.accept_reject_sample( - proposal=self._prior.sample, + proposal=lambda sample_shape, **kwargs: self._prior.sample( + sample_shape + ), accept_reject_fn=self._accept_reject_fn, num_samples=num_samples, show_progress_bars=show_progress_bars,