Skip to content

Commit 20a6cc4

Browse files
Jake VanderPlasDistraxDev
authored andcommitted
Fix wrong call signatures for jnp.shape, jnp.size, and jnp.ndim
Passing lists or sequences to these functions is deprecated starting in jax-ml/jax#26641 PiperOrigin-RevId: 730489201
1 parent c02708a commit 20a6cc4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

distrax/_src/distributions/distribution_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def log_prob(self, value):
3636
@property
3737
def event_shape(self):
3838
"""Shape of the events."""
39-
return jnp.shape([])
39+
return np.shape([])
4040

4141

4242
class DummyMultivariateDist(distribution.Distribution):

0 commit comments

Comments
 (0)