|
| 1 | +import logging |
| 2 | +from collections.abc import Callable |
| 3 | + |
| 4 | +import arviz as az |
| 5 | +import numpy as np |
| 6 | +from getdist import MCSamples |
| 7 | +from pyro.infer import HMC, MCMC, NUTS, Predictive |
| 8 | +from pyro.infer.mcmc import RandomWalkKernel |
| 9 | + |
| 10 | +from autoemulate.core.types import TensorLike |
| 11 | + |
| 12 | + |
| 13 | +class BayesianMixin: |
| 14 | + """Mixin class for Bayesian calibration methods.""" |
| 15 | + |
| 16 | + logger: logging.Logger |
| 17 | + model: Callable |
| 18 | + observations: dict[str, TensorLike] | None |
| 19 | + |
| 20 | + def _get_kernel( |
| 21 | + self, |
| 22 | + sampler: str, |
| 23 | + model_kwargs: dict[str, TensorLike] | None = None, |
| 24 | + **sampler_kwargs, |
| 25 | + ): |
| 26 | + """Get the appropriate MCMC kernel based on sampler choice.""" |
| 27 | + # TODO: consider how to pass model args, functools.partial? |
| 28 | + model_kwargs = model_kwargs or {} |
| 29 | + sampler = sampler.lower() |
| 30 | + if sampler == "nuts": |
| 31 | + self.logger.debug("Using NUTS kernel.") |
| 32 | + return NUTS(self.model, **sampler_kwargs) |
| 33 | + if sampler == "hmc": |
| 34 | + step_size = sampler_kwargs.pop("step_size", 0.01) |
| 35 | + trajectory_length = sampler_kwargs.pop("trajectory_length", 1.0) |
| 36 | + self.logger.debug( |
| 37 | + "Using HMC kernel with step_size=%s, trajectory_length=%s", |
| 38 | + step_size, |
| 39 | + trajectory_length, |
| 40 | + ) |
| 41 | + return HMC( |
| 42 | + self.model, |
| 43 | + step_size=step_size, |
| 44 | + trajectory_length=trajectory_length, |
| 45 | + **sampler_kwargs, |
| 46 | + ) |
| 47 | + if sampler == "metropolis": |
| 48 | + self.logger.debug("Using Metropolis (RandomWalkKernel).") |
| 49 | + return RandomWalkKernel(self.model, **sampler_kwargs) |
| 50 | + self.logger.error("Unknown sampler: %s", sampler) |
| 51 | + raise ValueError(f"Unknown sampler: {sampler}") |
| 52 | + |
| 53 | + def run_mcmc( |
| 54 | + self, |
| 55 | + warmup_steps: int = 500, |
| 56 | + num_samples: int = 1000, |
| 57 | + num_chains: int = 1, |
| 58 | + initial_params: dict[str, TensorLike] | None = None, |
| 59 | + model_kwargs: dict | None = None, |
| 60 | + sampler: str = "nuts", |
| 61 | + **sampler_kwargs, |
| 62 | + ) -> MCMC: |
| 63 | + """ |
| 64 | + Run Markov Chain Monte Carlo (MCMC). Defaults to using the NUTS sampler. |
| 65 | +
|
| 66 | + Parameters |
| 67 | + ---------- |
| 68 | + warmup_steps: int |
| 69 | + Number of warm up steps to run per chain (i.e., burn-in). These samples are |
| 70 | + discarded. Defaults to 500. |
| 71 | + num_samples: int |
| 72 | + Number of samples to draw after warm up. Defaults to 1000. |
| 73 | + num_chains: int |
| 74 | + Number of parallel chains to run. Defaults to 1. |
| 75 | + initial_params: dict[str, TensorLike] | None |
| 76 | + Optional dictionary specifiying initial values for each calibration |
| 77 | + parameter per chain. The tensors must be of length `num_chains`. |
| 78 | + model_kwargs: dict | None |
| 79 | + Optional dictionary of keyword arguments to pass to the model. |
| 80 | + sampler: str |
| 81 | + The MCMC kernel to use, one of "hmc", "nuts" or "metropolis". |
| 82 | + **sampler_kwargs |
| 83 | + Additional keyword arguments to pass to the MCMC kernel. |
| 84 | +
|
| 85 | + Returns |
| 86 | + ------- |
| 87 | + MCMC |
| 88 | + The Pyro MCMC object. Methods include `summary()` and `get_samples()`. |
| 89 | + """ |
| 90 | + # Check initial param values match number of chains |
| 91 | + |
| 92 | + if initial_params is not None: |
| 93 | + for param, init_vals in initial_params.items(): |
| 94 | + if init_vals.shape[0] != num_chains: |
| 95 | + msg = ( |
| 96 | + "An initial value must be provided for each chain, parameter " |
| 97 | + f"{param} tensor only has {init_vals.shape[0]} values." |
| 98 | + ) |
| 99 | + self.logger.error(msg) |
| 100 | + raise ValueError(msg) |
| 101 | + self.logger.debug( |
| 102 | + "Initial parameters provided for MCMC: %s", initial_params |
| 103 | + ) |
| 104 | + |
| 105 | + # Run NUTS |
| 106 | + kernel = self._get_kernel(sampler, model_kwargs=model_kwargs, **sampler_kwargs) |
| 107 | + mcmc = MCMC( |
| 108 | + kernel, |
| 109 | + warmup_steps=warmup_steps, |
| 110 | + num_samples=num_samples, |
| 111 | + num_chains=num_chains, |
| 112 | + # If None, init values are sampled from the prior. |
| 113 | + initial_params=initial_params, |
| 114 | + # Multiprocessing |
| 115 | + mp_context="spawn" if num_chains > 1 else None, |
| 116 | + ) |
| 117 | + self.logger.info("Starting MCMC run.") |
| 118 | + mcmc.run() |
| 119 | + self.logger.info("MCMC run completed.") |
| 120 | + return mcmc |
| 121 | + |
| 122 | + def posterior_predictive(self, mcmc: MCMC) -> dict[str, TensorLike]: |
| 123 | + """ |
| 124 | + Return posterior predictive samples. |
| 125 | +
|
| 126 | + Parameters |
| 127 | + ---------- |
| 128 | + mcmc: MCMC |
| 129 | + The MCMC object. |
| 130 | +
|
| 131 | + Returns |
| 132 | + ------- |
| 133 | + TensorLike |
| 134 | + Tensor of posterior predictive samples [n_mcmc_samples, n_obs, n_outputs]. |
| 135 | + """ |
| 136 | + posterior_samples = mcmc.get_samples() |
| 137 | + posterior_predictive = Predictive(self.model, posterior_samples) |
| 138 | + samples = posterior_predictive(predict=True) |
| 139 | + self.logger.debug("Posterior predictive samples generated.") |
| 140 | + return samples |
| 141 | + |
| 142 | + def to_arviz( |
| 143 | + self, mcmc: MCMC, posterior_predictive: bool = False |
| 144 | + ) -> az.InferenceData: |
| 145 | + """ |
| 146 | + Convert MCMC object to Arviz InferenceData object for plotting. |
| 147 | +
|
| 148 | + Parameters |
| 149 | + ---------- |
| 150 | + mcmc: MCMC |
| 151 | + The MCMC object. |
| 152 | + posterior_predictive: bool |
| 153 | + Whether to include posterior predictive samples. Defaults to False. |
| 154 | +
|
| 155 | + Returns |
| 156 | + ------- |
| 157 | + az.InferenceData |
| 158 | + """ |
| 159 | + pp_samples = None |
| 160 | + if posterior_predictive: |
| 161 | + self.logger.info("Including posterior predictive samples in Arviz output.") |
| 162 | + pp_samples = self.posterior_predictive(mcmc) |
| 163 | + |
| 164 | + # Need to create dataset manually for Metropolis Hastings |
| 165 | + # This is because az.from_pyro expects kernel with `divergences` |
| 166 | + if isinstance(mcmc.kernel, RandomWalkKernel): |
| 167 | + self.logger.debug( |
| 168 | + "Using manual conversion for Metropolis (RandomWalkKernel) kernel." |
| 169 | + ) |
| 170 | + if posterior_predictive: |
| 171 | + if self.observations is None: |
| 172 | + msg = ( |
| 173 | + "Observations must be provided to include observed_data in " |
| 174 | + "Arviz InferenceData." |
| 175 | + ) |
| 176 | + self.logger.error(msg) |
| 177 | + raise ValueError(msg) |
| 178 | + az_data = az.InferenceData( |
| 179 | + posterior=az.convert_to_dataset( |
| 180 | + mcmc.get_samples(group_by_chain=True) |
| 181 | + ), |
| 182 | + posterior_predictive=az.convert_to_dataset(pp_samples), |
| 183 | + observed_data=az.convert_to_dataset(self.observations), |
| 184 | + ) |
| 185 | + else: |
| 186 | + az_data = az.InferenceData( |
| 187 | + posterior=az.convert_to_dataset( |
| 188 | + mcmc.get_samples(group_by_chain=True) |
| 189 | + ), |
| 190 | + ) |
| 191 | + else: |
| 192 | + self.logger.debug("Using az.from_pyro for conversion.") |
| 193 | + az_data = az.from_pyro(mcmc, posterior_predictive=pp_samples) |
| 194 | + |
| 195 | + self.logger.info("Arviz InferenceData conversion complete.") |
| 196 | + return az_data |
| 197 | + |
| 198 | + @staticmethod |
| 199 | + def to_getdist( |
| 200 | + data: MCMC | az.InferenceData, |
| 201 | + label: str, |
| 202 | + use_weights: bool = True, |
| 203 | + weight_name: str = "weight", |
| 204 | + ) -> MCSamples: |
| 205 | + """Convert Pyro MCMC or ArviZ InferenceData to GetDist MCSamples. |
| 206 | +
|
| 207 | + This lightweight helper extends the original implementation to also accept |
| 208 | + SMC / other results already converted to ArviZ InferenceData. If a weight |
| 209 | + variable (default: smc_weight) is present in sample_stats it will be |
| 210 | + used as importance weights. |
| 211 | +
|
| 212 | + Parameters |
| 213 | + ---------- |
| 214 | + data: MCMC | az.InferenceData |
| 215 | + The Pyro MCMC object or an ArviZ InferenceData object containing posterior |
| 216 | + samples. |
| 217 | + label: str |
| 218 | + Label for the MCSamples object. |
| 219 | + use_weights: bool |
| 220 | + If True and `data` is an `InferenceData` with `weight_name` in |
| 221 | + `sample_stats` then those weights are applied. Defaults to True. |
| 222 | + weight_name: str |
| 223 | + Name of the weight variable inside `sample_stats` to look up. |
| 224 | +
|
| 225 | + Returns |
| 226 | + ------- |
| 227 | + MCSamples |
| 228 | + The GetDist MCSamples object. |
| 229 | + """ |
| 230 | + if isinstance(data, MCMC): |
| 231 | + samples_dict = data.get_samples() |
| 232 | + arr = np.array(list(samples_dict.values())).T |
| 233 | + names = list(samples_dict.keys()) |
| 234 | + weights = None |
| 235 | + else: |
| 236 | + posterior = data.posterior # type: ignore[attr-defined] |
| 237 | + names = list(posterior.data_vars) |
| 238 | + cols = [] |
| 239 | + for name in names: |
| 240 | + vals = np.asarray(posterior[name].values) |
| 241 | + # Expect shape (chain, draw) for scalar parameters |
| 242 | + if vals.ndim != 2: |
| 243 | + msg = ( |
| 244 | + f"Posterior variable '{name}' has shape {vals.shape}; " |
| 245 | + "only scalar parameter sites (chain, draw) supported here." |
| 246 | + ) |
| 247 | + raise ValueError(msg) |
| 248 | + cols.append(vals.reshape(-1)) |
| 249 | + arr = np.vstack(cols).T # (n_total_draws, n_params) |
| 250 | + weights = None |
| 251 | + sample_stats = getattr(data, "sample_stats", None) # type: ignore[attr-defined] |
| 252 | + if use_weights and sample_stats is not None and weight_name in sample_stats: |
| 253 | + w = np.asarray(sample_stats[weight_name].values) |
| 254 | + if w.ndim == 2: # (chain, draw) |
| 255 | + weights = w.reshape(-1) |
| 256 | + return MCSamples(samples=arr, names=names, label=label, weights=weights) |
0 commit comments