Skip to content

Commit d9608a0

Browse files
committed
Add timeout handling and reject_outside_prior option to rejection sampling
1 parent d735f97 commit d9608a0

File tree

3 files changed

+77
-22
lines changed

3 files changed

+77
-22
lines changed

sbi/inference/posteriors/rejection_posterior.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def sample(
137137
m: Optional[float] = None,
138138
sample_with: Optional[str] = None,
139139
show_progress_bars: bool = True,
140+
reject_outside_prior: bool = True,
141+
max_sampling_time: Optional[float] = None,
140142
):
141143
r"""Return samples from posterior $p(\theta|x)$ via rejection sampling.
142144
@@ -147,7 +149,13 @@ def sample(
147149
sample_with: This argument only exists to keep backward-compatibility with
148150
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
149151
show_progress_bars: Whether to show sampling progress monitor.
150-
152+
reject_outside_prior:
153+
If True (default), rejection sampling is used to ensure samples lie
154+
within the prior support. If False, samples are drawn directly from
155+
the proposal without rejection sampling.
156+
max_sampling_time:
157+
Optional maximum allowed sampling time in seconds. If exceeded,
158+
sampling is aborted and a RuntimeError is raised.
151159
Returns:
152160
Samples from posterior.
153161
"""
@@ -180,18 +188,22 @@ def sample(
180188
)
181189
m = self.m if m is None else m
182190

183-
samples, _ = rejection_sample(
184-
potential,
185-
proposal=self.proposal,
186-
num_samples=num_samples,
187-
show_progress_bars=show_progress_bars,
188-
warn_acceptance=0.01,
189-
max_sampling_batch_size=max_sampling_batch_size,
190-
num_samples_to_find_max=num_samples_to_find_max,
191-
num_iter_to_find_max=num_iter_to_find_max,
192-
m=m,
193-
device=self._device,
194-
)
191+
if reject_outside_prior:
192+
samples, _ = rejection_sample(
193+
potential,
194+
proposal=self.proposal,
195+
num_samples=num_samples,
196+
show_progress_bars=show_progress_bars,
197+
warn_acceptance=0.01,
198+
max_sampling_batch_size=max_sampling_batch_size,
199+
num_samples_to_find_max=num_samples_to_find_max,
200+
num_iter_to_find_max=num_iter_to_find_max,
201+
m=m,
202+
max_sampling_time=max_sampling_time,
203+
device=self._device,
204+
)
205+
else:
206+
samples = self.proposal.sample((num_samples,))
195207

196208
return samples.reshape((*sample_shape, -1))
197209

sbi/samplers/rejection/rejection.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,12 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor:
188188
logging.warning(
189189
f"""Only {acceptance_rate:.3%} proposal samples were accepted. It
190190
may take a long time to collect the remaining {num_remaining}
191-
samples. Consider interrupting (Ctrl-C) and switching to a
192-
different sampling method with
191+
samples. You can prevent long runtimes by
192+
setting `max_sampling_time` to limit runtime, or disabling
193+
rejection sampling (e.g. via `reject_outside_prior=False` in
194+
`posterior.sample()` when available).
195+
Alternatively, consider interrupting (Ctrl-C) and switching
196+
to a different sampling method with
193197
`build_posterior(..., sample_with='mcmc')`. or
194198
`build_posterior(..., sample_with='vi')`."""
195199
)
@@ -387,10 +391,13 @@ def accept_reject_sample(
387391
else:
388392
warn_msg = f"""Only {min_acceptance_rate:.3%} proposal samples are
389393
accepted. It may take a long time to collect the remaining
390-
{num_remaining} samples. """
394+
{num_remaining} samples. You can prevent very long runtimes by
395+
setting `max_sampling_time` to limit runtime, or disabling
396+
rejection sampling (e.g. via `reject_outside_prior=False` in
397+
`posterior.sample()` when available)."""
391398
if alternative_method is not None:
392-
warn_msg += f"""Consider interrupting (Ctrl-C) and switching to
393-
`{alternative_method}`."""
399+
warn_msg += f"""Alternatively, consider interrupting (Ctrl-C)
400+
and switching to `{alternative_method}`."""
394401
logging.warning(warn_msg)
395402

396403
leakage_warning_raised = True # Ensure warning is raised just once.

tests/rejection_timeout_test.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import torch
66

7+
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
78
from sbi.samplers.rejection import accept_reject_sample, rejection_sample
89

910

@@ -28,12 +29,12 @@ def always_reject_fn(x):
2829
def test_accept_reject_sample_timeout():
2930
proposal = DummyProposal()
3031

31-
with pytest.raises(RuntimeError):
32+
with pytest.raises(RuntimeError, match="rejection sampling exceeded"):
3233
accept_reject_sample(
3334
proposal=proposal,
3435
accept_reject_fn=always_reject_fn,
3536
num_samples=5,
36-
max_sampling_time=0.2,
37+
max_sampling_time=0.01,
3738
)
3839

3940

@@ -44,11 +45,46 @@ def test_rejection_sample_timeout():
4445
def dummy_potential_fn(x):
4546
return torch.full((x.shape[0],), -1e6)
4647

47-
with pytest.raises(RuntimeError):
48+
with pytest.raises(RuntimeError, match="rejection sampling exceeded"):
4849
rejection_sample(
4950
potential_fn=dummy_potential_fn,
5051
proposal=proposal,
5152
num_samples=5,
52-
max_sampling_time=0.2,
53+
max_sampling_time=0.01,
5354
m=1e12,
5455
)
56+
57+
58+
@pytest.mark.slow
59+
def test_rejection_posterior_timeout():
60+
prior = torch.distributions.MultivariateNormal(torch.zeros(1), torch.eye(1))
61+
62+
class DummyPotential:
63+
"""
64+
Minimal compliant potential implementing CustomPotential interface.
65+
Forces rejection by returning very low log-probability.
66+
"""
67+
68+
device = "cpu"
69+
70+
def __call__(
71+
self, theta: torch.Tensor, x_o: torch.Tensor = None
72+
) -> torch.Tensor:
73+
return torch.full((theta.shape[0],), -1e6)
74+
75+
def set_x(self, x):
76+
pass
77+
78+
def to(self, device):
79+
self.device = device
80+
return self
81+
82+
posterior = RejectionPosterior(
83+
potential_fn=DummyPotential(),
84+
proposal=prior,
85+
)
86+
87+
posterior.set_default_x(torch.zeros(1))
88+
89+
with pytest.raises(RuntimeError, match="max_sampling_time"):
90+
posterior.sample((5,), max_sampling_time=0.01)

0 commit comments

Comments
 (0)