-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
It's unclear to me why the following code does not work as MultivariateNormalDiag supports batch dimensions for loc and scale:
import distrax as dx
import jax
import jax.numpy as jnp
from jax import vmap
@jax.jit
def build():
def single(i):
return dx.MultivariateNormalDiag(jnp.zeros(10), jnp.ones(10))
x = vmap(single)(jnp.arange(10))
return x
dist = build()
dist.locproduces the following error:
Traceback (most recent call last):
File ".../test.py", line 17, in <module>
dist.loc
File ".../python3.12/site-packages/distrax/_src/distributions/mvn_from_bijector.py", line 103, in loc
return jnp.broadcast_to(self._loc, shape=shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2087, in broadcast_to
return util._broadcast_to(array, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../python3.12/site-packages/jax/_src/numpy/util.py", line 422, in _broadcast_to
raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}")
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(10, 10) shape=(10,)
This seems similar to #239
Ah, I see in the README that this distribution is specifically called out for being problematic with vmap.
Red-Portal
Metadata
Metadata
Assignees
Labels
No labels