Skip to content

MultivariateNormalDiag vmap issue #276

@haydn-jones

Description

@haydn-jones

It's unclear to me why the following code does not work as MultivariateNormalDiag supports batch dimensions for loc and scale:

import distrax as dx
import jax
import jax.numpy as jnp
from jax import vmap


@jax.jit
def build():
    def single(i):
        return dx.MultivariateNormalDiag(jnp.zeros(10), jnp.ones(10))

    x = vmap(single)(jnp.arange(10))
    return x


dist = build()
dist.loc

produces the following error:

Traceback (most recent call last):
  File ".../test.py", line 17, in <module>
    dist.loc
  File ".../python3.12/site-packages/distrax/_src/distributions/mvn_from_bijector.py", line 103, in loc
    return jnp.broadcast_to(self._loc, shape=shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2087, in broadcast_to
    return util._broadcast_to(array, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../python3.12/site-packages/jax/_src/numpy/util.py", line 422, in _broadcast_to
    raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}")
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(10, 10) shape=(10,)

This seems similar to #239

Ah, I see in the README that this distribution is specifically called out for being problematic with vmap.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions