From 8dd4f8e2934126a8323b8c8f00e69b0182b8ecee Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 19 Feb 2025 21:00:13 +0000 Subject: [PATCH] support 3.13t free-threaded python (#471) * support 3.13t free-threaded python * make dynamic borrow-checking threadsafe * fix compiler error in benches target * fix warning on nightly rust about extern usage * remove parking_lot dependency * Add deadlock-avoidance using direct FFI calls * refactor BorrowFlags tests to not hold a lock in the tests * fix clippy * give BorrowFlagsState fields descriptive names * move thread state guard into the crate root * use ThreadStateGuard to avoid deadlocks acquiring the dtype cache --------- Co-authored-by: Icxolu <10486322+Icxolu@users.noreply.github.com> Co-authored-by: Nathan Goldbaum --- .github/workflows/ci.yml | 3 +- CHANGELOG.md | 3 + Cargo.toml | 5 +- benches/array.rs | 8 +- build.rs | 3 + src/borrow/shared.rs | 262 ++++++++++++++++++++------------------- src/datetime.rs | 18 ++- src/lib.rs | 17 +++ src/npyffi/mod.rs | 6 +- src/strings.rs | 17 ++- 10 files changed, 198 insertions(+), 144 deletions(-) create mode 100644 build.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f4e801413..f6c84b891 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,6 +39,7 @@ jobs: "3.11", "3.12", "3.13", + "3.13t", "pypy-3.9", "pypy-3.10", ] @@ -108,7 +109,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: Quansight-Labs/setup-python@v5 with: python-version: ${{ matrix.python-version }} architecture: ${{ matrix.platform.python-architecture }} diff --git a/CHANGELOG.md b/CHANGELOG.md index d0ef24861..dbbd00574 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Changelog +- v0.24.0 + - Support Python 3.13t "free-threaded" Python. ([#471](https://github.com/PyO3/rust-numpy/pull/471) + - v0.23.0 - Drop support for PyPy 3.7 and 3.8. ([#470](https://github.com/PyO3/rust-numpy/pull/470)) - Require `Element: Sync` as part of the free-threading support in PyO3 0.23 ([#469](https://github.com/PyO3/rust-numpy/pull/469)) diff --git a/Cargo.toml b/Cargo.toml index f0db8fe44..5054b0082 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,12 +22,15 @@ num-complex = ">= 0.2, < 0.5" num-integer = "0.1" num-traits = "0.2" ndarray = ">= 0.15, < 0.17" -pyo3 = { version = "0.23.3", default-features = false, features = ["macros"] } +pyo3 = { version = "0.23.4", default-features = false, features = ["macros"] } rustc-hash = "2.0" [dev-dependencies] pyo3 = { version = "0.23.3", default-features = false, features = ["auto-initialize"] } nalgebra = { version = ">=0.30, <0.34", default-features = false, features = ["std"] } +[build-dependencies] +pyo3-build-config = { version = "0.23.1", features = ["resolve-config"] } + [package.metadata.docs.rs] all-features = true diff --git a/benches/array.rs b/benches/array.rs index 3df2321bb..756a352bd 100644 --- a/benches/array.rs +++ b/benches/array.rs @@ -6,7 +6,7 @@ use test::{black_box, Bencher}; use std::ops::Range; use numpy::{PyArray1, PyArray2, PyArray3}; -use pyo3::{types::PyAnyMethods, Bound, Python}; +use pyo3::{types::PyAnyMethods, Bound, IntoPyObjectExt, Python}; #[bench] fn extract_success(bencher: &mut Bencher) { @@ -115,7 +115,11 @@ fn from_slice_large(bencher: &mut Bencher) { } fn from_object_slice(bencher: &mut Bencher, size: usize) { - let vec = Python::with_gil(|py| (0..size).map(|val| val.to_object(py)).collect::>()); + let vec = Python::with_gil(|py| { + (0..size) + .map(|val| val.into_py_any(py).unwrap()) + .collect::>() + }); Python::with_gil(|py| { bencher.iter(|| { diff --git a/build.rs b/build.rs new file mode 100644 index 000000000..6b746f1a1 --- /dev/null +++ b/build.rs @@ -0,0 +1,3 @@ +fn main() { + pyo3_build_config::use_pyo3_cfgs(); +} diff --git a/src/borrow/shared.rs b/src/borrow/shared.rs index 36860b275..939098c64 100644 --- a/src/borrow/shared.rs +++ b/src/borrow/shared.rs @@ -3,6 +3,7 @@ use std::ffi::{c_void, CString}; use std::mem::forget; use std::os::raw::{c_char, c_int}; use std::slice::from_raw_parts; +use std::sync::Mutex; use num_integer::gcd; use pyo3::types::{PyAnyMethods, PyCapsuleMethods}; @@ -19,9 +20,6 @@ use crate::npyffi::{PyArrayObject, PyArray_Check, PyDataType_ELSIZE, NPY_ARRAY_W /// This structure will be placed into a capsule at /// `numpy.core.multiarray._RUST_NUMPY_BORROW_CHECKING_API`. /// -/// All functions exposed here assume the GIL is held -/// while they are called. -/// /// Versions are assumed to be backwards-compatible, i.e. /// an extension which knows version N will work using /// any API version M as long as M >= N holds. @@ -45,7 +43,7 @@ unsafe impl Send for Shared {} unsafe extern "C" fn acquire_shared(flags: *mut c_void, array: *mut PyArrayObject) -> c_int { // SAFETY: GIL must be held when calling `acquire_shared`. let py = Python::assume_gil_acquired(); - let flags = &mut *(flags as *mut BorrowFlags); + let flags = &*(flags as *mut BorrowFlags); let address = base_address(py, array); let key = borrow_key(py, array); @@ -63,7 +61,7 @@ unsafe extern "C" fn acquire_mut_shared(flags: *mut c_void, array: *mut PyArrayO // SAFETY: GIL must be held when calling `acquire_shared`. let py = Python::assume_gil_acquired(); - let flags = &mut *(flags as *mut BorrowFlags); + let flags = &*(flags as *mut BorrowFlags); let address = base_address(py, array); let key = borrow_key(py, array); @@ -77,8 +75,7 @@ unsafe extern "C" fn acquire_mut_shared(flags: *mut c_void, array: *mut PyArrayO unsafe extern "C" fn release_shared(flags: *mut c_void, array: *mut PyArrayObject) { // SAFETY: GIL must be held when calling `acquire_shared`. let py = Python::assume_gil_acquired(); - let flags = &mut *(flags as *mut BorrowFlags); - + let flags = &*(flags as *mut BorrowFlags); let address = base_address(py, array); let key = borrow_key(py, array); @@ -88,7 +85,7 @@ unsafe extern "C" fn release_shared(flags: *mut c_void, array: *mut PyArrayObjec unsafe extern "C" fn release_mut_shared(flags: *mut c_void, array: *mut PyArrayObject) { // SAFETY: GIL must be held when calling `acquire_shared`. let py = Python::assume_gil_acquired(); - let flags = &mut *(flags as *mut BorrowFlags); + let flags = &*(flags as *mut BorrowFlags); let address = base_address(py, array); let key = borrow_key(py, array); @@ -253,12 +250,11 @@ impl BorrowKey { type BorrowFlagsInner = FxHashMap<*mut c_void, FxHashMap>; #[derive(Default)] -struct BorrowFlags(BorrowFlagsInner); +struct BorrowFlags(Mutex); impl BorrowFlags { - fn acquire(&mut self, address: *mut c_void, key: BorrowKey) -> Result<(), ()> { - let borrow_flags = &mut self.0; - + fn acquire(&self, address: *mut c_void, key: BorrowKey) -> Result<(), ()> { + let mut borrow_flags = self.0.lock().unwrap(); match borrow_flags.entry(address) { Entry::Occupied(entry) => { let same_base_arrays = entry.into_mut(); @@ -298,11 +294,10 @@ impl BorrowFlags { Ok(()) } - fn release(&mut self, address: *mut c_void, key: BorrowKey) { - let borrow_flags = &mut self.0; + fn release(&self, address: *mut c_void, key: BorrowKey) { + let mut borrow_flags = self.0.lock().unwrap(); let same_base_arrays = borrow_flags.get_mut(&address).unwrap(); - let readers = same_base_arrays.get_mut(&key).unwrap(); *readers -= 1; @@ -316,8 +311,8 @@ impl BorrowFlags { } } - fn acquire_mut(&mut self, address: *mut c_void, key: BorrowKey) -> Result<(), ()> { - let borrow_flags = &mut self.0; + fn acquire_mut(&self, address: *mut c_void, key: BorrowKey) -> Result<(), ()> { + let mut borrow_flags = self.0.lock().unwrap(); match borrow_flags.entry(address) { Entry::Occupied(entry) => { @@ -352,8 +347,8 @@ impl BorrowFlags { Ok(()) } - fn release_mut(&mut self, address: *mut c_void, key: BorrowKey) { - let borrow_flags = &mut self.0; + fn release_mut(&self, address: *mut c_void, key: BorrowKey) { + let mut borrow_flags = self.0.lock().unwrap(); let same_base_arrays = borrow_flags.get_mut(&address).unwrap(); @@ -452,10 +447,38 @@ mod tests { use crate::untyped_array::PyUntypedArrayMethods; use pyo3::ffi::c_str; - fn get_borrow_flags<'py>(py: Python<'py>) -> &'py BorrowFlagsInner { + struct BorrowFlagsState { + #[cfg(not(Py_GIL_DISABLED))] + n_flags: usize, + n_arrays: usize, + flag: Option, + } + + fn get_borrow_flags_state<'py>( + py: Python<'py>, + base: *mut c_void, + key: &BorrowKey, + ) -> BorrowFlagsState { let shared = get_or_insert_shared(py).unwrap(); assert_eq!(shared.version, 1); - unsafe { &(*(shared.flags as *mut BorrowFlags)).0 } + let inner = unsafe { &(*(shared.flags as *mut BorrowFlags)).0 } + .lock() + .unwrap(); + if let Some(base_arrays) = inner.get(&base) { + BorrowFlagsState { + #[cfg(not(Py_GIL_DISABLED))] + n_flags: inner.len(), + n_arrays: base_arrays.len(), + flag: base_arrays.get(key).copied(), + } + } else { + BorrowFlagsState { + #[cfg(not(Py_GIL_DISABLED))] + n_flags: 0, + n_arrays: 0, + flag: None, + } + } } #[test] @@ -782,34 +805,30 @@ mod tests { let _exclusive1 = array1.readwrite(); { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 1); + let state = get_borrow_flags_state(py, base1, &key1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.n_flags, 1); - let same_base_arrays = &borrow_flags[&base1]; - assert_eq!(same_base_arrays.len(), 1); - - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + assert_eq!(state.n_arrays, 1); + assert_eq!(state.flag, Some(-1)); } let key2 = borrow_key(py, array2.as_array_ptr()); let _shared2 = array2.readonly(); { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 2); - - let same_base_arrays = &borrow_flags[&base1]; - assert_eq!(same_base_arrays.len(), 1); + let state = get_borrow_flags_state(py, base1, &key1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.n_flags, 2); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + assert_eq!(state.n_arrays, 1); + assert_eq!(state.flag, Some(-1)); - let same_base_arrays = &borrow_flags[&base2]; - assert_eq!(same_base_arrays.len(), 1); - - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base2, &key2); + assert_eq!(state.n_arrays, 1); + assert_eq!(state.flag, Some(1)); } }); } @@ -832,14 +851,13 @@ mod tests { let exclusive1 = view1.readwrite(); { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 1); + let state = get_borrow_flags_state(py, base, &key1); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.n_flags, 1); + assert_eq!(state.n_arrays, 1); + assert_eq!(state.flag, Some(-1)); } let view2 = py @@ -852,17 +870,15 @@ mod tests { let shared2 = view2.readonly(); { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 2); - - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); - - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.n_flags, 1); + assert_eq!(state.n_arrays, 2); + assert_eq!(state.flag, Some(-1)); + + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.flag, Some(1)); } let view3 = py @@ -875,20 +891,18 @@ mod tests { let shared3 = view3.readonly(); { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 2); - - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); - - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 2); - - let flag = same_base_arrays[&key3]; - assert_eq!(flag, 2); + let state = get_borrow_flags_state(py, base, &key1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.n_flags, 1); + assert_eq!(state.n_arrays, 2); + assert_eq!(state.flag, Some(-1)); + + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.flag, Some(2)); + + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.flag, Some(2)); } let view4 = py @@ -901,91 +915,89 @@ mod tests { let shared4 = view4.readonly(); { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 3); + let state = get_borrow_flags_state(py, base, &key1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.n_flags, 1); + assert_eq!(state.n_arrays, 3); + assert_eq!(state.flag, Some(-1)); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.flag, Some(2)); - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 2); + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.flag, Some(2)); - let flag = same_base_arrays[&key3]; - assert_eq!(flag, 2); - - let flag = same_base_arrays[&key4]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key4); + assert_eq!(state.flag, Some(1)); } drop(shared2); { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 3); + let state = get_borrow_flags_state(py, base, &key1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.n_flags, 1); + assert_eq!(state.n_arrays, 3); + assert_eq!(state.flag, Some(-1)); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.flag, Some(1)); - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.flag, Some(1)); - let flag = same_base_arrays[&key3]; - assert_eq!(flag, 1); - - let flag = same_base_arrays[&key4]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key4); + assert_eq!(state.flag, Some(1)); } drop(shared3); { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 2); + let state = get_borrow_flags_state(py, base, &key1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.n_flags, 1); + assert_eq!(state.n_arrays, 2); + assert_eq!(state.flag, Some(-1)); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.flag, None); - assert!(!same_base_arrays.contains_key(&key2)); + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.flag, None); - assert!(!same_base_arrays.contains_key(&key3)); - - let flag = same_base_arrays[&key4]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key4); + assert_eq!(state.flag, Some(1)); } drop(exclusive1); { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 1); - - assert!(!same_base_arrays.contains_key(&key1)); + let state = get_borrow_flags_state(py, base, &key1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.n_flags, 1); + assert_eq!(state.n_arrays, 1); + assert_eq!(state.flag, None); - assert!(!same_base_arrays.contains_key(&key2)); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.flag, None); - assert!(!same_base_arrays.contains_key(&key3)); + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.flag, None); - let flag = same_base_arrays[&key4]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key4); + assert_eq!(state.flag, Some(1)); } drop(shared4); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow { - let borrow_flags = get_borrow_flags(py); - assert_eq!(borrow_flags.len(), 0); + assert_eq!(get_borrow_flags_state(py, base, &key1).n_flags, 0); } }); } diff --git a/src/datetime.rs b/src/datetime.rs index 3eef346d8..a65c562bb 100644 --- a/src/datetime.rs +++ b/src/datetime.rs @@ -54,19 +54,20 @@ //! [scalars-datetime64]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.datetime64 //! [scalars-timedelta64]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.timedelta64 -use std::cell::RefCell; use std::collections::hash_map::Entry; use std::fmt; use std::hash::Hash; use std::marker::PhantomData; +use std::sync::Mutex; -use pyo3::{sync::GILProtected, Bound, Py, Python}; +use pyo3::{Bound, Py, Python}; use rustc_hash::FxHashMap; use crate::dtype::{clone_methods_impl, Element, PyArrayDescr, PyArrayDescrMethods}; use crate::npyffi::{ PyArray_DatetimeDTypeMetaData, PyDataType_C_METADATA, NPY_DATETIMEUNIT, NPY_TYPES, }; +use crate::ThreadStateGuard; /// Represents the [datetime units][datetime-units] supported by NumPy /// @@ -209,8 +210,7 @@ impl fmt::Debug for Timedelta { struct TypeDescriptors { npy_type: NPY_TYPES, - #[allow(clippy::type_complexity)] - dtypes: GILProtected>>>>, + dtypes: Mutex>>>, } impl TypeDescriptors { @@ -218,13 +218,19 @@ impl TypeDescriptors { const unsafe fn new(npy_type: NPY_TYPES) -> Self { Self { npy_type, - dtypes: GILProtected::new(RefCell::new(None)), + dtypes: Mutex::new(None), } } #[allow(clippy::wrong_self_convention)] fn from_unit<'py>(&self, py: Python<'py>, unit: NPY_DATETIMEUNIT) -> Bound<'py, PyArrayDescr> { - let mut dtypes = self.dtypes.get(py).borrow_mut(); + // Detach from the runtime to avoid deadlocking on acquiring the mutex. + let ts_guard = ThreadStateGuard::new(); + + let mut dtypes = self.dtypes.lock().expect("dtype cache poisoned"); + + // Now we hold the mutex so it's safe to re-attach to the runtime. + drop(ts_guard); let dtype = match dtypes.get_or_insert_with(Default::default).entry(unit) { Entry::Occupied(entry) => entry.into_mut(), diff --git a/src/lib.rs b/src/lib.rs index e147a9c18..c5d608ffa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -149,6 +149,23 @@ mod doctest { #[inline(always)] fn cold() {} +/// An RAII guard for avoiding deadlocks with the GIL or other global +/// synchronization events in the Python runtime +// FIXME create a proper MutexExt trait that handles poisoning and upstream to PyO3 +struct ThreadStateGuard(*mut pyo3::ffi::PyThreadState); + +impl ThreadStateGuard { + fn new() -> ThreadStateGuard { + ThreadStateGuard(unsafe { pyo3::ffi::PyEval_SaveThread() }) + } +} + +impl Drop for ThreadStateGuard { + fn drop(&mut self) { + unsafe { pyo3::ffi::PyEval_RestoreThread(self.0) }; + } +} + /// Create a [`PyArray`] with one, two or three dimensions. /// /// This macro is backed by [`ndarray::array`]. diff --git a/src/npyffi/mod.rs b/src/npyffi/mod.rs index b96e69344..bf846f8e2 100644 --- a/src/npyffi/mod.rs +++ b/src/npyffi/mod.rs @@ -53,7 +53,7 @@ macro_rules! impl_api { [$offset: expr; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => { #[allow(non_snake_case)] pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* { - let fptr = self.get(py, $offset) as *const extern fn ($($arg : $t), *) $(-> $ret)*; + let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg : $t), *) $(-> $ret)*; (*fptr)($($arg), *) } }; @@ -69,7 +69,7 @@ macro_rules! impl_api { API_VERSION_2_0, *API_VERSION.get(py).expect("API_VERSION is initialized"), ); - let fptr = self.get(py, $offset) as *const extern fn ($($arg: $t), *) $(-> $ret)*; + let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*; (*fptr)($($arg), *) } @@ -84,7 +84,7 @@ macro_rules! impl_api { API_VERSION_2_0, *API_VERSION.get(py).expect("API_VERSION is initialized"), ); - let fptr = self.get(py, $offset) as *const extern fn ($($arg: $t), *) $(-> $ret)*; + let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*; (*fptr)($($arg), *) } diff --git a/src/strings.rs b/src/strings.rs index 067327865..8826e674a 100644 --- a/src/strings.rs +++ b/src/strings.rs @@ -3,16 +3,15 @@ //! [ascii]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_STRING //! [ucs4]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_UNICODE -use std::cell::RefCell; use std::collections::hash_map::Entry; use std::fmt; use std::mem::size_of; use std::os::raw::c_char; use std::str; +use std::sync::Mutex; use pyo3::{ ffi::{Py_UCS1, Py_UCS4}, - sync::GILProtected, Bound, Py, Python, }; use rustc_hash::FxHashMap; @@ -20,6 +19,7 @@ use rustc_hash::FxHashMap; use crate::dtype::{clone_methods_impl, Element, PyArrayDescr, PyArrayDescrMethods}; use crate::npyffi::PyDataType_SET_ELSIZE; use crate::npyffi::NPY_TYPES; +use crate::ThreadStateGuard; /// A newtype wrapper around [`[u8; N]`][Py_UCS1] to handle [`byte` scalars][numpy-bytes] while satisfying coherence. /// @@ -160,14 +160,13 @@ unsafe impl Element for PyFixedUnicode { } struct TypeDescriptors { - #[allow(clippy::type_complexity)] - dtypes: GILProtected>>>>, + dtypes: Mutex>>>, } impl TypeDescriptors { const fn new() -> Self { Self { - dtypes: GILProtected::new(RefCell::new(None)), + dtypes: Mutex::new(None), } } @@ -180,7 +179,13 @@ impl TypeDescriptors { byteorder: c_char, size: usize, ) -> Bound<'py, PyArrayDescr> { - let mut dtypes = self.dtypes.get(py).borrow_mut(); + // Detach from the runtime to avoid deadlocking on acquiring the mutex. + let ts_guard = ThreadStateGuard::new(); + + let mut dtypes = self.dtypes.lock().expect("dtype cache poisoned"); + + // Now we hold the mutex so it's safe to re-attach to the runtime. + drop(ts_guard); let dtype = match dtypes.get_or_insert_with(Default::default).entry(size) { Entry::Occupied(entry) => entry.into_mut(),