Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/pybind11/detail/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@
# define PYBIND11_HAS_STRING_VIEW 1
#endif

#if defined(PYBIND11_CPP20) && defined(__has_include) && __has_include(<span>)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't correct. Use the cpp feature label for this instead: https://en.cppreference.com/w/cpp/experimental/feature_test.html#cpp_lib_span this resolves the ambiguity of some pre CPP20 happen to having a global header called span somewhere

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Skylion007 Does this look correct? — e9de405

# define PYBIND11_HAS_SPAN 1
#endif

#if (defined(PYPY_VERSION) || defined(GRAALVM_PYTHON)) && !defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
# define PYBIND11_SIMPLE_GIL_MANAGEMENT
#endif
Expand Down
18 changes: 18 additions & 0 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
#include <utility>
#include <vector>

#ifdef PYBIND11_HAS_SPAN
# include <span>
#endif

#if defined(PYBIND11_NUMPY_1_ONLY)
# error "PYBIND11_NUMPY_1_ONLY is no longer supported (see PR #5595)."
#endif
Expand Down Expand Up @@ -1143,6 +1147,13 @@ class array : public buffer {
/// Dimensions of the array
const ssize_t *shape() const { return detail::array_proxy(m_ptr)->dimensions; }

#ifdef PYBIND11_HAS_SPAN
/// Dimensions of the array as a span
std::span<const ssize_t, std::dynamic_extent> shape_span() const {
return std::span(shape(), static_cast<std::size_t>(ndim()));
}
#endif

/// Dimension along a given axis
ssize_t shape(ssize_t dim) const {
if (dim >= ndim()) {
Expand All @@ -1154,6 +1165,13 @@ class array : public buffer {
/// Strides of the array
const ssize_t *strides() const { return detail::array_proxy(m_ptr)->strides; }

#ifdef PYBIND11_HAS_SPAN
/// Strides of the array as a span
std::span<const ssize_t, std::dynamic_extent> strides_span() const {
return std::span(strides(), static_cast<std::size_t>(ndim()));
}
#endif

/// Stride along a given axis
ssize_t strides(ssize_t dim) const {
if (dim >= ndim()) {
Expand Down
17 changes: 17 additions & 0 deletions tests/test_numpy_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <cstdint>
#include <utility>
#include <vector>

// Size / dtype checks.
struct DtypeCheck {
Expand Down Expand Up @@ -246,6 +247,22 @@ TEST_SUBMODULE(numpy_array, sm) {
sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
sm.def("owndata", [](const arr &a) { return a.owndata(); });

#ifdef PYBIND11_HAS_SPAN
// test_shape_strides_span
sm.def("shape_span", [](const arr &a) {
auto span = a.shape_span();
return std::vector<ssize_t>(span.begin(), span.end());
});
sm.def("strides_span", [](const arr &a) {
auto span = a.strides_span();
return std::vector<ssize_t>(span.begin(), span.end());
});
// Test that spans can be used to construct new arrays
sm.def("array_from_spans", [](const arr &a) {
return py::array(a.dtype(), a.shape_span(), a.strides_span(), a.data(), a);
});
#endif

// test_index_offset
def_index_fn(index_at, const arr &);
def_index_fn(index_at_t, const arr_t &);
Expand Down
39 changes: 39 additions & 0 deletions tests/test_numpy_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,45 @@ def test_array_attributes():
assert not m.owndata(a)


@pytest.mark.skipif(not hasattr(m, "shape_span"), reason="std::span not available")
def test_shape_strides_span():
# Test 0-dimensional array (scalar)
a = np.array(42, "f8")
assert m.ndim(a) == 0
assert m.shape_span(a) == []
assert m.strides_span(a) == []

# Test 1-dimensional array
a = np.array([1, 2, 3, 4], "u2")
assert m.ndim(a) == 1
assert m.shape_span(a) == [4]
assert m.strides_span(a) == [2]

# Test 2-dimensional array
a = np.array([[1, 2, 3], [4, 5, 6]], "u2").view()
a.flags.writeable = False
assert m.ndim(a) == 2
assert m.shape_span(a) == [2, 3]
assert m.strides_span(a) == [6, 2]

# Test 3-dimensional array
a = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], "i4")
assert m.ndim(a) == 3
assert m.shape_span(a) == [2, 2, 2]
# Verify spans match regular shape/strides
assert list(m.shape_span(a)) == list(m.shape(a))
assert list(m.strides_span(a)) == list(m.strides(a))

# Test that spans can be used to construct new arrays
original = np.array([[1, 2, 3], [4, 5, 6]], "f4")
new_array = m.array_from_spans(original)
assert new_array.shape == original.shape
assert new_array.strides == original.strides
assert new_array.dtype == original.dtype
# Verify data is shared (since we pass the same data pointer)
np.testing.assert_array_equal(new_array, original)


@pytest.mark.parametrize(
("args", "ret"), [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]
)
Expand Down
Loading