Skip to content
This repository was archived by the owner on Dec 19, 2025. It is now read-only.

Commit 5c37d11

Browse files
NaderAlAwarJBludauIvanGrigorik
committed
ENH: add suppport for Kokkos::complex Views (#61)
* ENH: add suppport for Kokkos::complex Views * ENH: register Kokkos::complex as a numpy dtype * ENH: bind remaining arithmetic operators for complex numbers * ENH: read dtype properly for complex dtypes * ENH: properly format views with complex data types * ENH: fix issue with getting Kokkos::complex offsets when CUDA is enabled * FIX: add back if condition checking for float64 (removed by accident) * complex: fix comples type tests and conversion * fix atomics * formatting * fix atomics comparison * c++ formatting * version: update kokkos version in the comment --------- Co-authored-by: JBludau <104908666+JBludau@users.noreply.github.com> Co-authored-by: Ivan Grigorik <givan502@gmail.com>
1 parent 2b9cb60 commit 5c37d11

File tree

10 files changed

+318
-21
lines changed

10 files changed

+318
-21
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ SET(libpykokkos_SOURCES
6464
${CMAKE_CURRENT_LIST_DIR}/src/available.cpp
6565
${CMAKE_CURRENT_LIST_DIR}/src/common.cpp
6666
${CMAKE_CURRENT_LIST_DIR}/src/tools.cpp
67-
${CMAKE_CURRENT_LIST_DIR}/src/execution_spaces.cpp)
67+
${CMAKE_CURRENT_LIST_DIR}/src/execution_spaces.cpp
68+
${CMAKE_CURRENT_LIST_DIR}/src/complex_dtypes.cpp)
6869

