@@ -89,7 +89,9 @@ class alignas(std::uint16_t) bfloat16 {
8989 // caused by something else in jacobi or isai.
9090 constexpr bfloat16 () noexcept : data_ (0 ){};
9191
92- template <typename T, typename = std::enable_if_t <std::is_scalar<T>::value>>
92+ template <typename T,
93+ typename = std::enable_if_t <std::is_scalar<T>::value ||
94+ std::is_same_v<T, half>>>
9395 bfloat16 (const T& val) : data_ (0 )
9496 {
9597 this ->float2bfloat16 (static_cast <float >(val));
@@ -135,12 +137,16 @@ class alignas(std::uint16_t) bfloat16 {
135137
136138 // Do operation with different type
137139 // If it is floating point, using floating point as type.
138- // If it is integer, using bfloat16 as type
140+ // If it is bfloat16, using float as type.
141+ // If it is integer, using bfloat16 as type.
139142#define BFLOAT16_FRIEND_OPERATOR (_op, _opeq ) \
140143 template <typename T> \
141144 friend std::enable_if_t < \
142- !std::is_same<T, bfloat16>::value && std::is_scalar<T>::value, \
143- std::conditional_t <std::is_floating_point<T>::value, T, bfloat16>> \
145+ !std::is_same<T, bfloat16>::value && \
146+ (std::is_scalar<T>::value || std::is_same_v<T, half>), \
147+ std::conditional_t < \
148+ std::is_floating_point<T>::value, T, \
149+ std::conditional_t <std::is_same_v<T, half>, float , bfloat16>>> \
144150 operator _op (const bfloat16& hf, const T& val) \
145151 { \
146152 using type = \
@@ -151,8 +157,11 @@ class alignas(std::uint16_t) bfloat16 {
151157 } \
152158 template <typename T> \
153159 friend std::enable_if_t < \
154- !std::is_same<T, bfloat16>::value && std::is_scalar<T>::value, \
155- std::conditional_t <std::is_floating_point<T>::value, T, bfloat16>> \
160+ !std::is_same<T, bfloat16>::value && \
161+ (std::is_scalar<T>::value || std::is_same_v<T, half>), \
162+ std::conditional_t < \
163+ std::is_floating_point<T>::value, T, \
164+ std::conditional_t <std::is_same_v<T, half>, float , bfloat16>>> \
156165 operator _op (const T& val, const bfloat16& hf) \
157166 { \
158167 using type = \
@@ -255,23 +264,29 @@ class complex<gko::bfloat16> {
255264 : real_(real), imag_(imag)
256265 {}
257266
258- template <typename T, typename U,
259- typename = std::enable_if_t <std::is_scalar<T>::value &&
260- std::is_scalar<U>::value>>
267+ template <
268+ typename T, typename U,
269+ typename = std::enable_if_t <
270+ (std::is_scalar<T>::value || std::is_same_v<T, gko::half>)&&(
271+ std::is_scalar<U>::value || std::is_same_v<U, gko::half>)>>
261272 explicit complex (const T& real, const U& imag)
262273 : real_(static_cast <value_type>(real)),
263274 imag_(static_cast <value_type>(imag))
264275 {}
265276
266- template <typename T, typename = std::enable_if_t <std::is_scalar<T>::value>>
277+ template <typename T,
278+ typename = std::enable_if_t <std::is_scalar<T>::value ||
279+ std::is_same_v<T, gko::half>>>
267280 complex (const T& real)
268281 : real_(static_cast <value_type>(real)),
269282 imag_ (static_cast <value_type>(0 .f))
270283 {}
271284
272285 // When using complex(real, imag), MSVC with CUDA try to recognize the
273286 // complex is a member not constructor.
274- template <typename T, typename = std::enable_if_t <std::is_scalar<T>::value>>
287+ template <typename T,
288+ typename = std::enable_if_t <std::is_scalar<T>::value ||
289+ std::is_same_v<T, gko::half>>>
275290 explicit complex (const complex <T>& other)
276291 : real_(static_cast <value_type>(other.real())),
277292 imag_(static_cast <value_type>(other.imag()))
0 commit comments