Skip to content

Commit d3f22b5

Browse files
authored
expose batched sampling option in diagnostics (#1321)
* expose batched sampling option; error handling * further improvements * undo batch_size option
1 parent 06890eb commit d3f22b5

File tree

3 files changed

+34
-16
lines changed

3 files changed

+34
-16
lines changed

sbi/diagnostics/sbc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def run_sbc(
4848
num_workers: number of CPU cores to use in parallel for running
4949
`num_sbc_samples` inferences.
5050
show_progress_bar: whether to display a progress over sbc runs.
51-
use_batched_sampling: whether to use batched sampling for posterior
52-
samples.
51+
use_batched_sampling: whether to use batched sampling for posterior samples.
5352
5453
Returns:
5554
ranks: ranks of the ground truth parameters under the inferred

sbi/diagnostics/tarp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def run_tarp(
2828
distance: Callable = l2,
2929
num_bins: Optional[int] = 30,
3030
z_score_theta: bool = True,
31+
use_batched_sampling: bool = True,
3132
) -> Tuple[Tensor, Tensor]:
3233
"""
3334
Estimates coverage of samples given true values thetas with the TARP method.
@@ -54,6 +55,7 @@ def run_tarp(
5455
num_bins: number of bins to use for the credibility values.
5556
If ``None``, then ``num_sims // 10`` bins are used.
5657
z_score_theta : whether to normalize parameters before coverage test.
58+
use_batched_sampling: whether to use batched sampling for posterior samples.
5759
5860
Returns:
5961
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
@@ -67,6 +69,7 @@ def run_tarp(
6769
(num_posterior_samples,),
6870
num_workers,
6971
show_progress_bar=show_progress_bar,
72+
use_batched_sampling=use_batched_sampling,
7073
)
7174
assert posterior_samples.shape == (
7275
num_posterior_samples,

sbi/utils/diagnostics_utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import warnings
2+
13
import torch
24
from joblib import Parallel, delayed
35
from torch import Tensor
46
from tqdm import tqdm
57

68
from sbi.inference.posteriors.base_posterior import NeuralPosterior
9+
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
710
from sbi.inference.posteriors.vi_posterior import VIPosterior
811
from sbi.sbi_types import Shape
912

@@ -29,18 +32,23 @@ def get_posterior_samples_on_batch(
2932
Returns:
3033
posterior_samples: of shape (num_samples, batch_size, dim_parameters).
3134
"""
32-
batch_size = len(xs)
35+
num_xs = len(xs)
3336

34-
# Try using batched sampling when implemented.
35-
try:
36-
# has shape (num_samples, batch_size, dim_parameters)
37-
if use_batched_sampling:
37+
if use_batched_sampling:
38+
try:
39+
# has shape (num_samples, num_xs, dim_parameters)
3840
posterior_samples = posterior.sample_batched(
3941
sample_shape, x=xs, show_progress_bars=show_progress_bar
4042
)
41-
else:
42-
raise NotImplementedError
43-
except NotImplementedError:
43+
except (NotImplementedError, AssertionError):
44+
warnings.warn(
45+
"Batched sampling not implemented for this posterior. "
46+
"Falling back to non-batched sampling.",
47+
stacklevel=2,
48+
)
49+
use_batched_sampling = False
50+
51+
if not use_batched_sampling:
4452
# We need a function with extra training step for new x for VIPosterior.
4553
def sample_fun(
4654
posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0
@@ -51,8 +59,16 @@ def sample_fun(
5159
torch.manual_seed(seed)
5260
return posterior.sample(sample_shape, x=x, show_progress_bars=False)
5361

62+
if isinstance(posterior, (VIPosterior, MCMCPosterior)):
63+
warnings.warn(
64+
"Using non-batched sampling. Depending on the number of different xs "
65+
f"( {num_xs}) and the number of parallel workers {num_workers}, "
66+
"this might take a lot of time.",
67+
stacklevel=2,
68+
)
69+
5470
# Run in parallel with progress bar.
55-
seeds = torch.randint(0, 2**32, (batch_size,))
71+
seeds = torch.randint(0, 2**32, (num_xs,))
5672
outputs = list(
5773
tqdm(
5874
Parallel(return_as="generator", n_jobs=num_workers)(
@@ -61,7 +77,7 @@ def sample_fun(
6177
),
6278
disable=not show_progress_bar,
6379
total=len(xs),
64-
desc=f"Sampling {batch_size} times {sample_shape} posterior samples.",
80+
desc=f"Sampling {num_xs} times {sample_shape} posterior samples.",
6581
)
6682
) # (batch_size, num_samples, dim_parameters)
6783
# Transpose to shape convention: (sample_shape, batch_size, dim_parameters)
@@ -70,8 +86,8 @@ def sample_fun(
7086
).permute(1, 0, 2)
7187

7288
assert posterior_samples.shape[:2] == sample_shape + (
73-
batch_size,
74-
), f"""Expected batched posterior samples of shape {
75-
sample_shape + (batch_size,)
76-
} got {posterior_samples.shape[:2]}."""
89+
num_xs,
90+
), f"""Expected batched posterior samples of shape {sample_shape + (num_xs,)} got {
91+
posterior_samples.shape[:2]
92+
}."""
7793
return posterior_samples

0 commit comments

Comments
 (0)