1+ import warnings
2+
13import torch
24from joblib import Parallel , delayed
35from torch import Tensor
46from tqdm import tqdm
57
68from sbi .inference .posteriors .base_posterior import NeuralPosterior
9+ from sbi .inference .posteriors .mcmc_posterior import MCMCPosterior
710from sbi .inference .posteriors .vi_posterior import VIPosterior
811from sbi .sbi_types import Shape
912
@@ -29,18 +32,23 @@ def get_posterior_samples_on_batch(
2932 Returns:
3033 posterior_samples: of shape (num_samples, batch_size, dim_parameters).
3134 """
32- batch_size = len (xs )
35+ num_xs = len (xs )
3336
34- # Try using batched sampling when implemented.
35- try :
36- # has shape (num_samples, batch_size, dim_parameters)
37- if use_batched_sampling :
37+ if use_batched_sampling :
38+ try :
39+ # has shape (num_samples, num_xs, dim_parameters)
3840 posterior_samples = posterior .sample_batched (
3941 sample_shape , x = xs , show_progress_bars = show_progress_bar
4042 )
41- else :
42- raise NotImplementedError
43- except NotImplementedError :
43+ except (NotImplementedError , AssertionError ):
44+ warnings .warn (
45+ "Batched sampling not implemented for this posterior. "
46+ "Falling back to non-batched sampling." ,
47+ stacklevel = 2 ,
48+ )
49+ use_batched_sampling = False
50+
51+ if not use_batched_sampling :
4452 # We need a function with extra training step for new x for VIPosterior.
4553 def sample_fun (
4654 posterior : NeuralPosterior , sample_shape : Shape , x : Tensor , seed : int = 0
@@ -51,8 +59,16 @@ def sample_fun(
5159 torch .manual_seed (seed )
5260 return posterior .sample (sample_shape , x = x , show_progress_bars = False )
5361
62+ if isinstance (posterior , (VIPosterior , MCMCPosterior )):
63+ warnings .warn (
64+ "Using non-batched sampling. Depending on the number of different xs "
65+ f"( { num_xs } ) and the number of parallel workers { num_workers } , "
66+ "this might take a lot of time." ,
67+ stacklevel = 2 ,
68+ )
69+
5470 # Run in parallel with progress bar.
55- seeds = torch .randint (0 , 2 ** 32 , (batch_size ,))
71+ seeds = torch .randint (0 , 2 ** 32 , (num_xs ,))
5672 outputs = list (
5773 tqdm (
5874 Parallel (return_as = "generator" , n_jobs = num_workers )(
@@ -61,7 +77,7 @@ def sample_fun(
6177 ),
6278 disable = not show_progress_bar ,
6379 total = len (xs ),
64- desc = f"Sampling { batch_size } times { sample_shape } posterior samples." ,
80+ desc = f"Sampling { num_xs } times { sample_shape } posterior samples." ,
6581 )
6682 ) # (batch_size, num_samples, dim_parameters)
6783 # Transpose to shape convention: (sample_shape, batch_size, dim_parameters)
@@ -70,8 +86,8 @@ def sample_fun(
7086 ).permute (1 , 0 , 2 )
7187
7288 assert posterior_samples .shape [:2 ] == sample_shape + (
73- batch_size ,
74- ), f"""Expected batched posterior samples of shape {
75- sample_shape + ( batch_size ,)
76- } got { posterior_samples . shape [: 2 ] } ."""
89+ num_xs ,
90+ ), f"""Expected batched posterior samples of shape { sample_shape + ( num_xs ,) } got {
91+ posterior_samples . shape [: 2 ]
92+ } ."""
7793 return posterior_samples
0 commit comments