Skip to content

Commit dd279d0

Browse files
committed
remove arviz dependency, move wrapping how-to-guide
- remove wrapping logic needed for arviz - remove lengthy tutorial on using arviz plots - instead, add a how-to-guide for how to wrap posterior samples into arviz object.
1 parent 649f5d3 commit dd279d0

File tree

9 files changed

+149
-541
lines changed

9 files changed

+149
-541
lines changed

docs/advanced_tutorials.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ Diagnostics
6565
advanced_tutorials/11_diagnostics_simulation_based_calibration.ipynb
6666
advanced_tutorials/13_diagnostics_lc2st.ipynb
6767
advanced_tutorials/21_diagnostics_misspecification_checks.ipynb
68-
advanced_tutorials/14_mcmc_diagnostics_with_arviz.ipynb
6968

7069

7170
Visualization

docs/advanced_tutorials/14_mcmc_diagnostics_with_arviz.ipynb

Lines changed: 0 additions & 399 deletions
This file was deleted.

docs/how_to_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Sampling
5757
how_to_guide/09_sampler_interface.ipynb
5858
how_to_guide/10_refine_posterior_with_importance_sampling.ipynb
5959
how_to_guide/11_iid_sampling_with_nle_or_nre.ipynb
60+
how_to_guide/12_mcmc_diagnostics_with_arviz.ipynb
6061

6162

6263
Diagnostics

docs/how_to_guide/12_mcmc_diagnostics_with_arviz.ipynb

Lines changed: 146 additions & 0 deletions
Large diffs are not rendered by default.

mkdocs/docs/tutorials/index.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ inference.
3737
- [Posterior predictive checks](10_diagnostics_posterior_predictive_checks.md)
3838
- [Simulation-based calibration](11_diagnostics_simulation_based_calibration.md)
3939
- [Local-C2ST coverage checks](13_diagnostics_lc2st.md)
40-
- [Density plots and MCMC diagnostics with ArviZ](14_mcmc_diagnostics_with_arviz.md)
4140
</div>
4241

