Open
Description
Bug Description
When using PyArrayLikeDyn
with AllowTypeChange
, trailing singleton axes may be removed from inputs that are ndarray
s but have the wrong dtype.
Steps to Reproduce
Cargo.toml
[package]
name = "singleton-removed"
version = "0.1.0"
edition = "2024"
[dependencies]
numpy = "0.24.0"
pyo3 = { version = "0.24.2", features = ["auto-initialize"] }
main.rs
use numpy::{AllowTypeChange, PyArrayDyn, PyArrayLikeDyn, PyArrayMethods};
use pyo3::ffi::c_str;
use pyo3::prelude::*;
#[pyfunction]
fn double<'py>(
py: Python<'py>,
a: PyArrayLikeDyn<'py, f64, AllowTypeChange>,
) -> Bound<'py, PyArrayDyn<f64>> {
PyArrayDyn::from_owned_array(py, a.to_owned_array() * 2.0)
}
#[pymodule]
fn singleton_removed(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(double, m)?)?;
Ok(())
}
fn main() -> PyResult<()> {
pyo3::append_to_inittab!(singleton_removed);
Python::with_gil(|py| {
let code = c_str!(include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/example.py"
)));
py.run(code, None, None)?;
Ok(())
})
}
example.py
import singleton_removed
import numpy as np
a = np.ones((3, 1), dtype=np.int32)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"
This results in the following error (plus a deprecation warning from numpy, seemingly for implicitly removing the singleton axis):
<string>:5: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
Error: PyErr { type: <class 'AssertionError'>, value: AssertionError('a.shape=(3, 1), b.shape=(3,)'), traceback: Some("Traceback (most recent call last):\n File \"<string>\", line 6, in <module>\n") }
If no type change occurs, the axis is preserved, e.g.
import singleton_removed
import numpy as np
a = np.ones((3, 1), dtype=np.float64)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"
succeeds. Non-array inputs also behave properly, e.g.
import singleton_removed
import numpy as np
a = [[1], [2], [3]]
b = singleton_removed.double(a)
assert b.shape == (3, 1), f"{a.shape=}, {b.shape=}"
also succeeds.
Oddly enough, if I add a third axis, the trailing singleton dimension is no longer removed:
import singleton_removed
import numpy as np
a = np.ones((3, 2, 1), dtype=np.int32)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"
Relevant Info
Python Version
3.13.3
NumPy Version
2.2.5
PyO3 Version
0.24.2
rust-numpy Version
0.24.0
rustc Version
1.86.0
OS
Distributor ID: Ubuntu
Description: Ubuntu 24.04.2 LTS
Release: 24.04
Codename: noble
(via WSL)
Metadata
Metadata
Assignees
Labels
No labels