Skip to content

Commit

Permalink
Fixes from review
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo committed Sep 20, 2024
1 parent 17eda7f commit abebd0f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 49 deletions.
14 changes: 2 additions & 12 deletions src/npyffi/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,7 @@ impl PyArrayAPI {
dst: *mut PyArrayObject,
src: *mut PyArrayObject,
) -> c_int {
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
let offset = if api_version < API_VERSION_2_0 {
82
} else {
50
};
let offset = if is_numpy_2(py) { 50 } else { 82 };
let fptr = self.get(py, offset)
as *const extern "C" fn(dst: *mut PyArrayObject, src: *mut PyArrayObject) -> c_int;
(*fptr)(dst, src)
Expand All @@ -360,12 +355,7 @@ impl PyArrayAPI {
out: *mut PyArrayObject,
mp: *mut PyArrayObject,
) -> c_int {
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
let offset = if api_version < API_VERSION_2_0 {
83
} else {
51
};
let offset = if is_numpy_2(py) { 51 } else { 83 };
let fptr = self.get(py, offset)
as *const extern "C" fn(out: *mut PyArrayObject, mp: *mut PyArrayObject) -> c_int;
(*fptr)(out, mp)
Expand Down
47 changes: 22 additions & 25 deletions src/npyffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use pyo3::{

pub const API_VERSION_2_0: c_uint = 0x00000012;

pub static API_VERSION: GILOnceCell<c_uint> = GILOnceCell::new();
static API_VERSION: GILOnceCell<c_uint> = GILOnceCell::new();

fn get_numpy_api<'py>(
py: Python<'py>,
Expand All @@ -36,15 +36,16 @@ fn get_numpy_api<'py>(
// so we can safely cache a pointer into its interior.
forget(capsule);

API_VERSION.get_or_init(py, || unsafe {
#[allow(non_snake_case)]
let PyArray_GetNDArrayCFeatureVersion = api.offset(211) as *const extern "C" fn() -> c_uint;
(*PyArray_GetNDArrayCFeatureVersion)()
});

Ok(api)
}

fn is_numpy_2<'py>(py: Python<'py>) -> bool {
let api_version = *API_VERSION.get_or_init(py, || unsafe {
PY_ARRAY_API.PyArray_GetNDArrayCFeatureVersion(py)
});
api_version >= API_VERSION_2_0
}

// Implements wrappers for NumPy's Array and UFunc API
macro_rules! impl_api {
// API available on all versions
Expand All @@ -60,15 +61,13 @@ macro_rules! impl_api {
[$offset: expr; NumPy1; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
#[allow(non_snake_case)]
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
if api_version >= API_VERSION_2_0 {
panic!(
"{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
stringify!($fname),
API_VERSION_2_0,
api_version,
)
}
assert!(
!is_numpy_2(py),
"{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
stringify!($fname),
API_VERSION_2_0,
*API_VERSION.get(py).expect("API_VERSION is initialized"),
);
let fptr = self.get(py, $offset) as *const extern fn ($($arg: $t), *) $(-> $ret)*;
(*fptr)($($arg), *)
}
Expand All @@ -77,15 +76,13 @@ macro_rules! impl_api {
[$offset: expr; NumPy2; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
#[allow(non_snake_case)]
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
if api_version < API_VERSION_2_0 {
panic!(
"{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
stringify!($fname),
API_VERSION_2_0,
api_version,
)
}
assert!(
is_numpy_2(py),
"{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
stringify!($fname),
API_VERSION_2_0,
*API_VERSION.get(py).expect("API_VERSION is initialized"),
);
let fptr = self.get(py, $offset) as *const extern fn ($($arg: $t), *) $(-> $ret)*;
(*fptr)($($arg), *)
}
Expand Down
21 changes: 9 additions & 12 deletions src/npyffi/objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,24 @@ pub unsafe fn PyDataType_SET_ELSIZE<'py>(
dtype: *mut PyArray_Descr,
size: npy_intp,
) {
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
if api_version < API_VERSION_2_0 {
if is_numpy_2(py) {
unsafe {
(*(dtype as *mut PyArray_DescrProto)).elsize = size as c_int;
(*(dtype as *mut _PyArray_Descr_NumPy2)).elsize = size;
}
} else {
unsafe {
(*(dtype as *mut _PyArray_Descr_NumPy2)).elsize = size;
(*(dtype as *mut PyArray_DescrProto)).elsize = size as c_int;
}
}
}

#[allow(non_snake_case)]
#[inline(always)]
pub unsafe fn PyDataType_FLAGS<'py>(py: Python<'py>, dtype: *const PyArray_Descr) -> npy_uint64 {
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
if api_version < API_VERSION_2_0 {
unsafe { (*(dtype as *mut PyArray_DescrProto)).flags as c_uchar as npy_uint64 }
} else {
if is_numpy_2(py) {
unsafe { (*(dtype as *mut _PyArray_Descr_NumPy2)).flags }
} else {
unsafe { (*(dtype as *mut PyArray_DescrProto)).flags as c_uchar as npy_uint64 }
}
}

Expand All @@ -141,11 +139,10 @@ macro_rules! define_descr_accessor {
if $legacy_only && !PyDataType_ISLEGACY(dtype) {
$default
} else {
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
if api_version < API_VERSION_2_0 {
unsafe { (*(dtype as *mut PyArray_DescrProto)).$property as $type }
} else {
if is_numpy_2(py) {
unsafe { (*(dtype as *const _PyArray_Descr_NumPy1)).$property }
} else {
unsafe { (*(dtype as *mut PyArray_DescrProto)).$property as $type }
}
}
}
Expand Down

0 comments on commit abebd0f

Please sign in to comment.