Skip to content

Commit edec058

Browse files
authored
fix: implement is_exact_type_of to check dtype and shape for PyArray (#550)
* fix: implement is_exact_type_of to check dtype and shape * fix: use PyArray_CheckExact in is_exact_type_of to reject subclasses * docs: add changelog entry for PR #550
1 parent 8e6f59d commit edec058

3 files changed

Lines changed: 44 additions & 3 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
- Unreleased
44
- Fix PyArray_DTypeMeta definition when Py_LIMITED_API is disabled ([#532](https://github.com/PyO3/rust-numpy/pull/532))
55
- The NumPy C API binding has been updated to target the ABI v2, while maintaining runtime compatibility with NumPy v1 targeting the API v1.15. The higher interface is unchanged. ([#537](https://github.com/PyO3/rust-numpy/pull/537))
6+
- Fix `is_exact_type_of` / `cast_exact` to use `PyArray_CheckExact` instead of `PyArray_Check`, correctly rejecting subclasses of `ndarray`. ([#550](https://github.com/PyO3/rust-numpy/pull/550))
67

78
- v0.28.0
89
- Fix mismatched behavior between `PyArrayLike1` and `PyArrayLike2` when used with floats ([#520](https://github.com/PyO3/rust-numpy/pull/520))

src/array.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,25 @@ unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
132132
}
133133

134134
fn is_type_of(ob: &Bound<'_, PyAny>) -> bool {
135-
Self::extract::<IgnoreError>(ob).is_ok()
135+
Self::extract::<IgnoreError>(ob, npyffi::PyArray_Check).is_ok()
136+
}
137+
138+
fn is_exact_type_of(ob: &Bound<'_, PyAny>) -> bool {
139+
Self::extract::<IgnoreError>(ob, npyffi::PyArray_CheckExact).is_ok()
136140
}
137141
}
138142

139143
impl<T: Element, D: Dimension> PyArray<T, D> {
140-
fn extract<'a, 'py, E>(ob: &'a Bound<'py, PyAny>) -> Result<&'a Bound<'py, Self>, E>
144+
fn extract<'a, 'py, E>(
145+
ob: &'a Bound<'py, PyAny>,
146+
check: unsafe fn(Python<'py>, *mut ffi::PyObject) -> c_int,
147+
) -> Result<&'a Bound<'py, Self>, E>
141148
where
142149
E: From<CastError<'a, 'py>> + From<DimensionalityError> + From<TypeError<'py>>,
143150
{
144151
// Check if the object is an array.
145152
let array = unsafe {
146-
if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 {
153+
if check(ob.py(), ob.as_ptr()) == 0 {
147154
return Err(CastError::new(
148155
ob.as_borrowed(),
149156
<Self as PyTypeCheck>::classinfo_object(ob.py()),

tests/array.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,39 @@ fn is_instance() {
251251
});
252252
}
253253

254+
#[test]
255+
fn cast_exact_checks_dtype_and_shape() {
256+
Python::attach(|py| {
257+
let arr_f64 = PyArray2::<f64>::zeros(py, [3, 5], false);
258+
let arr_f32 = PyArray2::<f32>::zeros(py, [3, 5], false);
259+
let arr_f64_3d = PyArray::<f64, _>::zeros(py, [3, 5, 7], false);
260+
261+
// cast_exact should succeed when dtype and shape both match
262+
assert!(arr_f64.as_any().cast_exact::<PyArray2<f64>>().is_ok());
263+
assert!(arr_f32.as_any().cast_exact::<PyArray2<f32>>().is_ok());
264+
265+
// cast_exact should fail when dtype does not match
266+
assert!(arr_f64.as_any().cast_exact::<PyArray2<f32>>().is_err());
267+
assert!(arr_f32.as_any().cast_exact::<PyArray2<f64>>().is_err());
268+
269+
// cast_exact should fail when dimensionality does not match
270+
assert!(arr_f64_3d.as_any().cast_exact::<PyArray2<f64>>().is_err());
271+
272+
// cast_exact should reject subclasses of ndarray; cast should accept them
273+
let masked = py
274+
.eval(
275+
pyo3::ffi::c_str!(
276+
"__import__('numpy').ma.MaskedArray([[1.0, 2.0], [3.0, 4.0]], dtype='float64')"
277+
),
278+
None,
279+
None,
280+
)
281+
.unwrap();
282+
assert!(masked.cast::<PyArray2<f64>>().is_ok());
283+
assert!(masked.cast_exact::<PyArray2<f64>>().is_err());
284+
});
285+
}
286+
254287
#[test]
255288
fn from_vec2() {
256289
Python::attach(|py| {

0 commit comments

Comments
 (0)