Skip to content

ZeroSum bijector and ZeroSumNormal distribution #1980

Open
@jeffpollock9

Description

@jeffpollock9

numpyro and pymc have a zero sum normal distribution based on a zero sum bijector, (see e.g. numpyro zero sum normal and numpyro zero sum transform)).

I was wondering if there is any appetite in adding this to TFP? I have already got a simple port working (needs some changes, in particular maybe allowing variable number of axes to be constrained to sum to zero):

"""ZeroSum bijector."""

import tensorflow as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util


class ZeroSum(bijector.AutoCompositeTensorBijector):

    def __init__(self, validate_args=False, name="zero_sum"):
        parameters = dict(locals())
        super(ZeroSum, self).__init__(
            is_constant_jacobian=True,
            forward_min_event_ndims=1,
            validate_args=validate_args,
            parameters=parameters,
            name=name,
        )

    @classmethod
    def _parameter_properties(cls, dtype):
        return dict()

    def _forward(self, x):
        n = ps.cast(ps.shape(x)[-1], x.dtype) + 1
        sum_vals = tf.reduce_sum(x, axis=-1, keepdims=True)
        norm = sum_vals / (ps.sqrt(n) + n)
        fill_val = norm - sum_vals / ps.sqrt(n)
        out = tf.concat([x, fill_val], axis=-1)
        return out - norm

    def _inverse(self, y):
        normalized_axis = ps.rank(y) - 1
        n = ps.cast(ps.shape(y)[normalized_axis], y.dtype)
        last = y[..., -1]
        sum_vals = -last * ps.sqrt(n)
        norm = sum_vals / (ps.sqrt(n) + n)
        slice_before = (slice(None, None),) * normalized_axis
        return y[(*slice_before, slice(None, -1))] + norm

    def _inverse_log_det_jacobian(self, y):
        return tf.zeros([], dtype=y.dtype)

    def _forward_log_det_jacobian(self, x):
        return tf.zeros([], dtype=x.dtype)

    def _forward_event_shape(self, input_shape):
        return tensorshape_util.concatenate(input_shape[:-1], input_shape[-1] + 1)

    def _forward_event_shape_tensor(self, input_shape):
        n = ps.shape(input_shape)[-1]
        return ps.tensor_scatter_nd_add(input_shape, [[n - 1]], [1])

    def _inverse_event_shape(self, input_shape):
        return tensorshape_util.concatenate(input_shape[:-1], input_shape[-1] + 1)

    def _inverse_event_shape_tensor(self, input_shape):
        n = ps.shape(input_shape)[-1]
        return ps.tensor_scatter_nd_sub(input_shape, [[n - 1]], [1])

usage:

import numpy as np
import tensorflow_probability as tfp

tfd = tfp.distributions

zero_sum_normal = tfd.TransformedDistribution(
    distribution=tfd.MultivariateNormalDiag(loc=0.0, scale_diag=[1.0, 1.0]),
    bijector=ZeroSum(),
)
zero_sum_normal
# <tfp.distributions.TransformedDistribution 'zero_sumMultivariateNormalDiag' batch_shape=[] event_shape=[3] dtype=float32>

samples = zero_sum_normal.sample(int(1e7))

np.max(np.abs(np.sum(samples, axis=-1)))
# 4.7683716e-07

np.mean(samples, axis=0)
# array([ 4.8274879e-04, -5.6865485e-04,  8.5900021e-05], dtype=float32)

np.std(samples, axis=0)
# array([0.8100124, 0.8101918, 0.8100392], dtype=float32)

compare to numpyro:

import jax.numpy as jnp
import jax.random as jr
import numpyro.distributions as dist

zero_sum_normal = dist.ZeroSumNormal(scale=jnp.array(1.0), event_shape=[3])

rng = jr.key(123)

samples = zero_sum_normal.sample(rng, sample_shape=(int(1e7),))

jnp.max(jnp.abs(jnp.sum(samples, axis=1)))
# Array(5.9604645e-07, dtype=float32)

jnp.mean(samples, axis=0)
# Array([-1.7739683e-04, -6.7088688e-05,  2.4448565e-04], dtype=float32)

jnp.std(samples, axis=0)
# Array([0.8164292 , 0.81622946, 0.8164195 ], dtype=float32)

If this is something useful, I can work on bits of it over the next couple of weeks, or if someone else wants to take it over, that's great too.

Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions