Skip to content

Commit 88555c2

Browse files
committed
Fix data race in weakref_lru_cache under free-threading.
When run under an optimized build and Python 3.13.2t, I saw the following high probability crash in lax_control_flow_test: ``` Stack trace of thread 3526917: #0 0x00007f0898c4bf91 dump_frame (libpython3.13t.so.1.0 + 0x24bf91) #1 0x00007f0898c4b73f dump_traceback (libpython3.13t.so.1.0 + 0x24b73f) #2 0x00007f0898c4b86f _Py_DumpTracebackThreads (libpython3.13t.so.1.0 + 0x24b86f) #3 0x00007f0898cd4fe0 faulthandler_dump_traceback (libpython3.13t.so.1.0 + 0x2d4fe0) #4 0x00007f0898cd4f44 faulthandler_fatal_error (libpython3.13t.so.1.0 + 0x2d4f44) #5 0x00007f0898849e20 __restore_rt (libc.so.6 + 0x3fe20) #6 0x00007f07eb80e493 _ZNSt8__detail16_Hashtable_allocISaINS_10_Hash_nodeISt4pairIKN3jax15WeakrefLRUCache15WeakrefCacheKeyENS4_17WeakrefCacheValueEELb1EEEEE18_M_deallocate_nodeEPS9_ (libjax_common.so + 0x2c0e493) #7 0x00007f07eb80e13e _ZN3jax15WeakrefLRUCache5ClearEv (libjax_common.so + 0x2c0e13e) #8 0x00007f07eb812e37 _ZZN8nanobind6detail11func_createILb0ELb1EZNS_16cpp_function_defIN3jax15WeakrefLRUCacheEvS4_JEJNS_5scopeENS_4nameENS_9is_methodENS_9lock_selfEEEEvMT1_FT0_DpT2_EDpRKT3_EUlPS4_E_vJSJ_EJLm0EEJS5_S6_S7_S8_EEEP> #9 0x00007f07eb7fff70 _ZN8nanobind6detailL25nb_func_vectorcall_simpleEP7_objectPKS2_mS2_ (libjax_common.so + 0x2bfff70) #10 0x00007f0898dbbdee _PyObject_VectorcallTstate (libpython3.13t.so.1.0 + 0x3bbdee) #11 0x00007f0898d1d4db _PyEval_EvalFrame (libpython3.13t.so.1.0 + 0x31d4db) #12 0x00007f0898d1ee78 _PyObject_VectorcallTstate (libpython3.13t.so.1.0 + 0x31ee78) #13 0x00007f0898dc0054 _PyVectorcall_Call (libpython3.13t.so.1.0 + 0x3c0054) #14 0x00007f0898d1d4db _PyEval_EvalFrame (libpython3.13t.so.1.0 + 0x31d4db) #15 0x00007f0898d1e02c _PyObject_VectorcallDictTstate (libpython3.13t.so.1.0 + 0x31e02c) #16 0x00007f0898ed8e35 slot_tp_call (libpython3.13t.so.1.0 + 0x4d8e35) #17 0x00007f0898dbc312 _PyObject_MakeTpCall (libpython3.13t.so.1.0 + 0x3bc312) #18 0x00007f0898d1d4db _PyEval_EvalFrame (libpython3.13t.so.1.0 + 0x31d4db) #19 0x00007f0898d1ef54 _PyObject_VectorcallTstate (libpython3.13t.so.1.0 + 0x31ef54) #20 0x00007f0899094c1f thread_run (libpython3.13t.so.1.0 + 0x694c1f) #21 0x00007f0898fa0c58 pythread_wrapper (libpython3.13t.so.1.0 + 0x5a0c58) #22 0x00007f089889c103 start_thread (libc.so.6 + 0x92103) #23 0x00007f089891a7b8 __clone3 (libc.so.6 + 0x1107b8) ``` It appears that this is due to freeing Python objects during unordered_map::clear(), which may release the enclosing critical section (`nb::lock_self()` on the method). Fix this by deferring destruction of the both the keys and the values to after the map's destruction.
1 parent a516988 commit 88555c2

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

jaxlib/weakref_lru_cache.cc

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ limitations under the License.
3434
#include "absl/synchronization/notification.h"
3535
#include "nanobind/nanobind.h"
3636
#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep
37-
#include "nanobind/stl/string.h" // IWYU pragma: keep
38-
#include "nanobind/stl/vector.h" // IWYU pragma: keep
37+
#include "nanobind/stl/string.h" // IWYU pragma: keep
38+
#include "nanobind/stl/vector.h" // IWYU pragma: keep
3939
#include "xla/pjrt/lru_cache.h"
4040
#include "xla/tsl/platform/logging.h"
4141

@@ -309,16 +309,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
309309
result.currsize = lru_list_.Size();
310310
return result;
311311
}
312-
void Clear() {
313-
total_queries_ = misses_ = 0;
314-
std::vector<std::shared_ptr<Cache>> deferred_deletes;
315-
deferred_deletes.reserve(entries_.size());
316-
for (auto& entry : entries_) {
317-
deferred_deletes.push_back(std::move(entry.second.cache));
318-
}
319-
entries_.clear();
320-
deferred_deletes.clear();
321-
}
312+
void Clear();
322313

323314
nb::callable cache_context_fn_;
324315
nb::callable fn_;
@@ -361,6 +352,17 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
361352
static PyType_Slot slots_[];
362353
};
363354

355+
void WeakrefLRUCache::Clear() {
356+
total_queries_ = misses_ = 0;
357+
std::vector<std::pair<WeakrefCacheKey, WeakrefCacheValue>> deferred_deletes;
358+
deferred_deletes.reserve(entries_.size());
359+
for (auto& entry : entries_) {
360+
deferred_deletes.emplace_back(entry.first, std::move(entry.second));
361+
}
362+
entries_.clear();
363+
deferred_deletes.clear();
364+
}
365+
364366
/* static */ PyType_Slot WeakrefLRUCache::slots_[] = {
365367
{Py_tp_traverse, (void*)WeakrefLRUCache::tp_traverse},
366368
{Py_tp_clear, (void*)WeakrefLRUCache::tp_clear},

0 commit comments

Comments
 (0)