Skip to content

Commit 74c5ea1

Browse files
committed
add reject_outside_prior_support logic and update rejection sampling tests
1 parent d9608a0 commit 74c5ea1

File tree

5 files changed

+245
-144
lines changed

5 files changed

+245
-144
lines changed

sbi/inference/posteriors/direct_posterior.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def sample(
139139
max_sampling_batch_size: int = 10_000,
140140
sample_with: Optional[str] = None,
141141
show_progress_bars: bool = True,
142+
reject_outside_prior: bool = True,
143+
max_sampling_time: Optional[float] = None,
142144
) -> Tensor:
143145
r"""Return samples from posterior distribution $p(\theta|x)$.
144146
@@ -149,6 +151,11 @@ def sample(
149151
sample_with: This argument only exists to keep backward-compatibility with
150152
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
151153
show_progress_bars: Whether to show sampling progress monitor.
154+
reject_outside_prior: If True (default), rejection sampling is used to
155+
ensure samples lie within the prior support. If False, samples are drawn
156+
directly from the proposal without rejection sampling.
157+
max_sampling_time: Optional maximum allowed sampling time in seconds.
158+
If exceeded, sampling is aborted and a RuntimeError is raised.
152159
"""
153160
num_samples = torch.Size(sample_shape).numel()
154161
x = self._x_else_default_x(x)
@@ -177,15 +184,24 @@ def sample(
177184
f"`.build_posterior(sample_with={sample_with}).`"
178185
)
179186

180-
samples = rejection.accept_reject_sample(
181-
proposal=self.posterior_estimator.sample,
182-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
183-
num_samples=num_samples,
184-
show_progress_bars=show_progress_bars,
185-
max_sampling_batch_size=max_sampling_batch_size,
186-
proposal_sampling_kwargs={"condition": x},
187-
alternative_method="build_posterior(..., sample_with='mcmc')",
188-
)[0] # [0] to return only samples, not acceptance probabilities.
187+
if reject_outside_prior:
188+
# normal rejection behaviour
189+
samples = rejection.accept_reject_sample(
190+
proposal=self.posterior_estimator.sample,
191+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
192+
num_samples=num_samples,
193+
show_progress_bars=show_progress_bars,
194+
max_sampling_batch_size=max_sampling_batch_size,
195+
proposal_sampling_kwargs={"condition": x},
196+
alternative_method="build_posterior(..., sample_with='mcmc')",
197+
max_sampling_time=max_sampling_time,
198+
)[0]
199+
else:
200+
# bypass rejection sampling entirely
201+
samples = self.posterior_estimator.sample(
202+
(num_samples,),
203+
condition=x,
204+
)
189205

190206
return samples[:, 0] # Remove batch dimension.
191207

