Skip to content

Commit 9f8b73a

Browse files
authored
Merge pull request #32 from arviz-devs/ecdf_pit
Remove plots, use arviz instead
2 parents 9ddc557 + 2ed36af commit 9f8b73a

File tree

8 files changed

+57
-189
lines changed

8 files changed

+57
-189
lines changed

docs/examples/gallery/sbc.md

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ This example demonstrates how to use the `SBC` class for simulation-based calibr
1515

1616
```{jupyter-execute}
1717
18-
import matplotlib.pyplot as plt
18+
from arviz_plots import plot_ecdf_pit, style
1919
import numpy as np
2020
import simuk
21+
style.use("arviz-variat")
2122
```
2223

2324
::::::{tab-set}
@@ -42,24 +43,28 @@ with pm.Model() as centered_eight:
4243
y_obs = pm.Normal('y', mu=theta, sigma=sigma, observed=data)
4344
```
4445

45-
Pass the model to the SBC class, set the number of simulations to 100, and run the simulations. This process may take
46-
some time since the model runs multiple times.
46+
Pass the model to the SBC class, set the number of simulations to 100, and run the simulations. This process may take
47+
some time since the model runs multiple times (100 in this example).
4748

4849
```{jupyter-execute}
4950
5051
sbc = simuk.SBC(centered_eight,
5152
num_simulations=100,
5253
sample_kwargs={'draws': 25, 'tune': 50})
5354
54-
sbc.run_simulations()
55+
sbc.run_simulations();
5556
```
5657

57-
To compare the prior and posterior distributions, we will plot the results. You can adjust the type of visualization
58-
using the ``kind`` parameter. We use the empirical CDF to compare the differences between the prior and posterior.
58+
To compare the prior and posterior distributions, we will plot the results from the simulations,
59+
using the ArviZ function `plot_ecdf_pit`.
60+
We expect a uniform distribution, the gray envelope corresponds to the 94% credible interval.
5961

6062
```{jupyter-execute}
6163
62-
sbc.plot_results(kind="ecdf")
64+
plot_ecdf_pit(sbc.simulations,
65+
pc_kwargs={'col_wrap':4},
66+
plot_kwargs={"xlabel":False},
67+
)
6368
```
6469

6570
:::::
@@ -80,7 +85,7 @@ df = pd.DataFrame({"x": x, "y": y})
8085
bmb_model = bmb.Model("y ~ x", df)
8186
```
8287

83-
Pass the model to the `SBC` class, set the number of simulations to 100, and run the simulations.
88+
Pass the model to the `SBC` class, set the number of simulations to 100, and run the simulations.
8489
This process may take some time, as the model runs multiple times
8590

8691
```{jupyter-execute}
@@ -89,15 +94,14 @@ sbc = simuk.SBC(bmb_model,
8994
num_simulations=100,
9095
sample_kwargs={'draws': 25, 'tune': 50})
9196
92-
sbc.run_simulations()
97+
sbc.run_simulations();
9398
```
9499

95-
To compare the prior and posterior distributions, we will plot the results. You can customize the visualization type
96-
using the `kind` parameter. The example below displays a histogram.
100+
To compare the prior and posterior distributions, we will plot the results from the simulations.
101+
We expect a uniform distribution, the gray envelope corresponds to the 94% credible interval.
97102

98103
```{jupyter-execute}
99-
100-
sbc.plot_results(kind="hist")
104+
plot_ecdf_pit(sbc.simulations)
101105
```
102106

103107
:::::

docs/examples/img/sbc.png

68.8 KB
Loading

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dynamic = ["version"]
2525
description = "Simulation based calibration and generation of synthetic data."
2626
dependencies = [
2727
"pymc>=5.20",
28-
"arviz>=0.20.0",
28+
"arviz_base>=0.4.0",
2929
"tqdm"
3030
]
3131

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ pre-commit>=2.19
66
ipytest==0.13.0
77
pymc>=5.20.1
88
bambi>=0.13.0
9-
arviz>=0.20.0
9+
arviz_base>=0.4.0
1010
ruff==0.9.1

requirements-docs.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
pydata-sphinx-theme>=0.6.3
22
myst-nb
3-
pymc @ git+https://github.com/pymc-devs/pymc@main
3+
pymc>=5.20.1
44
bambi>=0.15.0
5+
arviz_plots @ git+https://github.com/arviz-devs/arviz-plots@main
56
sphinx>=4
67
sphinx-copybutton
78
sphinx_tabs

simuk/plots.py

Lines changed: 0 additions & 124 deletions
This file was deleted.

simuk/sbc.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
import logging
44
from copy import copy
5+
from importlib.metadata import version
56

6-
import arviz as az
77
import numpy as np
88
import pymc as pm
9+
from arviz_base import extract, from_dict
910
from tqdm import tqdm
1011

11-
from simuk.plots import plot_results
12-
1312

