Skip to content
This repository was archived by the owner on Dec 19, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ SET(libpykokkos_SOURCES
${CMAKE_CURRENT_LIST_DIR}/src/available.cpp
${CMAKE_CURRENT_LIST_DIR}/src/common.cpp
${CMAKE_CURRENT_LIST_DIR}/src/tools.cpp
${CMAKE_CURRENT_LIST_DIR}/src/execution_spaces.cpp)
${CMAKE_CURRENT_LIST_DIR}/src/execution_spaces.cpp
${CMAKE_CURRENT_LIST_DIR}/src/complex_dtypes.cpp)

SET(libpykokkos_HEADERS
${CMAKE_CURRENT_LIST_DIR}/include/libpykokkos.hpp
Expand Down
2 changes: 2 additions & 0 deletions include/fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ enum KokkosViewDataType {
Uint64,
Float32,
Float64,
ComplexFloat32,
ComplexFloat64,
ViewDataTypesEnd
};

Expand Down
1 change: 1 addition & 0 deletions include/libpykokkos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ void generate_atomic_variants(py::module& kokkos);
void generate_backend_versions(py::module& kokkos);
void generate_pool_variants(py::module& kokkos);
void generate_execution_spaces(py::module& kokkos);
void generate_complex_dtypes(py::module& kokkos);
void destroy_callbacks();
4 changes: 4 additions & 0 deletions include/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ VIEW_DATA_TYPE(uint32_t, Uint32, "uint32", "unsigned", "unsigned_int")
VIEW_DATA_TYPE(uint64_t, Uint64, "uint64", "unsigned_long")
VIEW_DATA_TYPE(float, Float32, "float32", "float")
VIEW_DATA_TYPE(double, Float64, "float64", "double")
VIEW_DATA_TYPE(Kokkos::complex<float>, ComplexFloat32, "complex_float32_dtype",
"complex_float_dtype")
VIEW_DATA_TYPE(Kokkos::complex<double>, ComplexFloat64, "complex_float64_dtype",
"complex_double_dtype")

//----------------------------------------------------------------------------//
// <data-type> <enum> <string identifiers>
Expand Down
71 changes: 56 additions & 15 deletions include/variants/atomics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,25 @@
namespace Space {
namespace SpaceDim {

// Helper to convert values to string, with specialization for complex types
template <typename T>
std::string value_to_string(const T &val) {
return std::to_string(val);
}

template <typename T>
std::string value_to_string(const Kokkos::complex<T> &val) {
return "(" + std::to_string(val.real()) + "+" + std::to_string(val.imag()) +
"i)";
}

// Helper to detect if a type is complex
template <typename T>
struct is_complex : std::false_type {};

template <typename T>
struct is_complex<Kokkos::complex<T>> : std::true_type {};

// this function creates bindings for the atomic type returned to python from
// views with MemoryTrait<Kokkos::Atomic | ...>
template <size_t DataIdx, size_t SpaceIdx, size_t DimIdx, size_t LayoutIdx>
Expand Down Expand Up @@ -102,29 +121,51 @@ void generate_atomic_variant(py::module &_mod) {
_atomic.def(
"__str__",
[](atomic_type &_obj) {
return std::to_string(static_cast<value_type>(_obj));
return value_to_string(static_cast<value_type>(_obj));
},
"String repr");

_atomic.def(
"__eq__", [](atomic_type &_obj, value_type _v) { return (_obj == _v); },
py::is_operator());
_atomic.def(
"__ne__", [](atomic_type &_obj, value_type _v) { return (_obj != _v); },
py::is_operator());
_atomic.def(
"__lt__", [](atomic_type &_obj, value_type _v) { return (_obj < _v); },
py::is_operator());
_atomic.def(
"__gt__", [](atomic_type &_obj, value_type _v) { return (_obj > _v); },
py::is_operator());
_atomic.def(
"__le__", [](atomic_type &_obj, value_type _v) { return (_obj <= _v); },
"__eq__",
[](atomic_type &_obj, value_type _v) {
return (static_cast<value_type>(_obj) == _v);
},
py::is_operator());
_atomic.def(
"__ge__", [](atomic_type &_obj, value_type _v) { return (_obj >= _v); },
"__ne__",
[](atomic_type &_obj, value_type _v) {
return (static_cast<value_type>(_obj) != _v);
},
py::is_operator());

// Only bind ordering operators for non-complex types
if constexpr (!is_complex<value_type>::value) {
_atomic.def(
Comment on lines +141 to +143
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx ... this actually is a problem we still have in core ... complex has only a subset of the operations implemented ...

"__lt__",
[](atomic_type &_obj, value_type _v) {
return (static_cast<value_type>(_obj) < _v);
},
py::is_operator());
_atomic.def(
"__gt__",
[](atomic_type &_obj, value_type _v) {
return (static_cast<value_type>(_obj) > _v);
},
py::is_operator());
_atomic.def(
"__le__",
[](atomic_type &_obj, value_type _v) {
return (static_cast<value_type>(_obj) <= _v);
},
py::is_operator());
_atomic.def(
"__ge__",
[](atomic_type &_obj, value_type _v) {
return (static_cast<value_type>(_obj) >= _v);
},
py::is_operator());
}

// self type
_atomic.def(py::self + py::self);
_atomic.def(py::self - py::self);
Expand Down
26 changes: 22 additions & 4 deletions include/views.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

#pragma once

#include <pybind11/numpy.h>

#include <Kokkos_Core.hpp>
#include <Kokkos_DynRankView.hpp>
#include <iostream>
Expand All @@ -66,6 +68,21 @@ RetT get_extents(Tp &m, std::index_sequence<Idx...>) {
template <typename Up, size_t Idx, typename Tp>
constexpr auto get_stride(Tp &m);

template <typename Tp>
inline std::string get_format() {
return py::format_descriptor<Tp>::format();
}

template <>
inline std::string get_format<Kokkos::complex<float>>() {
return py::format_descriptor<std::complex<float>>::format();
}

template <>
inline std::string get_format<Kokkos::complex<double>>() {
return py::format_descriptor<std::complex<double>>::format();
}

template <typename Up, typename Tp, size_t... Idx,
typename RetT = std::array<size_t, sizeof...(Idx)>>
RetT get_strides(Tp &m, std::index_sequence<Idx...>) {
Expand Down Expand Up @@ -325,12 +342,13 @@ void generate_view(py::module &_mod, const std::string &_name,
_view.def_buffer([_ndim](ViewT &m) -> py::buffer_info {
auto _extents = get_extents(m, std::make_index_sequence<DimIdx + 1>{});
auto _strides = get_stride<Tp>(m, std::make_index_sequence<DimIdx + 1>{});
auto _format = get_format<Tp>();
return py::buffer_info(m.data(), // Pointer to buffer
sizeof(Tp), // Size of one scalar
py::format_descriptor<Tp>::format(), // Descriptor
_ndim, // Number of dimensions
_extents, // Buffer dimensions
_strides // Strides (in bytes) for each index
_format, // Descriptor
_ndim, // Number of dimensions
_extents, // Buffer dimensions
_strides // Strides (in bytes) for each index
);
});

Expand Down
13 changes: 13 additions & 0 deletions kokkos/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def convert_dtype(_dtype, _module=None):
import numpy as np

_module = np

# Handle complex dtypes mapping
if _true_dtype == "complex_float32_dtype":
_true_dtype = "complex64"
elif (
_true_dtype == "complex_float64_dtype" or _true_dtype == "complex_double_dtype"
):
_true_dtype = "complex128"

return getattr(_module, _true_dtype)


Expand Down Expand Up @@ -101,6 +110,10 @@ def read_dtype(_dtype):
return lib.float32
elif _dtype == np.float64:
return lib.float64
elif _dtype == np.complex64:
return lib.complex_float32_dtype
elif _dtype == np.complex128:
return lib.complex_float64_dtype
except ImportError:
pass

Expand Down
Loading