@@ -195,6 +211,8 @@ def sample_batched(
195211
x: Tensor,
196212
max_sampling_batch_size: int = 10_000,
197213
show_progress_bars: bool = True,
214+
reject_outside_prior: bool = True,
215+
max_sampling_time: Optional[float] = None,
198216
) -> Tensor:
199217
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
200218
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
@@ -207,6 +225,11 @@ def sample_batched(
207225
`batch_dim` corresponds to the number of observations to be drawn.
208226
max_sampling_batch_size: Maximum batch size for rejection sampling.
209227
show_progress_bars: Whether to show sampling progress monitor.
228+
reject_outside_prior: If True (default), rejection sampling is used to
229+
ensure samples lie within the prior support. If False, samples are drawn
230+
directly from the proposal without rejection sampling.
231+
max_sampling_time: Optional maximum allowed sampling time in seconds.
232+
If exceeded, sampling is aborted and a RuntimeError is raised.
210233
211234
Returns:
212235
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -242,15 +265,24 @@ def sample_batched(
242265
)
243266
max_sampling_batch_size = capped
244267

245-
samples = rejection.accept_reject_sample(
246-
proposal=self.posterior_estimator.sample,
247-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
248-
num_samples=num_samples,
249-
show_progress_bars=show_progress_bars,
250-
max_sampling_batch_size=max_sampling_batch_size,
251-
proposal_sampling_kwargs={"condition": x},
252-
alternative_method="build_posterior(..., sample_with='mcmc')",
253-
)[0]
268+
if reject_outside_prior:
269+
# normal rejection behaviour
270+
samples = rejection.accept_reject_sample(
271+
proposal=self.posterior_estimator.sample,
272+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
273+
num_samples=num_samples,
274+
show_progress_bars=show_progress_bars,
275+
max_sampling_batch_size=max_sampling_batch_size,
276+
proposal_sampling_kwargs={"condition": x},
277+
alternative_method="build_posterior(..., sample_with='mcmc')",
278+
max_sampling_time=max_sampling_time,
279+
)[0]
280+
else:
281+
# bypass rejection sampling entirely
282+
samples = self.posterior_estimator.sample(
283+
(num_samples,),
284+
condition=x,
285+
)
254286

255287
return samples
256288

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def sample(
158158
max_sampling_batch_size: int = 10_000,
159159
sample_with: Optional[str] = None,
160160
show_progress_bars: bool = True,
161+
reject_outside_prior: bool = True,
162+
max_sampling_time: Optional[float] = None,
161163
) -> Tensor:
162164
r"""Return samples from posterior distribution $p(\theta|x)$.
163165
@@ -195,6 +197,11 @@ def sample(
195197
use the 'sde' sampling method, the vector field estimator must support
196198
it and have the SCORE_DEFINED class attribute set to True.
197199
show_progress_bars: Whether to show a progress bar during sampling.
200+
reject_outside_prior: If True (default), rejection sampling is used to
201+
ensure samples lie within the prior support. If False, samples are drawn
202+
directly from the proposal without rejection sampling.
203+
max_sampling_time: Optional maximum allowed sampling time in seconds.
204+
If exceeded, sampling is aborted and a RuntimeError is raised.
198205
"""
199206

200207
if sample_with is None:
@@ -213,13 +220,17 @@ def sample(
213220
num_samples = torch.Size(sample_shape).numel()
214221

215222
if sample_with == "ode":
216-
samples, _ = rejection.accept_reject_sample(
217-
proposal=self.sample_via_ode,
218-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
219-
num_samples=num_samples,
220-
show_progress_bars=show_progress_bars,
221-
max_sampling_batch_size=max_sampling_batch_size,
222-
)
223+
if reject_outside_prior:
224+
samples, _ = rejection.accept_reject_sample(
225+
proposal=self.sample_via_ode,
226+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
227+
num_samples=num_samples,
228+
show_progress_bars=show_progress_bars,
229+
max_sampling_batch_size=max_sampling_batch_size,
230+
max_sampling_time=max_sampling_time,
231+
)
232+
else:
233+
samples = self.sample_via_ode((num_samples,))
223234
elif sample_with == "sde":
224235
proposal_sampling_kwargs = {
225236
"predictor": predictor,
@@ -231,14 +242,21 @@ def sample(
231242
"max_sampling_batch_size": max_sampling_batch_size,
232243
"show_progress_bars": show_progress_bars,
233244
}
234-
samples, _ = rejection.accept_reject_sample(
235-
proposal=self._sample_via_diffusion,
236-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
237-
num_samples=num_samples,
238-
show_progress_bars=show_progress_bars,
239-
max_sampling_batch_size=max_sampling_batch_size,
240-
proposal_sampling_kwargs=proposal_sampling_kwargs,
241-
)
245+
if reject_outside_prior:
246+
samples, _ = rejection.accept_reject_sample(
247+
proposal=self._sample_via_diffusion,
248+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
249+
num_samples=num_samples,
250+
show_progress_bars=show_progress_bars,
251+
max_sampling_batch_size=max_sampling_batch_size,
252+
proposal_sampling_kwargs=proposal_sampling_kwargs,
253+
max_sampling_time=max_sampling_time,
254+
)
255+
else:
256+
samples = self._sample_via_diffusion(
257+
(num_samples,),
258+
**proposal_sampling_kwargs,
259+
)
242260
else:
243261
raise ValueError(
244262
f"Expected sample_with to be 'ode' or 'sde', but got {sample_with}."
@@ -427,6 +445,8 @@ def sample_batched(
427445
ts: Optional[Tensor] = None,
428446
max_sampling_batch_size: int = 10000,
429447
show_progress_bars: bool = True,
448+
reject_outside_prior: bool = True,
449+
max_sampling_time: Optional[float] = None,
430450
) -> Tensor:
431451
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
432452
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
@@ -449,6 +469,11 @@ def sample_batched(
449469
linear grid between t_max and t_min is used.
450470
max_sampling_batch_size: Maximum batch size for sampling.
451471
show_progress_bars: Whether to show sampling progress monitor.
472+
reject_outside_prior: If True (default), rejection sampling is used to
473+
ensure samples lie within the prior support. If False, samples are drawn
474+
directly from the proposal without rejection sampling.
475+
max_sampling_time: Optional maximum allowed sampling time in seconds.
476+
If exceeded, sampling is aborted and a RuntimeError is raised.
452477
453478
Returns:
454479
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -477,14 +502,18 @@ def sample_batched(
477502
max_sampling_batch_size = capped
478503

479504
if self.sample_with == "ode":
480-
samples, _ = rejection.accept_reject_sample(
481-
proposal=self.sample_via_ode,
482-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
483-
num_samples=num_samples,
484-
num_xos=batch_size,
485-
show_progress_bars=show_progress_bars,
486-
max_sampling_batch_size=max_sampling_batch_size,
487-
)
505+
if reject_outside_prior:
506+
samples, _ = rejection.accept_reject_sample(
507+
proposal=self.sample_via_ode,
508+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
509+
num_samples=num_samples,
510+
num_xos=batch_size,
511+
show_progress_bars=show_progress_bars,
512+
max_sampling_batch_size=max_sampling_batch_size,
513+
max_sampling_time=max_sampling_time,
514+
)
515+
else:
516+
samples = self.sample_via_ode((num_samples,))
488517
samples = samples.reshape(
489518
sample_shape + batch_shape + self.vector_field_estimator.input_shape
490519
)
@@ -499,15 +528,21 @@ def sample_batched(
499528
"max_sampling_batch_size": max_sampling_batch_size,
500529
"show_progress_bars": show_progress_bars,
501530
}
502-
samples, _ = rejection.accept_reject_sample(
503-
proposal=self._sample_via_diffusion,
504-
accept_reject_fn=lambda theta: within_support(self.prior, theta),
505-
num_samples=num_samples,
506-
num_xos=batch_size,
507-
show_progress_bars=show_progress_bars,
508-
max_sampling_batch_size=max_sampling_batch_size,
509-
proposal_sampling_kwargs=proposal_sampling_kwargs,
510-
)
531+
if reject_outside_prior:
532+
samples, _ = rejection.accept_reject_sample(
533+
proposal=self._sample_via_diffusion,
534+
accept_reject_fn=lambda theta: within_support(self.prior, theta),
535+
num_samples=num_samples,
536+
num_xos=batch_size,
537+
show_progress_bars=show_progress_bars,
538+
max_sampling_batch_size=max_sampling_batch_size,
539+
proposal_sampling_kwargs=proposal_sampling_kwargs,
540+
max_sampling_time=max_sampling_time,
541+
)
542+
else:
543+
samples = self._sample_via_diffusion(
544+
(num_samples,), **proposal_sampling_kwargs
545+
)
511546
samples = samples.reshape(
512547
sample_shape + batch_shape + self.vector_field_estimator.input_shape
513548
)

sbi/samplers/rejection/rejection.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,10 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor:
146146
raise RuntimeError(
147147
"Sampling aborted early because rejection sampling exceeded "
148148
"max_sampling_time. This is likely due to extremely low "
149-
"acceptance. Consider switching to MCMC or VI, or checking "
150-
"for model misspecification."
149+
"acceptance. You can disable rejection sampling using "
150+
"`reject_outside_prior=False` to draw samples directly from "
151+
"the trained estimator. Consider switching to MCMC or VI, or "
152+
"checking for model misspecification."
151153
)
152154

153155
# Sample and reject.
@@ -319,8 +321,10 @@ def accept_reject_sample(
319321
raise RuntimeError(
320322
"Sampling aborted early because rejection sampling exceeded "
321323
"max_sampling_time. This is likely due to extremely low "
322-
"acceptance. Consider switching to MCMC or VI, or checking "
323-
"for model misspecification."
324+
"acceptance. You can disable rejection sampling using "
325+
"`reject_outside_prior=False` to draw samples directly from "
326+
"the trained estimator. Consider switching to MCMC or VI, or "
327+
"checking for model misspecification."
324328
)
325329

326330
# Sample and reject.

0 commit comments

Comments
 (0)