Skip to content

Commit 7d85212

Browse files
authored
Add a test that runs everything (#12)
1 parent 594b6d8 commit 7d85212

6 files changed

Lines changed: 99 additions & 28 deletions

File tree

.pre-commit-config.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ repos:
4040
language: python
4141
language_version: python3
4242
types: [python]
43+
args: ["-c", "pyproject.toml"]
44+
additional_dependencies: ["toml"]
4345
files: "(sbijax|examples)"
4446

4547
- repo: https://github.com/PyCQA/flake8
@@ -58,6 +60,17 @@ repos:
5860
args: ["--ignore-missing-imports"]
5961
files: "(sbijax|examples)"
6062

63+
- repo: https://github.com/nbQA-dev/nbQA
64+
rev: 1.6.3
65+
hooks:
66+
- id: nbqa-black
67+
- id: nbqa-pyupgrade
68+
args: [--py39-plus]
69+
- id: nbqa-isort
70+
args: ['--profile=black']
71+
- id: nbqa-flake8
72+
args: ['--ignore=E501,E203,E302,E402,E731,W503']
73+
6174
- repo: https://github.com/jorisroovers/gitlint
6275
rev: v0.18.0
6376
hooks:

README.md

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
[![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept)
44
[![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml)
5+
[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/)
56

67
> Simulation-based inference in JAX
78
89
## About
910

1011
SbiJAX implements several algorithms for simulation-based inference using
11-
[BlackJAX](https://github.com/blackjax-devs/blackjax), [Haiku](https://github.com/deepmind/dm-haiku) and [JAX](https://github.com/google/jax).
12+
[JAX](https://github.com/google/jax), [Haiku](https://github.com/deepmind/dm-haiku) and [BlackJAX](https://github.com/blackjax-devs/blackjax).
1213

1314
SbiJAX so far implements
1415

@@ -37,29 +38,6 @@ To install the latest GitHub <RELEASE>, use:
3738
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
3839
```
3940

40-
## Contributing
41-
42-
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
43-
["good first issue"](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22). In order to contribute:
44-
45-
1) Fork the repository and install `hatch` and `pre-commit`
46-
47-
```bash
48-
pip install hatch pre-commit
49-
pre-commit install
50-
```
51-
52-
2) Create a new branch in your fork and implement your contribution
53-
54-
3) Test your contribution/implementation by calling `hatch run test` on the (Unix) command line before submitting a PR
55-
56-
```bash
57-
hatch run test:lint
58-
hatch run test:test
59-
```
60-
61-
4) Submit a pull request :slightly_smiling_face:
62-
6341
## Author
6442

6543
Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,16 @@ dependencies = [
2525
"dm-haiku>=0.0.9",
2626
"flax>=0.6.3",
2727
"optax>=0.1.3",
28+
"surjectors@git+https://git@github.com/dirmeier/surjectors@v0.2.2",
2829
]
2930
dynamic = ["version"]
3031

3132
[project.urls]
3233
homepage = "https://github.com/dirmeier/sbijax"
3334

35+
[tool.hatch.metadata]
36+
allow-direct-references = true
37+
3438
[tool.hatch.version]
3539
path = "sbijax/__init__.py"
3640

@@ -50,7 +54,7 @@ dependencies = [
5054

5155
[tool.hatch.envs.test.scripts]
5256
lint = 'pylint sbijax'
53-
test = 'pytest -v --doctest-modules --cov=./sbi --cov-report=xml sbijax'
57+
test = 'pytest -v --doctest-modules --cov=./sbijax --cov-report=xml sbijax'
5458

5559

5660
[tool.black]

sbijax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
sbijax: Simulation-based inference in JAX
33
"""
44

5-
__version__ = "0.0.10"
5+
__version__ = "0.0.11"
66

77

88
from sbijax.abc.rejection_abc import RejectionABC

sbijax/snl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
import optax
77
from absl import logging
8+
9+
# TODO(simon): this is a bit an annoying dependency to have
810
from flax.training.early_stopping import EarlyStopping
911
from jax import numpy as jnp
1012

sbijax/snl_test.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,80 @@
11
# pylint: skip-file
2-
import chex
2+
3+
import distrax
4+
import haiku as hk
5+
import optax
6+
from jax import numpy as jnp
7+
from surjectors import Chain, MaskedCoupling, TransformedDistribution
8+
from surjectors.conditioners import mlp_conditioner
9+
from surjectors.util import make_alternating_binary_mask
10+
11+
from sbijax import SNL
12+
13+
14+
def prior_model_fns():
15+
p = distrax.Independent(
16+
distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0)), 1
17+
)
18+
return p.sample, p.log_prob
19+
20+
21+
def simulator_fn(seed, theta):
22+
p = distrax.MultivariateNormalDiag(theta, 0.1 * jnp.ones_like(theta))
23+
y = p.sample(seed=seed)
24+
return y
25+
26+
27+
def log_density_fn(theta, y):
28+
prior = distrax.Uniform(jnp.full(2, -3.0), jnp.full(2, 3.0))
29+
likelihood = distrax.MultivariateNormalDiag(
30+
theta, 0.1 * jnp.ones_like(theta)
31+
)
32+
33+
lp = jnp.sum(prior.log_prob(theta)) + jnp.sum(likelihood.log_prob(y))
34+
return lp
35+
36+
37+
def make_model(dim):
38+
def _bijector_fn(params):
39+
means, log_scales = jnp.split(params, 2, -1)
40+
return distrax.ScalarAffine(means, jnp.exp(log_scales))
41+
42+
def _flow(method, **kwargs):
43+
layers = []
44+
for i in range(2):
45+
mask = make_alternating_binary_mask(dim, i % 2 == 0)
46+
layer = MaskedCoupling(
47+
mask=mask,
48+
bijector=_bijector_fn,
49+
conditioner=mlp_conditioner([8, 8, dim * 2]),
50+
)
51+
layers.append(layer)
52+
chain = Chain(layers)
53+
base_distribution = distrax.Independent(
54+
distrax.Normal(jnp.zeros(dim), jnp.ones(dim)),
55+
1,
56+
)
57+
td = TransformedDistribution(base_distribution, chain)
58+
return td(method, **kwargs)
59+
60+
td = hk.transform(_flow)
61+
td = hk.without_apply_rng(td)
62+
return td
363

464

565
def test_snl():
6-
chex.assert_equal(1, 1)
66+
rng_seq = hk.PRNGSequence(0)
67+
y_observed = jnp.array([-1.0, 1.0])
68+
69+
prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
70+
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn
71+
72+
snl = SNL(fns, make_model(2))
73+
params, info = snl.fit(
74+
next(rng_seq),
75+
y_observed,
76+
n_rounds=1,
77+
optimizer=optax.adam(1e-4),
78+
sampler="slice",
79+
)
80+
_ = snl.sample_posterior(params, 2, 100, 50, sampler="slice")

0 commit comments

Comments
 (0)