4342
## Analysis

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ dynamic = ["version"]
2929
readme = "README.md"
3030
keywords = ["Bayesian inference", "simulation-based inference", "PyTorch"]
3131
dependencies = [
32-
"arviz",
3332
"joblib>=1.0.0",
3433
"matplotlib",
3534
"numpy",

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
from typing import Any, Callable, Dict, Literal, Optional, Union
99
from warnings import warn
1010

11-
import arviz as az
1211
import torch
1312
import torch.distributions.transforms as torch_tf
14-
from arviz.data import InferenceData
1513
from joblib import Parallel, delayed
1614
from numpy import ndarray
1715
from pyro.infer.mcmc import HMC, NUTS
@@ -1040,64 +1038,6 @@ def map(
10401038
force_update=force_update,
10411039
)
10421040

1043-
def get_arviz_inference_data(self) -> InferenceData:
1044-
"""Returns arviz InferenceData object constructed most recent samples.
1045-
1046-
Note: the InferenceData is constructed using the posterior samples generated in
1047-
most recent call to `.sample(...)`.
1048-
1049-
For Pyro and PyMC samplers, InferenceData will contain diagnostics, but for
1050-
sbi slice samplers, only the samples are added.
1051-
1052-
Returns:
1053-
inference_data: Arviz InferenceData object.
1054-
"""
1055-
assert self._posterior_sampler is not None, (
1056-
"""No samples have been generated, call .sample() first."""
1057-
)
1058-
1059-
sampler: Union[
1060-
MCMC, SliceSamplerSerial, SliceSamplerVectorized, PyMCSampler
1061-
] = self._posterior_sampler
1062-
1063-
# If Pyro sampler and samples not transformed, use arviz' from_pyro.
1064-
if isinstance(sampler, (HMC, NUTS)) and isinstance(
1065-
self.theta_transform, torch_tf.IndependentTransform
1066-
):
1067-
inference_data = az.from_pyro(sampler)
1068-
# If PyMC sampler and samples not transformed, get cached InferenceData.
1069-
elif isinstance(sampler, PyMCSampler) and isinstance(
1070-
self.theta_transform, torch_tf.IndependentTransform
1071-
):
1072-
inference_data = sampler.get_inference_data()
1073-
1074-
# otherwise get samples from sampler and transform to original space.
1075-
else:
1076-
transformed_samples = sampler.get_samples(group_by_chain=True)
1077-
# Pyro samplers returns dicts, get values.
1078-
if isinstance(transformed_samples, Dict):
1079-
# popitem gets last items, [1] get the values as tensor.
1080-
transformed_samples = transformed_samples.popitem()[1]
1081-
# Our slice samplers return numpy arrays.
1082-
elif isinstance(transformed_samples, ndarray):
1083-
transformed_samples = torch.from_numpy(transformed_samples).type(
1084-
torch.float32
1085-
)
1086-
# For MultipleIndependent priors transforms first dim must be batch dim.
1087-
# thus, reshape back and forth to have batch dim in front.
1088-
samples_shape = transformed_samples.shape
1089-
samples = self.theta_transform.inv( # type: ignore
1090-
transformed_samples.reshape(-1, samples_shape[-1])
1091-
).reshape( # type: ignore
1092-
*samples_shape
1093-
)
1094-
1095-
inference_data = az.convert_to_inference_data({
1096-
f"{self.param_name}": samples
1097-
})
1098-
1099-
return inference_data
1100-
11011041
def __getstate__(self) -> Dict:
11021042
"""Get state of MCMCPosterior.
11031043
@@ -1124,12 +1064,6 @@ def _process_thin_default(thin: int) -> int:
11241064
"""
11251065
if thin == -1:
11261066
thin = 1
1127-
warn(
1128-
"The default value for thinning in MCMC sampling has been changed from "
1129-
"10 to 1. This might cause the results differ from the last benchmark.",
1130-
UserWarning,
1131-
stacklevel=2,
1132-
)
11331067

11341068
return thin
11351069

sbi/samplers/mcmc/pymc_wrapper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pymc
88
import pytensor.tensor as pt
99
import torch
10-
from arviz.data import InferenceData
1110

1211
from sbi.utils.torchutils import tensor2numpy
1312

@@ -206,7 +205,7 @@ def get_samples(
206205
else:
207206
return samples[-num_samples:, :]
208207

209-
def get_inference_data(self) -> InferenceData:
208+
def get_inference_data(self) -> Any:
210209
"""Returns InferenceData from last call to self.run,
211210
which contains diagnostic information in addition to samples
212211

tests/mcmc_test.py

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,22 @@
33

44
from __future__ import annotations
55

6-
from dataclasses import asdict
7-
86
import numpy as np
97
import pytest
108
import torch
119
from torch import eye, ones, zeros
12-
from torch.distributions import Uniform
1310

14-
from sbi.inference import (
15-
NLE,
16-
MCMCPosterior,
17-
likelihood_estimator_based_potential,
18-
)
1911
from sbi.inference.posteriors.mcmc_posterior import build_from_potential
2012
from sbi.inference.posteriors.posterior_parameters import MCMCPosteriorParameters
21-
from sbi.neural_nets import likelihood_nn
2213
from sbi.samplers.mcmc.pymc_wrapper import PyMCSampler
2314
from sbi.samplers.mcmc.slice_numpy import (
2415
SliceSampler,
2516
SliceSamplerSerial,
2617
SliceSamplerVectorized,
2718
)
28-
from sbi.simulators.linear_gaussian import (
29-
diagonal_linear_gaussian,
30-
true_posterior_linear_gaussian_mvn_prior,
31-
)
19+
from sbi.simulators.linear_gaussian import true_posterior_linear_gaussian_mvn_prior
3220
from sbi.utils import BoxUniform
3321
from sbi.utils.metrics import check_c2st
34-
from sbi.utils.user_input_checks import process_prior
3522

3623

3724
@pytest.mark.mcmc
@@ -198,63 +185,6 @@ def lp_f(x, track_gradients=True):
198185
"slice_np_vectorized",
199186
),
200187
)
201-
def test_getting_inference_diagnostics(
202-
method, mcmc_params_fast: MCMCPosteriorParameters
203-
):
204-
num_simulations = 100
205-
num_samples = 10
206-
num_dim = 2
207-
208-
# Use composed prior to test MultipleIndependent case.
209-
prior = [
210-
Uniform(low=-ones(1), high=ones(1)),
211-
Uniform(low=-ones(1), high=ones(1)),
212-
]
213-
214-
simulator = diagonal_linear_gaussian
215-
density_estimator = likelihood_nn("maf", num_transforms=3)
216-
inference = NLE(density_estimator=density_estimator, show_progress_bars=False)
217-
prior, *_ = process_prior(prior)
218-
theta = prior.sample((num_simulations,))
219-
x = simulator(theta)
220-
likelihood_estimator = inference.append_simulations(theta, x).train(
221-
training_batch_size=num_simulations, max_num_epochs=2
222-
)
223-
224-
x_o = zeros((1, num_dim))
225-
potential_fn, theta_transform = likelihood_estimator_based_potential(
226-
prior=prior, likelihood_estimator=likelihood_estimator, x_o=x_o
227-
)
228-
posterior = MCMCPosterior(
229-
proposal=prior,
230-
potential_fn=potential_fn,
231-
theta_transform=theta_transform,
232-
**asdict(mcmc_params_fast),
233-
)
234-
posterior.sample(
235-
sample_shape=(num_samples,),
236-
method=method,
237-
)
238-
idata = posterior.get_arviz_inference_data()
239-
240-
assert hasattr(idata, "posterior"), (
241-
f"`MCMCPosterior.get_arviz_inference_data()` for method {method} "
242-
f"returned invalid InferenceData. Must contain key 'posterior', "
243-
f"but found only {list(idata.keys())}"
244-
)
245-
samples = getattr(idata.posterior, posterior.param_name).data
246-
samples = samples.reshape(-1, samples.shape[-1])[:: mcmc_params_fast.thin][
247-
:num_samples
248-
]
249-
assert samples.shape == (
250-
num_samples,
251-
num_dim,
252-
), (
253-
f"MCMC samples for method {method} have incorrect shape (n_samples, n_dims). "
254-
f"Expected {(num_samples, num_dim)}, got {samples.shape}"
255-
)
256-
257-
258188
@pytest.mark.mcmc
259189
def test_direct_mcmc_unconditional():
260190
"Test MCMCPosterior from user defined potential (unconditional)"

0 commit comments

Comments
 (0)