Skip to content

Commit c337a7a

Browse files
Jake VanderPlasDistraxDev
authored andcommitted
Replace deprecated jax.core.get_aval with jax.typeof
`core.get_aval` is deprecated as of JAX v0.8.2 and will be removed in a future release. PiperOrigin-RevId: 840774897
1 parent d901057 commit c337a7a

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

distrax/_src/utils/math.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import chex
2121
import jax
22-
from jax import core as jax_core
2322
from jax.custom_derivatives import SymbolicZero
2423
import jax.numpy as jnp
2524

@@ -72,7 +71,7 @@ def multiply_no_nan_jvp(
7271
x, y = primals
7372
x_dot, y_dot = tangents
7473
primal_out = multiply_no_nan(x, y)
75-
primal_aval = jax_core.get_aval(primal_out)
74+
primal_aval = jax.typeof(primal_out)
7675
result_aval = primal_aval.at_least_vspace()
7776
tangent_out_1 = scale_maybe_symbolic(result_aval, x_dot, y)
7877
tangent_out_2 = scale_maybe_symbolic(result_aval, y_dot, x)

0 commit comments

Comments
 (0)