Skip to content

Commit 094413b

Browse files
authored
Add NPSE + AiO and add more unit tests (#55)
1 parent e6acf94 commit 094413b

53 files changed

Lines changed: 1862 additions & 293 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
pip install hatch
7474
- name: Build package
7575
run: |
76-
pip install jaxlib==0.4.24 jax==0.4.24
76+
pip install jaxlib jax
7777
- name: Run tests
7878
run: |
7979
hatch run test:test

.github/workflows/examples.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ jobs:
3535
pip install .
3636
- name: Run tests
3737
run: |
38-
python examples/bivariate_gaussian-smcabc.py --n-rounds 1
38+
python examples/gaussian_linear-aio.py --n-iter 10
39+
python examples/gaussian_linear-smcabc.py --n-rounds 1
3940
python examples/mixture_model-cmpe.py --n-iter 10
4041
python examples/mixture_model-nle.py --n-iter 10
4142
python examples/mixture_model-nle.py --n-iter 10 --use-spf
4243
python examples/mixture_model-npe.py --n-iter 10
4344
python examples/mixture_model-nre.py --n-iter 10
45+
python examples/mixture_model-npse.py --n-iter 10
4446
python examples/slcp-fmpe.py --n-iter 10
4547
python examples/slcp-nass_nle.py --n-iter 10 --n-rounds 1
4648
python examples/slcp-nass_smcabc.py --n-iter 10 --n-rounds 1

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ repos:
1515
- id: trailing-whitespace
1616

1717
- repo: https://github.com/pre-commit/mirrors-mypy
18-
rev: v0.910-1
18+
rev: v1.14.1
1919
hooks:
2020
- id: mypy
2121
args: ["--ignore-missing-imports"]

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ In order to contribute:
8989
4) test it by calling `make tests`, `make lints` and `make format` on the (Unix) command line,
9090
5) submit a PR 🙂
9191

92+
## Citing sbijax
93+
94+
If you find our work relevant to your research, please consider citing:
95+
96+
```
97+
@article{dirmeier2024simulation,
98+
title={Simulation-based inference with the Python Package sbijax},
99+
author={Dirmeier, Simon and Ulzega, Simone and Mira, Antonietta and Albert, Carlo},
100+
journal={arXiv preprint arXiv:2409.19435},
101+
year={2024}
102+
}
103+
```
104+
92105
## Acknowledgements
93106

94107
> [!NOTE]

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ License
130130
:hidden:
131131

132132
sbijax
133+
sbijax.experimental
133134
sbijax.mcmc
134135
sbijax.nn
135136
sbijax.util

