Skip to content

Commit abebd0f

Browse files
committed
Fixes from review
1 parent 17eda7f commit abebd0f

File tree

3 files changed

+33
-49
lines changed

3 files changed

+33
-49
lines changed

src/npyffi/array.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,7 @@ impl PyArrayAPI {
342342
dst: *mut PyArrayObject,
343343
src: *mut PyArrayObject,
344344
) -> c_int {
345-
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
346-
let offset = if api_version < API_VERSION_2_0 {
347-
82
348-
} else {
349-
50
350-
};
345+
let offset = if is_numpy_2(py) { 50 } else { 82 };
351346
let fptr = self.get(py, offset)
352347
as *const extern "C" fn(dst: *mut PyArrayObject, src: *mut PyArrayObject) -> c_int;
353348
(*fptr)(dst, src)
@@ -360,12 +355,7 @@ impl PyArrayAPI {
360355
out: *mut PyArrayObject,
361356
mp: *mut PyArrayObject,
362357
) -> c_int {
363-
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
364-
let offset = if api_version < API_VERSION_2_0 {
365-
83
366-
} else {
367-
51
368-
};
358+
let offset = if is_numpy_2(py) { 51 } else { 83 };
369359
let fptr = self.get(py, offset)
370360
as *const extern "C" fn(out: *mut PyArrayObject, mp: *mut PyArrayObject) -> c_int;
371361
(*fptr)(out, mp)

src/npyffi/mod.rs

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use pyo3::{
2020

2121
pub const API_VERSION_2_0: c_uint = 0x00000012;
2222

23-
pub static API_VERSION: GILOnceCell<c_uint> = GILOnceCell::new();
23+
static API_VERSION: GILOnceCell<c_uint> = GILOnceCell::new();
2424

2525
fn get_numpy_api<'py>(
2626
py: Python<'py>,
@@ -36,15 +36,16 @@ fn get_numpy_api<'py>(
3636
// so we can safely cache a pointer into its interior.
3737
forget(capsule);
3838

39-
API_VERSION.get_or_init(py, || unsafe {
40-
#[allow(non_snake_case)]
41-
let PyArray_GetNDArrayCFeatureVersion = api.offset(211) as *const extern "C" fn() -> c_uint;
42-
(*PyArray_GetNDArrayCFeatureVersion)()
43-
});
44-
4539
Ok(api)
4640
}
4741

42+
fn is_numpy_2<'py>(py: Python<'py>) -> bool {
43+
let api_version = *API_VERSION.get_or_init(py, || unsafe {
44+
PY_ARRAY_API.PyArray_GetNDArrayCFeatureVersion(py)
45+
});
46+
api_version >= API_VERSION_2_0
47+
}
48+
4849
// Implements wrappers for NumPy's Array and UFunc API
4950
macro_rules! impl_api {
5051
// API available on all versions
@@ -60,15 +61,13 @@ macro_rules! impl_api {
6061
[$offset: expr; NumPy1; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
6162
#[allow(non_snake_case)]
6263
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
63-
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
64-
if api_version >= API_VERSION_2_0 {
65-
panic!(
66-
"{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
67-
stringify!($fname),
68-
API_VERSION_2_0,
69-
api_version,
70-
)
71-
}
64+
assert!(
65+
!is_numpy_2(py),
66+
"{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
67+
stringify!($fname),
68+
API_VERSION_2_0,
69+
*API_VERSION.get(py).expect("API_VERSION is initialized"),
70+
);
7271
let fptr = self.get(py, $offset) as *const extern fn ($($arg: $t), *) $(-> $ret)*;
7372
(*fptr)($($arg), *)
7473
}
@@ -77,15 +76,13 @@ macro_rules! impl_api {
7776
[$offset: expr; NumPy2; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
7877
#[allow(non_snake_case)]
7978
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
80-
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
81-
if api_version < API_VERSION_2_0 {
82-
panic!(
83-
"{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
84-
stringify!($fname),
85-
API_VERSION_2_0,
86-
api_version,
87-
)
88-
}
79+
assert!(
80+
is_numpy_2(py),
81+
"{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
82+
stringify!($fname),
83+
API_VERSION_2_0,
84+
*API_VERSION.get(py).expect("API_VERSION is initialized"),
85+
);
8986
let fptr = self.get(py, $offset) as *const extern fn ($($arg: $t), *) $(-> $ret)*;
9087
(*fptr)($($arg), *)
9188
}

src/npyffi/objects.rs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,26 +110,24 @@ pub unsafe fn PyDataType_SET_ELSIZE<'py>(
110110
dtype: *mut PyArray_Descr,
111111
size: npy_intp,
112112
) {
113-
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
114-
if api_version < API_VERSION_2_0 {
113+
if is_numpy_2(py) {
115114
unsafe {
116-
(*(dtype as *mut PyArray_DescrProto)).elsize = size as c_int;
115+
(*(dtype as *mut _PyArray_Descr_NumPy2)).elsize = size;
117116
}
118117
} else {
119118
unsafe {
120-
(*(dtype as *mut _PyArray_Descr_NumPy2)).elsize = size;
119+
(*(dtype as *mut PyArray_DescrProto)).elsize = size as c_int;
121120
}
122121
}
123122
}
124123

125124
#[allow(non_snake_case)]
126125
#[inline(always)]
127126
pub unsafe fn PyDataType_FLAGS<'py>(py: Python<'py>, dtype: *const PyArray_Descr) -> npy_uint64 {
128-
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
129-
if api_version < API_VERSION_2_0 {
130-
unsafe { (*(dtype as *mut PyArray_DescrProto)).flags as c_uchar as npy_uint64 }
131-
} else {
127+
if is_numpy_2(py) {
132128
unsafe { (*(dtype as *mut _PyArray_Descr_NumPy2)).flags }
129+
} else {
130+
unsafe { (*(dtype as *mut PyArray_DescrProto)).flags as c_uchar as npy_uint64 }
133131
}
134132
}
135133

@@ -141,11 +139,10 @@ macro_rules! define_descr_accessor {
141139
if $legacy_only && !PyDataType_ISLEGACY(dtype) {
142140
$default
143141
} else {
144-
let api_version = *API_VERSION.get(py).expect("API_VERSION is initialized");
145-
if api_version < API_VERSION_2_0 {
146-
unsafe { (*(dtype as *mut PyArray_DescrProto)).$property as $type }
147-
} else {
142+
if is_numpy_2(py) {
148143
unsafe { (*(dtype as *const _PyArray_Descr_NumPy1)).$property }
144+
} else {
145+
unsafe { (*(dtype as *mut PyArray_DescrProto)).$property as $type }
149146
}
150147
}
151148
}

0 commit comments

Comments
 (0)