Skip to content

Commit fa36b0f

Browse files
satwikspsjanfb
andauthored
Add max_sampling_time support to rejection samplers and corresponding tests (#1705)
* Add max_sampling_time timeout support to rejection samplers * Add timeout tests for rejection and accept_reject samplers * Add Licence header to new test file * refine timeout RuntimeError messages per reviewer suggestion * Add timeout handling and reject_outside_prior option to rejection sampling * add reject_outside_prior_support logic and update rejection sampling tests * warn if sampling without rejection * add tests for warnings * improve docstrings. --------- Co-authored-by: Jan <jan.boelts@mailbox.org>
1 parent 6e255ce commit fa36b0f

File tree

6 files changed

+420
-69
lines changed

6 files changed

+420
-69
lines changed

sbi/inference/posteriors/direct_posterior.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from sbi.samplers.rejection import rejection
2121
from sbi.sbi_types import Shape
22-
from sbi.utils.sbiutils import within_support
22+
from sbi.utils.sbiutils import warn_if_outside_prior_support, within_support
2323
from sbi.utils.torchutils import ensure_theta_batched
2424
from sbi.utils.user_input_checks import check_prior
2525

@@ -139,6 +139,8 @@ def sample(
139139
max_sampling_batch_size: int = 10_000,
140140
sample_with: Optional[str] = None,
141141
show_progress_bars: bool = True,
142+
reject_outside_prior: bool = True,
143+
max_sampling_time: Optional[float] = None,
142144
) -> Tensor:
143145
r"""Return samples from posterior distribution $p(\theta|x)$.
144146
@@ -149,6 +151,14 @@ def sample(
149151
sample_with: This argument only exists to keep backward-compatibility with
150152
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
151153
show_progress_bars: Whether to show sampling progress monitor.
154+
reject_outside_prior: If True (default), rejection sampling is used to
155+
ensure samples lie within the prior support. If False, samples are drawn
156+
directly from the neural density estimator without rejection, which is
157+
faster but may include samples outside the prior support.
158+
max_sampling_time: Optional maximum allowed sampling time in seconds.
159+
If exceeded, sampling is aborted and a RuntimeError is raised. Only
160+
applies when `reject_outside_prior=True` (no effect otherwise since
161+
direct sampling is fast).
152162
"""
153163
num_samples = torch.Size(sample_shape).numel()
154164
x = self._x_else_default_x(x)
@@ -177,15 +187,25 @@ def sample(
177187
f"`.build_posterior(sample_with={sample_with}).`"
178188
)
179189

180-
samples = rejection.accept_reject_sample(
181-
proposal=self.posterior_estimator.sample,
182-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
183-
num_samples=num_samples,
184-
show_progress_bars=show_progress_bars,
185-
max_sampling_batch_size=max_sampling_batch_size,
186-
proposal_sampling_kwargs={"condition": x},
187-
alternative_method="build_posterior(..., sample_with='mcmc')",
188-
)[0] # [0] to return only samples, not acceptance probabilities.
190+
if reject_outside_prior:
191+
# Normal rejection behavior.
192+
samples = rejection.accept_reject_sample(
193+
proposal=self.posterior_estimator.sample,
194+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
195+
num_samples=num_samples,
196+
show_progress_bars=show_progress_bars,
197+
max_sampling_batch_size=max_sampling_batch_size,
198+
proposal_sampling_kwargs={"condition": x},
199+
alternative_method="build_posterior(..., sample_with='mcmc')",
200+
max_sampling_time=max_sampling_time,
201+
)[0]
202+
else:
203+
# Bypass rejection sampling entirely.
204+
samples = self.posterior_estimator.sample(
205+
torch.Size([num_samples]),
206+
condition=x,
207+
)
208+
warn_if_outside_prior_support(self.prior, samples[:, 0])
189209

190210
return samples[:, 0] # Remove batch dimension.
191211

@@ -195,6 +215,8 @@ def sample_batched(
195215
x: Tensor,
196216
max_sampling_batch_size: int = 10_000,
197217
show_progress_bars: bool = True,
218+
reject_outside_prior: bool = True,
219+
max_sampling_time: Optional[float] = None,
198220
) -> Tensor:
199221
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
200222
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
@@ -207,6 +229,13 @@ def sample_batched(
207229
`batch_dim` corresponds to the number of observations to be drawn.
208230
max_sampling_batch_size: Maximum batch size for rejection sampling.
209231
show_progress_bars: Whether to show sampling progress monitor.
232+
reject_outside_prior: If True (default), rejection sampling is used to
233+
ensure samples lie within the prior support. If False, samples are drawn
234+
directly from the neural density estimator without rejection, which is
235+
faster but may include samples outside the prior support.
236+
max_sampling_time: Optional maximum allowed sampling time in seconds.
237+
If exceeded, sampling is aborted and a RuntimeError is raised. Only
238+
applies when `reject_outside_prior=True`.
210239
211240
Returns:
212241
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -242,15 +271,25 @@ def sample_batched(
242271
)
243272
max_sampling_batch_size = capped
244273

245-
samples = rejection.accept_reject_sample(
246-
proposal=self.posterior_estimator.sample,
247-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
248-
num_samples=num_samples,
249-
show_progress_bars=show_progress_bars,
250-
max_sampling_batch_size=max_sampling_batch_size,
251-
proposal_sampling_kwargs={"condition": x},
252-
alternative_method="build_posterior(..., sample_with='mcmc')",
253-
)[0]
274+
if reject_outside_prior:
275+
# Normal rejection behavior.
276+
samples = rejection.accept_reject_sample(
277+
proposal=self.posterior_estimator.sample,
278+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
279+
num_samples=num_samples,
280+
show_progress_bars=show_progress_bars,
281+
max_sampling_batch_size=max_sampling_batch_size,
282+
proposal_sampling_kwargs={"condition": x},
283+
alternative_method="build_posterior(..., sample_with='mcmc')",
284+
max_sampling_time=max_sampling_time,
285+
)[0]
286+
else:
287+
# Bypass rejection sampling entirely.
288+
samples = self.posterior_estimator.sample(
289+
torch.Size([num_samples]),
290+
condition=x,
291+
)
292+
warn_if_outside_prior_support(self.prior, samples)
254293

255294
return samples
256295

sbi/inference/posteriors/rejection_posterior.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def sample(
137137
m: Optional[float] = None,
138138
sample_with: Optional[str] = None,
139139
show_progress_bars: bool = True,
140+
reject_outside_prior: bool = True,
141+
max_sampling_time: Optional[float] = None,
140142
):
141143
r"""Return samples from posterior $p(\theta|x)$ via rejection sampling.
142144
@@ -147,6 +149,14 @@ def sample(
147149
sample_with: This argument only exists to keep backward-compatibility with
148150
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
149151
show_progress_bars: Whether to show sampling progress monitor.
152+
reject_outside_prior: If True (default), rejection sampling is used to
153+
ensure samples lie within the prior support. If False, samples are drawn
154+
directly from the proposal without rejection, which is faster but may
155+
include samples outside the prior support.
156+
max_sampling_time: Optional maximum allowed sampling time in seconds.
157+
If exceeded, sampling is aborted and a RuntimeError is raised. Only
158+
applies when `reject_outside_prior=True` (no effect otherwise since
159+
direct sampling from the proposal is fast).
150160
151161
Returns:
152162
Samples from posterior.
@@ -180,18 +190,29 @@ def sample(
180190
)
181191
m = self.m if m is None else m
182192

183-
samples, _ = rejection_sample(
184-
potential,
185-
proposal=self.proposal,
186-
num_samples=num_samples,
187-
show_progress_bars=show_progress_bars,
188-
warn_acceptance=0.01,
189-
max_sampling_batch_size=max_sampling_batch_size,
190-
num_samples_to_find_max=num_samples_to_find_max,
191-
num_iter_to_find_max=num_iter_to_find_max,
192-
m=m,
193-
device=self._device,
194-
)
193+
if reject_outside_prior:
194+
samples, _ = rejection_sample(
195+
potential,
196+
proposal=self.proposal,
197+
num_samples=num_samples,
198+
show_progress_bars=show_progress_bars,
199+
warn_acceptance=0.01,
200+
max_sampling_batch_size=max_sampling_batch_size,
201+
num_samples_to_find_max=num_samples_to_find_max,
202+
num_iter_to_find_max=num_iter_to_find_max,
203+
m=m,
204+
max_sampling_time=max_sampling_time,
205+
device=self._device,
206+
)
207+
else:
208+
# Bypass rejection sampling entirely.
209+
samples = self.proposal.sample((num_samples,))
210+
warn(
211+
"Samples drawn with reject_outside_prior=False are taken directly "
212+
"from the proposal without rejection sampling. These samples may lie "
213+
"outside the prior support, which could lead to incorrect inference.",
214+
stacklevel=2,
215+
)
195216

196217
return samples.reshape((*sample_shape, -1))
197218

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from sbi.samplers.score.predictors import Predictor
2626
from sbi.sbi_types import Shape
2727
from sbi.utils import check_prior
28-
from sbi.utils.sbiutils import gradient_ascent, within_support
28+
from sbi.utils.sbiutils import (
29+
gradient_ascent,
30+
warn_if_outside_prior_support,
31+
within_support,
32+
)
2933
from sbi.utils.torchutils import ensure_theta_batched
3034

3135

@@ -158,6 +162,8 @@ def sample(
158162
max_sampling_batch_size: int = 10_000,
159163
sample_with: Optional[str] = None,
160164
show_progress_bars: bool = True,
165+
reject_outside_prior: bool = True,
166+
max_sampling_time: Optional[float] = None,
161167
) -> Tensor:
162168
r"""Return samples from posterior distribution $p(\theta|x)$.
163169
@@ -195,6 +201,14 @@ def sample(
195201
use the 'sde' sampling method, the vector field estimator must support
196202
it and have the SCORE_DEFINED class attribute set to True.
197203
show_progress_bars: Whether to show a progress bar during sampling.
204+
reject_outside_prior: If True (default), rejection sampling is used to
205+
ensure samples lie within the prior support. If False, samples are drawn
206+
directly from the ODE/SDE sampler without rejection, which is faster but
207+
may include samples outside the prior support.
208+
max_sampling_time: Optional maximum allowed sampling time in seconds.
209+
If exceeded, sampling is aborted and a RuntimeError is raised. Only
210+
applies when `reject_outside_prior=True` (no effect otherwise since
211+
direct sampling does not use rejection).
198212
"""
199213

200214
if sample_with is None:
@@ -213,13 +227,18 @@ def sample(
213227
num_samples = torch.Size(sample_shape).numel()
214228

215229
if sample_with == "ode":
216-
samples, _ = rejection.accept_reject_sample(
217-
proposal=self.sample_via_ode,
218-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
219-
num_samples=num_samples,
220-
show_progress_bars=show_progress_bars,
221-
max_sampling_batch_size=max_sampling_batch_size,
222-
)
230+
if reject_outside_prior:
231+
samples, _ = rejection.accept_reject_sample(
232+
proposal=self.sample_via_ode,
233+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
234+
num_samples=num_samples,
235+
show_progress_bars=show_progress_bars,
236+
max_sampling_batch_size=max_sampling_batch_size,
237+
max_sampling_time=max_sampling_time,
238+
)
239+
else:
240+
# Bypass rejection sampling entirely.
241+
samples = self.sample_via_ode(torch.Size([num_samples]))
223242
elif sample_with == "sde":
224243
proposal_sampling_kwargs = {
225244
"predictor": predictor,
@@ -231,19 +250,30 @@ def sample(
231250
"max_sampling_batch_size": max_sampling_batch_size,
232251
"show_progress_bars": show_progress_bars,
233252
}
234-
samples, _ = rejection.accept_reject_sample(
235-
proposal=self._sample_via_diffusion,
236-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
237-
num_samples=num_samples,
238-
show_progress_bars=show_progress_bars,
239-
max_sampling_batch_size=max_sampling_batch_size,
240-
proposal_sampling_kwargs=proposal_sampling_kwargs,
241-
)
253+
if reject_outside_prior:
254+
samples, _ = rejection.accept_reject_sample(
255+
proposal=self._sample_via_diffusion,
256+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
257+
num_samples=num_samples,
258+
show_progress_bars=show_progress_bars,
259+
max_sampling_batch_size=max_sampling_batch_size,
260+
proposal_sampling_kwargs=proposal_sampling_kwargs,
261+
max_sampling_time=max_sampling_time,
262+
)
263+
else:
264+
# Bypass rejection sampling entirely.
265+
samples = self._sample_via_diffusion(
266+
(num_samples,),
267+
**proposal_sampling_kwargs,
268+
)
242269
else:
243270
raise ValueError(
244271
f"Expected sample_with to be 'ode' or 'sde', but got {sample_with}."
245272
)
246273

274+
if not reject_outside_prior:
275+
warn_if_outside_prior_support(self.prior, samples)
276+
247277
samples = samples.reshape(
248278
sample_shape + self.vector_field_estimator.input_shape
249279
)
@@ -427,6 +457,8 @@ def sample_batched(
427457
ts: Optional[Tensor] = None,
428458
max_sampling_batch_size: int = 10000,
429459
show_progress_bars: bool = True,
460+
reject_outside_prior: bool = True,
461+
max_sampling_time: Optional[float] = None,
430462
) -> Tensor:
431463
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
432464
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
@@ -449,6 +481,13 @@ def sample_batched(
449481
linear grid between t_max and t_min is used.
450482
max_sampling_batch_size: Maximum batch size for sampling.
451483
show_progress_bars: Whether to show sampling progress monitor.
484+
reject_outside_prior: If True (default), rejection sampling is used to
485+
ensure samples lie within the prior support. If False, samples are drawn
486+
directly from the ODE/SDE sampler without rejection, which is faster but
487+
may include samples outside the prior support.
488+
max_sampling_time: Optional maximum allowed sampling time in seconds.
489+
If exceeded, sampling is aborted and a RuntimeError is raised. Only
490+
applies when `reject_outside_prior=True`.
452491
453492
Returns:
454493
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -477,14 +516,19 @@ def sample_batched(
477516
max_sampling_batch_size = capped
478517

479518
if self.sample_with == "ode":
480-
samples, _ = rejection.accept_reject_sample(
481-
proposal=self.sample_via_ode,
482-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
483-
num_samples=num_samples,
484-
num_xos=batch_size,
485-
show_progress_bars=show_progress_bars,
486-
max_sampling_batch_size=max_sampling_batch_size,
487-
)
519+
if reject_outside_prior:
520+
samples, _ = rejection.accept_reject_sample(
521+
proposal=self.sample_via_ode,
522+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
523+
num_samples=num_samples,
524+
num_xos=batch_size,
525+
show_progress_bars=show_progress_bars,
526+
max_sampling_batch_size=max_sampling_batch_size,
527+
max_sampling_time=max_sampling_time,
528+
)
529+
else:
530+
# Bypass rejection sampling.
531+
samples = self.sample_via_ode(torch.Size([num_samples]))
488532
samples = samples.reshape(
489533
sample_shape + batch_shape + self.vector_field_estimator.input_shape
490534
)
@@ -499,19 +543,29 @@ def sample_batched(
499543
"max_sampling_batch_size": max_sampling_batch_size,
500544
"show_progress_bars": show_progress_bars,
501545
}
502-
samples, _ = rejection.accept_reject_sample(
503-
proposal=self._sample_via_diffusion,
504-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
505-
num_samples=num_samples,
506-
num_xos=batch_size,
507-
show_progress_bars=show_progress_bars,
508-
max_sampling_batch_size=max_sampling_batch_size,
509-
proposal_sampling_kwargs=proposal_sampling_kwargs,
510-
)
546+
if reject_outside_prior:
547+
samples, _ = rejection.accept_reject_sample(
548+
proposal=self._sample_via_diffusion,
549+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
550+
num_samples=num_samples,
551+
num_xos=batch_size,
552+
show_progress_bars=show_progress_bars,
553+
max_sampling_batch_size=max_sampling_batch_size,
554+
proposal_sampling_kwargs=proposal_sampling_kwargs,
555+
max_sampling_time=max_sampling_time,
556+
)
557+
else:
558+
# Bypass rejection sampling.
559+
samples = self._sample_via_diffusion(
560+
(num_samples,), **proposal_sampling_kwargs
561+
)
511562
samples = samples.reshape(
512563
sample_shape + batch_shape + self.vector_field_estimator.input_shape
513564
)
514565

566+
if not reject_outside_prior:
567+
warn_if_outside_prior_support(self.prior, samples)
568+
515569
return samples
516570

517571
def map(

0 commit comments

Comments
 (0)