Skip to content

Commit da887e8

Browse files
committed
Formatting and minor changes to verbose output
1 parent 1ec988a commit da887e8

File tree

4 files changed

+109
-58
lines changed

4 files changed

+109
-58
lines changed

docs/nf-adapt.qmd

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ compiled = compiled.with_transform_adapt(
8585
8686
# Sample with normalizing flow adaptation
8787
trace_nf = nutpie.sample(compiled, transform_adapt=True, seed=1, chains=2, cores=1)
88-
assert trace_no_nf.sample_stats.diverging.sum() == 0
89-
assert (arviz.ess(trace_no_nf) > 500).all().to_array().all()
88+
assert trace_nf.sample_stats.diverging.sum() == 0
89+
assert (arviz.ess(trace_nf) > 500).all().to_array().all()
9090
```
9191

9292
The flow adaptation occurs during warmup, so the number of warmup draws should

python/nutpie/compile_pymc.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from importlib.util import find_spec
99
from math import prod
1010
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
11-
import threading
1211

1312
import numpy as np
1413
import pandas as pd
@@ -645,7 +644,7 @@ def _make_functions(
645644
"""
646645
import pytensor
647646
import pytensor.tensor as pt
648-
from pymc.pytensorf import compile_pymc
647+
from pymc.pytensorf import compile as compile_pymc
649648

650649
shapes = _compute_shapes(model)
651650

python/nutpie/normalizing_flow.py

+38-25
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from paramax.wrappers import AbstractUnwrappable
1818

1919

20-
_NN_ACTIVATION = jax.nn.leaky_relu
20+
_NN_ACTIVATION = jax.nn.gelu
21+
2122

2223
def _generate_sequences(k, r_vals):
2324
"""
@@ -324,6 +325,7 @@ class AsymmetricAffine(bijections.AbstractBijection):
324325
scale: Scale parameter σ (positive)
325326
theta: Asymmetry parameter θ (positive)
326327
"""
328+
327329
shape: tuple[int, ...] = ()
328330
cond_shape: ClassVar[None] = None
329331
loc: Array
@@ -340,6 +342,7 @@ def __init__(
340342
*(arraylike_to_array(a, dtype=float) for a in (loc, scale, theta)),
341343
)
342344
self.shape = scale.shape
345+
assert self.shape == ()
343346
self.scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))
344347
self.theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))
345348

@@ -348,17 +351,18 @@ def _log_derivative_f(self, x, mu, sigma, theta):
348351
theta = jnp.log(theta)
349352

350353
sinh_theta = jnp.sinh(theta)
351-
#sinh_theta = (theta - 1 / theta) / 2
354+
# sinh_theta = (theta - 1 / theta) / 2
352355
cosh_theta = jnp.cosh(theta)
353-
#cosh_theta = (theta + 1 / theta) / 2
356+
# cosh_theta = (theta + 1 / theta) / 2
354357
numerator = sinh_theta * x * (abs_x + 2.0)
355-
denominator = (abs_x + 1.0)**2
358+
denominator = (abs_x + 1.0) ** 2
356359
term = numerator / denominator
357360
dy_dx = sigma * (cosh_theta + term)
358361
return jnp.log(dy_dx)
359362

360-
def transform_and_log_det(self, x: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]:
361-
363+
def transform_and_log_det(
364+
self, x: ArrayLike, condition: ArrayLike | None = None
365+
) -> tuple[Array, Array]:
362366
def transform(x, mu, sigma, theta):
363367
weight = (jax.nn.soft_sign(x) + 1) / 2
364368
z = x * sigma
@@ -372,17 +376,22 @@ def transform(x, mu, sigma, theta):
372376
y = transform(x, mu, sigma, theta)
373377
logjac = self._log_derivative_f(x, mu, sigma, theta)
374378
return y, logjac.sum()
379+
# y, jac = jax.value_and_grad(transform, argnums=0)(x, mu, sigma, theta)
380+
# return y, jnp.log(jac)
375381

376-
def inverse_and_log_det(self, y: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]:
377-
382+
def inverse_and_log_det(
383+
self, y: ArrayLike, condition: ArrayLike | None = None
384+
) -> tuple[Array, Array]:
378385
def inverse(y, mu, sigma, theta):
379386
delta = y - mu
380387
inv_theta = 1 / theta
381388

382389
# Case 1: y >= mu (delta >= 0)
383390
a = sigma * (theta + inv_theta)
384-
discriminant_pos = jnp.square(a - 2.0 * delta) + 16.0 * sigma * theta * delta
385-
discriminant_pos = jnp.where(discriminant_pos < 0, 1., discriminant_pos)
391+
discriminant_pos = (
392+
jnp.square(a - 2.0 * delta) + 16.0 * sigma * theta * delta
393+
)
394+
discriminant_pos = jnp.where(discriminant_pos < 0, 1.0, discriminant_pos)
386395
sqrt_pos = jnp.sqrt(discriminant_pos)
387396
numerator_pos = 2.0 * delta - a + sqrt_pos
388397
denominator_pos = 4.0 * sigma * theta
@@ -391,8 +400,10 @@ def inverse(y, mu, sigma, theta):
391400
# Case 2: y < mu (delta < 0)
392401
sigma_part = sigma * (1.0 + theta * theta)
393402
term2 = 2.0 * delta * theta
394-
inside_sqrt_neg = jnp.square(sigma_part + term2) - 16.0 * sigma * delta * theta
395-
inside_sqrt_neg = jnp.where(inside_sqrt_neg < 0, 1., inside_sqrt_neg)
403+
inside_sqrt_neg = (
404+
jnp.square(sigma_part + term2) - 16.0 * sigma * delta * theta
405+
)
406+
inside_sqrt_neg = jnp.where(inside_sqrt_neg < 0, 1.0, inside_sqrt_neg)
396407
sqrt_neg = jnp.sqrt(inside_sqrt_neg)
397408
numerator_neg = sigma_part + term2 - sqrt_neg
398409
denominator_neg = 4.0 * sigma
@@ -407,6 +418,8 @@ def inverse(y, mu, sigma, theta):
407418
x = inverse(y, mu, sigma, theta)
408419
logjac = self._log_derivative_f(x, mu, sigma, theta)
409420
return x, -logjac.sum()
421+
# x, jac = jax.value_and_grad(inverse, argnums=0)(y, mu, sigma, theta)
422+
# return x, jnp.log(jac)
410423

411424

412425
class MvScale(bijections.AbstractBijection):
@@ -499,7 +512,6 @@ def __init__(
499512
self.requires_vmap = False
500513
conditioner_output_size = num_params
501514

502-
503515
self.transformer_constructor = constructor
504516
self.untransformed_dim = untransformed_dim
505517
self.dim = dim
@@ -509,7 +521,9 @@ def __init__(
509521
if conditioner is None:
510522
conditioner = eqx.nn.MLP(
511523
in_size=(
512-
untransformed_dim if cond_dim is None else untransformed_dim + cond_dim
524+
untransformed_dim
525+
if cond_dim is None
526+
else untransformed_dim + cond_dim
513527
),
514528
out_size=conditioner_output_size,
515529
width_size=nn_width,
@@ -542,7 +556,9 @@ def _flat_params_to_transformer(self, params: Array):
542556
if self.requires_vmap:
543557
dim = self.dim - self.untransformed_dim
544558
transformer_params = jnp.reshape(params, (dim, -1))
545-
transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params)
559+
transformer = eqx.filter_vmap(self.transformer_constructor)(
560+
transformer_params
561+
)
546562
return bijections.Vmap(transformer, in_axes=eqx.if_array(0))
547563
else:
548564
transformer = self.transformer_constructor(params)
@@ -612,7 +628,7 @@ def make_elemwise(key, loc):
612628
replace=theta,
613629
)
614630

615-
return affine
631+
return bijections.Invert(affine)
616632

617633
def make(key):
618634
keys = jax.random.split(key, count + 1)
@@ -626,11 +642,9 @@ def make(key):
626642
return bijections.Vmap(make_affine, in_axes=eqx.if_array(0))
627643

628644

629-
def make_coupling(key, dim, n_untransformed, **kwargs):
645+
def make_coupling(key, dim, n_untransformed, *, inner_mvscale=False, **kwargs):
630646
n_transformed = dim - n_untransformed
631647

632-
mvscale = make_mvscale(key, n_transformed, 1, randomize_base=True)
633-
634648
nn_width = kwargs.get("nn_width", None)
635649
nn_depth = kwargs.get("nn_depth", None)
636650

@@ -646,12 +660,11 @@ def make_coupling(key, dim, n_untransformed, **kwargs):
646660
else:
647661
nn_depth = len(nn_width)
648662

649-
transformer = bijections.Chain(
650-
[
651-
make_elemwise_trafo(key, n_transformed, count=3),
652-
#mvscale,
653-
]
654-
)
663+
transformer = make_elemwise_trafo(key, n_transformed, count=3)
664+
665+
if inner_mvscale:
666+
mvscale = make_mvscale(key, n_transformed, 1, randomize_base=True)
667+
transformer = bijections.Chain([transformer, mvscale])
655668

656669
def make_mlp(out_size):
657670
if isinstance(nn_width, tuple):

0 commit comments

Comments
 (0)