44from copy import copy
55from importlib .metadata import version
66
7+ try :
8+ import pymc as pm
9+ except ImportError :
10+ pass
11+ try :
12+ import jax
13+ from numpyro .handlers import seed , trace
14+ from numpyro .infer import MCMC , Predictive
15+ from numpyro .infer .mcmc import MCMCKernel
16+ except ImportError :
17+ pass
18+
719import numpy as np
8- import pymc as pm
20+ from arviz import from_numpyro
921from arviz_base import extract , from_dict
1022from tqdm import tqdm
1123
@@ -35,8 +47,8 @@ class SBC:
3547
3648 Parameters
3749 ----------
38- model : function
39- A PyMC or Bambi model. If a PyMC model the data needs to be defined as
50+ model : pymc.Model, bambi.Model or numpyro.infer.mcmc.MCMCKernel
51+ A PyMC, Bambi model or Numpyro MCMC kernel . If a PyMC model the data needs to be defined as
4052 mutable data.
4153 num_simulations : int
4254 How many simulations to run
@@ -45,6 +57,9 @@ class SBC:
4557 seed : int (optional)
4658 Random seed. This persists even if running the simulations is
4759 paused for whatever reason.
60+ data_dir : dict
61+ Keyword arguments passed to numpyro model, intended for use when providing
62+ an MCMC Kernel model.
4863
4964 Example
5065 -------
@@ -61,39 +76,63 @@ class SBC:
6176
6277 """
6378
64- def __init__ (
65- self ,
66- model ,
67- num_simulations = 1000 ,
68- sample_kwargs = None ,
69- seed = None ,
70- ):
71- if isinstance (model , pm .Model ):
79+ def __init__ (self , model , num_simulations = 1000 , sample_kwargs = None , seed = None , data_dir = None ):
80+ if hasattr (model , "basic_RVs" ) and isinstance (model , pm .Model ):
7281 self .engine = "pymc"
7382 self .model = model
74- else :
83+ elif hasattr ( model , "formula" ) :
7584 self .engine = "bambi"
7685 model .build ()
7786 self .bambi_model = model
7887 self .model = model .backend .model
7988 self .formula = model .formula
8089 self .new_data = copy (model .data )
81-
82- self .observed_vars = [obs_rvs .name for obs_rvs in self .model .observed_RVs ]
90+ elif isinstance (model , MCMCKernel ):
91+ self .engine = "numpyro"
92+ self .numpyro_model = model
93+ self .model = self .numpyro_model .model
94+ self .run_simulations = self ._run_simulations_numpyro
95+ self .data_dir = data_dir
96+ else :
97+ raise ValueError (
98+ "model should be one of pymc.Model, bambi.Model, or numpyro.infer.mcmc.MCMCKernel"
99+ )
83100 self .num_simulations = num_simulations
84-
85- self .var_names = [v .name for v in self .model .free_RVs ]
86-
87101 if sample_kwargs is None :
88102 sample_kwargs = {}
89- sample_kwargs .setdefault ("progressbar" , False )
90- sample_kwargs .setdefault ("compute_convergence_checks" , False )
103+ if self .engine == "numpyro" :
104+ sample_kwargs .setdefault ("num_warmup" , 1000 )
105+ sample_kwargs .setdefault ("num_samples" , 1000 )
106+ sample_kwargs .setdefault ("progress_bar" , False )
107+ else :
108+ sample_kwargs .setdefault ("progressbar" , False )
109+ sample_kwargs .setdefault ("compute_convergence_checks" , False )
91110 self .sample_kwargs = sample_kwargs
92-
93- self .simulations = {name : [] for name in self .var_names }
94- self ._simulations_complete = 0
95111 self .seed = seed
96112 self ._seeds = self ._get_seeds ()
113+ self ._extract_variable_names ()
114+ self .simulations = {name : [] for name in self .var_names }
115+ self ._simulations_complete = 0
116+
117+ def _extract_variable_names (self ):
118+ """Extract observed and free variables from the model."""
119+ if self .engine == "numpyro" :
120+ with trace () as tr :
121+ with seed (rng_seed = int (self ._seeds [0 ])):
122+ self .numpyro_model .model (** self .data_dir )
123+ self .var_names = [
124+ name
125+ for name , site in tr .items ()
126+ if site ["type" ] == "sample" and not site .get ("is_observed" , False )
127+ ]
128+ self .observed_vars = [
129+ name
130+ for name , site in tr .items ()
131+ if site ["type" ] == "sample" and site .get ("is_observed" , False )
132+ ]
133+ else :
134+ self .observed_vars = [obs .name for obs in self .model .observed_RVs ]
135+ self .var_names = [v .name for v in self .model .free_RVs ]
97136
98137 def _get_seeds (self ):
99138 """Set the random seed, and generate seeds for all the simulations."""
@@ -110,6 +149,15 @@ def _get_prior_predictive_samples(self):
110149 prior = extract (idata , group = "prior" , keep_dataset = True )
111150 return prior , prior_pred
112151
152+ def _get_prior_predictive_samples_numpyro (self ):
153+ """Generate samples to use for the simulations using numpyro."""
154+ predictive = Predictive (self .model , num_samples = self .num_simulations )
155+ free_vars_data = {k : v for k , v in self .data_dir .items () if k not in self .observed_vars }
156+ samples = predictive (jax .random .PRNGKey (self ._seeds [0 ]), ** free_vars_data )
157+ prior = {k : v for k , v in samples .items () if k not in self .observed_vars }
158+ prior_pred = {k : v for k , v in samples .items () if k in self .observed_vars }
159+ return prior , prior_pred
160+
113161 def _get_posterior_samples (self , prior_predictive_draw ):
114162 """Generate posterior samples conditioned to a prior predictive sample."""
115163 new_model = pm .observe (self .model , prior_predictive_draw )
@@ -121,6 +169,14 @@ def _get_posterior_samples(self, prior_predictive_draw):
121169 posterior = extract (check , group = "posterior" , keep_dataset = True )
122170 return posterior
123171
172+ def _get_posterior_samples_numpyro (self , prior_predictive_draw ):
173+ """Generate posterior samples using numpyro conditioned to a prior predictive sample."""
174+ mcmc = MCMC (self .numpyro_model , ** self .sample_kwargs )
175+ rng_seed = jax .random .PRNGKey (self ._seeds [self ._simulations_complete ])
176+ free_vars_data = {k : v for k , v in self .data_dir .items () if k not in self .observed_vars }
177+ mcmc .run (rng_seed , ** free_vars_data , ** prior_predictive_draw )
178+ return from_numpyro (mcmc )["posterior" ]
179+
124180 def _convert_to_datatree (self ):
125181 self .simulations = from_dict (
126182 {"prior_sbc" : self .simulations },
@@ -171,3 +227,30 @@ def run_simulations(self):
171227 }
172228 self ._convert_to_datatree ()
173229 progress .close ()
230+
231+ @quiet_logging ("numpyro" )
232+ def _run_simulations_numpyro (self ):
233+ """Run all the simulations for Numpyro Model."""
234+ prior , prior_pred = self ._get_prior_predictive_samples_numpyro ()
235+ progress = tqdm (
236+ initial = self ._simulations_complete ,
237+ total = self .num_simulations ,
238+ )
239+ try :
240+ while self ._simulations_complete < self .num_simulations :
241+ idx = self ._simulations_complete
242+ prior_predictive_draw = {k : v [idx ] for k , v in prior_pred .items ()}
243+ posterior = self ._get_posterior_samples_numpyro (prior_predictive_draw )
244+ for name in self .var_names :
245+ self .simulations [name ].append (
246+ (posterior [name ].sel (chain = 0 ) < prior [name ][idx ]).sum (axis = 0 ).values
247+ )
248+ self ._simulations_complete += 1
249+ progress .update ()
250+ finally :
251+ self .simulations = {
252+ k : np .stack (v [: self ._simulations_complete ])[None , :]
253+ for k , v in self .simulations .items ()
254+ }
255+ self ._convert_to_datatree ()
256+ progress .close ()
0 commit comments