@@ -20,7 +20,7 @@ use pyo3::{
20
20
21
21
pub const API_VERSION_2_0 : c_uint = 0x00000012 ;
22
22
23
- pub static API_VERSION : GILOnceCell < c_uint > = GILOnceCell :: new ( ) ;
23
+ static API_VERSION : GILOnceCell < c_uint > = GILOnceCell :: new ( ) ;
24
24
25
25
fn get_numpy_api < ' py > (
26
26
py : Python < ' py > ,
@@ -36,15 +36,16 @@ fn get_numpy_api<'py>(
36
36
// so we can safely cache a pointer into its interior.
37
37
forget ( capsule) ;
38
38
39
- API_VERSION . get_or_init ( py, || unsafe {
40
- #[ allow( non_snake_case) ]
41
- let PyArray_GetNDArrayCFeatureVersion = api. offset ( 211 ) as * const extern "C" fn ( ) -> c_uint ;
42
- ( * PyArray_GetNDArrayCFeatureVersion ) ( )
43
- } ) ;
44
-
45
39
Ok ( api)
46
40
}
47
41
42
+ fn is_numpy_2 < ' py > ( py : Python < ' py > ) -> bool {
43
+ let api_version = * API_VERSION . get_or_init ( py, || unsafe {
44
+ PY_ARRAY_API . PyArray_GetNDArrayCFeatureVersion ( py)
45
+ } ) ;
46
+ api_version >= API_VERSION_2_0
47
+ }
48
+
48
49
// Implements wrappers for NumPy's Array and UFunc API
49
50
macro_rules! impl_api {
50
51
// API available on all versions
@@ -60,15 +61,13 @@ macro_rules! impl_api {
60
61
[ $offset: expr; NumPy1 ; $fname: ident ( $( $arg: ident: $t: ty) ,* $( , ) ?) $( -> $ret: ty) ?] => {
61
62
#[ allow( non_snake_case) ]
62
63
pub unsafe fn $fname<' py>( & self , py: Python <' py>, $( $arg : $t) , * ) $( -> $ret) * {
63
- let api_version = * API_VERSION . get( py) . expect( "API_VERSION is initialized" ) ;
64
- if api_version >= API_VERSION_2_0 {
65
- panic!(
66
- "{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}" ,
67
- stringify!( $fname) ,
68
- API_VERSION_2_0 ,
69
- api_version,
70
- )
71
- }
64
+ assert!(
65
+ !is_numpy_2( py) ,
66
+ "{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}" ,
67
+ stringify!( $fname) ,
68
+ API_VERSION_2_0 ,
69
+ * API_VERSION . get( py) . expect( "API_VERSION is initialized" ) ,
70
+ ) ;
72
71
let fptr = self . get( py, $offset) as * const extern fn ( $( $arg: $t) , * ) $( -> $ret) * ;
73
72
( * fptr) ( $( $arg) , * )
74
73
}
@@ -77,15 +76,13 @@ macro_rules! impl_api {
77
76
[ $offset: expr; NumPy2 ; $fname: ident ( $( $arg: ident: $t: ty) ,* $( , ) ?) $( -> $ret: ty) ?] => {
78
77
#[ allow( non_snake_case) ]
79
78
pub unsafe fn $fname<' py>( & self , py: Python <' py>, $( $arg : $t) , * ) $( -> $ret) * {
80
- let api_version = * API_VERSION . get( py) . expect( "API_VERSION is initialized" ) ;
81
- if api_version < API_VERSION_2_0 {
82
- panic!(
83
- "{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}" ,
84
- stringify!( $fname) ,
85
- API_VERSION_2_0 ,
86
- api_version,
87
- )
88
- }
79
+ assert!(
80
+ is_numpy_2( py) ,
81
+ "{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}" ,
82
+ stringify!( $fname) ,
83
+ API_VERSION_2_0 ,
84
+ * API_VERSION . get( py) . expect( "API_VERSION is initialized" ) ,
85
+ ) ;
89
86
let fptr = self . get( py, $offset) as * const extern fn ( $( $arg: $t) , * ) $( -> $ret) * ;
90
87
( * fptr) ( $( $arg) , * )
91
88
}
0 commit comments