We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c02708a commit aeb76d0Copy full SHA for aeb76d0
distrax/_src/utils/jittable.py
@@ -61,9 +61,9 @@ def _is_jax_data(x):
61
if isinstance(x, (bool, int, float)) or x is None:
62
return False
63
64
- # Otherwise, try to make it into a tracer. If it succeeds, then it's JAX data.
+ # Otherwise, attempt to trace the value. If it succeeds, then it's JAX data.
65
try:
66
- jax.interpreters.xla.abstractify(x)
+ jax.eval_shape(lambda x: x, x)
67
return True
68
except TypeError:
69
0 commit comments