From e58c0afc31a669d6c975a6929fde198d17fef3c8 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 17 Feb 2025 14:51:56 -0700 Subject: [PATCH] refactor BorrowFlags tests to not hold a lock in the tests --- src/borrow/shared.rs | 209 +++++++++++++++++++++---------------------- 1 file changed, 102 insertions(+), 107 deletions(-) diff --git a/src/borrow/shared.rs b/src/borrow/shared.rs index 52b3c70f..ffffa0d5 100644 --- a/src/borrow/shared.rs +++ b/src/borrow/shared.rs @@ -247,15 +247,14 @@ impl BorrowKey { } } -type BorrowFlagsInner = Mutex>>; +type BorrowFlagsInner = FxHashMap<*mut c_void, FxHashMap>; #[derive(Default)] -struct BorrowFlags(BorrowFlagsInner); +struct BorrowFlags(Mutex); impl BorrowFlags { fn acquire(&self, address: *mut c_void, key: BorrowKey) -> Result<(), ()> { let mut borrow_flags = self.0.lock().unwrap(); - match borrow_flags.entry(address) { Entry::Occupied(entry) => { let same_base_arrays = entry.into_mut(); @@ -448,10 +447,27 @@ mod tests { use crate::untyped_array::PyUntypedArrayMethods; use pyo3::ffi::c_str; - fn get_borrow_flags<'py>(py: Python<'py>) -> &'py BorrowFlagsInner { + struct BorrowFlagsState(usize, usize, Option); + + fn get_borrow_flags_state<'py>( + py: Python<'py>, + base: *mut c_void, + key: &BorrowKey, + ) -> BorrowFlagsState { let shared = get_or_insert_shared(py).unwrap(); assert_eq!(shared.version, 1); - unsafe { &(*(shared.flags as *mut BorrowFlags)).0 } + let inner = unsafe { &(*(shared.flags as *mut BorrowFlags)).0 } + .lock() + .unwrap(); + if let Some(base_arrays) = inner.get(&base) { + BorrowFlagsState( + inner.len(), + base_arrays.len(), + base_arrays.get(key).map(|x| *x), + ) + } else { + BorrowFlagsState(0, 0, None) + } } #[test] @@ -778,36 +794,30 @@ mod tests { let _exclusive1 = array1.readwrite(); { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); + let state = get_borrow_flags_state(py, base1, &key1); #[cfg(not(Py_GIL_DISABLED))] - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base1]; - assert_eq!(same_base_arrays.len(), 1); + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.0, 1); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + assert_eq!(state.1, 1); + assert_eq!(state.2, Some(-1)); } let key2 = borrow_key(py, array2.as_array_ptr()); let _shared2 = array2.readonly(); { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); + let state = get_borrow_flags_state(py, base1, &key1); #[cfg(not(Py_GIL_DISABLED))] - assert_eq!(borrow_flags.len(), 2); - - let same_base_arrays = &borrow_flags[&base1]; - assert_eq!(same_base_arrays.len(), 1); + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.0, 2); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + assert_eq!(state.1, 1); + assert_eq!(state.2, Some(-1)); - let same_base_arrays = &borrow_flags[&base2]; - assert_eq!(same_base_arrays.len(), 1); - - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base2, &key2); + assert_eq!(state.1, 1); + assert_eq!(state.2, Some(1)); } }); } @@ -830,15 +840,13 @@ mod tests { let exclusive1 = view1.readwrite(); { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); - #[cfg(not(Py_GIL_DISABLED))] - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 1); + let state = get_borrow_flags_state(py, base, &key1); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.0, 1); + assert_eq!(state.1, 1); + assert_eq!(state.2, Some(-1)); } let view2 = py @@ -851,18 +859,15 @@ mod tests { let shared2 = view2.readonly(); { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); + let state = get_borrow_flags_state(py, base, &key1); #[cfg(not(Py_GIL_DISABLED))] - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 2); + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.0, 1); + assert_eq!(state.1, 2); + assert_eq!(state.2, Some(-1)); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); - - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.2, Some(1)); } let view3 = py @@ -875,21 +880,18 @@ mod tests { let shared3 = view3.readonly(); { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); + let state = get_borrow_flags_state(py, base, &key1); #[cfg(not(Py_GIL_DISABLED))] - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 2); - - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.0, 1); + assert_eq!(state.1, 2); + assert_eq!(state.2, Some(-1)); - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 2); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.2, Some(2)); - let flag = same_base_arrays[&key3]; - assert_eq!(flag, 2); + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.2, Some(2)); } let view4 = py @@ -902,96 +904,89 @@ mod tests { let shared4 = view4.readonly(); { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); + let state = get_borrow_flags_state(py, base, &key1); #[cfg(not(Py_GIL_DISABLED))] - assert_eq!(borrow_flags.len(), 1); + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.0, 1); + assert_eq!(state.1, 3); + assert_eq!(state.2, Some(-1)); - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 3); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.2, Some(2)); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.2, Some(2)); - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 2); - - let flag = same_base_arrays[&key3]; - assert_eq!(flag, 2); - - let flag = same_base_arrays[&key4]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key4); + assert_eq!(state.2, Some(1)); } drop(shared2); { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); + let state = get_borrow_flags_state(py, base, &key1); #[cfg(not(Py_GIL_DISABLED))] - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 3); + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.0, 1); + assert_eq!(state.1, 3); + assert_eq!(state.2, Some(-1)); - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.2, Some(1)); - let flag = same_base_arrays[&key2]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.2, Some(1)); - let flag = same_base_arrays[&key3]; - assert_eq!(flag, 1); - - let flag = same_base_arrays[&key4]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key4); + assert_eq!(state.2, Some(1)); } drop(shared3); { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); + let state = get_borrow_flags_state(py, base, &key1); #[cfg(not(Py_GIL_DISABLED))] - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 2); - - let flag = same_base_arrays[&key1]; - assert_eq!(flag, -1); + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.0, 1); + assert_eq!(state.1, 2); + assert_eq!(state.2, Some(-1)); - assert!(!same_base_arrays.contains_key(&key2)); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.2, None); - assert!(!same_base_arrays.contains_key(&key3)); + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.2, None); - let flag = same_base_arrays[&key4]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key4); + assert_eq!(state.2, Some(1)); } drop(exclusive1); { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); + let state = get_borrow_flags_state(py, base, &key1); #[cfg(not(Py_GIL_DISABLED))] - assert_eq!(borrow_flags.len(), 1); - - let same_base_arrays = &borrow_flags[&base]; - assert_eq!(same_base_arrays.len(), 1); - - assert!(!same_base_arrays.contains_key(&key1)); + // borrow checking state is shared and other tests might have registered a borrow + assert_eq!(state.0, 1); + assert_eq!(state.1, 1); + assert_eq!(state.2, None); - assert!(!same_base_arrays.contains_key(&key2)); + let state = get_borrow_flags_state(py, base, &key2); + assert_eq!(state.2, None); - assert!(!same_base_arrays.contains_key(&key3)); + let state = get_borrow_flags_state(py, base, &key3); + assert_eq!(state.2, None); - let flag = same_base_arrays[&key4]; - assert_eq!(flag, 1); + let state = get_borrow_flags_state(py, base, &key4); + assert_eq!(state.2, Some(1)); } drop(shared4); #[cfg(not(Py_GIL_DISABLED))] + // borrow checking state is shared and other tests might have registered a borrow { - let borrow_flags = get_borrow_flags(py).lock().unwrap(); - assert_eq!(borrow_flags.len(), 0); + assert_eq!(get_borrow_flags_state(py, base, &key1).0, 0); } }); }