22
33import logging
44from copy import copy
5+ from importlib .metadata import version
56
6- import arviz as az
77import numpy as np
88import pymc as pm
9+ from arviz_base import extract , from_dict
910from tqdm import tqdm
1011
11- from simuk .plots import plot_results
12-
1312
1413class quiet_logging :
1514 """Turn off logging for PyMC, Bambi and PyTensor."""
@@ -93,31 +92,48 @@ def __init__(
9392
9493 self .simulations = {name : [] for name in self .var_names }
9594 self ._simulations_complete = 0
96- self ._seed = seed
95+ self .seed = seed
96+ self ._seeds = self ._get_seeds ()
9797
9898 def _get_seeds (self ):
9999 """Set the random seed, and generate seeds for all the simulations."""
100- if self ._seed is not None :
101- np .random .seed (self ._seed )
102- return np .random .randint (2 ** 30 , size = self .num_simulations )
100+ rng = np .random .default_rng (self .seed )
101+ return rng .integers (0 , 2 ** 30 , size = self .num_simulations )
103102
104103 def _get_prior_predictive_samples (self ):
105104 """Generate samples to use for the simulations."""
106105 with self .model :
107- idata = pm .sample_prior_predictive (samples = self .num_simulations )
108- prior_pred = az .extract (idata , group = "prior_predictive" )
109- prior = az .extract (idata , group = "prior" )
106+ idata = pm .sample_prior_predictive (
107+ samples = self .num_simulations , random_seed = self ._seeds [0 ]
108+ )
109+ prior_pred = extract (idata , group = "prior_predictive" , keep_dataset = True )
110+ prior = extract (idata , group = "prior" , keep_dataset = True )
110111 return prior , prior_pred
111112
112113 def _get_posterior_samples (self , prior_predictive_draw ):
113114 """Generate posterior samples conditioned to a prior predictive sample."""
114115 new_model = pm .observe (self .model , prior_predictive_draw )
115116 with new_model :
116- check = pm .sample (** self .sample_kwargs )
117+ check = pm .sample (
118+ ** self .sample_kwargs , random_seed = self ._seeds [self ._simulations_complete ]
119+ )
117120
118- posterior = az . extract (check , group = "posterior" )
121+ posterior = extract (check , group = "posterior" , keep_dataset = True )
119122 return posterior
120123
124+ def _convert_to_datatree (self ):
125+ self .simulations = from_dict (
126+ {"prior_sbc" : self .simulations },
127+ attrs = {
128+ "/" : {
129+ "inferece_library" : self .engine ,
130+ "inferece_library_version" : version (self .engine ),
131+ "modeling_interface" : "simuk" ,
132+ "modeling_interface_version" : version ("simuk" ),
133+ }
134+ },
135+ )
136+
121137 @quiet_logging ("pymc" , "pytensor.gof.compilelock" , "bambi" )
122138 def run_simulations (self ):
123139 """Run all the simulations.
@@ -127,7 +143,6 @@ def run_simulations(self):
127143 seed was passed initially, it will still be respected (that is, the resulting
128144 simulations will be identical to running without pausing in the middle).
129145 """
130- seeds = self ._get_seeds ()
131146 prior , prior_pred = self ._get_prior_predictive_samples ()
132147
133148 progress = tqdm (
@@ -142,8 +157,6 @@ def run_simulations(self):
142157 for var_name in self .observed_vars
143158 }
144159
145- np .random .seed (seeds [idx ])
146-
147160 posterior = self ._get_posterior_samples (prior_predictive_draw )
148161 for name in self .var_names :
149162 self .simulations [name ].append (
@@ -153,34 +166,8 @@ def run_simulations(self):
153166 progress .update ()
154167 finally :
155168 self .simulations = {
156- k : v [: self ._simulations_complete ] for k , v in self .simulations .items ()
169+ k : np .stack (v [: self ._simulations_complete ])[None , :]
170+ for k , v in self .simulations .items ()
157171 }
172+ self ._convert_to_datatree ()
158173 progress .close ()
159-
160- def plot_results (self , kind = "ecdf" , var_names = None , color = "C0" ):
161- """Visual diagnostic for SBC.
162-
163- Currently it support two options: `ecdf` for the empirical CDF plots
164- of the difference between prior and posterior. `hist` for the rank
165- histogram.
166-
167-
168- Parameters
169- ----------
170- simulations : dict[str] -> listlike
171- The SBC.simulations dictionary.
172- kind : str
173- What kind of plot to make. Supported values are 'ecdf' (default) and 'hist'
174- var_names : list[str]
175- Variables to plot (defaults to all)
176- figsize : tuple
177- Figure size for the plot. If None, it will be defined automatically.
178- color : str
179- Color to use for the eCDF or histogram
180-
181- Returns
182- -------
183- fig, axes
184- matplotlib figure and axes
185- """
186- return plot_results (self .simulations , kind = kind , var_names = var_names , color = color )
0 commit comments