Skip to content

Commit fc9e2b4

Browse files
authored
Merge pull request #30 from rohanbabbar04/numpyro
Add support for numpyro models in SBC
2 parents 9f8b73a + 2df4fae commit fc9e2b4

File tree

6 files changed

+176
-23
lines changed

6 files changed

+176
-23
lines changed

docs/examples/gallery/sbc.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,47 @@ plot_ecdf_pit(sbc.simulations)
105105
```
106106

107107
:::::
108+
109+
:::::{tab-item} Numpyro
110+
:sync: numpyro
111+
112+
We define a Numpyro Model, we use the centered eight schools model.
113+
114+
```{jupyter-execute}
115+
import numpyro
116+
import numpyro.distributions as dist
117+
from jax import random
118+
from numpyro.infer import NUTS
119+
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
120+
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
121+
def eight_schools_cauchy_prior(J, sigma, y=None):
122+
mu = numpyro.sample("mu", dist.Normal(0, 5))
123+
tau = numpyro.sample("tau", dist.HalfCauchy(5))
124+
with numpyro.plate("J", J):
125+
theta = numpyro.sample("theta", dist.Normal(mu, tau))
126+
numpyro.sample("y", dist.Normal(theta, sigma), obs=y)
127+
# We use the NUTS sampler
128+
nuts_kernel = NUTS(eight_schools_cauchy_prior)
129+
```
130+
131+
Pass the model to the `SBC` class, set the number of simulations to 100, and run the simulations. For numpyro model,
132+
we pass in the ``data_dir`` parameter.
133+
134+
```{jupyter-execute}
135+
sbc = simuk.SBC(nuts_kernel,
136+
sample_kwargs={"num_warmup": 50, "num_samples": 75},
137+
num_simulations=100,
138+
data_dir={"J": 8, "sigma": sigma, "y": y},
139+
)
140+
sbc.run_simulations()
141+
```
142+
143+
To compare the prior and posterior distributions, we will plot the results.
144+
We expect a uniform distribution, the gray envelope corresponds to the 94% credible interval.
145+
146+
```{jupyter-execute}
147+
plot_ecdf_pit(sbc.simulations,
148+
pc_kwargs={'col_wrap':4},
149+
plot_kwargs={"xlabel":False}
150+
)
151+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ classifiers = [
2424
dynamic = ["version"]
2525
description = "Simulation based calibration and generation of synthetic data."
2626
dependencies = [
27-
"pymc>=5.20",
27+
"arviz>=0.20.0",
2828
"arviz_base>=0.4.0",
2929
"tqdm"
3030
]

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ pymc>=5.20.1
88
bambi>=0.13.0
99
arviz_base>=0.4.0
1010
ruff==0.9.1
11+
numpyro>=0.17.0

requirements-docs.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ sphinx_tabs
99
sphinx-design
1010
numpydoc
1111
jupyter-sphinx
12+
numpyro>=0.17.0

simuk/sbc.py

Lines changed: 105 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,20 @@
44
from copy import copy
55
from importlib.metadata import version
66

7+
try:
8+
import pymc as pm
9+
except ImportError:
10+
pass
11+
try:
12+
import jax
13+
from numpyro.handlers import seed, trace
14+
from numpyro.infer import MCMC, Predictive
15+
from numpyro.infer.mcmc import MCMCKernel
16+
except ImportError:
17+
pass
18+
719
import numpy as np
8-
import pymc as pm
20+
from arviz import from_numpyro
921
from arviz_base import extract, from_dict
1022
from tqdm import tqdm
1123

@@ -35,8 +47,8 @@ class SBC:
3547
3648
Parameters
3749
----------
38-
model : function
39-
A PyMC or Bambi model. If a PyMC model the data needs to be defined as
50+
model : pymc.Model, bambi.Model or numpyro.infer.mcmc.MCMCKernel
51+
A PyMC, Bambi model or Numpyro MCMC kernel. If a PyMC model the data needs to be defined as
4052
mutable data.
4153
num_simulations : int
4254
How many simulations to run
@@ -45,6 +57,9 @@ class SBC:
4557
seed : int (optional)
4658
Random seed. This persists even if running the simulations is
4759
paused for whatever reason.
60+
data_dir : dict
61+
Keyword arguments passed to numpyro model, intended for use when providing
62+
an MCMC Kernel model.
4863
4964
Example
5065
-------
@@ -61,39 +76,63 @@ class SBC:
6176
6277
"""
6378

64-
def __init__(
65-
self,
66-
model,
67-
num_simulations=1000,
68-
sample_kwargs=None,
69-
seed=None,
70-
):
71-
if isinstance(model, pm.Model):
79+
def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, data_dir=None):
80+
if hasattr(model, "basic_RVs") and isinstance(model, pm.Model):
7281
self.engine = "pymc"
7382
self.model = model
74-
else:
83+
elif hasattr(model, "formula"):
7584
self.engine = "bambi"
7685
model.build()
7786
self.bambi_model = model
7887
self.model = model.backend.model
7988
self.formula = model.formula
8089
self.new_data = copy(model.data)
81-
82-
self.observed_vars = [obs_rvs.name for obs_rvs in self.model.observed_RVs]
90+
elif isinstance(model, MCMCKernel):
91+
self.engine = "numpyro"
92+
self.numpyro_model = model
93+
self.model = self.numpyro_model.model
94+
self.run_simulations = self._run_simulations_numpyro
95+
self.data_dir = data_dir
96+
else:
97+
raise ValueError(
98+
"model should be one of pymc.Model, bambi.Model, or numpyro.infer.mcmc.MCMCKernel"
99+
)
83100
self.num_simulations = num_simulations
84-
85-
self.var_names = [v.name for v in self.model.free_RVs]
86-
87101
if sample_kwargs is None:
88102
sample_kwargs = {}
89-
sample_kwargs.setdefault("progressbar", False)
90-
sample_kwargs.setdefault("compute_convergence_checks", False)
103+
if self.engine == "numpyro":
104+
sample_kwargs.setdefault("num_warmup", 1000)
105+
sample_kwargs.setdefault("num_samples", 1000)
106+
sample_kwargs.setdefault("progress_bar", False)
107+
else:
108+
sample_kwargs.setdefault("progressbar", False)
109+
sample_kwargs.setdefault("compute_convergence_checks", False)
91110
self.sample_kwargs = sample_kwargs
92-
93-
self.simulations = {name: [] for name in self.var_names}
94-
self._simulations_complete = 0
95111
self.seed = seed
96112
self._seeds = self._get_seeds()
113+
self._extract_variable_names()
114+
self.simulations = {name: [] for name in self.var_names}
115+
self._simulations_complete = 0
116+
117+
def _extract_variable_names(self):
118+
"""Extract observed and free variables from the model."""
119+
if self.engine == "numpyro":
120+
with trace() as tr:
121+
with seed(rng_seed=int(self._seeds[0])):
122+
self.numpyro_model.model(**self.data_dir)
123+
self.var_names = [
124+
name
125+
for name, site in tr.items()
126+
if site["type"] == "sample" and not site.get("is_observed", False)
127+
]
128+
self.observed_vars = [
129+
name
130+
for name, site in tr.items()
131+
if site["type"] == "sample" and site.get("is_observed", False)
132+
]
133+
else:
134+
self.observed_vars = [obs.name for obs in self.model.observed_RVs]
135+
self.var_names = [v.name for v in self.model.free_RVs]
97136

98137
def _get_seeds(self):
99138
"""Set the random seed, and generate seeds for all the simulations."""
@@ -110,6 +149,15 @@ def _get_prior_predictive_samples(self):
110149
prior = extract(idata, group="prior", keep_dataset=True)
111150
return prior, prior_pred
112151

152+
def _get_prior_predictive_samples_numpyro(self):
153+
"""Generate samples to use for the simulations using numpyro."""
154+
predictive = Predictive(self.model, num_samples=self.num_simulations)
155+
free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars}
156+
samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data)
157+
prior = {k: v for k, v in samples.items() if k not in self.observed_vars}
158+
prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars}
159+
return prior, prior_pred
160+
113161
def _get_posterior_samples(self, prior_predictive_draw):
114162
"""Generate posterior samples conditioned to a prior predictive sample."""
115163
new_model = pm.observe(self.model, prior_predictive_draw)
@@ -121,6 +169,14 @@ def _get_posterior_samples(self, prior_predictive_draw):
121169
posterior = extract(check, group="posterior", keep_dataset=True)
122170
return posterior
123171

172+
def _get_posterior_samples_numpyro(self, prior_predictive_draw):
173+
"""Generate posterior samples using numpyro conditioned to a prior predictive sample."""
174+
mcmc = MCMC(self.numpyro_model, **self.sample_kwargs)
175+
rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete])
176+
free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars}
177+
mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw)
178+
return from_numpyro(mcmc)["posterior"]
179+
124180
def _convert_to_datatree(self):
125181
self.simulations = from_dict(
126182
{"prior_sbc": self.simulations},
@@ -171,3 +227,30 @@ def run_simulations(self):
171227
}
172228
self._convert_to_datatree()
173229
progress.close()
230+
231+
@quiet_logging("numpyro")
232+
def _run_simulations_numpyro(self):
233+
"""Run all the simulations for Numpyro Model."""
234+
prior, prior_pred = self._get_prior_predictive_samples_numpyro()
235+
progress = tqdm(
236+
initial=self._simulations_complete,
237+
total=self.num_simulations,
238+
)
239+
try:
240+
while self._simulations_complete < self.num_simulations:
241+
idx = self._simulations_complete
242+
prior_predictive_draw = {k: v[idx] for k, v in prior_pred.items()}
243+
posterior = self._get_posterior_samples_numpyro(prior_predictive_draw)
244+
for name in self.var_names:
245+
self.simulations[name].append(
246+
(posterior[name].sel(chain=0) < prior[name][idx]).sum(axis=0).values
247+
)
248+
self._simulations_complete += 1
249+
progress.update()
250+
finally:
251+
self.simulations = {
252+
k: np.stack(v[: self._simulations_complete])[None, :]
253+
for k, v in self.simulations.items()
254+
}
255+
self._convert_to_datatree()
256+
progress.close()

simuk/tests/test_sbc.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import bambi as bmb
22
import numpy as np
3+
import numpyro
4+
import numpyro.distributions as dist
35
import pandas as pd
46
import pymc as pm
57
import pytest
8+
from numpyro.infer import NUTS
69

710
import simuk
811

@@ -32,3 +35,24 @@ def test_sbc(model):
3235
)
3336
sbc.run_simulations()
3437
assert "prior_sbc" in sbc.simulations
38+
39+
40+
def test_sbc_numpyro():
41+
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
42+
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
43+
44+
def eight_schools_cauchy_prior(J, sigma, y=None):
45+
mu = numpyro.sample("mu", dist.Normal(0, 5))
46+
tau = numpyro.sample("tau", dist.HalfCauchy(5))
47+
with numpyro.plate("J", J):
48+
theta = numpyro.sample("theta", dist.Normal(mu, tau))
49+
numpyro.sample("y", dist.Normal(theta, sigma), obs=y)
50+
51+
sbc = simuk.SBC(
52+
NUTS(eight_schools_cauchy_prior),
53+
data_dir={"J": 8, "sigma": sigma, "y": y},
54+
num_simulations=10,
55+
sample_kwargs={"num_warmup": 50, "num_samples": 25},
56+
)
57+
sbc.run_simulations()
58+
assert "prior_sbc" in sbc.simulations

0 commit comments

Comments
 (0)