Skip to content

Commit 76c1e1b

Browse files
authored
fix: add nan handling in diagnostics (#1359)
* add NaN and Inf check for diagnostics, adapt test. * feedback
1 parent e1305b9 commit 76c1e1b

File tree

6 files changed

+52
-3
lines changed

6 files changed

+52
-3
lines changed

sbi/diagnostics/lc2st.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from torch import Tensor
1313
from tqdm import tqdm
1414

15+
from sbi.utils.diagnostics_utils import remove_nans_and_infs_in_x
16+
1517

1618
class LC2ST:
1719
def __init__(
@@ -83,6 +85,9 @@ def __init__(
8385
[2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py
8486
"""
8587

88+
# check inputs
89+
thetas, xs = remove_nans_and_infs_in_x(thetas, xs)
90+
8691
assert thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0], (
8792
"Number of samples must match"
8893
)

sbi/diagnostics/sbc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from sbi.inference import DirectPosterior
1414
from sbi.inference.posteriors.base_posterior import NeuralPosterior
1515
from sbi.inference.posteriors.vi_posterior import VIPosterior
16-
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
16+
from sbi.utils.diagnostics_utils import (
17+
get_posterior_samples_on_batch,
18+
remove_nans_and_infs_in_x,
19+
)
1720
from sbi.utils.metrics import c2st
1821

1922

@@ -54,6 +57,9 @@ def run_sbc(
5457
ranks: ranks of the ground truth parameters under the inferred
5558
dap_samples: samples from the data averaged posterior.
5659
"""
60+
61+
thetas, xs = remove_nans_and_infs_in_x(thetas, xs)
62+
5763
num_sbc_samples = thetas.shape[0]
5864

5965
if num_sbc_samples < 100:

sbi/diagnostics/tarp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from torch import Tensor
1414

1515
from sbi.inference.posteriors.base_posterior import NeuralPosterior
16-
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
16+
from sbi.utils.diagnostics_utils import (
17+
get_posterior_samples_on_batch,
18+
remove_nans_and_infs_in_x,
19+
)
1720
from sbi.utils.metrics import l2
1821

1922

@@ -61,6 +64,9 @@ def run_tarp(
6164
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
6265
alpha: credibility values, see equation 2 of the paper
6366
"""
67+
68+
thetas, xs = remove_nans_and_infs_in_x(thetas, xs)
69+
6470
num_tarp_samples, dim_theta = thetas.shape
6571

6672
posterior_samples = get_posterior_samples_on_batch(

sbi/utils/diagnostics_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from typing import Tuple
23

34
import torch
45
from joblib import Parallel, delayed
@@ -9,6 +10,7 @@
910
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
1011
from sbi.inference.posteriors.vi_posterior import VIPosterior
1112
from sbi.sbi_types import Shape
13+
from sbi.utils import handle_invalid_x
1214

1315

1416
def get_posterior_samples_on_batch(
@@ -91,3 +93,25 @@ def sample_fun(
9193
posterior_samples.shape[:2]
9294
}."""
9395
return posterior_samples
96+
97+
98+
def remove_nans_and_infs_in_x(thetas: Tensor, xs: Tensor) -> Tuple[Tensor, Tensor]:
99+
"""Remove NaNs and Infs entries in x from both the theta and x.
100+
101+
Args:
102+
thetas: Tensor of shape (num_samples, dim_parameters).
103+
xs: Tensor of shape (num_samples, dim_observations).
104+
105+
Returns:
106+
Tuple of filtered thetas and xs, both of shape (num_valid_samples, ...).
107+
"""
108+
is_valid_x, num_nans, num_infs = handle_invalid_x(xs, exclude_invalid_x=True)
109+
if num_nans > 0 or num_infs > 0:
110+
warnings.warn(
111+
f"Found {num_nans} NaNs and {num_infs} Infs in the data. "
112+
f"These will be ignored below. Beware that only {is_valid_x.sum()} "
113+
f"/ {len(xs)} samples are left.",
114+
stacklevel=2,
115+
)
116+
117+
return thetas[is_valid_x], xs[is_valid_x]

sbi/utils/sbiutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def nle_nre_apt_msg_on_invalid_x(
387387

388388
if num_nans + num_infs > 0:
389389
if exclude_invalid_x:
390-
logging.warn(
390+
logging.warning(
391391
f"Found {num_nans} NaN simulations and {num_infs} Inf simulations."
392392
f"These will be discarded from training due to "
393393
f"`exclude_invalid_x=True`. Please be aware that this gives "

tests/inference_with_NaN_simulator_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.distributions import MultivariateNormal
88

99
from sbi import utils as utils
10+
from sbi.diagnostics import run_sbc
1011
from sbi.inference import (
1112
NPE_A,
1213
NPE_C,
@@ -113,6 +114,13 @@ def linear_gaussian_nan(
113114
# Compute the c2st and assert it is near chance level of 0.5.
114115
check_c2st(samples, target_samples, alg=f"{method}")
115116

117+
# run sbc
118+
num_sbc_samples = 100
119+
thetas = prior.sample((num_sbc_samples,))
120+
xs = simulator(thetas)
121+
ranks, daps = run_sbc(thetas, xs, posterior, num_posterior_samples=1000)
122+
assert torch.isfinite(ranks).all()
123+
116124

117125
@pytest.mark.slow
118126
def test_inference_with_restriction_estimator():

0 commit comments

Comments
 (0)