Skip to content

Commit 376e59a

Browse files
authored
Add Distrax classes and methods to surjectors (#34)
1 parent 68c2f8a commit 376e59a

File tree

7 files changed

+99
-24
lines changed

7 files changed

+99
-24
lines changed

.github/workflows/examples.yaml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: examples
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
precommit:
11+
name: Pre-commit checks
12+
runs-on: ubuntu-latest
13+
steps:
14+
- uses: actions/checkout@v3
15+
16+
examples:
17+
runs-on: ubuntu-latest
18+
needs:
19+
- precommit
20+
strategy:
21+
matrix:
22+
python-version: [3.11]
23+
steps:
24+
- uses: actions/checkout@v3
25+
- name: Set up Python ${{ matrix.python-version }}
26+
uses: actions/setup-python@v3
27+
with:
28+
python-version: ${{ matrix.python-version }}
29+
- name: Install dependencies
30+
run: |
31+
pip install hatch matplotlib
32+
- name: Build package
33+
run: |
34+
pip install jaxlib jax
35+
pip install .
36+
- name: Run tests
37+
run: |
38+
python examples/autoregressive_inference_surjection.py --n-iter 10
39+
python examples/conditional_density_estimation.py --n-iter 10 --model coupling
40+
python examples/conditional_density_estimation.py --n-iter 10 --model autoregressive
41+
python examples/coupling_inference_surjection.py --n-iter 10

examples/autoregressive_inference_surjection.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
from collections import namedtuple
23

34
import distrax
@@ -99,14 +100,14 @@ def loss_fn(params):
99100
return params, losses
100101

101102

102-
def run():
103+
def run(n_iter):
103104
n, p = 1000, 20
104105
rng_seq = hk.PRNGSequence(2)
105106
y = jr.normal(next(rng_seq), shape=(n, p))
106107
data = namedtuple("named_dataset", "y")(y)
107108

108109
model = make_model(p)
109-
params, losses = train(rng_seq, data, model)
110+
params, losses = train(rng_seq, data, model, n_iter)
110111
plt.plot(losses)
111112
plt.show()
112113

@@ -115,4 +116,7 @@ def run():
115116

116117

117118
if __name__ == "__main__":
118-
run()
119+
parser = argparse.ArgumentParser()
120+
parser.add_argument("--n-iter", type=int, default=1_000)
121+
args = parser.parse_args()
122+
run(args.n_iter)

examples/conditional_density_estimation.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import argparse
2+
13
import distrax
24
import haiku as hk
35
import jax
@@ -49,9 +51,7 @@ def _flow(method, **kwargs):
4951
layer = MaskedAutoregressive(
5052
bijector_fn=_bijector_fn,
5153
conditioner=MADE(
52-
2,
53-
[32, 32, 2 * 2],
54-
2,
54+
2, [32, 32], 2,
5555
w_init=hk.initializers.TruncatedNormal(0.01),
5656
b_init=jnp.zeros,
5757
),
@@ -104,7 +104,7 @@ def loss_fn(params):
104104
return params, losses
105105

106106

107-
def run():
107+
def run(n_iter, model):
108108
n = 10000
109109
thetas = distrax.Normal(jnp.zeros(2), jnp.full(2, 10)).sample(
110110
seed=random.PRNGKey(0), sample_shape=(n,)
@@ -114,8 +114,8 @@ def run():
114114
)
115115
data = named_dataset(y, thetas)
116116

117-
model = make_model(2)
118-
params, losses = train(hk.PRNGSequence(2), data, model)
117+
model = make_model(2, model)
118+
params, losses = train(hk.PRNGSequence(2), data, model, n_iter)
119119
samples = model.apply(
120120
params,
121121
random.PRNGKey(2),
@@ -129,4 +129,10 @@ def run():
129129

130130

131131
if __name__ == "__main__":
132-
run()
132+
parser = argparse.ArgumentParser()
133+
parser.add_argument("--n-iter", type=int, default=1_000)
134+
parser.add_argument("--model", type=str, default="coupling")
135+
args = parser.parse_args()
136+
run(args.n_iter, args.model)
137+
138+

examples/coupling_inference_surjection.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
from collections import namedtuple
23

34
import distrax
@@ -99,14 +100,14 @@ def loss_fn(params):
99100
return params, losses
100101

101102

102-
def run():
103+
def run(n_iter):
103104
n, p = 1000, 20
104105
rng_seq = hk.PRNGSequence(2)
105106
y = jr.normal(next(rng_seq), shape=(n, p))
106107
data = namedtuple("named_dataset", "y")(y)
107108

108109
model = make_model(p)
109-
params, losses = train(rng_seq, data, model)
110+
params, losses = train(rng_seq, data, model, n_iter)
110111
plt.plot(losses)
111112
plt.show()
112113

@@ -115,4 +116,9 @@ def run():
115116

116117

117118
if __name__ == "__main__":
118-
run()
119+
parser = argparse.ArgumentParser()
120+
parser.add_argument("--n-iter", type=int, default=1_000)
121+
args = parser.parse_args()
122+
run(args.n_iter)
123+
124+

surjectors/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""surjectors: Surjection layers for density estimation with normalizing flows."""
22

3-
__version__ = "0.3.2"
3+
__version__ = "0.3.3"
4+
5+
from distrax import ScalarAffine
46

57
from surjectors._src.bijectors.affine_masked_autoregressive import (
68
AffineMaskedAutoregressive,
@@ -60,4 +62,5 @@
6062
"MLPInferenceFunnel",
6163
"Slice",
6264
# "Augment",
65+
"ScalarAffine",
6366
]

surjectors/_src/bijectors/permutation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import distrax
21
from jax import numpy as jnp
32

3+
from surjectors._src.bijectors.bijector import Bijector
4+
45

56
# pylint: disable=arguments-renamed
6-
class Permutation(distrax.Bijector):
7+
class Permutation(Bijector):
78
"""Permute the dimensions of a vector.
89
910
Args:
@@ -20,13 +21,13 @@ class Permutation(distrax.Bijector):
2021
"""
2122

2223
def __init__(self, permutation, event_ndims_in: int):
23-
super().__init__(event_ndims_in)
2424
self.permutation = permutation
25+
self.event_ndims_in = event_ndims_in
2526

26-
def _forward_and_likelihood_contribution(self, z):
27+
def _forward_and_likelihood_contribution(self, z, **kwargs):
2728
return z[..., self.permutation], jnp.full(jnp.shape(z)[:-1], 0.0)
2829

29-
def _inverse_and_likelihood_contribution(self, y):
30+
def _inverse_and_likelihood_contribution(self, y, **kwargs):
3031
size = self.permutation.size
3132
permutation_inv = (
3233
jnp.zeros(size, dtype=jnp.result_type(int))

surjectors/_src/distributions/transformed_distribution.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from typing import Union
2+
13
import chex
24
import distrax
35
import haiku as hk
46
from distrax import Distribution
57
from jax import Array
8+
from tensorflow_probability.substrates.jax import distributions as tfd
69

710
from surjectors._src.surjectors.surjector import Surjector
811

@@ -31,7 +34,11 @@ class TransformedDistribution:
3134
>>> )
3235
"""
3336

34-
def __init__(self, base_distribution: Distribution, transform: Surjector):
37+
def __init__(
38+
self,
39+
base_distribution: Union[Distribution, tfd.Distribution],
40+
transform: Surjector,
41+
):
3542
self.base_distribution = base_distribution
3643
self.transform = transform
3744

@@ -118,10 +125,17 @@ def sample_and_log_prob(self, sample_shape=(), x: Array = None):
118125
if x is not None:
119126
chex.assert_equal(sample_shape[0], x.shape[0])
120127

121-
z, lp_z = self.base_distribution.sample_and_log_prob(
122-
seed=hk.next_rng_key(),
123-
sample_shape=sample_shape,
124-
)
128+
try:
129+
z, lp_z = self.base_distribution.sample_and_log_prob(
130+
seed=hk.next_rng_key(),
131+
sample_shape=sample_shape,
132+
)
133+
except AttributeError:
134+
z, lp_z = self.base_distribution.experimental_sample_and_log_prob(
135+
seed=hk.next_rng_key(),
136+
sample_shape=sample_shape,
137+
)
138+
125139
y, fldj = self.transform.forward_and_likelihood_contribution(z, x=x)
126140
lp = lp_z - fldj
127141
return y, lp

0 commit comments

Comments
 (0)