6970
SET(libpykokkos_HEADERS
7071
${CMAKE_CURRENT_LIST_DIR}/include/libpykokkos.hpp

include/fwd.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ enum KokkosViewDataType {
133133
Uint64,
134134
Float32,
135135
Float64,
136+
ComplexFloat32,
137+
ComplexFloat64,
136138
ViewDataTypesEnd
137139
};
138140

include/libpykokkos.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,5 @@ void generate_atomic_variants(py::module& kokkos);
5858
void generate_backend_versions(py::module& kokkos);
5959
void generate_pool_variants(py::module& kokkos);
6060
void generate_execution_spaces(py::module& kokkos);
61+
void generate_complex_dtypes(py::module& kokkos);
6162
void destroy_callbacks();

include/traits.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ VIEW_DATA_TYPE(uint32_t, Uint32, "uint32", "unsigned", "unsigned_int")
8585
VIEW_DATA_TYPE(uint64_t, Uint64, "uint64", "unsigned_long")
8686
VIEW_DATA_TYPE(float, Float32, "float32", "float")
8787
VIEW_DATA_TYPE(double, Float64, "float64", "double")
88+
VIEW_DATA_TYPE(Kokkos::complex<float>, ComplexFloat32, "complex_float32_dtype",
89+
"complex_float_dtype")
90+
VIEW_DATA_TYPE(Kokkos::complex<double>, ComplexFloat64, "complex_float64_dtype",
91+
"complex_double_dtype")
8892

8993
//----------------------------------------------------------------------------//
9094
// <data-type> <enum> <string identifiers>

include/variants/atomics.hpp

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,25 @@
5656
namespace Space {
5757
namespace SpaceDim {
5858

59+
// Helper to convert values to string, with specialization for complex types
60+
template <typename T>
61+
std::string value_to_string(const T &val) {
62+
return std::to_string(val);
63+
}
64+
65+
template <typename T>
66+
std::string value_to_string(const Kokkos::complex<T> &val) {
67+
return "(" + std::to_string(val.real()) + "+" + std::to_string(val.imag()) +
68+
"i)";
69+
}
70+
71+
// Helper to detect if a type is complex
72+
template <typename T>
73+
struct is_complex : std::false_type {};
74+
75+
template <typename T>
76+
struct is_complex<Kokkos::complex<T>> : std::true_type {};
77+
5978
// this function creates bindings for the atomic type returned to python from
6079
// views with MemoryTrait<Kokkos::Atomic | ...>
6180
template <size_t DataIdx, size_t SpaceIdx, size_t DimIdx, size_t LayoutIdx>
@@ -102,29 +121,51 @@ void generate_atomic_variant(py::module &_mod) {
102121
_atomic.def(
103122
"__str__",
104123
[](atomic_type &_obj) {
105-
return std::to_string(static_cast<value_type>(_obj));
124+
return value_to_string(static_cast<value_type>(_obj));
106125
},
107126
"String repr");
108127

109128
_atomic.def(
110-
"__eq__", [](atomic_type &_obj, value_type _v) { return (_obj == _v); },
111-
py::is_operator());
112-
_atomic.def(
113-
"__ne__", [](atomic_type &_obj, value_type _v) { return (_obj != _v); },
114-
py::is_operator());
115-
_atomic.def(
116-
"__lt__", [](atomic_type &_obj, value_type _v) { return (_obj < _v); },
117-
py::is_operator());
118-
_atomic.def(
119-
"__gt__", [](atomic_type &_obj, value_type _v) { return (_obj > _v); },
120-
py::is_operator());
121-
_atomic.def(
122-
"__le__", [](atomic_type &_obj, value_type _v) { return (_obj <= _v); },
129+
"__eq__",
130+
[](atomic_type &_obj, value_type _v) {
131+
return (static_cast<value_type>(_obj) == _v);
132+
},
123133
py::is_operator());
124134
_atomic.def(
125-
"__ge__", [](atomic_type &_obj, value_type _v) { return (_obj >= _v); },
135+
"__ne__",
136+
[](atomic_type &_obj, value_type _v) {
137+
return (static_cast<value_type>(_obj) != _v);
138+
},
126139
py::is_operator());
127140

141+
// Only bind ordering operators for non-complex types
142+
if constexpr (!is_complex<value_type>::value) {
143+
_atomic.def(
144+
"__lt__",
145+
[](atomic_type &_obj, value_type _v) {
146+
return (static_cast<value_type>(_obj) < _v);
147+
},
148+
py::is_operator());
149+
_atomic.def(
150+
"__gt__",
151+
[](atomic_type &_obj, value_type _v) {
152+
return (static_cast<value_type>(_obj) > _v);
153+
},
154+
py::is_operator());
155+
_atomic.def(
156+
"__le__",
157+
[](atomic_type &_obj, value_type _v) {
158+
return (static_cast<value_type>(_obj) <= _v);
159+
},
160+
py::is_operator());
161+
_atomic.def(
162+
"__ge__",
163+
[](atomic_type &_obj, value_type _v) {
164+
return (static_cast<value_type>(_obj) >= _v);
165+
},
166+
py::is_operator());
167+
}
168+
128169
// self type
129170
_atomic.def(py::self + py::self);
130171
_atomic.def(py::self - py::self);

include/views.hpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444

4545
#pragma once
4646

47+
#include <pybind11/numpy.h>
48+
4749
#include <Kokkos_Core.hpp>
4850
#include <Kokkos_DynRankView.hpp>
4951
#include <iostream>
@@ -66,6 +68,21 @@ RetT get_extents(Tp &m, std::index_sequence<Idx...>) {
6668
template <typename Up, size_t Idx, typename Tp>
6769
constexpr auto get_stride(Tp &m);
6870

71+
template <typename Tp>
72+
inline std::string get_format() {
73+
return py::format_descriptor<Tp>::format();
74+
}
75+
76+
template <>
77+
inline std::string get_format<Kokkos::complex<float>>() {
78+
return py::format_descriptor<std::complex<float>>::format();
79+
}
80+
81+
template <>
82+
inline std::string get_format<Kokkos::complex<double>>() {
83+
return py::format_descriptor<std::complex<double>>::format();
84+
}
85+
6986
template <typename Up, typename Tp, size_t... Idx,
7087
typename RetT = std::array<size_t, sizeof...(Idx)>>
7188
RetT get_strides(Tp &m, std::index_sequence<Idx...>) {
@@ -325,12 +342,13 @@ void generate_view(py::module &_mod, const std::string &_name,
325342
_view.def_buffer([_ndim](ViewT &m) -> py::buffer_info {
326343
auto _extents = get_extents(m, std::make_index_sequence<DimIdx + 1>{});
327344
auto _strides = get_stride<Tp>(m, std::make_index_sequence<DimIdx + 1>{});
345+
auto _format = get_format<Tp>();
328346
return py::buffer_info(m.data(), // Pointer to buffer
329347
sizeof(Tp), // Size of one scalar
330-
py::format_descriptor<Tp>::format(), // Descriptor
331-
_ndim, // Number of dimensions
332-
_extents, // Buffer dimensions
333-
_strides // Strides (in bytes) for each index
348+
_format, // Descriptor
349+
_ndim, // Number of dimensions
350+
_extents, // Buffer dimensions
351+
_strides // Strides (in bytes) for each index
334352
);
335353
});
336354

kokkos/utility.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ def convert_dtype(_dtype, _module=None):
6565
import numpy as np
6666

6767
_module = np
68+
69+
# Handle complex dtypes mapping
70+
if _true_dtype == "complex_float32_dtype":
71+
_true_dtype = "complex64"
72+
elif (
73+
_true_dtype == "complex_float64_dtype" or _true_dtype == "complex_double_dtype"
74+
):
75+
_true_dtype = "complex128"
76+
6877
return getattr(_module, _true_dtype)
6978

7079

@@ -101,6 +110,10 @@ def read_dtype(_dtype):
101110
return lib.float32
102111
elif _dtype == np.float64:
103112
return lib.float64
113+
elif _dtype == np.complex64:
114+
return lib.complex_float32_dtype
115+
elif _dtype == np.complex128:
116+
return lib.complex_float64_dtype
104117
except ImportError:
105118
pass
106119

0 commit comments

Comments
 (0)