Skip to content

Trailing singleton dimensions are removed during dtype conversion #491

Open
@jakemoran

Description

@jakemoran

Bug Description

When using PyArrayLikeDyn with AllowTypeChange, trailing singleton axes may be removed from inputs that are ndarrays but have the wrong dtype.

Steps to Reproduce

Cargo.toml

[package]
name = "singleton-removed"
version = "0.1.0"
edition = "2024"

[dependencies]
numpy = "0.24.0"
pyo3 = { version = "0.24.2", features = ["auto-initialize"] }

main.rs

use numpy::{AllowTypeChange, PyArrayDyn, PyArrayLikeDyn, PyArrayMethods};
use pyo3::ffi::c_str;
use pyo3::prelude::*;

#[pyfunction]
fn double<'py>(
    py: Python<'py>,
    a: PyArrayLikeDyn<'py, f64, AllowTypeChange>,
) -> Bound<'py, PyArrayDyn<f64>> {
    PyArrayDyn::from_owned_array(py, a.to_owned_array() * 2.0)
}

#[pymodule]
fn singleton_removed(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(double, m)?)?;
    Ok(())
}

fn main() -> PyResult<()> {
    pyo3::append_to_inittab!(singleton_removed);
    Python::with_gil(|py| {
        let code = c_str!(include_str!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/example.py"
        )));
        py.run(code, None, None)?;

        Ok(())
    })
}

example.py

import singleton_removed
import numpy as np

a = np.ones((3, 1), dtype=np.int32)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"

This results in the following error (plus a deprecation warning from numpy, seemingly for implicitly removing the singleton axis):

<string>:5: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
Error: PyErr { type: <class 'AssertionError'>, value: AssertionError('a.shape=(3, 1), b.shape=(3,)'), traceback: Some("Traceback (most recent call last):\n  File \"<string>\", line 6, in <module>\n") }

If no type change occurs, the axis is preserved, e.g.

import singleton_removed
import numpy as np

a = np.ones((3, 1), dtype=np.float64)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"

succeeds. Non-array inputs also behave properly, e.g.

import singleton_removed
import numpy as np

a = [[1], [2], [3]]
b = singleton_removed.double(a)
assert b.shape == (3, 1), f"{a.shape=}, {b.shape=}"

also succeeds.

Oddly enough, if I add a third axis, the trailing singleton dimension is no longer removed:

import singleton_removed
import numpy as np

a = np.ones((3, 2, 1), dtype=np.int32)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"

Relevant Info

Python Version

3.13.3

NumPy Version

2.2.5

PyO3 Version

0.24.2

rust-numpy Version

0.24.0

rustc Version

1.86.0

OS

Distributor ID: Ubuntu
Description:    Ubuntu 24.04.2 LTS
Release:        24.04
Codename:       noble

(via WSL)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions