Skip to content

Entropy not available when determinant is not constant: Bijector #318

@cdelv

Description

@cdelv

Hi,

I'm using Distrax distributions with some custom bijectors. This is an example of my bijector:

class MaxNormSpace(distrax.Bijector):
    r"""
    **Radial max-norm** constraint for vector actions:
    Scales the radius with a `tanh` squashing while preserving direction.

    **Mapping (vector case,** :math:`x \in \mathbb{R}^d`

    .. math::
        r = \lVert \vec{x} \rVert_2,\qquad
        \hat{u} = \begin{cases}
            \frac{\vec{x}}{r}, & r>0,\\[4pt]
            0, & r=0,
        \end{cases}
        \qquad
        y = s \tanh(r) \hat{u},
        \quad s = (1-\epsilon) \texttt{max_norm}.

    Equivalently, :math:`y = b(r)\,x` with :math:`b(r)= s\,\tanh(r)/r` for :math:`r>0`.


    **Jacobian determinant**

    For an isotropic radial map :math:`f(x)=b(r)` with :math:`x \in \mathbb{R}^d`, the Jacobian
    eigenvalues are :math:`b` (multiplicity d-1) on the tangent subspace and :math:`b + r\,b'(r)` on the radial direction, hence

    .. math::
        \bigl|\det J_f(x)\bigr| = b(r)^{\,d-1}\,\bigl(b(r)+r\,b'(r)\bigr)
        = s^d \left(\frac{\tanh r}{r}\right)^{\!d-1} sech^2 r.

    Therefore

    .. math::
        \log\lvert\det J_f(x)\rvert
        = d\log s + (d-1)\bigl(\log\tanh r - \log r\bigr) + \log( sech^2 r),

    We use the stable identity :math:`\log(sech^2 z)=2 [\log 2 - z - softplus(-2z)]`,
    which we apply for good numerical behavior.

    Near :math:`r\approx 0`, we use the second-order expansion.

    .. math::
        \log\lvert\det J_f(x)\rvert \approx d\log s - \tfrac{2}{3} r^2

    to avoid division by :math:`r`.

    Parameters
    ----------
    -max_norm : float
        target radius \(s\) after squashing (default 1.0). We actually use \(s=(1-\varepsilon)\,\texttt{max\_norm}\) to avoid the exact boundary.

    -eps : float
        numerical safety margin used near \(r=0\) and \(r\to\infty\).

    -event_ndims_in : int
        dimensionality of a *single event* seen by the bijector (defaults to 0 for a scalar transform).

    -event_ndims_out : Optional[int]
        standard Distrax/TFP bijector flags.

    -is_constant_jacobian : bool
        standard Distrax/TFP bijector flags.

    -is_constant_log_det : bool
        standard Distrax/TFP bijector flags.

    Note
    ----------
    This bijector is **vector-valued** with ``event_ndims_in = 1`` (i.e., it operates on length-\(d\) action vectors as a
    single event). Do **not** wrap it in `Block` unless you intend to apply it independently to multiple last-axis blocks.
    """

    __slots__ = ()

    def __init__(
        self,
        max_norm: float = 1.0,
        eps: float = 1e-6,
        event_ndims_in: int = 1,
        event_ndims_out: Optional[int] = None,
        is_constant_jacobian: bool = False,
        is_constant_log_det: Optional[bool] = None,
    ):
        super().__init__(
            event_ndims_in=event_ndims_in,
            event_ndims_out=event_ndims_out,
            is_constant_jacobian=is_constant_jacobian,
            is_constant_log_det=is_constant_log_det,
        )
        self.eps = float(eps)
        self.max_norm = float(max_norm)

    @property
    def kws(self) -> Dict:
        return dict(
            max_norm=self.max_norm,
            eps=self.eps,
            event_ndims_in=self.event_ndims_in,
            event_ndims_out=self.event_ndims_out,
            is_constant_jacobian=self.is_constant_jacobian,
            is_constant_log_det=self.is_constant_log_det,
        )

    @staticmethod
    @partial(jax.named_call, name="MaxNormSpace.sec2_log")
    def sec2_log(r):
        # r is scalar radius
        return 2 * (jnp.log(2.0) - r - jax.nn.softplus(-2.0 * r))

    @partial(jax.named_call, name="MaxNormSpace.forward_log_det_jacobian")
    def forward_log_det_jacobian(self, x: Array) -> jax.Array:
        r = jnp.linalg.norm(x, axis=-1)  # shape (...,)
        x = jnp.atleast_1d(x)  # ensures x.ndim >= 1
        d = jnp.asarray(x.shape[-1], x.dtype)  # scalar, works under jit

        # Stable pieces
        log_s = jnp.log((1.0 - self.eps) * self.max_norm + self.eps)
        log_tanh_r = jnp.log(jnp.tanh(r) + self.eps)
        log_r = jnp.log(r + self.eps)
        log_sech2_r = MaxNormSpace.sec2_log(r)

        main = d * log_s + (d - 1.0) * (log_tanh_r - log_r) + log_sech2_r
        small = d * log_s - (2.0 / 3.0) * (r * r)
        return jnp.where(r < self.eps, small, main)

    @partial(jax.named_call, name="MaxNormSpace.forward_and_log_det")
    def forward_and_log_det(self, x: Array) -> Tuple[jax.Array, jax.Array]:
        r = jnp.linalg.norm(x, axis=-1, keepdims=True)
        unit = jnp.where(r > 0.0, x / r, jnp.zeros_like(x))
        y = (1.0 - self.eps) * self.max_norm * jnp.tanh(r) * unit
        return y, self.forward_log_det_jacobian(x)

    @partial(jax.named_call, name="MaxNormSpace.inverse_and_log_det")
    def inverse_and_log_det(self, y: Array) -> Tuple[jax.Array, jax.Array]:
        r = jnp.linalg.norm(y, axis=-1, keepdims=True)
        u = (r / ((1.0 - self.eps) * self.max_norm)).clip(
            -1.0 + self.eps, 1.0 - self.eps
        )
        unit = jnp.where(r > 0.0, y / r, jnp.zeros_like(y))
        x = jnp.arctanh(u) * unit
        return x, -self.forward_log_det_jacobian(x)

    def same_as(self, other: distrax.Bijector) -> bool:
        return type(other) is MaxNormSpace

This bijector is a variation of the Tanh bijector for multivariate distributions that tries to bound the output norm. When I compute the entropy of the distribution, I get this error:

NotImplementedError: entropy is not implemented for this transformed distribution, because its bijector's Jacobian determinant is not known to be constant.

It seems that I can "cheat" and define is_constant_jacobian = True, and the code would work, but the result would be wrong. What do I have to do to make entropy work? Do I need to define some extra methods to patch the problem? Or do I have to override the entropy() method?

I appreciate your help.

Best regards.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions