@@ -54,7 +54,7 @@ def compute_loss(bijection, draw, grad, logp):
54
54
return cost
55
55
56
56
costs = jax .vmap (compute_loss , [None , 0 , 0 , 0 ])(
57
- flow . bijection ,
57
+ flow ,
58
58
draws ,
59
59
grads ,
60
60
logps ,
@@ -69,7 +69,7 @@ def compute_loss(bijection, draw, grad, logp):
69
69
return cost
70
70
71
71
costs = jax .vmap (compute_loss , [None , 0 , 0 , 0 ])(
72
- flow . bijection ,
72
+ flow ,
73
73
draws ,
74
74
grads ,
75
75
logps ,
@@ -86,7 +86,7 @@ def compute_loss(bijection, draw, grad, logp):
86
86
else :
87
87
88
88
def transform (draw , grad , logp ):
89
- return flow .bijection . inverse_gradient_and_val_ (draw , grad , logp )
89
+ return flow .inverse_gradient_and_val_ (draw , grad , logp )
90
90
91
91
draws , grads , logps = jax .vmap (transform , [0 , 0 , 0 ], (0 , 0 , 0 ))(
92
92
draws , grads , logps
@@ -98,9 +98,7 @@ def transform(draw, grad, logp):
98
98
99
99
100
100
def fit_flow (key , bijection , loss_fn , draws , grads , logps , ** kwargs ):
101
- flow = flowjax .flows .Transformed (
102
- flowjax .distributions .StandardNormal (bijection .shape ), bijection
103
- )
101
+ flow = bijection
104
102
105
103
key , train_key = jax .random .split (key )
106
104
@@ -113,7 +111,7 @@ def fit_flow(key, bijection, loss_fn, draws, grads, logps, **kwargs):
113
111
return_best = True ,
114
112
** kwargs ,
115
113
)
116
- return fit . bijection , losses , losses ["opt_state" ]
114
+ return fit , losses , losses ["opt_state" ]
117
115
118
116
119
117
@eqx .filter_jit
@@ -298,9 +296,7 @@ def update(self, seed, positions, gradients, logps):
298
296
299
297
fit = self ._make_flow_fn (seed , positions , gradients , n_layers = 0 )
300
298
301
- flow = flowjax .flows .Transformed (
302
- flowjax .distributions .StandardNormal (fit .shape ), fit
303
- )
299
+ flow = fit
304
300
params , static = eqx .partition (flow , eqx .is_inexact_array )
305
301
new_loss = self ._loss_fn (params , static , positions , gradients , logps )
306
302
@@ -341,9 +337,7 @@ def update(self, seed, positions, gradients, logps):
341
337
untransformed_dim = self ._untransformed_dim ,
342
338
zero_init = self ._zero_init ,
343
339
)
344
- flow = flowjax .flows .Transformed (
345
- flowjax .distributions .StandardNormal (base .shape ), base
346
- )
340
+ flow = base
347
341
params , static = eqx .partition (flow , eqx .is_inexact_array )
348
342
if self ._verbose :
349
343
print (
@@ -356,9 +350,9 @@ def update(self, seed, positions, gradients, logps):
356
350
self ._loss_fn (
357
351
params ,
358
352
static ,
359
- positions [- 100 :],
360
- gradients [- 100 :],
361
- logps [- 100 :],
353
+ positions [- 128 :],
354
+ gradients [- 128 :],
355
+ logps [- 128 :],
362
356
),
363
357
)
364
358
else :
@@ -392,10 +386,7 @@ def update(self, seed, positions, gradients, logps):
392
386
self ._opt_state = None
393
387
return
394
388
395
- flow = flowjax .flows .Transformed (
396
- flowjax .distributions .StandardNormal (self ._bijection .shape ),
397
- self ._bijection ,
398
- )
389
+ flow = self ._bijection
399
390
params , static = eqx .partition (flow , eqx .is_inexact_array )
400
391
old_loss = self ._loss_fn (
401
392
params , static , positions [- 128 :], gradients [- 128 :], logps [- 128 :]
@@ -420,9 +411,7 @@ def update(self, seed, positions, gradients, logps):
420
411
max_patience = self ._max_patience ,
421
412
)
422
413
423
- flow = flowjax .flows .Transformed (
424
- flowjax .distributions .StandardNormal (fit .shape ), fit
425
- )
414
+ flow = fit
426
415
params , static = eqx .partition (flow , eqx .is_inexact_array )
427
416
new_loss = self ._loss_fn (
428
417
params , static , positions [- 128 :], gradients [- 128 :], logps [- 128 :]
@@ -432,10 +421,7 @@ def update(self, seed, positions, gradients, logps):
432
421
print (f"Chain { self ._chain } : New loss { new_loss } , old loss { old_loss } " )
433
422
434
423
if not np .isfinite (old_loss ):
435
- flow = flowjax .flows .Transformed (
436
- flowjax .distributions .StandardNormal (self ._bijection .shape ),
437
- self ._bijection ,
438
- )
424
+ flow = self ._bijection
439
425
params , static = eqx .partition (flow , eqx .is_inexact_array )
440
426
print (
441
427
self ._loss_fn (
@@ -449,9 +435,7 @@ def update(self, seed, positions, gradients, logps):
449
435
)
450
436
451
437
if not np .isfinite (new_loss ):
452
- flow = flowjax .flows .Transformed (
453
- flowjax .distributions .StandardNormal (fit .shape ), fit
454
- )
438
+ flow = fit
455
439
params , static = eqx .partition (flow , eqx .is_inexact_array )
456
440
print (
457
441
self ._loss_fn (
0 commit comments