Skip to content

Commit 50a8094

Browse files
Jake VanderPlasDistraxDev
authored andcommitted
Replace usage of xla.abstractify with core.get_aval where feasible.
PiperOrigin-RevId: 730598559
1 parent c02708a commit 50a8094

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

distrax/_src/utils/jittable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ def _is_jax_data(x):
6161
if isinstance(x, (bool, int, float)) or x is None:
6262
return False
6363

64-
# Otherwise, try to make it into a tracer. If it succeeds, then it's JAX data.
64+
# Otherwise, attempt to trace the value. If it succeeds, then it's JAX data.
6565
try:
66-
jax.interpreters.xla.abstractify(x)
66+
jax.eval_shape(lambda x: x, x)
6767
return True
6868
except TypeError:
6969
return False

0 commit comments

Comments
 (0)