2525from sbi .samplers .score .predictors import Predictor
2626from sbi .sbi_types import Shape
2727from sbi .utils import check_prior
28- from sbi .utils .sbiutils import gradient_ascent , within_support
28+ from sbi .utils .sbiutils import (
29+ gradient_ascent ,
30+ warn_if_outside_prior_support ,
31+ within_support ,
32+ )
2933from sbi .utils .torchutils import ensure_theta_batched
3034
3135
@@ -158,6 +162,8 @@ def sample(
158162 max_sampling_batch_size : int = 10_000 ,
159163 sample_with : Optional [str ] = None ,
160164 show_progress_bars : bool = True ,
165+ reject_outside_prior : bool = True ,
166+ max_sampling_time : Optional [float ] = None ,
161167 ) -> Tensor :
162168 r"""Return samples from posterior distribution $p(\theta|x)$.
163169
@@ -195,6 +201,14 @@ def sample(
195201 use the 'sde' sampling method, the vector field estimator must support
196202 it and have the SCORE_DEFINED class attribute set to True.
197203 show_progress_bars: Whether to show a progress bar during sampling.
204+ reject_outside_prior: If True (default), rejection sampling is used to
205+ ensure samples lie within the prior support. If False, samples are drawn
206+ directly from the ODE/SDE sampler without rejection, which is faster but
207+ may include samples outside the prior support.
208+ max_sampling_time: Optional maximum allowed sampling time in seconds.
209+ If exceeded, sampling is aborted and a RuntimeError is raised. Only
210+ applies when `reject_outside_prior=True` (no effect otherwise since
211+ direct sampling does not use rejection).
198212 """
199213
200214 if sample_with is None :
@@ -213,13 +227,18 @@ def sample(
213227 num_samples = torch .Size (sample_shape ).numel ()
214228
215229 if sample_with == "ode" :
216- samples , _ = rejection .accept_reject_sample (
217- proposal = self .sample_via_ode ,
218- accept_reject_fn = lambda theta : within_support (self .prior , theta ),
219- num_samples = num_samples ,
220- show_progress_bars = show_progress_bars ,
221- max_sampling_batch_size = max_sampling_batch_size ,
222- )
230+ if reject_outside_prior :
231+ samples , _ = rejection .accept_reject_sample (
232+ proposal = self .sample_via_ode ,
233+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
234+ num_samples = num_samples ,
235+ show_progress_bars = show_progress_bars ,
236+ max_sampling_batch_size = max_sampling_batch_size ,
237+ max_sampling_time = max_sampling_time ,
238+ )
239+ else :
240+ # Bypass rejection sampling entirely.
241+ samples = self .sample_via_ode (torch .Size ([num_samples ]))
223242 elif sample_with == "sde" :
224243 proposal_sampling_kwargs = {
225244 "predictor" : predictor ,
@@ -231,19 +250,30 @@ def sample(
231250 "max_sampling_batch_size" : max_sampling_batch_size ,
232251 "show_progress_bars" : show_progress_bars ,
233252 }
234- samples , _ = rejection .accept_reject_sample (
235- proposal = self ._sample_via_diffusion ,
236- accept_reject_fn = lambda theta : within_support (self .prior , theta ),
237- num_samples = num_samples ,
238- show_progress_bars = show_progress_bars ,
239- max_sampling_batch_size = max_sampling_batch_size ,
240- proposal_sampling_kwargs = proposal_sampling_kwargs ,
241- )
253+ if reject_outside_prior :
254+ samples , _ = rejection .accept_reject_sample (
255+ proposal = self ._sample_via_diffusion ,
256+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
257+ num_samples = num_samples ,
258+ show_progress_bars = show_progress_bars ,
259+ max_sampling_batch_size = max_sampling_batch_size ,
260+ proposal_sampling_kwargs = proposal_sampling_kwargs ,
261+ max_sampling_time = max_sampling_time ,
262+ )
263+ else :
264+ # Bypass rejection sampling entirely.
265+ samples = self ._sample_via_diffusion (
266+ (num_samples ,),
267+ ** proposal_sampling_kwargs ,
268+ )
242269 else :
243270 raise ValueError (
244271 f"Expected sample_with to be 'ode' or 'sde', but got { sample_with } ."
245272 )
246273
274+ if not reject_outside_prior :
275+ warn_if_outside_prior_support (self .prior , samples )
276+
247277 samples = samples .reshape (
248278 sample_shape + self .vector_field_estimator .input_shape
249279 )
@@ -427,6 +457,8 @@ def sample_batched(
427457 ts : Optional [Tensor ] = None ,
428458 max_sampling_batch_size : int = 10000 ,
429459 show_progress_bars : bool = True ,
460+ reject_outside_prior : bool = True ,
461+ max_sampling_time : Optional [float ] = None ,
430462 ) -> Tensor :
431463 r"""Given a batch of observations [x_1, ..., x_B] this function samples from
432464 posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
@@ -449,6 +481,13 @@ def sample_batched(
449481 linear grid between t_max and t_min is used.
450482 max_sampling_batch_size: Maximum batch size for sampling.
451483 show_progress_bars: Whether to show sampling progress monitor.
484+ reject_outside_prior: If True (default), rejection sampling is used to
485+ ensure samples lie within the prior support. If False, samples are drawn
486+ directly from the ODE/SDE sampler without rejection, which is faster but
487+ may include samples outside the prior support.
488+ max_sampling_time: Optional maximum allowed sampling time in seconds.
489+ If exceeded, sampling is aborted and a RuntimeError is raised. Only
490+ applies when `reject_outside_prior=True`.
452491
453492 Returns:
454493 Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -477,14 +516,19 @@ def sample_batched(
477516 max_sampling_batch_size = capped
478517
479518 if self .sample_with == "ode" :
480- samples , _ = rejection .accept_reject_sample (
481- proposal = self .sample_via_ode ,
482- accept_reject_fn = lambda theta : within_support (self .prior , theta ),
483- num_samples = num_samples ,
484- num_xos = batch_size ,
485- show_progress_bars = show_progress_bars ,
486- max_sampling_batch_size = max_sampling_batch_size ,
487- )
519+ if reject_outside_prior :
520+ samples , _ = rejection .accept_reject_sample (
521+ proposal = self .sample_via_ode ,
522+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
523+ num_samples = num_samples ,
524+ num_xos = batch_size ,
525+ show_progress_bars = show_progress_bars ,
526+ max_sampling_batch_size = max_sampling_batch_size ,
527+ max_sampling_time = max_sampling_time ,
528+ )
529+ else :
530+ # Bypass rejection sampling.
531+ samples = self .sample_via_ode (torch .Size ([num_samples ]))
488532 samples = samples .reshape (
489533 sample_shape + batch_shape + self .vector_field_estimator .input_shape
490534 )
@@ -499,19 +543,29 @@ def sample_batched(
499543 "max_sampling_batch_size" : max_sampling_batch_size ,
500544 "show_progress_bars" : show_progress_bars ,
501545 }
502- samples , _ = rejection .accept_reject_sample (
503- proposal = self ._sample_via_diffusion ,
504- accept_reject_fn = lambda theta : within_support (self .prior , theta ),
505- num_samples = num_samples ,
506- num_xos = batch_size ,
507- show_progress_bars = show_progress_bars ,
508- max_sampling_batch_size = max_sampling_batch_size ,
509- proposal_sampling_kwargs = proposal_sampling_kwargs ,
510- )
546+ if reject_outside_prior :
547+ samples , _ = rejection .accept_reject_sample (
548+ proposal = self ._sample_via_diffusion ,
549+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
550+ num_samples = num_samples ,
551+ num_xos = batch_size ,
552+ show_progress_bars = show_progress_bars ,
553+ max_sampling_batch_size = max_sampling_batch_size ,
554+ proposal_sampling_kwargs = proposal_sampling_kwargs ,
555+ max_sampling_time = max_sampling_time ,
556+ )
557+ else :
558+ # Bypass rejection sampling.
559+ samples = self ._sample_via_diffusion (
560+ (num_samples ,), ** proposal_sampling_kwargs
561+ )
511562 samples = samples .reshape (
512563 sample_shape + batch_shape + self .vector_field_estimator .input_shape
513564 )
514565
566+ if not reject_outside_prior :
567+ warn_if_outside_prior_support (self .prior , samples )
568+
515569 return samples
516570
517571 def map (
0 commit comments