docs/references.bib

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,20 @@ @article{dirmeier2023simulation
3636
@inproceedings{papama2019neural,
3737
title = {Sequential Neural Likelihood: Fast Likelihood-free Inference with Autoregressive Flows},
3838
author = {Papamakarios, George and Sterratt, David and Murray, Iain},
39-
booktitle = {Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics},
40-
year = {2019},
41-
doi = {10.48550/arXiv.1805.07226}
39+
booktitle = {International Conference on Artificial Intelligence and Statistics},
40+
year = {2019}
4241
}
4342
@inproceedings{greenberg2019automatic,
4443
title = {Automatic posterior transformation for likelihood-free inference},
4544
author = {Greenberg, David and Nonnenmacher, Marcel and Macke, Jakob},
4645
booktitle = {International Conference on Machine Learning},
47-
year = {2019},
48-
doi = {10.48550/arXiv.1905.0748}
46+
year = {2019}
4947
}
5048
@inproceedings{miller2022contrast,
5149
author = {Miller, Benjamin K and Weniger, Christoph and Forr\'{e}, Patrick},
5250
booktitle = {Advances in Neural Information Processing Systems},
5351
title = {Contrastive Neural Ratio Estimation},
54-
year = {2022},
55-
doi = {10.48550/arXiv.2210.06170}
52+
year = {2022}
5653
}
5754
@article{beaumont2009adaptive,
5855
title={Adaptive approximate {B}ayesian computation},
@@ -62,8 +59,7 @@ @article{beaumont2009adaptive
6259
number={4},
6360
pages={983--990},
6461
year={2009},
65-
publisher={Oxford University Press},
66-
doi={10.1093/biomet/asp052}
62+
publisher={Oxford University Press}
6763
}
6864
@article{albert2015simulated,
6965
title={A simulated annealing approach to approximate {B}ayes computations},
@@ -74,3 +70,15 @@ @article{albert2015simulated
7470
year={2015},
7571
publisher={Springer}
7672
}
73+
@inproceedings{sharrock2024sequential,
74+
title={Sequential Neural Score Estimation: Likelihood-Free Inference with Conditional Score Based Diffusion Models},
75+
author={Louis Sharrock and Jack Simons and Song Liu and Mark Beaumont},
76+
booktitle={International Conference on Machine Learning},
77+
year={2024}
78+
}
79+
@inproceedings{gloeckler2024allinone,
80+
title={All-in-one simulation-based inference},
81+
author={Manuel Gloeckler and Michael Deistler and Christian Dietrich Weilbach and Frank Wood and Jakob H. Macke},
82+
booktitle={International Conference on Machine Learning},
83+
year={2024},
84+
}

docs/sbijax.experimental.rst

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
``sbijax.experimental``
2+
=======================
3+
4+
.. currentmodule:: sbijax.experimental
5+
6+
``sbijax.experimental`` contains experimental code that might get ported to the
7+
main code base or possibly deleted again.
8+
9+
.. autosummary::
10+
AiO
11+
NPSE
12+
13+
.. autoclass:: AiO
14+
:members: fit, simulate_data, simulate_data_and_possibly_append, sample_posterior
15+
16+
.. autoclass:: NPSE
17+
:members: fit, simulate_data, simulate_data_and_possibly_append, sample_posterior
18+
19+
20+
.. currentmodule:: sbijax.experimental.nn
21+
22+
.. autosummary::
23+
make_score_model
24+
make_simformer_based_score_model
25+
ScoreModel
26+
27+
.. autofunction:: make_simformer_based_score_model
28+
29+
.. autofunction:: make_score_model
30+
31+
.. autoclass:: ScoreModel
32+
:members: __call__

examples/gaussian_linear-aio.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""All-in-one simulation-based inference.
2+
3+
Demonstrates AiO on a linear Gaussian model.
4+
"""
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
from jax import numpy as jnp, random as jr
8+
from tensorflow_probability.substrates.jax import distributions as tfd
9+
10+
from sbijax import plot_posterior
11+
from sbijax.experimental import AiO
12+
from sbijax.experimental.nn import make_simformer_based_score_model
13+
14+
15+
def prior_fn():
16+
prior = tfd.JointDistributionNamed(dict(
17+
theta=tfd.Normal(jnp.zeros(5), 1)
18+
), batch_ndims=0)
19+
return prior
20+
21+
22+
def simulator_fn(seed, theta):
23+
mean = theta["theta"].reshape(-1, 5)
24+
y = tfd.Normal(mean, 0.1).sample(seed=seed)
25+
return y
26+
27+
28+
def run(n_iter):
29+
y_observed = jnp.linspace(-2.0, 2.0, 5)
30+
fns = prior_fn, simulator_fn
31+
mask = jnp.zeros((10, 10))
32+
mask = mask.at[np.arange(5, 10), np.arange(5)].set(1)
33+
mask = mask + mask.T + jnp.eye(10)
34+
35+
neural_network = make_simformer_based_score_model(5, mask, 1, 1)
36+
model = AiO(fns, neural_network)
37+
38+
data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=10_000)
39+
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter)
40+
inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)
41+
42+
plot_posterior(inference_result, point_estimate="mean")
43+
plt.show()
44+
45+
46+
if __name__ == "__main__":
47+
import argparse
48+
parser = argparse.ArgumentParser()
49+
parser.add_argument("--n-iter", type=int, default=1_000)
50+
args = parser.parse_args()
51+
run(args.n_iter)

examples/mixture_model-fmpe.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Flow matching posterior estimation example.
2+
3+
Demonstrates FPME on a simple mixture model.
4+
"""
5+
import matplotlib.pyplot as plt
6+
from jax import numpy as jnp, random as jr
7+
from tensorflow_probability.substrates.jax import distributions as tfd
8+
9+
from sbijax import plot_posterior
10+
from sbijax import FMPE
11+
from sbijax.nn import make_cnf
12+
13+
14+
def prior_fn():
15+
prior = tfd.JointDistributionNamed(dict(
16+
theta=tfd.Normal(jnp.zeros(2), 1)
17+
), batch_ndims=0)
18+
return prior
19+
20+
21+
def simulator_fn(seed, theta):
22+
mean = theta["theta"].reshape(-1, 2)
23+
n = mean.shape[0]
24+
data_key, cat_key = jr.split(seed)
25+
categories = tfd.Categorical(logits=jnp.zeros(2)).sample(seed=cat_key, sample_shape=(n,))
26+
scales = jnp.array([1.0, 0.1])[categories].reshape(-1, 1)
27+
y = tfd.Normal(mean, scales).sample(seed=data_key)
28+
return y
29+
30+
31+
def run(n_iter):
32+
y_observed = jnp.array([-2.0, 2.0])
33+
fns = prior_fn, simulator_fn
34+
neural_network = make_cnf(2)
35+
model = FMPE(fns, neural_network)
36+
37+
data, _ = model.simulate_data(jr.PRNGKey(1), n_simulations=20_000)
38+
params, info = model.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25, n_iter=n_iter)
39+
inference_result, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)
40+
41+
plot_posterior(inference_result)
42+
plt.show()
43+
44+
45+
if __name__ == "__main__":
46+
import argparse
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument("--n-iter", type=int, default=1_000)
49+
args = parser.parse_args()
50+
run(args.n_iter)

0 commit comments

Comments
 (0)