Skip to content

Commit d3ed455

Browse files
ICGogGoogle-ML-Automation
authored andcommitted
Move array_garbage_collection_guard logic from PyArray_tp_clear to PyArray_tp_finalize.
PiperOrigin-RevId: 859832848
1 parent d8f3b88 commit d3ed455

File tree

5 files changed

+62
-9
lines changed

5 files changed

+62
-9
lines changed

docs/jax.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Configuration
4242
.. autosummary::
4343
:toctree: _autosummary
4444

45+
array_garbage_collection_guard
4546
config
4647
check_tracer_leaks
4748
checking_leaks

jax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
6969
legacy_prng_key as legacy_prng_key,
7070
threefry_partitionable as threefry_partitionable,
71+
array_garbage_collection_guard as array_garbage_collection_guard,
7172
transfer_guard as transfer_guard,
7273
transfer_guard_host_to_device as transfer_guard_host_to_device,
7374
transfer_guard_device_to_device as transfer_guard_device_to_device,

jax/_src/test_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def get_output() -> str:
227227
f.seek(0)
228228
captured = f.read()
229229
os.dup2(original_fd, fp.fileno())
230+
os.close(original_fd)
230231

231232

232233
capture_stdout = partial(_capture_output, sys.stdout)

jaxlib/py_array.cc

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,11 @@ extern "C" int PyArray_tp_traverse(PyObject* self, visitproc visit, void* arg) {
332332
return 0;
333333
}
334334

335-
// dynamic_attr: Allow the GC to clear the dictionary.
336-
extern "C" int PyArray_tp_clear(PyObject* self) {
335+
extern "C" void PyArray_tp_finalize(PyObject* self) {
336+
// This method assumes that `PyObject_CallFinalizerFromDealloc` is not called
337+
// from `PyArray_tp_dealloc`. If this assumption is violated, then the garbage
338+
// collector guard would trigger for an array deallocated via reference
339+
// counting.
337340
switch (auto guard_level = GetGarbageCollectArrayGuard(); guard_level) {
338341
case GarbageCollectionGuardLevel::kAllow:
339342
break;
@@ -361,13 +364,28 @@ extern "C" int PyArray_tp_clear(PyObject* self) {
361364
if (guard_level == GarbageCollectionGuardLevel::kFatal) {
362365
Py_FatalError(error_msg.c_str());
363366
} else {
367+
#if PY_VERSION_HEX < 0x030C0000
368+
PyObject *err_type, *err_value, *err_traceback;
369+
PyErr_Fetch(&err_type, &err_value, &err_traceback);
370+
#else
371+
PyObject* exc = PyErr_GetRaisedException();
372+
#endif
364373
PyErr_SetString(PyExc_RuntimeError, error_msg.c_str());
365374
PyErr_Print();
366375
PyErr_Clear();
376+
#if PY_VERSION_HEX < 0x030C0000
377+
PyErr_Restore(err_type, err_value, err_traceback);
378+
#else
379+
PyErr_SetRaisedException(exc);
380+
#endif
367381
}
368382
break;
369383
}
370384
}
385+
}
386+
387+
// dynamic_attr: Allow the GC to clear the dictionary.
388+
extern "C" int PyArray_tp_clear(PyObject* self) {
371389
#if PY_VERSION_HEX < 0x030C0000
372390
PyObject*& dict = *_PyObject_GetDictPtr(self);
373391
Py_CLEAR(dict);
@@ -2079,6 +2097,7 @@ PyMemberDef array_impl_members[] = {
20792097

20802098
PyType_Slot array_impl_slots[] = {
20812099
{Py_tp_new, reinterpret_cast<void*>(PyArray_tp_new)},
2100+
{Py_tp_finalize, reinterpret_cast<void*>(PyArray_tp_finalize)},
20822101
{Py_tp_dealloc, reinterpret_cast<void*>(PyArray_tp_dealloc)},
20832102
{Py_tp_members, reinterpret_cast<void*>(array_impl_members)},
20842103
{Py_tp_traverse, reinterpret_cast<void*>(PyArray_tp_traverse)},

tests/garbage_collection_guard_test.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,32 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import dataclasses
1516
import gc
1617
import weakref
1718

1819
from absl.testing import absltest
1920
import jax
2021
from jax._src import config
22+
from jax._src.lib import jaxlib_extension_version
2123
import jax._src.test_util as jtu
2224
import jax.numpy as jnp
2325

2426
jax.config.parse_flags_with_absl()
2527

28+
_GC_ERROR_MESSAGE = "`jax.Array` was deleted by the Python garbage collector"
29+
30+
31+
@dataclasses.dataclass()
32+
class _A:
33+
a: jax.Array
34+
ref: "_B"
35+
36+
37+
@dataclasses.dataclass()
38+
class _B:
39+
ref: _A | None = None
40+
2641

2742
def _create_array_cycle():
2843
"""Creates a reference cycle of two jax.Arrays."""
@@ -45,11 +60,9 @@ def test_gced_array_is_not_logged_by_default(self):
4560
self.assertIsNone(ref()) # Cycle collected.
4661
# Check that no error message is logged because
4762
# `array_garbage_collection_guard` defaults to `allow`.
48-
self.assertNotIn(
49-
"`jax.Array` was deleted by the Python garbage collector", stderr(),
50-
)
63+
self.assertNotIn(_GC_ERROR_MESSAGE, stderr())
5164

52-
def test_gced_array_is_logged(self):
65+
def test_array_part_of_cycle_is_logged(self):
5366
with config.array_garbage_collection_guard("log"):
5467
with jtu.capture_stderr() as stderr:
5568
# Create a reference cycle of two jax.Arrays.
@@ -59,9 +72,27 @@ def test_gced_array_is_logged(self):
5972
self.assertIsNone(ref()) # Cycle collected.
6073
# Verify that an error message is logged because two jax.Arrays were garbage
6174
# collected.
62-
self.assertIn(
63-
"`jax.Array` was deleted by the Python garbage collector", stderr()
64-
)
75+
self.assertIn(_GC_ERROR_MESSAGE, stderr())
76+
77+
def test_array_reachable_from_cycle_is_logged(self):
78+
if jaxlib_extension_version < 401:
79+
self.skipTest("This functionality is not yet supported in jaxlib.")
80+
with config.array_garbage_collection_guard("log"):
81+
with jtu.capture_stderr() as stderr:
82+
b = _B()
83+
a = _A(a=jnp.array([1, 2, 3]), ref=b)
84+
b.ref = a
85+
del a
86+
del b
87+
gc.collect()
88+
self.assertIn(_GC_ERROR_MESSAGE, stderr())
89+
90+
def test_no_error_is_logged_when_array_is_not_gced(self):
91+
with config.array_garbage_collection_guard("log"):
92+
with jtu.capture_stderr() as stderr:
93+
a = jnp.array([1, 2, 3])
94+
del a
95+
self.assertNotIn(_GC_ERROR_MESSAGE, stderr())
6596

6697

6798
if __name__ == "__main__":

0 commit comments

Comments
 (0)