17
17
from paramax .wrappers import AbstractUnwrappable
18
18
19
19
20
- _NN_ACTIVATION = jax .nn .leaky_relu
20
+ _NN_ACTIVATION = jax .nn .gelu
21
+
21
22
22
23
def _generate_sequences (k , r_vals ):
23
24
"""
@@ -324,6 +325,7 @@ class AsymmetricAffine(bijections.AbstractBijection):
324
325
scale: Scale parameter σ (positive)
325
326
theta: Asymmetry parameter θ (positive)
326
327
"""
328
+
327
329
shape : tuple [int , ...] = ()
328
330
cond_shape : ClassVar [None ] = None
329
331
loc : Array
@@ -340,6 +342,7 @@ def __init__(
340
342
* (arraylike_to_array (a , dtype = float ) for a in (loc , scale , theta )),
341
343
)
342
344
self .shape = scale .shape
345
+ assert self .shape == ()
343
346
self .scale = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
344
347
self .theta = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
345
348
@@ -348,17 +351,18 @@ def _log_derivative_f(self, x, mu, sigma, theta):
348
351
theta = jnp .log (theta )
349
352
350
353
sinh_theta = jnp .sinh (theta )
351
- #sinh_theta = (theta - 1 / theta) / 2
354
+ # sinh_theta = (theta - 1 / theta) / 2
352
355
cosh_theta = jnp .cosh (theta )
353
- #cosh_theta = (theta + 1 / theta) / 2
356
+ # cosh_theta = (theta + 1 / theta) / 2
354
357
numerator = sinh_theta * x * (abs_x + 2.0 )
355
- denominator = (abs_x + 1.0 )** 2
358
+ denominator = (abs_x + 1.0 ) ** 2
356
359
term = numerator / denominator
357
360
dy_dx = sigma * (cosh_theta + term )
358
361
return jnp .log (dy_dx )
359
362
360
- def transform_and_log_det (self , x : ArrayLike , condition : ArrayLike | None = None ) -> tuple [Array , Array ]:
361
-
363
+ def transform_and_log_det (
364
+ self , x : ArrayLike , condition : ArrayLike | None = None
365
+ ) -> tuple [Array , Array ]:
362
366
def transform (x , mu , sigma , theta ):
363
367
weight = (jax .nn .soft_sign (x ) + 1 ) / 2
364
368
z = x * sigma
@@ -372,17 +376,22 @@ def transform(x, mu, sigma, theta):
372
376
y = transform (x , mu , sigma , theta )
373
377
logjac = self ._log_derivative_f (x , mu , sigma , theta )
374
378
return y , logjac .sum ()
379
+ # y, jac = jax.value_and_grad(transform, argnums=0)(x, mu, sigma, theta)
380
+ # return y, jnp.log(jac)
375
381
376
- def inverse_and_log_det (self , y : ArrayLike , condition : ArrayLike | None = None ) -> tuple [Array , Array ]:
377
-
382
+ def inverse_and_log_det (
383
+ self , y : ArrayLike , condition : ArrayLike | None = None
384
+ ) -> tuple [Array , Array ]:
378
385
def inverse (y , mu , sigma , theta ):
379
386
delta = y - mu
380
387
inv_theta = 1 / theta
381
388
382
389
# Case 1: y >= mu (delta >= 0)
383
390
a = sigma * (theta + inv_theta )
384
- discriminant_pos = jnp .square (a - 2.0 * delta ) + 16.0 * sigma * theta * delta
385
- discriminant_pos = jnp .where (discriminant_pos < 0 , 1. , discriminant_pos )
391
+ discriminant_pos = (
392
+ jnp .square (a - 2.0 * delta ) + 16.0 * sigma * theta * delta
393
+ )
394
+ discriminant_pos = jnp .where (discriminant_pos < 0 , 1.0 , discriminant_pos )
386
395
sqrt_pos = jnp .sqrt (discriminant_pos )
387
396
numerator_pos = 2.0 * delta - a + sqrt_pos
388
397
denominator_pos = 4.0 * sigma * theta
@@ -391,8 +400,10 @@ def inverse(y, mu, sigma, theta):
391
400
# Case 2: y < mu (delta < 0)
392
401
sigma_part = sigma * (1.0 + theta * theta )
393
402
term2 = 2.0 * delta * theta
394
- inside_sqrt_neg = jnp .square (sigma_part + term2 ) - 16.0 * sigma * delta * theta
395
- inside_sqrt_neg = jnp .where (inside_sqrt_neg < 0 , 1. , inside_sqrt_neg )
403
+ inside_sqrt_neg = (
404
+ jnp .square (sigma_part + term2 ) - 16.0 * sigma * delta * theta
405
+ )
406
+ inside_sqrt_neg = jnp .where (inside_sqrt_neg < 0 , 1.0 , inside_sqrt_neg )
396
407
sqrt_neg = jnp .sqrt (inside_sqrt_neg )
397
408
numerator_neg = sigma_part + term2 - sqrt_neg
398
409
denominator_neg = 4.0 * sigma
@@ -407,6 +418,8 @@ def inverse(y, mu, sigma, theta):
407
418
x = inverse (y , mu , sigma , theta )
408
419
logjac = self ._log_derivative_f (x , mu , sigma , theta )
409
420
return x , - logjac .sum ()
421
+ # x, jac = jax.value_and_grad(inverse, argnums=0)(y, mu, sigma, theta)
422
+ # return x, jnp.log(jac)
410
423
411
424
412
425
class MvScale (bijections .AbstractBijection ):
@@ -499,7 +512,6 @@ def __init__(
499
512
self .requires_vmap = False
500
513
conditioner_output_size = num_params
501
514
502
-
503
515
self .transformer_constructor = constructor
504
516
self .untransformed_dim = untransformed_dim
505
517
self .dim = dim
@@ -509,7 +521,9 @@ def __init__(
509
521
if conditioner is None :
510
522
conditioner = eqx .nn .MLP (
511
523
in_size = (
512
- untransformed_dim if cond_dim is None else untransformed_dim + cond_dim
524
+ untransformed_dim
525
+ if cond_dim is None
526
+ else untransformed_dim + cond_dim
513
527
),
514
528
out_size = conditioner_output_size ,
515
529
width_size = nn_width ,
@@ -542,7 +556,9 @@ def _flat_params_to_transformer(self, params: Array):
542
556
if self .requires_vmap :
543
557
dim = self .dim - self .untransformed_dim
544
558
transformer_params = jnp .reshape (params , (dim , - 1 ))
545
- transformer = eqx .filter_vmap (self .transformer_constructor )(transformer_params )
559
+ transformer = eqx .filter_vmap (self .transformer_constructor )(
560
+ transformer_params
561
+ )
546
562
return bijections .Vmap (transformer , in_axes = eqx .if_array (0 ))
547
563
else :
548
564
transformer = self .transformer_constructor (params )
@@ -612,7 +628,7 @@ def make_elemwise(key, loc):
612
628
replace = theta ,
613
629
)
614
630
615
- return affine
631
+ return bijections . Invert ( affine )
616
632
617
633
def make (key ):
618
634
keys = jax .random .split (key , count + 1 )
@@ -626,11 +642,9 @@ def make(key):
626
642
return bijections .Vmap (make_affine , in_axes = eqx .if_array (0 ))
627
643
628
644
629
- def make_coupling (key , dim , n_untransformed , ** kwargs ):
645
+ def make_coupling (key , dim , n_untransformed , * , inner_mvscale = False , * *kwargs ):
630
646
n_transformed = dim - n_untransformed
631
647
632
- mvscale = make_mvscale (key , n_transformed , 1 , randomize_base = True )
633
-
634
648
nn_width = kwargs .get ("nn_width" , None )
635
649
nn_depth = kwargs .get ("nn_depth" , None )
636
650
@@ -646,12 +660,11 @@ def make_coupling(key, dim, n_untransformed, **kwargs):
646
660
else :
647
661
nn_depth = len (nn_width )
648
662
649
- transformer = bijections .Chain (
650
- [
651
- make_elemwise_trafo (key , n_transformed , count = 3 ),
652
- #mvscale,
653
- ]
654
- )
663
+ transformer = make_elemwise_trafo (key , n_transformed , count = 3 )
664
+
665
+ if inner_mvscale :
666
+ mvscale = make_mvscale (key , n_transformed , 1 , randomize_base = True )
667
+ transformer = bijections .Chain ([transformer , mvscale ])
655
668
656
669
def make_mlp (out_size ):
657
670
if isinstance (nn_width , tuple ):
0 commit comments