Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sbi/inference/posteriors/vector_field_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _sample_via_diffusion(
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
save_intermediate: bool = False,
**kwargs,
) -> Tensor:
r"""Return samples from posterior distribution $p(\theta|x)$.

Expand Down
19 changes: 10 additions & 9 deletions sbi/samplers/rejection/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import Tensor, as_tensor
from tqdm.auto import tqdm

from sbi.sbi_types import AcceptRejectFn, SampleProposal
from sbi.utils.sbiutils import gradient_ascent


Expand Down Expand Up @@ -214,8 +215,8 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor:

@torch.no_grad()
def accept_reject_sample(
proposal: Callable,
accept_reject_fn: Callable,
proposal: SampleProposal,
accept_reject_fn: AcceptRejectFn,
num_samples: int,
num_xos: int = 1,
show_progress_bars: bool = False,
Expand All @@ -242,12 +243,12 @@ def accept_reject_sample(
density during evaluation of the posterior.

Args:
proposal: A callable that takes `sample_shape` as arguments (and kwargs as
needed). Returns samples from the proposal distribution with shape
(*sample_shape, event_dim).
accept_reject_fn: Function that evaluates which samples are accepted or
rejected. Must take a batch of parameters and return a boolean tensor which
indicates which parameters get accepted.
proposal: A callable following the `SampleProposal` protocol, i.e., takes
`sample_shape` as first argument (and kwargs as needed). Returns samples
from the proposal distribution with shape (*sample_shape, event_dim).
accept_reject_fn: A callable following the `AcceptRejectFn` protocol that
evaluates which samples are accepted or rejected. Takes a batch of
parameters and returns a boolean tensor indicating which are accepted.
num_samples: Desired number of samples.
num_xos: Number of conditions for batched_sampling (currently only accepting
one batch dimension for the condition).
Expand Down Expand Up @@ -329,7 +330,7 @@ def accept_reject_sample(

# Sample and reject.
candidates = proposal(
(sampling_batch_size,), # type: ignore
torch.Size((sampling_batch_size,)),
**proposal_sampling_kwargs,
)
# SNPE-style rejection-sampling when the proposal is the neural net.
Expand Down
25 changes: 24 additions & 1 deletion sbi/sbi_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Optional, Sequence, Tuple, TypeVar, Union
from typing import Optional, Protocol, Sequence, Tuple, TypeVar, Union

import numpy as np
import torch
Expand Down Expand Up @@ -35,10 +35,33 @@
PyroTransformedDistribution: TypeAlias = TransformedDistribution
TorchTensor: TypeAlias = Tensor


class SampleProposal(Protocol):
"""Protocol for sample proposal callables used in rejection sampling.

Any callable that takes a sample shape and optional keyword arguments
and returns a Tensor of samples satisfies this protocol.
"""

def __call__(self, sample_shape: torch.Size, **kwargs) -> Tensor: ...


class AcceptRejectFn(Protocol):
"""Protocol for accept/reject functions used in rejection sampling.

Any callable that takes a batch of parameters (theta) and returns a boolean
Tensor indicating which samples are accepted satisfies this protocol.
"""

def __call__(self, theta: Tensor) -> Tensor: ...


__all__ = [
"AcceptRejectFn",
"Array",
"Shape",
"OneOrMore",
"SampleProposal",
"ScalarFloat",
"TensorBoardSummaryWriter",
"TorchTransform",
Expand Down
4 changes: 3 additions & 1 deletion sbi/utils/restriction_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,9 @@ def sample(

if sample_with == "rejection":
samples, acceptance_rate = rejection.accept_reject_sample(
proposal=self._prior.sample,
proposal=lambda sample_shape, **kwargs: self._prior.sample(
sample_shape
),
accept_reject_fn=self._accept_reject_fn,
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down
Loading