|
1 | 1 | import warnings |
| 2 | +from typing import Tuple |
2 | 3 |
|
3 | 4 | import torch |
4 | 5 | from joblib import Parallel, delayed |
|
9 | 10 | from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior |
10 | 11 | from sbi.inference.posteriors.vi_posterior import VIPosterior |
11 | 12 | from sbi.sbi_types import Shape |
| 13 | +from sbi.utils import handle_invalid_x |
12 | 14 |
|
13 | 15 |
|
14 | 16 | def get_posterior_samples_on_batch( |
@@ -91,3 +93,25 @@ def sample_fun( |
91 | 93 | posterior_samples.shape[:2] |
92 | 94 | }.""" |
93 | 95 | return posterior_samples |
| 96 | + |
| 97 | + |
| 98 | +def remove_nans_and_infs_in_x(thetas: Tensor, xs: Tensor) -> Tuple[Tensor, Tensor]: |
| 99 | + """Remove NaNs and Infs entries in x from both the theta and x. |
| 100 | +
|
| 101 | + Args: |
| 102 | + thetas: Tensor of shape (num_samples, dim_parameters). |
| 103 | + xs: Tensor of shape (num_samples, dim_observations). |
| 104 | +
|
| 105 | + Returns: |
| 106 | + Tuple of filtered thetas and xs, both of shape (num_valid_samples, ...). |
| 107 | + """ |
| 108 | + is_valid_x, num_nans, num_infs = handle_invalid_x(xs, exclude_invalid_x=True) |
| 109 | + if num_nans > 0 or num_infs > 0: |
| 110 | + warnings.warn( |
| 111 | + f"Found {num_nans} NaNs and {num_infs} Infs in the data. " |
| 112 | + f"These will be ignored below. Beware that only {is_valid_x.sum()} " |
| 113 | + f"/ {len(xs)} samples are left.", |
| 114 | + stacklevel=2, |
| 115 | + ) |
| 116 | + |
| 117 | + return thetas[is_valid_x], xs[is_valid_x] |
0 commit comments