11# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33
4+ import math
45import warnings
56from typing import Dict , Literal , Optional , Union
67
@@ -150,7 +151,9 @@ def sample(
150151 corrector_params : Optional [Dict ] = None ,
151152 steps : int = 500 ,
152153 ts : Optional [Tensor ] = None ,
153- iid_method : Literal ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" ] = "auto_gauss" ,
154+ iid_method : Optional [
155+ Literal ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" ]
156+ ] = None ,
154157 iid_params : Optional [Dict ] = None ,
155158 max_sampling_batch_size : int = 10_000 ,
156159 sample_with : Optional [str ] = None ,
@@ -201,19 +204,22 @@ def sample(
201204 x = reshape_to_batch_event (x , self .vector_field_estimator .condition_shape )
202205 is_iid = x .shape [0 ] > 1
203206 self .potential_fn .set_x (
204- x , x_is_iid = is_iid , iid_method = iid_method , iid_params = iid_params
207+ x ,
208+ x_is_iid = is_iid ,
209+ iid_method = iid_method or self .potential_fn .iid_method ,
210+ iid_params = iid_params ,
205211 )
206212
207213 num_samples = torch .Size (sample_shape ).numel ()
208214
209215 if sample_with == "ode" :
210- samples = rejection .accept_reject_sample (
216+ samples , _ = rejection .accept_reject_sample (
211217 proposal = self .sample_via_ode ,
212218 accept_reject_fn = lambda theta : within_support (self .prior , theta ),
213219 num_samples = num_samples ,
214220 show_progress_bars = show_progress_bars ,
215221 max_sampling_batch_size = max_sampling_batch_size ,
216- )[ 0 ]
222+ )
217223 elif sample_with == "sde" :
218224 proposal_sampling_kwargs = {
219225 "predictor" : predictor ,
@@ -225,14 +231,14 @@ def sample(
225231 "max_sampling_batch_size" : max_sampling_batch_size ,
226232 "show_progress_bars" : show_progress_bars ,
227233 }
228- samples = rejection .accept_reject_sample (
234+ samples , _ = rejection .accept_reject_sample (
229235 proposal = self ._sample_via_diffusion ,
230236 accept_reject_fn = lambda theta : within_support (self .prior , theta ),
231237 num_samples = num_samples ,
232238 show_progress_bars = show_progress_bars ,
233239 max_sampling_batch_size = max_sampling_batch_size ,
234240 proposal_sampling_kwargs = proposal_sampling_kwargs ,
235- )[ 0 ]
241+ )
236242 else :
237243 raise ValueError (
238244 f"Expected sample_with to be 'ode' or 'sde', but got { sample_with } ."
@@ -282,13 +288,16 @@ def _sample_via_diffusion(
282288 "The vector field estimator does not support the 'sde' sampling method."
283289 )
284290
285- num_samples = torch .Size (sample_shape ).numel ()
291+ total_samples_needed = torch .Size (sample_shape ).numel ()
286292
287- max_sampling_batch_size = (
293+ # Determine effective batch size for sampling
294+ effective_batch_size = (
288295 self .max_sampling_batch_size
289296 if max_sampling_batch_size is None
290297 else max_sampling_batch_size
291298 )
299+ # Ensure we don't use larger batches than total samples needed
300+ effective_batch_size = min (effective_batch_size , total_samples_needed )
292301
293302 # TODO: the time schedule should be provided by the estimator, see issue #1437
294303 if ts is None :
@@ -297,28 +306,47 @@ def _sample_via_diffusion(
297306 ts = torch .linspace (t_max , t_min , steps )
298307 ts = ts .to (self .device )
299308
309+ # Initialize the diffusion sampler
300310 diffuser = Diffuser (
301311 self .potential_fn ,
302312 predictor = predictor ,
303313 corrector = corrector ,
304314 predictor_params = predictor_params ,
305315 corrector_params = corrector_params ,
306316 )
307- max_sampling_batch_size = min (max_sampling_batch_size , num_samples )
308- samples = []
309- num_iter = num_samples // max_sampling_batch_size
310- num_iter = (
311- num_iter + 1 if (num_samples % max_sampling_batch_size ) != 0 else num_iter
312- )
313- for _ in range (num_iter ):
314- samples .append (
315- diffuser .run (
316- num_samples = max_sampling_batch_size ,
317- ts = ts ,
318- show_progress_bars = show_progress_bars ,
319- )
317+
318+ # Calculate how many batches we need
319+ num_batches = math .ceil (total_samples_needed / effective_batch_size )
320+
321+ # Generate samples in batches
322+ all_samples = []
323+ samples_generated = 0
324+
325+ for _ in range (num_batches ):
326+ # Calculate how many samples to generate in this batch
327+ remaining_samples = total_samples_needed - samples_generated
328+ current_batch_size = min (effective_batch_size , remaining_samples )
329+
330+ # Generate samples for this batch
331+ batch_samples = diffuser .run (
332+ num_samples = current_batch_size ,
333+ ts = ts ,
334+ show_progress_bars = show_progress_bars ,
335+ )
336+
337+ all_samples .append (batch_samples )
338+ samples_generated += current_batch_size
339+
340+ # Concatenate all batches and ensure we return exactly the requested number
341+ samples = torch .cat (all_samples , dim = 0 )[:total_samples_needed ]
342+
343+ # Check for NaN values
344+ if torch .isnan (samples ).any ():
345+ raise RuntimeError (
346+ "NaN values detected during diffusion sampling. This may indicate"
347+ " numerical instability in the vector field or improper time "
348+ "scheduling."
320349 )
321- samples = torch .cat (samples , dim = 0 )[:num_samples ]
322350
323351 return samples
324352
@@ -443,14 +471,14 @@ def sample_batched(
443471 max_sampling_batch_size = capped
444472
445473 if self .sample_with == "ode" :
446- samples = rejection .accept_reject_sample (
474+ samples , _ = rejection .accept_reject_sample (
447475 proposal = self .sample_via_ode ,
448476 accept_reject_fn = lambda theta : within_support (self .prior , theta ),
449477 num_samples = num_samples ,
450478 num_xos = batch_size ,
451479 show_progress_bars = show_progress_bars ,
452480 max_sampling_batch_size = max_sampling_batch_size ,
453- )[ 0 ]
481+ )
454482 samples = samples .reshape (
455483 sample_shape + batch_shape + self .vector_field_estimator .input_shape
456484 )
@@ -465,15 +493,15 @@ def sample_batched(
465493 "max_sampling_batch_size" : max_sampling_batch_size ,
466494 "show_progress_bars" : show_progress_bars ,
467495 }
468- samples = rejection .accept_reject_sample (
496+ samples , _ = rejection .accept_reject_sample (
469497 proposal = self ._sample_via_diffusion ,
470498 accept_reject_fn = lambda theta : within_support (self .prior , theta ),
471499 num_samples = num_samples ,
472500 num_xos = batch_size ,
473501 show_progress_bars = show_progress_bars ,
474502 max_sampling_batch_size = max_sampling_batch_size ,
475503 proposal_sampling_kwargs = proposal_sampling_kwargs ,
476- )[ 0 ]
504+ )
477505 samples = samples .reshape (
478506 sample_shape + batch_shape + self .vector_field_estimator .input_shape
479507 )
0 commit comments