1413
class quiet_logging:
1514
"""Turn off logging for PyMC, Bambi and PyTensor."""
@@ -93,31 +92,48 @@ def __init__(
9392

9493
self.simulations = {name: [] for name in self.var_names}
9594
self._simulations_complete = 0
96-
self._seed = seed
95+
self.seed = seed
96+
self._seeds = self._get_seeds()
9797

9898
def _get_seeds(self):
9999
"""Set the random seed, and generate seeds for all the simulations."""
100-
if self._seed is not None:
101-
np.random.seed(self._seed)
102-
return np.random.randint(2**30, size=self.num_simulations)
100+
rng = np.random.default_rng(self.seed)
101+
return rng.integers(0, 2**30, size=self.num_simulations)
103102

104103
def _get_prior_predictive_samples(self):
105104
"""Generate samples to use for the simulations."""
106105
with self.model:
107-
idata = pm.sample_prior_predictive(samples=self.num_simulations)
108-
prior_pred = az.extract(idata, group="prior_predictive")
109-
prior = az.extract(idata, group="prior")
106+
idata = pm.sample_prior_predictive(
107+
samples=self.num_simulations, random_seed=self._seeds[0]
108+
)
109+
prior_pred = extract(idata, group="prior_predictive", keep_dataset=True)
110+
prior = extract(idata, group="prior", keep_dataset=True)
110111
return prior, prior_pred
111112

112113
def _get_posterior_samples(self, prior_predictive_draw):
113114
"""Generate posterior samples conditioned to a prior predictive sample."""
114115
new_model = pm.observe(self.model, prior_predictive_draw)
115116
with new_model:
116-
check = pm.sample(**self.sample_kwargs)
117+
check = pm.sample(
118+
**self.sample_kwargs, random_seed=self._seeds[self._simulations_complete]
119+
)
117120

118-
posterior = az.extract(check, group="posterior")
121+
posterior = extract(check, group="posterior", keep_dataset=True)
119122
return posterior
120123

124+
def _convert_to_datatree(self):
125+
self.simulations = from_dict(
126+
{"prior_sbc": self.simulations},
127+
attrs={
128+
"/": {
129+
"inferece_library": self.engine,
130+
"inferece_library_version": version(self.engine),
131+
"modeling_interface": "simuk",
132+
"modeling_interface_version": version("simuk"),
133+
}
134+
},
135+
)
136+
121137
@quiet_logging("pymc", "pytensor.gof.compilelock", "bambi")
122138
def run_simulations(self):
123139
"""Run all the simulations.
@@ -127,7 +143,6 @@ def run_simulations(self):
127143
seed was passed initially, it will still be respected (that is, the resulting
128144
simulations will be identical to running without pausing in the middle).
129145
"""
130-
seeds = self._get_seeds()
131146
prior, prior_pred = self._get_prior_predictive_samples()
132147

133148
progress = tqdm(
@@ -142,8 +157,6 @@ def run_simulations(self):
142157
for var_name in self.observed_vars
143158
}
144159

145-
np.random.seed(seeds[idx])
146-
147160
posterior = self._get_posterior_samples(prior_predictive_draw)
148161
for name in self.var_names:
149162
self.simulations[name].append(
@@ -153,34 +166,8 @@ def run_simulations(self):
153166
progress.update()
154167
finally:
155168
self.simulations = {
156-
k: v[: self._simulations_complete] for k, v in self.simulations.items()
169+
k: np.stack(v[: self._simulations_complete])[None, :]
170+
for k, v in self.simulations.items()
157171
}
172+
self._convert_to_datatree()
158173
progress.close()
159-
160-
def plot_results(self, kind="ecdf", var_names=None, color="C0"):
161-
"""Visual diagnostic for SBC.
162-
163-
Currently it support two options: `ecdf` for the empirical CDF plots
164-
of the difference between prior and posterior. `hist` for the rank
165-
histogram.
166-
167-
168-
Parameters
169-
----------
170-
simulations : dict[str] -> listlike
171-
The SBC.simulations dictionary.
172-
kind : str
173-
What kind of plot to make. Supported values are 'ecdf' (default) and 'hist'
174-
var_names : list[str]
175-
Variables to plot (defaults to all)
176-
figsize : tuple
177-
Figure size for the plot. If None, it will be defined automatically.
178-
color : str
179-
Color to use for the eCDF or histogram
180-
181-
Returns
182-
-------
183-
fig, axes
184-
matplotlib figure and axes
185-
"""
186-
return plot_results(self.simulations, kind=kind, var_names=var_names, color=color)

simuk/tests/test_sbc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@
1717
theta = pm.Normal("theta", mu=mu, sigma=tau, shape=8)
1818
y_obs = pm.Normal("y", mu=theta, sigma=sigma, observed=data)
1919

20-
x = np.random.normal(0, 1, 200)
20+
x = np.random.normal(0, 1, 20)
2121
y = 2 + np.random.normal(x, 1)
2222
df = pd.DataFrame({"x": x, "y": y})
2323
bmb_model = bmb.Model("y ~ x", df)
2424

2525

26-
@pytest.mark.parametrize("model, kind", [(centered_eight, "ecdf"), (bmb_model, "hist")])
27-
def test_sbc(model, kind):
26+
@pytest.mark.parametrize("model", [centered_eight, bmb_model])
27+
def test_sbc(model):
2828
sbc = simuk.SBC(
2929
model,
3030
num_simulations=10,
31-
sample_kwargs={"draws": 25, "tune": 50},
31+
sample_kwargs={"draws": 5, "tune": 5},
3232
)
3333
sbc.run_simulations()
34-
sbc.plot_results(kind=kind)
34+
assert "prior_sbc" in sbc.simulations

0 commit comments

Comments
 (0)