Skip to content

Commit cec066d

Browse files
committed
Merge remote-tracking branch 'origin/main' into joss-v1.0.0
2 parents 7d1b1ec + 24f32d7 commit cec066d

File tree

84 files changed

+9418
-2988
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+9418
-2988
lines changed

.all-contributorsrc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,15 @@
256256
"contributions": [
257257
"bug"
258258
]
259+
},
260+
{
261+
"login": "ritkaarsingh30",
262+
"name": "Ritkaar Singh",
263+
"avatar_url": "https://avatars.githubusercontent.com/u/85431642?v=4",
264+
"profile": "https://github.com/ritkaarsingh30",
265+
"contributions": [
266+
"doc"
267+
]
259268
}
260269
]
261270
}

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- name: Install dependencies
3333
run: |
3434
python -m pip install --upgrade pip
35-
pip install -e .[dev]
35+
pip install -e .[dev,spatiotemporal]
3636
3737
- name: Test with pytest
3838
run: |

.github/workflows/docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: "Install jupyterbook"
2424
run: pip install -r docs/requirements.txt
2525
- name: "Install autoemulate"
26-
run: pip install git+https://github.com/alan-turing-institute/autoemulate.git
26+
run: pip install -e .
2727
- name: "Run jupyterbook"
2828
run: jupyter-book build docs --all
2929
- name: "Deploy"
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: github-repo-stats-for-autoemulate
2+
3+
# This workflow uses a github action to fetch GitHub repository stats
4+
# (traffic data, clones, views, referrers, popular content) for the
5+
# repository `alan-turing-institute/autoemulate` and to generate a
6+
# report file aswell as stats on the branch "github-repo-stats".
7+
# There is a link to the stats at the top of the readme.
8+
9+
on:
10+
schedule:
11+
# Run this once per day, towards the end of the day for keeping the most
12+
# recent data point most meaningful (hours are interpreted in UTC).
13+
- cron: "0 23 * * *"
14+
workflow_dispatch: # Allow for running this manually.
15+
16+
jobs:
17+
j1:
18+
name: repostats-for-autoemulate
19+
runs-on: ubuntu-latest
20+
steps:
21+
- name: run-ghrs
22+
uses: jgehrcke/github-repo-stats@RELEASE
23+
with:
24+
# Define the stats repository (the repo to fetch
25+
# stats for and to generate the report for).
26+
# Remove the parameter when the stats repository
27+
# and the data repository are the same.
28+
repository: alan-turing-institute/autoemulate
29+
# Set a GitHub API token that can read the GitHub
30+
# repository traffic API for the stats repository,
31+
# and that can push commits to the data repository
32+
# (which this workflow file lives in, to store data
33+
# and the report files).
34+
ghtoken: ${{ secrets.ghrs_github_api_token }}

.github/workflows/precommit.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
python -m venv .venv
3434
source .venv/bin/activate
3535
python -m pip install --upgrade pip
36-
pip install -e .[dev]
36+
pip install -e .[dev,spatiotemporal]
3737
3838
- uses: pre-commit/[email protected]
3939
with:

