Skip to content

Commit 5b6ec5d

Browse files
committed
Don't wrap bijections in Transformed
1 parent 5dcda60 commit 5b6ec5d

File tree

1 file changed

+14
-30
lines changed

1 file changed

+14
-30
lines changed

python/nutpie/transform_adapter.py

+14-30
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def compute_loss(bijection, draw, grad, logp):
5454
return cost
5555

5656
costs = jax.vmap(compute_loss, [None, 0, 0, 0])(
57-
flow.bijection,
57+
flow,
5858
draws,
5959
grads,
6060
logps,
@@ -69,7 +69,7 @@ def compute_loss(bijection, draw, grad, logp):
6969
return cost
7070

7171
costs = jax.vmap(compute_loss, [None, 0, 0, 0])(
72-
flow.bijection,
72+
flow,
7373
draws,
7474
grads,
7575
logps,
@@ -86,7 +86,7 @@ def compute_loss(bijection, draw, grad, logp):
8686
else:
8787

8888
def transform(draw, grad, logp):
89-
return flow.bijection.inverse_gradient_and_val_(draw, grad, logp)
89+
return flow.inverse_gradient_and_val_(draw, grad, logp)
9090

9191
draws, grads, logps = jax.vmap(transform, [0, 0, 0], (0, 0, 0))(
9292
draws, grads, logps
@@ -98,9 +98,7 @@ def transform(draw, grad, logp):
9898

9999

100100
def fit_flow(key, bijection, loss_fn, draws, grads, logps, **kwargs):
101-
flow = flowjax.flows.Transformed(
102-
flowjax.distributions.StandardNormal(bijection.shape), bijection
103-
)
101+
flow = bijection
104102

105103
key, train_key = jax.random.split(key)
106104

@@ -113,7 +111,7 @@ def fit_flow(key, bijection, loss_fn, draws, grads, logps, **kwargs):
113111
return_best=True,
114112
**kwargs,
115113
)
116-
return fit.bijection, losses, losses["opt_state"]
114+
return fit, losses, losses["opt_state"]
117115

118116

119117
@eqx.filter_jit
@@ -298,9 +296,7 @@ def update(self, seed, positions, gradients, logps):
298296

299297
fit = self._make_flow_fn(seed, positions, gradients, n_layers=0)
300298

301-
flow = flowjax.flows.Transformed(
302-
flowjax.distributions.StandardNormal(fit.shape), fit
303-
)
299+
flow = fit
304300
params, static = eqx.partition(flow, eqx.is_inexact_array)
305301
new_loss = self._loss_fn(params, static, positions, gradients, logps)
306302

@@ -341,9 +337,7 @@ def update(self, seed, positions, gradients, logps):
341337
untransformed_dim=self._untransformed_dim,
342338
zero_init=self._zero_init,
343339
)
344-
flow = flowjax.flows.Transformed(
345-
flowjax.distributions.StandardNormal(base.shape), base
346-
)
340+
flow = base
347341
params, static = eqx.partition(flow, eqx.is_inexact_array)
348342
if self._verbose:
349343
print(
@@ -356,9 +350,9 @@ def update(self, seed, positions, gradients, logps):
356350
self._loss_fn(
357351
params,
358352
static,
359-
positions[-100:],
360-
gradients[-100:],
361-
logps[-100:],
353+
positions[-128:],
354+
gradients[-128:],
355+
logps[-128:],
362356
),
363357
)
364358
else:
@@ -392,10 +386,7 @@ def update(self, seed, positions, gradients, logps):
392386
self._opt_state = None
393387
return
394388

395-
flow = flowjax.flows.Transformed(
396-
flowjax.distributions.StandardNormal(self._bijection.shape),
397-
self._bijection,
398-
)
389+
flow = self._bijection
399390
params, static = eqx.partition(flow, eqx.is_inexact_array)
400391
old_loss = self._loss_fn(
401392
params, static, positions[-128:], gradients[-128:], logps[-128:]
@@ -420,9 +411,7 @@ def update(self, seed, positions, gradients, logps):
420411
max_patience=self._max_patience,
421412
)
422413

423-
flow = flowjax.flows.Transformed(
424-
flowjax.distributions.StandardNormal(fit.shape), fit
425-
)
414+
flow = fit
426415
params, static = eqx.partition(flow, eqx.is_inexact_array)
427416
new_loss = self._loss_fn(
428417
params, static, positions[-128:], gradients[-128:], logps[-128:]
@@ -432,10 +421,7 @@ def update(self, seed, positions, gradients, logps):
432421
print(f"Chain {self._chain}: New loss {new_loss}, old loss {old_loss}")
433422

434423
if not np.isfinite(old_loss):
435-
flow = flowjax.flows.Transformed(
436-
flowjax.distributions.StandardNormal(self._bijection.shape),
437-
self._bijection,
438-
)
424+
flow = self._bijection
439425
params, static = eqx.partition(flow, eqx.is_inexact_array)
440426
print(
441427
self._loss_fn(
@@ -449,9 +435,7 @@ def update(self, seed, positions, gradients, logps):
449435
)
450436

451437
if not np.isfinite(new_loss):
452-
flow = flowjax.flows.Transformed(
453-
flowjax.distributions.StandardNormal(fit.shape), fit
454-
)
438+
flow = fit
455439
params, static = eqx.partition(flow, eqx.is_inexact_array)
456440
print(
457441
self._loss_fn(

0 commit comments

Comments
 (0)