Skip to content

Commit

Permalink
refactor BorrowFlags tests to not hold a lock in the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed Feb 17, 2025
1 parent cc109bf commit e58c0af
Showing 1 changed file with 102 additions and 107 deletions.
209 changes: 102 additions & 107 deletions src/borrow/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,14 @@ impl BorrowKey {
}
}

type BorrowFlagsInner = Mutex<FxHashMap<*mut c_void, FxHashMap<BorrowKey, isize>>>;
type BorrowFlagsInner = FxHashMap<*mut c_void, FxHashMap<BorrowKey, isize>>;

#[derive(Default)]
struct BorrowFlags(BorrowFlagsInner);
struct BorrowFlags(Mutex<BorrowFlagsInner>);

impl BorrowFlags {
fn acquire(&self, address: *mut c_void, key: BorrowKey) -> Result<(), ()> {
let mut borrow_flags = self.0.lock().unwrap();

match borrow_flags.entry(address) {
Entry::Occupied(entry) => {
let same_base_arrays = entry.into_mut();
Expand Down Expand Up @@ -448,10 +447,27 @@ mod tests {
use crate::untyped_array::PyUntypedArrayMethods;
use pyo3::ffi::c_str;

fn get_borrow_flags<'py>(py: Python<'py>) -> &'py BorrowFlagsInner {
struct BorrowFlagsState(usize, usize, Option<isize>);

fn get_borrow_flags_state<'py>(
py: Python<'py>,
base: *mut c_void,
key: &BorrowKey,
) -> BorrowFlagsState {
let shared = get_or_insert_shared(py).unwrap();
assert_eq!(shared.version, 1);
unsafe { &(*(shared.flags as *mut BorrowFlags)).0 }
let inner = unsafe { &(*(shared.flags as *mut BorrowFlags)).0 }
.lock()
.unwrap();
if let Some(base_arrays) = inner.get(&base) {
BorrowFlagsState(
inner.len(),
base_arrays.len(),
base_arrays.get(key).map(|x| *x),
)
} else {
BorrowFlagsState(0, 0, None)
}
}

#[test]
Expand Down Expand Up @@ -778,36 +794,30 @@ mod tests {
let _exclusive1 = array1.readwrite();

{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
let state = get_borrow_flags_state(py, base1, &key1);
#[cfg(not(Py_GIL_DISABLED))]
assert_eq!(borrow_flags.len(), 1);

let same_base_arrays = &borrow_flags[&base1];
assert_eq!(same_base_arrays.len(), 1);
// borrow checking state is shared and other tests might have registered a borrow
assert_eq!(state.0, 1);

let flag = same_base_arrays[&key1];
assert_eq!(flag, -1);
assert_eq!(state.1, 1);
assert_eq!(state.2, Some(-1));
}

let key2 = borrow_key(py, array2.as_array_ptr());
let _shared2 = array2.readonly();

{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
let state = get_borrow_flags_state(py, base1, &key1);
#[cfg(not(Py_GIL_DISABLED))]
assert_eq!(borrow_flags.len(), 2);

let same_base_arrays = &borrow_flags[&base1];
assert_eq!(same_base_arrays.len(), 1);
// borrow checking state is shared and other tests might have registered a borrow
assert_eq!(state.0, 2);

let flag = same_base_arrays[&key1];
assert_eq!(flag, -1);
assert_eq!(state.1, 1);
assert_eq!(state.2, Some(-1));

let same_base_arrays = &borrow_flags[&base2];
assert_eq!(same_base_arrays.len(), 1);

let flag = same_base_arrays[&key2];
assert_eq!(flag, 1);
let state = get_borrow_flags_state(py, base2, &key2);
assert_eq!(state.1, 1);
assert_eq!(state.2, Some(1));
}
});
}
Expand All @@ -830,15 +840,13 @@ mod tests {
let exclusive1 = view1.readwrite();

{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
#[cfg(not(Py_GIL_DISABLED))]
assert_eq!(borrow_flags.len(), 1);

let same_base_arrays = &borrow_flags[&base];
assert_eq!(same_base_arrays.len(), 1);
let state = get_borrow_flags_state(py, base, &key1);

let flag = same_base_arrays[&key1];
assert_eq!(flag, -1);
#[cfg(not(Py_GIL_DISABLED))]
// borrow checking state is shared and other tests might have registered a borrow
assert_eq!(state.0, 1);
assert_eq!(state.1, 1);
assert_eq!(state.2, Some(-1));
}

let view2 = py
Expand All @@ -851,18 +859,15 @@ mod tests {
let shared2 = view2.readonly();

{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
let state = get_borrow_flags_state(py, base, &key1);
#[cfg(not(Py_GIL_DISABLED))]
assert_eq!(borrow_flags.len(), 1);

let same_base_arrays = &borrow_flags[&base];
assert_eq!(same_base_arrays.len(), 2);
// borrow checking state is shared and other tests might have registered a borrow
assert_eq!(state.0, 1);
assert_eq!(state.1, 2);
assert_eq!(state.2, Some(-1));

let flag = same_base_arrays[&key1];
assert_eq!(flag, -1);

let flag = same_base_arrays[&key2];
assert_eq!(flag, 1);
let state = get_borrow_flags_state(py, base, &key2);
assert_eq!(state.2, Some(1));
}

let view3 = py
Expand All @@ -875,21 +880,18 @@ mod tests {
let shared3 = view3.readonly();

{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
let state = get_borrow_flags_state(py, base, &key1);
#[cfg(not(Py_GIL_DISABLED))]
assert_eq!(borrow_flags.len(), 1);

let same_base_arrays = &borrow_flags[&base];
assert_eq!(same_base_arrays.len(), 2);

let flag = same_base_arrays[&key1];
assert_eq!(flag, -1);
// borrow checking state is shared and other tests might have registered a borrow
assert_eq!(state.0, 1);
assert_eq!(state.1, 2);
assert_eq!(state.2, Some(-1));

let flag = same_base_arrays[&key2];
assert_eq!(flag, 2);
let state = get_borrow_flags_state(py, base, &key2);
assert_eq!(state.2, Some(2));

let flag = same_base_arrays[&key3];
assert_eq!(flag, 2);
let state = get_borrow_flags_state(py, base, &key3);
assert_eq!(state.2, Some(2));
}

let view4 = py
Expand All @@ -902,96 +904,89 @@ mod tests {
let shared4 = view4.readonly();

{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
let state = get_borrow_flags_state(py, base, &key1);
#[cfg(not(Py_GIL_DISABLED))]
assert_eq!(borrow_flags.len(), 1);
// borrow checking state is shared and other tests might have registered a borrow
assert_eq!(state.0, 1);
assert_eq!(state.1, 3);
assert_eq!(state.2, Some(-1));

let same_base_arrays = &borrow_flags[&base];
assert_eq!(same_base_arrays.len(), 3);
let state = get_borrow_flags_state(py, base, &key2);
assert_eq!(state.2, Some(2));

let flag = same_base_arrays[&key1];
assert_eq!(flag, -1);
let state = get_borrow_flags_state(py, base, &key3);
assert_eq!(state.2, Some(2));

let flag = same_base_arrays[&key2];
assert_eq!(flag, 2);

let flag = same_base_arrays[&key3];
assert_eq!(flag, 2);

let flag = same_base_arrays[&key4];
assert_eq!(flag, 1);
let state = get_borrow_flags_state(py, base, &key4);
assert_eq!(state.2, Some(1));
}

drop(shared2);

{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
let state = get_borrow_flags_state(py, base, &key1);
#[cfg(not(Py_GIL_DISABLED))]
assert_eq!(borrow_flags.len(), 1);

let same_base_arrays = &borrow_flags[&base];
assert_eq!(same_base_arrays.len(), 3);
// borrow checking state is shared and other tests might have registered a borrow
assert_eq!(state.0, 1);
assert_eq!(state.1, 3);
assert_eq!(state.2, Some(-1));

let flag = same_base_arrays[&key1];
assert_eq!(flag, -1);
let state = get_borrow_flags_state(py, base, &key2);
assert_eq!(state.2, Some(1));

let flag = same_base_arrays[&key2];
assert_eq!(flag, 1);
let state = get_borrow_flags_state(py, base, &key3);
assert_eq!(state.2, Some(1));

let flag = same_base_arrays[&key3];
assert_eq!(flag, 1);

let flag = same_base_arrays[&key4];
assert_eq!(flag, 1);
let state = get_borrow_flags_state(py, base, &key4);
assert_eq!(state.2, Some(1));
}

drop(shared3);

{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
let state = get_borrow_flags_state(py, base, &key1);
#[cfg(not(Py_GIL_DISABLED))]
assert_eq!(borrow_flags.len(), 1);

let same_base_arrays = &borrow_flags[&base];
assert_eq!(same_base_arrays.len(), 2);

let flag = same_base_arrays[&key1];
assert_eq!(flag, -1);
// borrow checking state is shared and other tests might have registered a borrow
assert_eq!(state.0, 1);
assert_eq!(state.1, 2);
assert_eq!(state.2, Some(-1));

assert!(!same_base_arrays.contains_key(&key2));
let state = get_borrow_flags_state(py, base, &key2);
assert_eq!(state.2, None);

assert!(!same_base_arrays.contains_key(&key3));
let state = get_borrow_flags_state(py, base, &key3);
assert_eq!(state.2, None);

let flag = same_base_arrays[&key4];
assert_eq!(flag, 1);
let state = get_borrow_flags_state(py, base, &key4);
assert_eq!(state.2, Some(1));
}

drop(exclusive1);

{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
let state = get_borrow_flags_state(py, base, &key1);
#[cfg(not(Py_GIL_DISABLED))]
assert_eq!(borrow_flags.len(), 1);

let same_base_arrays = &borrow_flags[&base];
assert_eq!(same_base_arrays.len(), 1);

assert!(!same_base_arrays.contains_key(&key1));
// borrow checking state is shared and other tests might have registered a borrow
assert_eq!(state.0, 1);
assert_eq!(state.1, 1);
assert_eq!(state.2, None);

assert!(!same_base_arrays.contains_key(&key2));
let state = get_borrow_flags_state(py, base, &key2);
assert_eq!(state.2, None);

assert!(!same_base_arrays.contains_key(&key3));
let state = get_borrow_flags_state(py, base, &key3);
assert_eq!(state.2, None);

let flag = same_base_arrays[&key4];
assert_eq!(flag, 1);
let state = get_borrow_flags_state(py, base, &key4);
assert_eq!(state.2, Some(1));
}

drop(shared4);

#[cfg(not(Py_GIL_DISABLED))]
// borrow checking state is shared and other tests might have registered a borrow
{
let borrow_flags = get_borrow_flags(py).lock().unwrap();
assert_eq!(borrow_flags.len(), 0);
assert_eq!(get_borrow_flags_state(py, base, &key1).0, 0);
}
});
}
Expand Down

0 comments on commit e58c0af

Please sign in to comment.