Skip to content

Commit ab4253c

Browse files
committed
Fix doctests (1/x)
1 parent 1b44054 commit ab4253c

26 files changed

+1262
-1368
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.ipynb linguist-vendored

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ repos:
2020
args: ["--ignore-missing-imports"]
2121
files: "(surjectors|examples)"
2222

23-
- repo: https://github.com/jorisroovers/gitlint
24-
rev: v0.19.1
25-
hooks:
26-
- id: gitlint
27-
- id: gitlint-ci
28-
2923
- repo: https://github.com/astral-sh/ruff-pre-commit
3024
rev: v0.3.0
3125
hooks:
3226
- id: ruff
3327
args: [ --fix ]
3428
- id: ruff-format
29+
30+
- repo: https://github.com/jorisroovers/gitlint
31+
rev: v0.18.0
32+
hooks:
33+
- id: gitlint
34+
- id: gitlint-ci

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.11.9

README.md

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,19 @@ Contributions in the form of pull requests are more than welcome. A good way to
8282

8383
In order to contribute:
8484

85-
1) Clone `Surjectors` and install `hatch` via `pip install hatch`,
86-
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
87-
3) implement your contribution and ideally a test case,
88-
4) test it by calling `hatch run test` on the (Unix) command line,
89-
5) submit a PR 🙂
90-
85+
1) Clone `surjectors` and install `uv` from [here](https://github.com/astral-sh/uv).
86+
2) Create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`.
87+
3) Install all dependencies via `uv sync --all-groups`.
88+
4) Activate the virtual environment: `source .venv/bin/activate`.
89+
5) Install `pre-commit` and `gitlint` via:
90+
91+
```shell
92+
pre-commit install
93+
gitlint install-hook
94+
```
95+
6) Implement your contribution and ideally a test case.
96+
7) Test it by calling `make format`, `make lints` and `make tests` on the (Unix) command line.
97+
8) Submit a PR 🙂.
9198

9299
## Citing Surjectors
93100

@@ -109,4 +116,4 @@ If you find our work relevant to your research, please consider citing:
109116

110117
## Author
111118

112-
Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>
119+
Simon Dirmeier <a href="mailto:simd23 @ pm me">simd23 @ pm me</a>

examples/autoregressive_inference_surjection.py

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -11,112 +11,112 @@
1111
from matplotlib import pyplot as plt
1212

1313
from surjectors import (
14-
AffineMaskedAutoregressiveInferenceFunnel,
15-
Chain,
16-
MaskedAutoregressive,
17-
TransformedDistribution,
14+
AffineMaskedAutoregressiveInferenceFunnel,
15+
Chain,
16+
MaskedAutoregressive,
17+
TransformedDistribution,
1818
)
1919
from surjectors.nn import MADE, make_mlp
2020
from surjectors.util import as_batch_iterator, unstack
2121

2222

2323
def _decoder_fn(n_dim):
24-
decoder_net = make_mlp([4, 4, n_dim * 2])
24+
decoder_net = make_mlp([4, 4, n_dim * 2])
2525

26-
def _fn(z):
27-
params = decoder_net(z)
28-
mu, log_scale = jnp.split(params, 2, -1)
29-
return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))
26+
def _fn(z):
27+
params = decoder_net(z)
28+
mu, log_scale = jnp.split(params, 2, -1)
29+
return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))
3030

31-
return _fn
31+
return _fn
3232

3333

3434
def _made_bijector_fn(params):
35-
means, log_scales = unstack(params, -1)
36-
return surjectors.Inverse(surjectors.ScalarAffine(means, jnp.exp(log_scales)))
35+
means, log_scales = unstack(params, -1)
36+
return surjectors.Inverse(surjectors.ScalarAffine(means, jnp.exp(log_scales)))
3737

3838

3939
def make_model(n_dimensions):
40-
def _flow(**kwargs):
41-
n_dim = n_dimensions
42-
layers = []
43-
for i in range(3):
44-
if i != 1:
45-
layer = AffineMaskedAutoregressiveInferenceFunnel(
46-
n_keep=int(n_dim / 2),
47-
decoder=_decoder_fn(int(n_dim / 2)),
48-
conditioner=MADE(int(n_dim / 2), [8, 8], 2),
49-
)
50-
n_dim = int(n_dim / 2)
51-
else:
52-
layer = MaskedAutoregressive(
53-
conditioner=MADE(n_dim, [8, 8], 2),
54-
bijector_fn=_made_bijector_fn,
55-
)
56-
layers.append(layer)
57-
# TODO(simon): needs to change order
58-
# layers.append(Permutation(order, 1))
59-
chain = Chain(layers)
60-
61-
base_distribution = distrax.Independent(
62-
distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)),
63-
reinterpreted_batch_ndims=1,
40+
def _flow(**kwargs):
41+
n_dim = n_dimensions
42+
layers = []
43+
for i in range(3):
44+
if i != 1:
45+
layer = AffineMaskedAutoregressiveInferenceFunnel(
46+
n_keep=int(n_dim / 2),
47+
decoder=_decoder_fn(int(n_dim / 2)),
48+
conditioner=MADE(int(n_dim / 2), [8, 8], 2),
6449
)
65-
td = TransformedDistribution(base_distribution, chain)
66-
return td.log_prob(**kwargs)
50+
n_dim = int(n_dim / 2)
51+
else:
52+
layer = MaskedAutoregressive(
53+
conditioner=MADE(n_dim, [8, 8], 2),
54+
bijector_fn=_made_bijector_fn,
55+
)
56+
layers.append(layer)
57+
# TODO(simon): needs to change order
58+
# layers.append(Permutation(order, 1))
59+
chain = Chain(layers)
60+
61+
base_distribution = distrax.Independent(
62+
distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)),
63+
reinterpreted_batch_ndims=1,
64+
)
65+
td = TransformedDistribution(base_distribution, chain)
66+
return td.log_prob(**kwargs)
6767

68-
td = hk.transform(_flow)
69-
td = hk.without_apply_rng(td)
70-
return td
68+
td = hk.transform(_flow)
69+
td = hk.without_apply_rng(td)
70+
return td
7171

7272

7373
def train(rng_seq, data, model, max_n_iter=1000):
74-
train_iter = as_batch_iterator(next(rng_seq), data, 100, True)
75-
params = model.init(next(rng_seq), **train_iter(0))
74+
train_iter = as_batch_iterator(next(rng_seq), data, 100, True)
75+
params = model.init(next(rng_seq), **train_iter(0))
7676

77-
optimizer = optax.adam(1e-4)
78-
state = optimizer.init(params)
77+
optimizer = optax.adam(1e-4)
78+
state = optimizer.init(params)
7979

80-
@jax.jit
81-
def step(params, state, **batch):
82-
def loss_fn(params):
83-
lp = model.apply(params, **batch)
84-
return -jnp.sum(lp)
80+
@jax.jit
81+
def step(params, state, **batch):
82+
def loss_fn(params):
83+
lp = model.apply(params, **batch)
84+
return -jnp.sum(lp)
8585

86-
loss, grads = jax.value_and_grad(loss_fn)(params)
87-
updates, new_state = optimizer.update(grads, state, params)
88-
new_params = optax.apply_updates(params, updates)
89-
return loss, new_params, new_state
86+
loss, grads = jax.value_and_grad(loss_fn)(params)
87+
updates, new_state = optimizer.update(grads, state, params)
88+
new_params = optax.apply_updates(params, updates)
89+
return loss, new_params, new_state
9090

91-
losses = np.zeros(max_n_iter)
92-
for i in range(max_n_iter):
93-
train_loss = 0.0
94-
for j in range(train_iter.num_batches):
95-
batch = train_iter(j)
96-
batch_loss, params, state = step(params, state, **batch)
97-
train_loss += batch_loss
98-
losses[i] = train_loss
91+
losses = np.zeros(max_n_iter)
92+
for i in range(max_n_iter):
93+
train_loss = 0.0
94+
for j in range(train_iter.num_batches):
95+
batch = train_iter(j)
96+
batch_loss, params, state = step(params, state, **batch)
97+
train_loss += batch_loss
98+
losses[i] = train_loss
9999

100-
return params, losses
100+
return params, losses
101101

102102

103103
def run(n_iter):
104-
n, p = 1000, 20
105-
rng_seq = hk.PRNGSequence(2)
106-
y = jr.normal(next(rng_seq), shape=(n, p))
107-
data = namedtuple("named_dataset", "y")(y)
104+
n, p = 1000, 20
105+
rng_seq = hk.PRNGSequence(2)
106+
y = jr.normal(next(rng_seq), shape=(n, p))
107+
data = namedtuple("named_dataset", "y")(y)
108108

109-
model = make_model(p)
110-
params, losses = train(rng_seq, data, model, n_iter)
111-
plt.plot(losses)
112-
plt.show()
109+
model = make_model(p)
110+
params, losses = train(rng_seq, data, model, n_iter)
111+
plt.plot(losses)
112+
plt.show()
113113

114-
y = jr.normal(next(rng_seq), shape=(10, p))
115-
print(model.apply(params, **{"y": y}))
114+
y = jr.normal(next(rng_seq), shape=(10, p))
115+
print(model.apply(params, **{"y": y}))
116116

117117

118118
if __name__ == "__main__":
119-
parser = argparse.ArgumentParser()
120-
parser.add_argument("--n-iter", type=int, default=1_000)
121-
args = parser.parse_args()
122-
run(args.n_iter)
119+
parser = argparse.ArgumentParser()
120+
parser.add_argument("--n-iter", type=int, default=1_000)
121+
args = parser.parse_args()
122+
run(args.n_iter)

0 commit comments

Comments
 (0)