Open
Description
numpyro and pymc have a zero sum normal distribution based on a zero sum bijector, (see e.g. numpyro zero sum normal and numpyro zero sum transform)).
I was wondering if there is any appetite in adding this to TFP? I have already got a simple port working (needs some changes, in particular maybe allowing variable number of axes to be constrained to sum to zero):
"""ZeroSum bijector."""
import tensorflow as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util
class ZeroSum(bijector.AutoCompositeTensorBijector):
def __init__(self, validate_args=False, name="zero_sum"):
parameters = dict(locals())
super(ZeroSum, self).__init__(
is_constant_jacobian=True,
forward_min_event_ndims=1,
validate_args=validate_args,
parameters=parameters,
name=name,
)
@classmethod
def _parameter_properties(cls, dtype):
return dict()
def _forward(self, x):
n = ps.cast(ps.shape(x)[-1], x.dtype) + 1
sum_vals = tf.reduce_sum(x, axis=-1, keepdims=True)
norm = sum_vals / (ps.sqrt(n) + n)
fill_val = norm - sum_vals / ps.sqrt(n)
out = tf.concat([x, fill_val], axis=-1)
return out - norm
def _inverse(self, y):
normalized_axis = ps.rank(y) - 1
n = ps.cast(ps.shape(y)[normalized_axis], y.dtype)
last = y[..., -1]
sum_vals = -last * ps.sqrt(n)
norm = sum_vals / (ps.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis
return y[(*slice_before, slice(None, -1))] + norm
def _inverse_log_det_jacobian(self, y):
return tf.zeros([], dtype=y.dtype)
def _forward_log_det_jacobian(self, x):
return tf.zeros([], dtype=x.dtype)
def _forward_event_shape(self, input_shape):
return tensorshape_util.concatenate(input_shape[:-1], input_shape[-1] + 1)
def _forward_event_shape_tensor(self, input_shape):
n = ps.shape(input_shape)[-1]
return ps.tensor_scatter_nd_add(input_shape, [[n - 1]], [1])
def _inverse_event_shape(self, input_shape):
return tensorshape_util.concatenate(input_shape[:-1], input_shape[-1] + 1)
def _inverse_event_shape_tensor(self, input_shape):
n = ps.shape(input_shape)[-1]
return ps.tensor_scatter_nd_sub(input_shape, [[n - 1]], [1])
usage:
import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
zero_sum_normal = tfd.TransformedDistribution(
distribution=tfd.MultivariateNormalDiag(loc=0.0, scale_diag=[1.0, 1.0]),
bijector=ZeroSum(),
)
zero_sum_normal
# <tfp.distributions.TransformedDistribution 'zero_sumMultivariateNormalDiag' batch_shape=[] event_shape=[3] dtype=float32>
samples = zero_sum_normal.sample(int(1e7))
np.max(np.abs(np.sum(samples, axis=-1)))
# 4.7683716e-07
np.mean(samples, axis=0)
# array([ 4.8274879e-04, -5.6865485e-04, 8.5900021e-05], dtype=float32)
np.std(samples, axis=0)
# array([0.8100124, 0.8101918, 0.8100392], dtype=float32)
compare to numpyro:
import jax.numpy as jnp
import jax.random as jr
import numpyro.distributions as dist
zero_sum_normal = dist.ZeroSumNormal(scale=jnp.array(1.0), event_shape=[3])
rng = jr.key(123)
samples = zero_sum_normal.sample(rng, sample_shape=(int(1e7),))
jnp.max(jnp.abs(jnp.sum(samples, axis=1)))
# Array(5.9604645e-07, dtype=float32)
jnp.mean(samples, axis=0)
# Array([-1.7739683e-04, -6.7088688e-05, 2.4448565e-04], dtype=float32)
jnp.std(samples, axis=0)
# Array([0.8164292 , 0.81622946, 0.8164195 ], dtype=float32)
If this is something useful, I can work on bits of it over the next couple of weeks, or if someone else wants to take it over, that's great too.
Thanks.
Metadata
Metadata
Assignees
Labels
No labels