@@ -158,6 +158,8 @@ def sample(
158158 max_sampling_batch_size : int = 10_000 ,
159159 sample_with : Optional [str ] = None ,
160160 show_progress_bars : bool = True ,
161+ reject_outside_prior : bool = True ,
162+ max_sampling_time : Optional [float ] = None ,
161163 ) -> Tensor :
162164 r"""Return samples from posterior distribution $p(\theta|x)$.
163165
@@ -195,6 +197,11 @@ def sample(
195197 use the 'sde' sampling method, the vector field estimator must support
196198 it and have the SCORE_DEFINED class attribute set to True.
197199 show_progress_bars: Whether to show a progress bar during sampling.
200+ reject_outside_prior: If True (default), rejection sampling is used to
201+ ensure samples lie within the prior support. If False, samples are drawn
202+ directly from the proposal without rejection sampling.
203+ max_sampling_time: Optional maximum allowed sampling time in seconds.
204+ If exceeded, sampling is aborted and a RuntimeError is raised.
198205 """
199206
200207 if sample_with is None :
@@ -213,13 +220,17 @@ def sample(
213220 num_samples = torch .Size (sample_shape ).numel ()
214221
215222 if sample_with == "ode" :
216- samples , _ = rejection .accept_reject_sample (
217- proposal = self .sample_via_ode ,
218- accept_reject_fn = lambda theta : within_support (self .prior , theta ),
219- num_samples = num_samples ,
220- show_progress_bars = show_progress_bars ,
221- max_sampling_batch_size = max_sampling_batch_size ,
222- )
223+ if reject_outside_prior :
224+ samples , _ = rejection .accept_reject_sample (
225+ proposal = self .sample_via_ode ,
226+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
227+ num_samples = num_samples ,
228+ show_progress_bars = show_progress_bars ,
229+ max_sampling_batch_size = max_sampling_batch_size ,
230+ max_sampling_time = max_sampling_time ,
231+ )
232+ else :
233+ samples = self .sample_via_ode ((num_samples ,))
223234 elif sample_with == "sde" :
224235 proposal_sampling_kwargs = {
225236 "predictor" : predictor ,
@@ -231,14 +242,21 @@ def sample(
231242 "max_sampling_batch_size" : max_sampling_batch_size ,
232243 "show_progress_bars" : show_progress_bars ,
233244 }
234- samples , _ = rejection .accept_reject_sample (
235- proposal = self ._sample_via_diffusion ,
236- accept_reject_fn = lambda theta : within_support (self .prior , theta ),
237- num_samples = num_samples ,
238- show_progress_bars = show_progress_bars ,
239- max_sampling_batch_size = max_sampling_batch_size ,
240- proposal_sampling_kwargs = proposal_sampling_kwargs ,
241- )
245+ if reject_outside_prior :
246+ samples , _ = rejection .accept_reject_sample (
247+ proposal = self ._sample_via_diffusion ,
248+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
249+ num_samples = num_samples ,
250+ show_progress_bars = show_progress_bars ,
251+ max_sampling_batch_size = max_sampling_batch_size ,
252+ proposal_sampling_kwargs = proposal_sampling_kwargs ,
253+ max_sampling_time = max_sampling_time ,
254+ )
255+ else :
256+ samples = self ._sample_via_diffusion (
257+ (num_samples ,),
258+ ** proposal_sampling_kwargs ,
259+ )
242260 else :
243261 raise ValueError (
244262 f"Expected sample_with to be 'ode' or 'sde', but got { sample_with } ."
@@ -427,6 +445,8 @@ def sample_batched(
427445 ts : Optional [Tensor ] = None ,
428446 max_sampling_batch_size : int = 10000 ,
429447 show_progress_bars : bool = True ,
448+ reject_outside_prior : bool = True ,
449+ max_sampling_time : Optional [float ] = None ,
430450 ) -> Tensor :
431451 r"""Given a batch of observations [x_1, ..., x_B] this function samples from
432452 posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
@@ -449,6 +469,11 @@ def sample_batched(
449469 linear grid between t_max and t_min is used.
450470 max_sampling_batch_size: Maximum batch size for sampling.
451471 show_progress_bars: Whether to show sampling progress monitor.
472+ reject_outside_prior: If True (default), rejection sampling is used to
473+ ensure samples lie within the prior support. If False, samples are drawn
474+ directly from the proposal without rejection sampling.
475+ max_sampling_time: Optional maximum allowed sampling time in seconds.
476+ If exceeded, sampling is aborted and a RuntimeError is raised.
452477
453478 Returns:
454479 Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -477,14 +502,18 @@ def sample_batched(
477502 max_sampling_batch_size = capped
478503
479504 if self .sample_with == "ode" :
480- samples , _ = rejection .accept_reject_sample (
481- proposal = self .sample_via_ode ,
482- accept_reject_fn = lambda theta : within_support (self .prior , theta ),
483- num_samples = num_samples ,
484- num_xos = batch_size ,
485- show_progress_bars = show_progress_bars ,
486- max_sampling_batch_size = max_sampling_batch_size ,
487- )
505+ if reject_outside_prior :
506+ samples , _ = rejection .accept_reject_sample (
507+ proposal = self .sample_via_ode ,
508+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
509+ num_samples = num_samples ,
510+ num_xos = batch_size ,
511+ show_progress_bars = show_progress_bars ,
512+ max_sampling_batch_size = max_sampling_batch_size ,
513+ max_sampling_time = max_sampling_time ,
514+ )
515+ else :
516+ samples = self .sample_via_ode ((num_samples ,))
488517 samples = samples .reshape (
489518 sample_shape + batch_shape + self .vector_field_estimator .input_shape
490519 )
@@ -499,15 +528,21 @@ def sample_batched(
499528 "max_sampling_batch_size" : max_sampling_batch_size ,
500529 "show_progress_bars" : show_progress_bars ,
501530 }
502- samples , _ = rejection .accept_reject_sample (
503- proposal = self ._sample_via_diffusion ,
504- accept_reject_fn = lambda theta : within_support (self .prior , theta ),
505- num_samples = num_samples ,
506- num_xos = batch_size ,
507- show_progress_bars = show_progress_bars ,
508- max_sampling_batch_size = max_sampling_batch_size ,
509- proposal_sampling_kwargs = proposal_sampling_kwargs ,
510- )
531+ if reject_outside_prior :
532+ samples , _ = rejection .accept_reject_sample (
533+ proposal = self ._sample_via_diffusion ,
534+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
535+ num_samples = num_samples ,
536+ num_xos = batch_size ,
537+ show_progress_bars = show_progress_bars ,
538+ max_sampling_batch_size = max_sampling_batch_size ,
539+ proposal_sampling_kwargs = proposal_sampling_kwargs ,
540+ max_sampling_time = max_sampling_time ,
541+ )
542+ else :
543+ samples = self ._sample_via_diffusion (
544+ (num_samples ,), ** proposal_sampling_kwargs
545+ )
511546 samples = samples .reshape (
512547 sample_shape + batch_shape + self .vector_field_estimator .input_shape
513548 )
0 commit comments