.pre-commit-config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
33
# Ruff version.
4-
rev: v0.11.4
4+
rev: v0.12.11
55
hooks:
66
# Run the linter.
77
- id: ruff
@@ -13,11 +13,12 @@ repos:
1313
types_or: [ python, pyi ]
1414
files: ^autoemulate/|^tests/|^benchmarks/
1515
- repo: https://github.com/RobertCraigie/pyright-python
16-
rev: v1.1.398
16+
rev: v1.1.405
1717
hooks:
1818
- id: pyright
1919
files: ^autoemulate/|^tests/|^benchmarks/
2020
- repo: https://github.com/kynan/nbstripout
2121
rev: 0.8.1
2222
hooks:
2323
- id: nbstripout
24+
exclude: ^case_studies/

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
66
[![All Contributors](https://img.shields.io/github/all-contributors/alan-turing-institute/autoemulate?color=ee8449&style=flat-square)](#contributors)
77
[![Documentation](https://img.shields.io/badge/documentation-blue)](https://alan-turing-institute.github.io/autoemulate/)
8+
[![Github Stats](https://img.shields.io/badge/repostats-orange)](https://github.com/alan-turing-institute/autoemulate/blob/github-repo-stats/alan-turing-institute/autoemulate/latest-report/report.pdf)
9+
810

911
<!-- SPHINX-START -->
1012
Simulations of physical systems are often slow and need lots of compute, which makes them unpractical for real-world applications like digital twins, or when they have to run thousands of times for sensitivity analyses. The goal of `AutoEmulate` is to make it easy to replace simulations with fast, accurate emulators. To do this, `AutoEmulate` automatically fits and compares various emulators, ranging from simple models like Radial Basis Functions and Second Order Polynomials to more complex models like Support Vector Machines and Gaussian Processes to find the best emulator for a simulation.
1113

14+
>[!WARNING]
15+
>Although AutoEmulate is currently on version 1.x, we are not following semantic versioning at the moment. The convention for V1 is that breaking and major changes will be made between minor version (1.1 -> 1.2). Bug fixes will be made in patch versions (1.1.1 -> 1.1.2). We plan to implement true semantic versioning in v2 of the package. We recommend pinning the minor version of AutoEmulate if using downstream and carefully reading release notes.
1216
1317
## Documentation
1418

@@ -63,6 +67,7 @@ You can find the project documentation [here](https://alan-turing-institute.gith
6367
<tr>
6468
<td align="center" valign="top" width="14.28%"><a href="https://jvwilliams23.github.io"><img src="https://avatars.githubusercontent.com/u/48445365?v=4?s=100" width="100px;" alt="Josh Williams"/><br /><sub><b>Josh Williams</b></sub></a><br /><a href="#bug-jvwilliams23" title="Bug reports">🐛</a> <a href="#ideas-jvwilliams23" title="Ideas, Planning, & Feedback">🤔</a></td>
6569
<td align="center" valign="top" width="14.28%"><a href="https://github.com/LevanBokeria"><img src="https://avatars.githubusercontent.com/u/7816766?v=4?s=100" width="100px;" alt="Levan Bokeria"/><br /><sub><b>Levan Bokeria</b></sub></a><br /><a href="#bug-LevanBokeria" title="Bug reports">🐛</a></td>
70+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/ritkaarsingh30"><img src="https://avatars.githubusercontent.com/u/85431642?v=4?s=100" width="100px;" alt="Ritkaar Singh"/><br /><sub><b>Ritkaar Singh</b></sub></a><br /><a href="#doc-ritkaarsingh30" title="Documentation">📖</a></td>
6671
</tr>
6772
</tbody>
6873
</table>

autoemulate/calibration/base.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import logging
2+
from collections.abc import Callable
3+
4+
import arviz as az
5+
import numpy as np
6+
from getdist import MCSamples
7+
from pyro.infer import HMC, MCMC, NUTS, Predictive
8+
from pyro.infer.mcmc import RandomWalkKernel
9+
10+
from autoemulate.core.types import TensorLike
11+
12+
13+
class BayesianMixin:
14+
"""Mixin class for Bayesian calibration methods."""
15+
16+
logger: logging.Logger
17+
model: Callable
18+
observations: dict[str, TensorLike] | None
19+
20+
def _get_kernel(
21+
self,
22+
sampler: str,
23+
model_kwargs: dict[str, TensorLike] | None = None,
24+
**sampler_kwargs,
25+
):
26+
"""Get the appropriate MCMC kernel based on sampler choice."""
27+
# TODO: consider how to pass model args, functools.partial?
28+
model_kwargs = model_kwargs or {}
29+
sampler = sampler.lower()
30+
if sampler == "nuts":
31+
self.logger.debug("Using NUTS kernel.")
32+
return NUTS(self.model, **sampler_kwargs)
33+
if sampler == "hmc":
34+
step_size = sampler_kwargs.pop("step_size", 0.01)
35+
trajectory_length = sampler_kwargs.pop("trajectory_length", 1.0)
36+
self.logger.debug(
37+
"Using HMC kernel with step_size=%s, trajectory_length=%s",
38+
step_size,
39+
trajectory_length,
40+
)
41+
return HMC(
42+
self.model,
43+
step_size=step_size,
44+
trajectory_length=trajectory_length,
45+
**sampler_kwargs,
46+
)
47+
if sampler == "metropolis":
48+
self.logger.debug("Using Metropolis (RandomWalkKernel).")
49+
return RandomWalkKernel(self.model, **sampler_kwargs)
50+
self.logger.error("Unknown sampler: %s", sampler)
51+
raise ValueError(f"Unknown sampler: {sampler}")
52+
53+
def run_mcmc(
54+
self,
55+
warmup_steps: int = 500,
56+
num_samples: int = 1000,
57+
num_chains: int = 1,
58+
initial_params: dict[str, TensorLike] | None = None,
59+
model_kwargs: dict | None = None,
60+
sampler: str = "nuts",
61+
**sampler_kwargs,
62+
) -> MCMC:
63+
"""
64+
Run Markov Chain Monte Carlo (MCMC). Defaults to using the NUTS sampler.
65+
66+
Parameters
67+
----------
68+
warmup_steps: int
69+
Number of warm up steps to run per chain (i.e., burn-in). These samples are
70+
discarded. Defaults to 500.
71+
num_samples: int
72+
Number of samples to draw after warm up. Defaults to 1000.
73+
num_chains: int
74+
Number of parallel chains to run. Defaults to 1.
75+
initial_params: dict[str, TensorLike] | None
76+
Optional dictionary specifiying initial values for each calibration
77+
parameter per chain. The tensors must be of length `num_chains`.
78+
model_kwargs: dict | None
79+
Optional dictionary of keyword arguments to pass to the model.
80+
sampler: str
81+
The MCMC kernel to use, one of "hmc", "nuts" or "metropolis".
82+
**sampler_kwargs
83+
Additional keyword arguments to pass to the MCMC kernel.
84+
85+
Returns
86+
-------
87+
MCMC
88+
The Pyro MCMC object. Methods include `summary()` and `get_samples()`.
89+
"""
90+
# Check initial param values match number of chains
91+
92+
if initial_params is not None:
93+
for param, init_vals in initial_params.items():
94+
if init_vals.shape[0] != num_chains:
95+
msg = (
96+
"An initial value must be provided for each chain, parameter "
97+
f"{param} tensor only has {init_vals.shape[0]} values."
98+
)
99+
self.logger.error(msg)
100+
raise ValueError(msg)
101+
self.logger.debug(
102+
"Initial parameters provided for MCMC: %s", initial_params
103+
)
104+
105+
# Run NUTS
106+
kernel = self._get_kernel(sampler, model_kwargs=model_kwargs, **sampler_kwargs)
107+
mcmc = MCMC(
108+
kernel,
109+
warmup_steps=warmup_steps,
110+
num_samples=num_samples,
111+
num_chains=num_chains,
112+
# If None, init values are sampled from the prior.
113+
initial_params=initial_params,
114+
# Multiprocessing
115+
mp_context="spawn" if num_chains > 1 else None,
116+
)
117+
self.logger.info("Starting MCMC run.")
118+
mcmc.run()
119+
self.logger.info("MCMC run completed.")
120+
return mcmc
121+
122+
def posterior_predictive(self, mcmc: MCMC) -> dict[str, TensorLike]:
123+
"""
124+
Return posterior predictive samples.
125+
126+
Parameters
127+
----------
128+
mcmc: MCMC
129+
The MCMC object.
130+
131+
Returns
132+
-------
133+
TensorLike
134+
Tensor of posterior predictive samples [n_mcmc_samples, n_obs, n_outputs].
135+
"""
136+
posterior_samples = mcmc.get_samples()
137+
posterior_predictive = Predictive(self.model, posterior_samples)
138+
samples = posterior_predictive(predict=True)
139+
self.logger.debug("Posterior predictive samples generated.")
140+
return samples
141+
142+
def to_arviz(
143+
self, mcmc: MCMC, posterior_predictive: bool = False
144+
) -> az.InferenceData:
145+
"""
146+
Convert MCMC object to Arviz InferenceData object for plotting.
147+
148+
Parameters
149+
----------
150+
mcmc: MCMC
151+
The MCMC object.
152+
posterior_predictive: bool
153+
Whether to include posterior predictive samples. Defaults to False.
154+
155+
Returns
156+
-------
157+
az.InferenceData
158+
"""
159+
pp_samples = None
160+
if posterior_predictive:
161+
self.logger.info("Including posterior predictive samples in Arviz output.")
162+
pp_samples = self.posterior_predictive(mcmc)
163+
164+
# Need to create dataset manually for Metropolis Hastings
165+
# This is because az.from_pyro expects kernel with `divergences`
166+
if isinstance(mcmc.kernel, RandomWalkKernel):
167+
self.logger.debug(
168+
"Using manual conversion for Metropolis (RandomWalkKernel) kernel."
169+
)
170+
if posterior_predictive:
171+
if self.observations is None:
172+
msg = (
173+
"Observations must be provided to include observed_data in "
174+
"Arviz InferenceData."
175+
)
176+
self.logger.error(msg)
177+
raise ValueError(msg)
178+
az_data = az.InferenceData(
179+
posterior=az.convert_to_dataset(
180+
mcmc.get_samples(group_by_chain=True)
181+
),
182+
posterior_predictive=az.convert_to_dataset(pp_samples),
183+
observed_data=az.convert_to_dataset(self.observations),
184+
)
185+
else:
186+
az_data = az.InferenceData(
187+
posterior=az.convert_to_dataset(
188+
mcmc.get_samples(group_by_chain=True)
189+
),
190+
)
191+
else:
192+
self.logger.debug("Using az.from_pyro for conversion.")
193+
az_data = az.from_pyro(mcmc, posterior_predictive=pp_samples)
194+
195+
self.logger.info("Arviz InferenceData conversion complete.")
196+
return az_data
197+
198+
@staticmethod
199+
def to_getdist(
200+
data: MCMC | az.InferenceData,
201+
label: str,
202+
use_weights: bool = True,
203+
weight_name: str = "weight",
204+
) -> MCSamples:
205+
"""Convert Pyro MCMC or ArviZ InferenceData to GetDist MCSamples.
206+
207+
This lightweight helper extends the original implementation to also accept
208+
SMC / other results already converted to ArviZ InferenceData. If a weight
209+
variable (default: smc_weight) is present in sample_stats it will be
210+
used as importance weights.
211+
212+
Parameters
213+
----------
214+
data: MCMC | az.InferenceData
215+
The Pyro MCMC object or an ArviZ InferenceData object containing posterior
216+
samples.
217+
label: str
218+
Label for the MCSamples object.
219+
use_weights: bool
220+
If True and `data` is an `InferenceData` with `weight_name` in
221+
`sample_stats` then those weights are applied. Defaults to True.
222+
weight_name: str
223+
Name of the weight variable inside `sample_stats` to look up.
224+
225+
Returns
226+
-------
227+
MCSamples
228+
The GetDist MCSamples object.
229+
"""
230+
if isinstance(data, MCMC):
231+
samples_dict = data.get_samples()
232+
arr = np.array(list(samples_dict.values())).T
233+
names = list(samples_dict.keys())
234+
weights = None
235+
else:
236+
posterior = data.posterior # type: ignore[attr-defined]
237+
names = list(posterior.data_vars)
238+
cols = []
239+
for name in names:
240+
vals = np.asarray(posterior[name].values)
241+
# Expect shape (chain, draw) for scalar parameters
242+
if vals.ndim != 2:
243+
msg = (
244+
f"Posterior variable '{name}' has shape {vals.shape}; "
245+
"only scalar parameter sites (chain, draw) supported here."
246+
)
247+
raise ValueError(msg)
248+
cols.append(vals.reshape(-1))
249+
arr = np.vstack(cols).T # (n_total_draws, n_params)
250+
weights = None
251+
sample_stats = getattr(data, "sample_stats", None) # type: ignore[attr-defined]
252+
if use_weights and sample_stats is not None and weight_name in sample_stats:
253+
w = np.asarray(sample_stats[weight_name].values)
254+
if w.ndim == 2: # (chain, draw)
255+
weights = w.reshape(-1)
256+
return MCSamples(samples=arr, names=names, label=label, weights=weights)

0 commit comments

Comments
 (0)