Skip to content

Commit b9a48c7

Browse files
committed
apply ruff to tests and tutorials, ignore long lines.
1 parent a6a220d commit b9a48c7

File tree

6 files changed

+46
-3
lines changed

6 files changed

+46
-3
lines changed

sbi/diagnostics/lc2st.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from torch import Tensor
1313
from tqdm import tqdm
1414

15+
from sbi.utils.diagnostics_utils import remove_nans_and_infs
16+
1517

1618
class LC2ST:
1719
def __init__(
@@ -83,6 +85,8 @@ def __init__(
8385
[2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py
8486
"""
8587

88+
thetas, xs = remove_nans_and_infs(thetas, xs)
89+
8690
assert (
8791
thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0]
8892
), "Number of samples must match"

sbi/diagnostics/sbc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from sbi.inference import DirectPosterior
1414
from sbi.inference.posteriors.base_posterior import NeuralPosterior
1515
from sbi.inference.posteriors.vi_posterior import VIPosterior
16-
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
16+
from sbi.utils.diagnostics_utils import (
17+
get_posterior_samples_on_batch,
18+
remove_nans_and_infs,
19+
)
1720
from sbi.utils.metrics import c2st
1821

1922

@@ -54,6 +57,9 @@ def run_sbc(
5457
ranks: ranks of the ground truth parameters under the inferred
5558
dap_samples: samples from the data averaged posterior.
5659
"""
60+
61+
thetas, xs = remove_nans_and_infs(thetas, xs)
62+
5763
num_sbc_samples = thetas.shape[0]
5864

5965
if num_sbc_samples < 100:

sbi/diagnostics/tarp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from torch import Tensor
1414

1515
from sbi.inference.posteriors.base_posterior import NeuralPosterior
16-
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
16+
from sbi.utils.diagnostics_utils import (
17+
get_posterior_samples_on_batch,
18+
remove_nans_and_infs,
19+
)
1720
from sbi.utils.metrics import l2
1821

1922

@@ -61,6 +64,9 @@ def run_tarp(
6164
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
6265
alpha: credibility values, see equation 2 of the paper
6366
"""
67+
68+
thetas, xs = remove_nans_and_infs(thetas, xs)
69+
6470
num_tarp_samples, dim_theta = thetas.shape
6571

6672
posterior_samples = get_posterior_samples_on_batch(

sbi/utils/diagnostics_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from typing import Tuple
23

34
import torch
45
from joblib import Parallel, delayed
@@ -9,6 +10,7 @@
910
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
1011
from sbi.inference.posteriors.vi_posterior import VIPosterior
1112
from sbi.sbi_types import Shape
13+
from sbi.utils import handle_invalid_x
1214

1315

1416
def get_posterior_samples_on_batch(
@@ -91,3 +93,20 @@ def sample_fun(
9193
posterior_samples.shape[:2]
9294
}."""
9395
return posterior_samples
96+
97+
98+
def remove_nans_and_infs(thetas: Tensor, xs: Tensor) -> Tuple[Tensor, Tensor]:
99+
"""Remove NaNs and Infs from the data."""
100+
is_valid_x, num_nans, num_infs = handle_invalid_x(xs)
101+
if num_nans > 0 or num_infs > 0:
102+
warnings.warn(
103+
f"Found {num_nans} NaNs and {num_infs} Infs in the data. "
104+
f"These will be ignored below. Beware that only {is_valid_x.sum()} "
105+
f"/ {len(xs)} samples are left.",
106+
stacklevel=2,
107+
)
108+
109+
thetas = thetas[is_valid_x]
110+
xs = xs[is_valid_x]
111+
112+
return thetas, xs

sbi/utils/sbiutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def nle_nre_apt_msg_on_invalid_x(
387387

388388
if num_nans + num_infs > 0:
389389
if exclude_invalid_x:
390-
logging.warn(
390+
logging.warning(
391391
f"Found {num_nans} NaN simulations and {num_infs} Inf simulations."
392392
f"These will be discarded from training due to "
393393
f"`exclude_invalid_x=True`. Please be aware that this gives "

tests/inference_with_NaN_simulator_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.distributions import MultivariateNormal
88

99
from sbi import utils as utils
10+
from sbi.diagnostics import run_sbc
1011
from sbi.inference import (
1112
NPE_A,
1213
NPE_C,
@@ -113,6 +114,13 @@ def linear_gaussian_nan(
113114
# Compute the c2st and assert it is near chance level of 0.5.
114115
check_c2st(samples, target_samples, alg=f"{method}")
115116

117+
# run sbc
118+
num_sbc_samples = 100
119+
thetas = prior.sample((num_sbc_samples,))
120+
xs = simulator(thetas)
121+
ranks, daps = run_sbc(thetas, xs, posterior, num_posterior_samples=1000)
122+
assert torch.isfinite(ranks).all()
123+
116124

117125
@pytest.mark.slow
118126
def test_inference_with_restriction_estimator():

0 commit comments

Comments
 (0)