Skip to content

Commit c22b438

Browse files
committed
docs: Add documentation template
1 parent e1fdfff commit c22b438

File tree

7 files changed

+324
-6
lines changed

7 files changed

+324
-6
lines changed

book.toml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[book]
2+
authors = ["Adrian Seyboldt"]
3+
language = "en"
4+
multilingual = false
5+
src = "docs"
6+
title = "nutpie"

docs/SUMMARY.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Summary
2+
3+
[Introduction](../README.md)
4+
5+
- [Usage with PyMC](./pymc-usage.md)
6+
- [Usage with Stan](./stan-usage.md)
7+
- [Adaptation with normalizing flows](./nf-adapt.md)
8+
- [Benchmarks](./benchmarks.md)

docs/benchmarks.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Benchmarks

docs/nf-adapt.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Adaptation with normalizing flows
2+
3+
**Experimental**

docs/pymc-usage.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Usage with PyMC models

docs/stan-usage.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Usage with Stan models
2+
3+
foobar

python/nutpie/compiled_pyfunc.py

+302-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,304 @@
99
from nutpie.sample import CompiledModel
1010

1111

12+
def make_transform_adapter(*, verbose=False, window_size=2000):
13+
import jax
14+
import equinox as eqx
15+
import jax.numpy as jnp
16+
import flowjax
17+
import flowjax.train
18+
import flowjax.flows
19+
import optax
20+
import traceback
21+
22+
class FisherLoss:
23+
@eqx.filter_jit
24+
def __call__(
25+
self,
26+
params,
27+
static,
28+
x,
29+
condition=None,
30+
key=None,
31+
):
32+
flow = flowjax.train.losses.unwrap(
33+
eqx.combine(params, static, is_leaf=eqx.is_inexact_array)
34+
)
35+
36+
def compute_loss(bijection, draw, grad):
37+
draw, grad, _ = bijection.inverse_gradient_and_val(
38+
draw, grad, jnp.array(0.0)
39+
)
40+
return ((draw + grad) ** 2).sum()
41+
42+
assert x.shape[1] == 2
43+
draws = x[:, 0, :]
44+
grads = x[:, 1, :]
45+
return jnp.log(
46+
jax.vmap(compute_loss, [None, 0, 0])(
47+
flow.bijection, draws, grads
48+
).mean()
49+
)
50+
51+
def _get_untransformed(bijection, draw_trafo, grad_trafo):
52+
bijection = flowjax.train.losses.unwrap(bijection)
53+
draw = bijection.inverse(draw_trafo)
54+
_, pull_grad_fn = jax.vjp(bijection.transform_and_log_det, draw)
55+
(grad,) = pull_grad_fn((grad_trafo, 1.0))
56+
return draw, grad
57+
58+
pull_points = eqx.filter_jit(jax.vmap(_get_untransformed, [None, 0, 0]))
59+
60+
def fit_flow(key, bijection, positions, gradients, **kwargs):
61+
flow = flowjax.flows.Transformed(
62+
flowjax.distributions.StandardNormal(bijection.shape), bijection
63+
)
64+
65+
points = jnp.transpose(jnp.array([positions, gradients]), [1, 0, 2])
66+
67+
key, train_key = jax.random.split(key)
68+
69+
fit, losses, opt_state = flowjax.train.fit_to_data(
70+
key=train_key,
71+
dist=flow,
72+
x=points,
73+
loss_fn=FisherLoss(),
74+
**kwargs,
75+
)
76+
77+
draws_pulled, grads_pulled = pull_points(fit.bijection, positions, gradients)
78+
final_cost = np.log(((draws_pulled + grads_pulled) ** 2).sum(1).mean(0))
79+
return fit, final_cost, opt_state
80+
81+
def make_flow(seed, positions, gradients, *, n_layers):
82+
positions = np.array(positions)
83+
gradients = np.array(gradients)
84+
85+
n_draws, n_dim = positions.shape
86+
87+
if n_dim < 2:
88+
n_layers = 0
89+
90+
assert positions.shape == gradients.shape
91+
assert n_draws > 0
92+
93+
if n_draws == 0:
94+
raise ValueError("No draws")
95+
elif n_draws == 1:
96+
diag = 1 / jnp.abs(gradients[0])
97+
mean = jnp.zeros_like(diag)
98+
else:
99+
diag = jnp.sqrt(positions.std(0) / gradients.std(0))
100+
mean = positions.mean(0) + diag * gradients.mean(0)
101+
102+
key = jax.random.PRNGKey(seed % (2**63))
103+
104+
flows = [
105+
flowjax.flows.Affine(loc=mean, scale=diag),
106+
]
107+
108+
for layer in range(n_layers):
109+
key, key_couple, key_permute, key_init = jax.random.split(key, 4)
110+
111+
scale = flowjax.wrappers.Parameterize(
112+
lambda x: jnp.exp(jnp.arcsinh(x)), jnp.array(0.0)
113+
)
114+
affine = eqx.tree_at(
115+
where=lambda aff: aff.scale,
116+
pytree=flowjax.bijections.Affine(),
117+
replace=scale,
118+
)
119+
120+
coupling = flowjax.bijections.coupling.Coupling(
121+
key_couple,
122+
transformer=affine,
123+
untransformed_dim=n_dim // 2,
124+
dim=n_dim,
125+
nn_activation=jax.nn.gelu,
126+
nn_width=n_dim // 2,
127+
nn_depth=1,
128+
)
129+
130+
if layer == n_layers - 1:
131+
flow = coupling
132+
else:
133+
flow = flowjax.flows._add_default_permute(coupling, n_dim, key_permute)
134+
135+
flows.append(flow)
136+
137+
return flowjax.bijections.Chain(flows[::-1])
138+
139+
@eqx.filter_jit
140+
def _init_from_transformed_position(logp_fn, bijection, transformed_position):
141+
bijection = flowjax.train.losses.unwrap(bijection)
142+
(untransformed_position, logdet), pull_grad = jax.vjp(
143+
bijection.transform_and_log_det, transformed_position
144+
)
145+
logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])(
146+
untransformed_position
147+
)
148+
(transformed_gradient,) = pull_grad((untransformed_gradient, 1.0))
149+
return (
150+
logp,
151+
logdet,
152+
untransformed_position,
153+
untransformed_gradient,
154+
transformed_gradient,
155+
)
156+
157+
@eqx.filter_jit
158+
def _init_from_untransformed_position(logp_fn, bijection, untransformed_position):
159+
logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])(
160+
untransformed_position
161+
)
162+
logdet, transformed_position, transformed_gradient = _inv_transform(
163+
bijection, untransformed_position, untransformed_gradient
164+
)
165+
return (
166+
logp,
167+
logdet,
168+
untransformed_gradient,
169+
transformed_position,
170+
transformed_gradient,
171+
)
172+
173+
@eqx.filter_jit
174+
def _inv_transform(bijection, untransformed_position, untransformed_gradient):
175+
bijection = flowjax.train.losses.unwrap(bijection)
176+
transformed_position, transformed_gradient, logdet = (
177+
bijection.inverse_gradient_and_val(
178+
untransformed_position, untransformed_gradient, 0.0
179+
)
180+
)
181+
return -logdet, transformed_position, transformed_gradient
182+
183+
class TransformAdapter:
184+
def __init__(
185+
self,
186+
seed,
187+
position,
188+
gradient,
189+
chain,
190+
*,
191+
logp_fn,
192+
make_flow_fn,
193+
verbose=False,
194+
window_size=2000,
195+
):
196+
self._logp_fn = logp_fn
197+
self._make_flow_fn = make_flow_fn
198+
self._chain = chain
199+
self._verbose = verbose
200+
self._window_size = window_size
201+
try:
202+
self._bijection = make_flow_fn(seed, [position], [gradient], n_layers=0)
203+
except Exception as e:
204+
print("make_flow", e)
205+
print(traceback.format_exc())
206+
raise
207+
self.index = 0
208+
209+
@property
210+
def transformation_id(self):
211+
return self.index
212+
213+
def update(self, seed, positions, gradients):
214+
self.index += 1
215+
if self._verbose:
216+
print(f"Chain {self._chain}: Total available points: {len(positions)}")
217+
n_draws = len(positions)
218+
if n_draws == 0:
219+
return
220+
try:
221+
if self.index <= 10:
222+
self._bijection = self._make_flow_fn(
223+
seed, positions[-10:], gradients[-10:], n_layers=0
224+
)
225+
return
226+
227+
positions = np.array(positions[-self._window_size :])
228+
gradients = np.array(gradients[-self._window_size :])
229+
230+
assert np.isfinite(positions).all()
231+
assert np.isfinite(gradients).all()
232+
233+
if len(self._bijection.bijections) == 1:
234+
self._bijection = self._make_flow_fn(
235+
seed, positions, gradients, n_layers=8
236+
)
237+
238+
# make_flow might still only return a single trafo if the for 1d problems
239+
if len(self._bijection.bijections) == 1:
240+
return
241+
242+
# TODO don't reuse seed
243+
key = jax.random.PRNGKey(seed % (2**63))
244+
fit, final_cost, _ = fit_flow(
245+
key,
246+
self._bijection,
247+
positions,
248+
gradients,
249+
show_progress=self._verbose,
250+
optimizer=optax.adabelief(1e-3),
251+
batch_size=128,
252+
)
253+
if self._verbose:
254+
print(f"Chain {self._chain}: final cost {final_cost}")
255+
if np.isfinite(final_cost).all():
256+
self._bijection = fit.bijection
257+
else:
258+
self._bijection = self._make_flow_fn(
259+
seed, positions, gradients, n_layers=0
260+
)
261+
except Exception as e:
262+
print("update error:", e)
263+
print(traceback.format_exc())
264+
265+
def init_from_transformed_position(self, transformed_position):
266+
try:
267+
logp, logdet, *arrays = _init_from_transformed_position(
268+
self._logp_fn,
269+
self._bijection,
270+
jnp.array(transformed_position),
271+
)
272+
return float(logp), float(logdet), *[np.array(val) for val in arrays]
273+
except Exception as e:
274+
print(e)
275+
print(traceback.format_exc())
276+
raise
277+
278+
def init_from_untransformed_position(self, untransformed_position):
279+
try:
280+
logp, logdet, *arrays = _init_from_untransformed_position(
281+
self._logp_fn,
282+
self._bijection,
283+
jnp.array(untransformed_position),
284+
)
285+
return float(logp), float(logdet), *[np.array(val) for val in arrays]
286+
except Exception as e:
287+
print(e)
288+
print(traceback.format_exc())
289+
raise
290+
291+
def inv_transform(self, position, gradient):
292+
try:
293+
logdet, *arrays = _inv_transform(
294+
self._bijection, jnp.array(position), jnp.array(gradient)
295+
)
296+
return logdet, *[np.array(val) for val in arrays]
297+
except Exception as e:
298+
print(e)
299+
print(traceback.format_exc())
300+
raise
301+
302+
return partial(
303+
TransformAdapter,
304+
verbose=verbose,
305+
window_size=window_size,
306+
make_flow_fn=make_flow,
307+
)
308+
309+
12310
@dataclass(frozen=True)
13311
class PyFuncModel(CompiledModel):
14312
_make_logp_func: Callable
@@ -59,19 +357,17 @@ def make_expand_func(seed1, seed2, chain):
59357
expand_fn = self._make_expand_func(seed1, seed2, chain)
60358
return partial(expand_fn, **self._shared_data)
61359

62-
if self._make_transform_adapter is not None:
63-
make_transform_adapter = partial(
64-
self._make_transform_adapter, logp_fn=self._raw_logp_fn
65-
)
360+
if self._raw_logp_fn is not None:
361+
make_adapter = partial(make_transform_adapter(), logp_fn=self._raw_logp_fn)
66362
else:
67-
make_transform_adapter = None
363+
make_adapter = None
68364

69365
return _lib.PyModel(
70366
make_logp_func,
71367
make_expand_func,
72368
self._variables,
73369
self.n_dim,
74-
make_transform_adapter,
370+
make_adapter,
75371
)
76372

77373

0 commit comments

Comments
 (0)