Skip to content

Commit 4a480c1

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Only consider Tracers with ShapedArray avals as instances of Array.
This fixes a TODO from #33420. PiperOrigin-RevId: 835283505
1 parent ba4b18b commit 4a480c1

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue):
4848
"""
4949
assert isinstance(
5050
x, (jax.Array, jax_core.ShapedArray, state_types.AbstractLinVal)
51+
) or (
52+
isinstance(x, jax_core.Tracer)
53+
and isinstance(x.aval, state_types.AbstractLinVal)
5154
), type(x)
5255
if isinstance(x, jax.Array):
5356
if dtypes.issubdtype(x.dtype, jax.numpy.bool_):

jaxlib/py_array.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,18 +2201,11 @@ absl::Status PyArray::Register(nb::module_& m) {
22012201
if (tracer_class.ptr() && self.ptr() == base_type.ptr() &&
22022202
PyObject_TypeCheck(x.ptr(), reinterpret_cast<PyTypeObject*>(
22032203
tracer_class.ptr())) != 0) {
2204-
// TODO(phawkins): we would like to change this to use the logic below
2205-
// but it is a somewhat breaking change. Let us defer it to a future
2206-
// PR.
2207-
return true;
2208-
// auto is_traced_array_fn =
2209-
// nb::getattr(x, "_is_traced_array", nb::none());
2210-
// if (!is_traced_array_fn.is_none()) {
2211-
// try {
2212-
// return nb::cast<bool>(is_traced_array_fn());
2213-
// } catch (...) {
2214-
// }
2215-
// }
2204+
auto is_traced_array_fn =
2205+
nb::getattr(x, "_is_traced_array", nb::none());
2206+
if (!is_traced_array_fn.is_none()) {
2207+
return nb::cast<bool>(is_traced_array_fn());
2208+
}
22162209
}
22172210
return false;
22182211
},

0 commit comments

Comments
 (0)