Skip to content

Commit 41735db

Browse files
committed
fix/add to allow half and bfloat16 at the same time
1 parent cab3c04 commit 41735db

10 files changed

Lines changed: 132 additions & 29 deletions

File tree

common/unified/components/fill_array_kernels.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ void fill_seq_array(std::shared_ptr<const DefaultExecutor> exec,
4141
run_kernel(
4242
exec,
4343
[] GKO_KERNEL(auto idx, auto array) {
44-
if constexpr (std::is_same_v<remove_complex<ValueType>, float16>) {
44+
if constexpr (std::is_same_v<remove_complex<ValueType>, float16> ||
45+
std::is_same_v<remove_complex<ValueType>, bfloat16>) {
4546
// __half can not be from int64_t
4647
// __hip_bfloat16 can not be from long long
4748
array[idx] = static_cast<float>(idx);

core/config/config_helper.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ get_value(const pnode& config)
213213
*/
214214
template <typename ValueType>
215215
inline std::enable_if_t<std::is_floating_point<ValueType>::value ||
216-
std::is_same<ValueType, float16>::value,
216+
std::is_same<ValueType, float16>::value ||
217+
std::is_same<ValueType, bfloat16>::value,
217218
ValueType>
218219
get_value(const pnode& config)
219220
{

core/config/type_descriptor_helper.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ TYPE_STRING_OVERLOAD(void, "void");
4040
TYPE_STRING_OVERLOAD(double, "float64");
4141
TYPE_STRING_OVERLOAD(float, "float32");
4242
TYPE_STRING_OVERLOAD(float16, "float16");
43+
TYPE_STRING_OVERLOAD(bfloat16, "bfloat16");
4344
TYPE_STRING_OVERLOAD(std::complex<double>, "complex<float64>");
4445
TYPE_STRING_OVERLOAD(std::complex<float>, "complex<float32>");
4546
TYPE_STRING_OVERLOAD(std::complex<float16>, "complex<float16>");
47+
TYPE_STRING_OVERLOAD(std::complex<bfloat16>, "complex<bfloat16>");
4648
TYPE_STRING_OVERLOAD(int32, "int32");
4749
TYPE_STRING_OVERLOAD(int64, "int64");
4850

cuda/solver/common_trs_kernels.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,9 @@ __global__ void sptrsv_naive_caching_kernel(
415415
// memory operation on the half-precision shared_memory seem to give
416416
// wrong result. we use float in shared_memory.
417417
using SharedValueType = std::conditional_t<
418-
std::is_same<remove_complex<ValueType>, device_type<float16>>::value,
418+
std::is_same<remove_complex<ValueType>, device_type<float16>>::value ||
419+
std::is_same<remove_complex<ValueType>,
420+
device_type<bfloat16>>::value,
419421
std::conditional_t<is_complex<ValueType>(), thrust::complex<float>,
420422
float>,
421423
ValueType>;

include/ginkgo/core/base/bfloat16.hpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ class alignas(std::uint16_t) bfloat16 {
8989
// caused by something else in jacobi or isai.
9090
constexpr bfloat16() noexcept : data_(0){};
9191

92-
template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
92+
template <typename T,
93+
typename = std::enable_if_t<std::is_scalar<T>::value ||
94+
std::is_same_v<T, half>>>
9395
bfloat16(const T& val) : data_(0)
9496
{
9597
this->float2bfloat16(static_cast<float>(val));
@@ -135,12 +137,16 @@ class alignas(std::uint16_t) bfloat16 {
135137

136138
// Do operation with different type
137139
// If it is floating point, using floating point as type.
138-
// If it is integer, using bfloat16 as type
140+
// If it is bfloat16, using float as type.
141+
// If it is integer, using bfloat16 as type.
139142
#define BFLOAT16_FRIEND_OPERATOR(_op, _opeq) \
140143
template <typename T> \
141144
friend std::enable_if_t< \
142-
!std::is_same<T, bfloat16>::value && std::is_scalar<T>::value, \
143-
std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>> \
145+
!std::is_same<T, bfloat16>::value && \
146+
(std::is_scalar<T>::value || std::is_same_v<T, half>), \
147+
std::conditional_t< \
148+
std::is_floating_point<T>::value, T, \
149+
std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
144150
operator _op(const bfloat16& hf, const T& val) \
145151
{ \
146152
using type = \
@@ -151,8 +157,11 @@ class alignas(std::uint16_t) bfloat16 {
151157
} \
152158
template <typename T> \
153159
friend std::enable_if_t< \
154-
!std::is_same<T, bfloat16>::value && std::is_scalar<T>::value, \
155-
std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>> \
160+
!std::is_same<T, bfloat16>::value && \
161+
(std::is_scalar<T>::value || std::is_same_v<T, half>), \
162+
std::conditional_t< \
163+
std::is_floating_point<T>::value, T, \
164+
std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
156165
operator _op(const T& val, const bfloat16& hf) \
157166
{ \
158167
using type = \
@@ -255,23 +264,29 @@ class complex<gko::bfloat16> {
255264
: real_(real), imag_(imag)
256265
{}
257266

258-
template <typename T, typename U,
259-
typename = std::enable_if_t<std::is_scalar<T>::value &&
260-
std::is_scalar<U>::value>>
267+
template <
268+
typename T, typename U,
269+
typename = std::enable_if_t<
270+
(std::is_scalar<T>::value || std::is_same_v<T, gko::half>)&&(
271+
std::is_scalar<U>::value || std::is_same_v<U, gko::half>)>>
261272
explicit complex(const T& real, const U& imag)
262273
: real_(static_cast<value_type>(real)),
263274
imag_(static_cast<value_type>(imag))
264275
{}
265276

266-
template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
277+
template <typename T,
278+
typename = std::enable_if_t<std::is_scalar<T>::value ||
279+
std::is_same_v<T, gko::half>>>
267280
complex(const T& real)
268281
: real_(static_cast<value_type>(real)),
269282
imag_(static_cast<value_type>(0.f))
270283
{}
271284

272285
// When using complex(real, imag), MSVC with CUDA try to recognize the
273286
// complex is a member not constructor.
274-
template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
287+
template <typename T,
288+
typename = std::enable_if_t<std::is_scalar<T>::value ||
289+
std::is_same_v<T, gko::half>>>
275290
explicit complex(const complex<T>& other)
276291
: real_(static_cast<value_type>(other.real())),
277292
imag_(static_cast<value_type>(other.imag()))

include/ginkgo/core/base/half.hpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -25,6 +25,8 @@ class truncated;
2525

2626
class half;
2727

28+
class bfloat16;
29+
2830

2931
namespace detail {
3032

@@ -298,7 +300,9 @@ class alignas(std::uint16_t) half {
298300
// caused by something else in jacobi or isai.
299301
constexpr half() noexcept : data_(0){};
300302

301-
template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
303+
template <typename T,
304+
typename = std::enable_if_t<std::is_scalar<T>::value ||
305+
std::is_same_v<T, bfloat16>>>
302306
half(const T& val) : data_(0)
303307
{
304308
this->float2half(static_cast<float>(val));
@@ -345,6 +349,8 @@ class alignas(std::uint16_t) half {
345349
// Do operation with different type
346350
// If it is floating point, using floating point as type.
347351
// If it is integer, using half as type
352+
// Note: we do not define the operation with bfloat16, which is already
353+
// defined in bfloat16.hpp
348354
#define HALF_FRIEND_OPERATOR(_op, _opeq) \
349355
template <typename T> \
350356
friend std::enable_if_t< \
@@ -464,23 +470,29 @@ class complex<gko::half> {
464470
: real_(real), imag_(imag)
465471
{}
466472

467-
template <typename T, typename U,
468-
typename = std::enable_if_t<std::is_scalar<T>::value &&
469-
std::is_scalar<U>::value>>
473+
template <
474+
typename T, typename U,
475+
typename = std::enable_if_t<
476+
(std::is_scalar<T>::value || std::is_same_v<T, gko::bfloat16>)&&(
477+
std::is_scalar<U>::value || std::is_same_v<U, gko::bfloat16>)>>
470478
explicit complex(const T& real, const U& imag)
471479
: real_(static_cast<value_type>(real)),
472480
imag_(static_cast<value_type>(imag))
473481
{}
474482

475-
template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
483+
template <typename T,
484+
typename = std::enable_if_t<std::is_scalar<T>::value ||
485+
std::is_same_v<T, gko::bfloat16>>>
476486
complex(const T& real)
477487
: real_(static_cast<value_type>(real)),
478488
imag_(static_cast<value_type>(0.f))
479489
{}
480490

481491
// When using complex(real, imag), MSVC with CUDA try to recognize the
482492
// complex is a member not constructor.
483-
template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
493+
template <typename T,
494+
typename = std::enable_if_t<std::is_scalar<T>::value ||
495+
std::is_same_v<T, gko::bfloat16>>>
484496
explicit complex(const complex<T>& other)
485497
: real_(static_cast<value_type>(other.real())),
486498
imag_(static_cast<value_type>(other.imag()))

include/ginkgo/core/base/mpi.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,17 @@ GKO_REGISTER_MPI_TYPE(unsigned long long, MPI_UNSIGNED_LONG_LONG);
8989
GKO_REGISTER_MPI_TYPE(float, MPI_FLOAT);
9090
GKO_REGISTER_MPI_TYPE(double, MPI_DOUBLE);
9191
GKO_REGISTER_MPI_TYPE(long double, MPI_LONG_DOUBLE);
92-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
92+
#if GINKGO_ENABLE_HALF
9393
// OpenMPI 5.0 have support from MPIX_C_FLOAT16 and MPICHv3.4a1 MPIX_C_FLOAT16
9494
// Only OpenMPI support complex float16
9595
// TODO: use native type when mpi is configured with half feature
96-
GKO_REGISTER_MPI_TYPE(float16, MPI_UNSIGNED_SHORT);
97-
GKO_REGISTER_MPI_TYPE(std::complex<float16>, MPI_FLOAT);
96+
GKO_REGISTER_MPI_TYPE(half, MPI_UNSIGNED_SHORT);
97+
GKO_REGISTER_MPI_TYPE(std::complex<half>, MPI_FLOAT);
9898
#endif // GKO_ENABLE_HALF
99+
#if GINKGO_ENABLE_BFLOAT16
100+
GKO_REGISTER_MPI_TYPE(bfloat16, MPI_UNSIGNED_SHORT);
101+
GKO_REGISTER_MPI_TYPE(std::complex<bfloat16>, MPI_FLOAT);
102+
#endif // GKO_ENABLE_BFLOAT16
99103
GKO_REGISTER_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX);
100104
GKO_REGISTER_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
101105

include/ginkgo/core/base/types.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,7 @@ using uintptr = std::uintptr_t;
145145
/**
146146
* 16 bit floating point type.
147147
*/
148-
#if !GINKGO_ENABLE_BFLOAT16
149148
using float16 = half;
150-
#else
151-
using float16 = bfloat16;
152-
#endif
153149

154150

155151
/**

include/ginkgo/core/log/logger.hpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ public: \
579579
{}
580580

581581

582-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
582+
#if GINKGO_ENABLE_HALF
583583

584584

585585
/**
@@ -598,6 +598,25 @@ public: \
598598
#endif
599599

600600

601+
#if GINKGO_ENABLE_BFLOAT16
602+
603+
604+
/**
605+
* Batch solver's event that records the iteration count and the residual
606+
* norm.
607+
*
608+
* @param iters the array of iteration counts.
609+
* @param residual_norms the array storing the residual norms.
610+
*/
611+
virtual void on_batch_solver_completed(
612+
const array<int>& iters,
613+
const array<gko::bfloat16>& residual_norms) const
614+
{}
615+
616+
617+
#endif
618+
619+
601620
public:
602621
#undef GKO_LOGGER_REGISTER_EVENT
603622

omp/components/atomic.hpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,40 @@ inline void atomic_add(float16& out, float16 val)
8484
}
8585

8686

87+
template <>
88+
inline void atomic_add(bfloat16& out, bfloat16 val)
89+
{
90+
#ifdef __NVCOMPILER
91+
// NVC++ uses atomic capture on uint16 leads the following error.
92+
// use of undefined value '%L.B*' br label %L.B* !llvm.loop !*, !dbg !*
93+
#pragma omp critical
94+
{
95+
out += val;
96+
}
97+
#else
98+
static_assert(
99+
sizeof(bfloat16) == sizeof(uint16_t) &&
100+
std::alignment_of_v<uint16_t> == std::alignment_of_v<bfloat16>,
101+
"half does not fulfill the requirement of reinterpret_cast to half or "
102+
"vice versa.");
103+
// It is undefined behavior with reinterpret_cast, but we do not have any
104+
// workaround when the #omp atomic does not support custom precision
105+
uint16_t* address_as_converter = reinterpret_cast<uint16_t*>(&out);
106+
uint16_t old = *address_as_converter;
107+
uint16_t assumed;
108+
do {
109+
assumed = old;
110+
auto answer = copy_cast<uint16_t>(copy_cast<bfloat16>(assumed) + val);
111+
#pragma omp atomic capture
112+
{
113+
old = *address_as_converter;
114+
*address_as_converter = (old == assumed) ? answer : old;
115+
}
116+
} while (assumed != old);
117+
#endif
118+
}
119+
120+
87121
// There is an error in Clang 17 which prevents us from merging the
88122
// implementation of double and float. The compiler will throw an error if the
89123
// templated version is implemented. GCC doesn't throw an error.
@@ -119,6 +153,14 @@ inline void store(float16* addr, float16 val)
119153
*uint_addr = uint_val;
120154
}
121155

156+
inline void store(bfloat16* addr, bfloat16 val)
157+
{
158+
auto uint_addr = copy_cast<uint16_t*>(addr);
159+
auto uint_val = copy_cast<uint16_t>(val);
160+
#pragma omp atomic write
161+
*uint_addr = uint_val;
162+
}
163+
122164
template <typename T>
123165
inline void store(std::complex<T>* addr, std::complex<T> val)
124166
{
@@ -170,6 +212,15 @@ inline float16 load(float16* addr)
170212
return copy_cast<float16>(uint_val);
171213
}
172214

215+
inline bfloat16 load(bfloat16* addr)
216+
{
217+
uint16_t uint_val;
218+
auto uint_addr = copy_cast<uint16_t*>(addr);
219+
#pragma omp atomic read
220+
uint_val = *uint_addr;
221+
return copy_cast<bfloat16>(uint_val);
222+
}
223+
173224
template <typename T>
174225
inline std::complex<T> load(std::complex<T>* addr)
175226
{

0 commit comments

Comments
 (0)