|
56 | 56 | namespace Space { |
57 | 57 | namespace SpaceDim { |
58 | 58 |
|
| 59 | +// Helper to convert values to string, with specialization for complex types |
| 60 | +template <typename T> |
| 61 | +std::string value_to_string(const T &val) { |
| 62 | + return std::to_string(val); |
| 63 | +} |
| 64 | + |
| 65 | +template <typename T> |
| 66 | +std::string value_to_string(const Kokkos::complex<T> &val) { |
| 67 | + return "(" + std::to_string(val.real()) + "+" + std::to_string(val.imag()) + |
| 68 | + "i)"; |
| 69 | +} |
| 70 | + |
| 71 | +// Helper to detect if a type is complex |
| 72 | +template <typename T> |
| 73 | +struct is_complex : std::false_type {}; |
| 74 | + |
| 75 | +template <typename T> |
| 76 | +struct is_complex<Kokkos::complex<T>> : std::true_type {}; |
| 77 | + |
59 | 78 | // this function creates bindings for the atomic type returned to python from |
60 | 79 | // views with MemoryTrait<Kokkos::Atomic | ...> |
61 | 80 | template <size_t DataIdx, size_t SpaceIdx, size_t DimIdx, size_t LayoutIdx> |
@@ -102,29 +121,51 @@ void generate_atomic_variant(py::module &_mod) { |
102 | 121 | _atomic.def( |
103 | 122 | "__str__", |
104 | 123 | [](atomic_type &_obj) { |
105 | | - return std::to_string(static_cast<value_type>(_obj)); |
| 124 | + return value_to_string(static_cast<value_type>(_obj)); |
106 | 125 | }, |
107 | 126 | "String repr"); |
108 | 127 |
|
109 | 128 | _atomic.def( |
110 | | - "__eq__", [](atomic_type &_obj, value_type _v) { return (_obj == _v); }, |
111 | | - py::is_operator()); |
112 | | - _atomic.def( |
113 | | - "__ne__", [](atomic_type &_obj, value_type _v) { return (_obj != _v); }, |
114 | | - py::is_operator()); |
115 | | - _atomic.def( |
116 | | - "__lt__", [](atomic_type &_obj, value_type _v) { return (_obj < _v); }, |
117 | | - py::is_operator()); |
118 | | - _atomic.def( |
119 | | - "__gt__", [](atomic_type &_obj, value_type _v) { return (_obj > _v); }, |
120 | | - py::is_operator()); |
121 | | - _atomic.def( |
122 | | - "__le__", [](atomic_type &_obj, value_type _v) { return (_obj <= _v); }, |
| 129 | + "__eq__", |
| 130 | + [](atomic_type &_obj, value_type _v) { |
| 131 | + return (static_cast<value_type>(_obj) == _v); |
| 132 | + }, |
123 | 133 | py::is_operator()); |
124 | 134 | _atomic.def( |
125 | | - "__ge__", [](atomic_type &_obj, value_type _v) { return (_obj >= _v); }, |
| 135 | + "__ne__", |
| 136 | + [](atomic_type &_obj, value_type _v) { |
| 137 | + return (static_cast<value_type>(_obj) != _v); |
| 138 | + }, |
126 | 139 | py::is_operator()); |
127 | 140 |
|
| 141 | + // Only bind ordering operators for non-complex types |
| 142 | + if constexpr (!is_complex<value_type>::value) { |
| 143 | + _atomic.def( |
| 144 | + "__lt__", |
| 145 | + [](atomic_type &_obj, value_type _v) { |
| 146 | + return (static_cast<value_type>(_obj) < _v); |
| 147 | + }, |
| 148 | + py::is_operator()); |
| 149 | + _atomic.def( |
| 150 | + "__gt__", |
| 151 | + [](atomic_type &_obj, value_type _v) { |
| 152 | + return (static_cast<value_type>(_obj) > _v); |
| 153 | + }, |
| 154 | + py::is_operator()); |
| 155 | + _atomic.def( |
| 156 | + "__le__", |
| 157 | + [](atomic_type &_obj, value_type _v) { |
| 158 | + return (static_cast<value_type>(_obj) <= _v); |
| 159 | + }, |
| 160 | + py::is_operator()); |
| 161 | + _atomic.def( |
| 162 | + "__ge__", |
| 163 | + [](atomic_type &_obj, value_type _v) { |
| 164 | + return (static_cast<value_type>(_obj) >= _v); |
| 165 | + }, |
| 166 | + py::is_operator()); |
| 167 | + } |
| 168 | + |
128 | 169 | // self type |
129 | 170 | _atomic.def(py::self + py::self); |
130 | 171 | _atomic.def(py::self - py::self); |
|
0 commit comments