Skip to content

Commit 1b6e098

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Only consider Tracers with ShapedArray avals as instances of Array.
This fixes a TODO from #33420. PiperOrigin-RevId: 861851484
1 parent 9f327f6 commit 1b6e098

File tree

3 files changed

+11
-12
lines changed

3 files changed

+11
-12
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2020
* Bug fixes:
2121
* Deprecations:
2222
* Changes:
23+
* JAX tracers that are not of `Array` type (e.g., of `Ref` type) will no
24+
longer report themselves to be instances of `Array`.
25+
2326

2427
## JAX 0.9.0 (January 20, 2026)
2528

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue):
4848
"""
4949
assert isinstance(
5050
x, (jax.Array, jax_core.ShapedArray, state_types.AbstractLinVal)
51+
) or (
52+
isinstance(x, jax_core.Tracer)
53+
and isinstance(x.aval, state_types.AbstractLinVal)
5154
), type(x)
5255
if isinstance(x, jax.Array):
5356
if dtypes.issubdtype(x.dtype, jax.numpy.bool_):

jaxlib/py_array.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2221,18 +2221,11 @@ absl::Status PyArray::Register(nb::module_& m) {
22212221
if (tracer_class.ptr() && self.ptr() == base_type.ptr() &&
22222222
PyObject_TypeCheck(x.ptr(), reinterpret_cast<PyTypeObject*>(
22232223
tracer_class.ptr())) != 0) {
2224-
// TODO(phawkins): we would like to change this to use the logic below
2225-
// but it is a somewhat breaking change. Let us defer it to a future
2226-
// PR.
2227-
return true;
2228-
// auto is_traced_array_fn =
2229-
// nb::getattr(x, "_is_traced_array", nb::none());
2230-
// if (!is_traced_array_fn.is_none()) {
2231-
// try {
2232-
// return nb::cast<bool>(is_traced_array_fn());
2233-
// } catch (...) {
2234-
// }
2235-
// }
2224+
auto is_traced_array_fn =
2225+
nb::getattr(x, "_is_traced_array", nb::none());
2226+
if (!is_traced_array_fn.is_none()) {
2227+
return nb::cast<bool>(is_traced_array_fn());
2228+
}
22362229
}
22372230
return false;
22382231
},

0 commit comments

Comments
 (0)