|
11 | 11 | from matplotlib import pyplot as plt |
12 | 12 |
|
13 | 13 | from surjectors import ( |
14 | | - AffineMaskedAutoregressiveInferenceFunnel, |
15 | | - Chain, |
16 | | - MaskedAutoregressive, |
17 | | - TransformedDistribution, |
| 14 | + AffineMaskedAutoregressiveInferenceFunnel, |
| 15 | + Chain, |
| 16 | + MaskedAutoregressive, |
| 17 | + TransformedDistribution, |
18 | 18 | ) |
19 | 19 | from surjectors.nn import MADE, make_mlp |
20 | 20 | from surjectors.util import as_batch_iterator, unstack |
21 | 21 |
|
22 | 22 |
|
23 | 23 | def _decoder_fn(n_dim): |
24 | | - decoder_net = make_mlp([4, 4, n_dim * 2]) |
| 24 | + decoder_net = make_mlp([4, 4, n_dim * 2]) |
25 | 25 |
|
26 | | - def _fn(z): |
27 | | - params = decoder_net(z) |
28 | | - mu, log_scale = jnp.split(params, 2, -1) |
29 | | - return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) |
| 26 | + def _fn(z): |
| 27 | + params = decoder_net(z) |
| 28 | + mu, log_scale = jnp.split(params, 2, -1) |
| 29 | + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) |
30 | 30 |
|
31 | | - return _fn |
| 31 | + return _fn |
32 | 32 |
|
33 | 33 |
|
34 | 34 | def _made_bijector_fn(params): |
35 | | - means, log_scales = unstack(params, -1) |
36 | | - return surjectors.Inverse(surjectors.ScalarAffine(means, jnp.exp(log_scales))) |
| 35 | + means, log_scales = unstack(params, -1) |
| 36 | + return surjectors.Inverse(surjectors.ScalarAffine(means, jnp.exp(log_scales))) |
37 | 37 |
|
38 | 38 |
|
39 | 39 | def make_model(n_dimensions): |
40 | | - def _flow(**kwargs): |
41 | | - n_dim = n_dimensions |
42 | | - layers = [] |
43 | | - for i in range(3): |
44 | | - if i != 1: |
45 | | - layer = AffineMaskedAutoregressiveInferenceFunnel( |
46 | | - n_keep=int(n_dim / 2), |
47 | | - decoder=_decoder_fn(int(n_dim / 2)), |
48 | | - conditioner=MADE(int(n_dim / 2), [8, 8], 2), |
49 | | - ) |
50 | | - n_dim = int(n_dim / 2) |
51 | | - else: |
52 | | - layer = MaskedAutoregressive( |
53 | | - conditioner=MADE(n_dim, [8, 8], 2), |
54 | | - bijector_fn=_made_bijector_fn, |
55 | | - ) |
56 | | - layers.append(layer) |
57 | | - # TODO(simon): needs to change order |
58 | | - # layers.append(Permutation(order, 1)) |
59 | | - chain = Chain(layers) |
60 | | - |
61 | | - base_distribution = distrax.Independent( |
62 | | - distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)), |
63 | | - reinterpreted_batch_ndims=1, |
| 40 | + def _flow(**kwargs): |
| 41 | + n_dim = n_dimensions |
| 42 | + layers = [] |
| 43 | + for i in range(3): |
| 44 | + if i != 1: |
| 45 | + layer = AffineMaskedAutoregressiveInferenceFunnel( |
| 46 | + n_keep=int(n_dim / 2), |
| 47 | + decoder=_decoder_fn(int(n_dim / 2)), |
| 48 | + conditioner=MADE(int(n_dim / 2), [8, 8], 2), |
64 | 49 | ) |
65 | | - td = TransformedDistribution(base_distribution, chain) |
66 | | - return td.log_prob(**kwargs) |
| 50 | + n_dim = int(n_dim / 2) |
| 51 | + else: |
| 52 | + layer = MaskedAutoregressive( |
| 53 | + conditioner=MADE(n_dim, [8, 8], 2), |
| 54 | + bijector_fn=_made_bijector_fn, |
| 55 | + ) |
| 56 | + layers.append(layer) |
| 57 | + # TODO(simon): needs to change order |
| 58 | + # layers.append(Permutation(order, 1)) |
| 59 | + chain = Chain(layers) |
| 60 | + |
| 61 | + base_distribution = distrax.Independent( |
| 62 | + distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)), |
| 63 | + reinterpreted_batch_ndims=1, |
| 64 | + ) |
| 65 | + td = TransformedDistribution(base_distribution, chain) |
| 66 | + return td.log_prob(**kwargs) |
67 | 67 |
|
68 | | - td = hk.transform(_flow) |
69 | | - td = hk.without_apply_rng(td) |
70 | | - return td |
| 68 | + td = hk.transform(_flow) |
| 69 | + td = hk.without_apply_rng(td) |
| 70 | + return td |
71 | 71 |
|
72 | 72 |
|
73 | 73 | def train(rng_seq, data, model, max_n_iter=1000): |
74 | | - train_iter = as_batch_iterator(next(rng_seq), data, 100, True) |
75 | | - params = model.init(next(rng_seq), **train_iter(0)) |
| 74 | + train_iter = as_batch_iterator(next(rng_seq), data, 100, True) |
| 75 | + params = model.init(next(rng_seq), **train_iter(0)) |
76 | 76 |
|
77 | | - optimizer = optax.adam(1e-4) |
78 | | - state = optimizer.init(params) |
| 77 | + optimizer = optax.adam(1e-4) |
| 78 | + state = optimizer.init(params) |
79 | 79 |
|
80 | | - @jax.jit |
81 | | - def step(params, state, **batch): |
82 | | - def loss_fn(params): |
83 | | - lp = model.apply(params, **batch) |
84 | | - return -jnp.sum(lp) |
| 80 | + @jax.jit |
| 81 | + def step(params, state, **batch): |
| 82 | + def loss_fn(params): |
| 83 | + lp = model.apply(params, **batch) |
| 84 | + return -jnp.sum(lp) |
85 | 85 |
|
86 | | - loss, grads = jax.value_and_grad(loss_fn)(params) |
87 | | - updates, new_state = optimizer.update(grads, state, params) |
88 | | - new_params = optax.apply_updates(params, updates) |
89 | | - return loss, new_params, new_state |
| 86 | + loss, grads = jax.value_and_grad(loss_fn)(params) |
| 87 | + updates, new_state = optimizer.update(grads, state, params) |
| 88 | + new_params = optax.apply_updates(params, updates) |
| 89 | + return loss, new_params, new_state |
90 | 90 |
|
91 | | - losses = np.zeros(max_n_iter) |
92 | | - for i in range(max_n_iter): |
93 | | - train_loss = 0.0 |
94 | | - for j in range(train_iter.num_batches): |
95 | | - batch = train_iter(j) |
96 | | - batch_loss, params, state = step(params, state, **batch) |
97 | | - train_loss += batch_loss |
98 | | - losses[i] = train_loss |
| 91 | + losses = np.zeros(max_n_iter) |
| 92 | + for i in range(max_n_iter): |
| 93 | + train_loss = 0.0 |
| 94 | + for j in range(train_iter.num_batches): |
| 95 | + batch = train_iter(j) |
| 96 | + batch_loss, params, state = step(params, state, **batch) |
| 97 | + train_loss += batch_loss |
| 98 | + losses[i] = train_loss |
99 | 99 |
|
100 | | - return params, losses |
| 100 | + return params, losses |
101 | 101 |
|
102 | 102 |
|
103 | 103 | def run(n_iter): |
104 | | - n, p = 1000, 20 |
105 | | - rng_seq = hk.PRNGSequence(2) |
106 | | - y = jr.normal(next(rng_seq), shape=(n, p)) |
107 | | - data = namedtuple("named_dataset", "y")(y) |
| 104 | + n, p = 1000, 20 |
| 105 | + rng_seq = hk.PRNGSequence(2) |
| 106 | + y = jr.normal(next(rng_seq), shape=(n, p)) |
| 107 | + data = namedtuple("named_dataset", "y")(y) |
108 | 108 |
|
109 | | - model = make_model(p) |
110 | | - params, losses = train(rng_seq, data, model, n_iter) |
111 | | - plt.plot(losses) |
112 | | - plt.show() |
| 109 | + model = make_model(p) |
| 110 | + params, losses = train(rng_seq, data, model, n_iter) |
| 111 | + plt.plot(losses) |
| 112 | + plt.show() |
113 | 113 |
|
114 | | - y = jr.normal(next(rng_seq), shape=(10, p)) |
115 | | - print(model.apply(params, **{"y": y})) |
| 114 | + y = jr.normal(next(rng_seq), shape=(10, p)) |
| 115 | + print(model.apply(params, **{"y": y})) |
116 | 116 |
|
117 | 117 |
|
118 | 118 | if __name__ == "__main__": |
119 | | - parser = argparse.ArgumentParser() |
120 | | - parser.add_argument("--n-iter", type=int, default=1_000) |
121 | | - args = parser.parse_args() |
122 | | - run(args.n_iter) |
| 119 | + parser = argparse.ArgumentParser() |
| 120 | + parser.add_argument("--n-iter", type=int, default=1_000) |
| 121 | + args = parser.parse_args() |
| 122 | + run(args.n_iter) |
0 commit comments