Skip to content

Commit 4d1bfdf

Browse files
committed
add bf16_alias.hpp
1 parent 23b2c27 commit 4d1bfdf

10 files changed

Lines changed: 131 additions & 100 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#ifndef GKO_COMMON_CUDA_HIP_BASE_BF16_ALIAS_HPP_
6+
#define GKO_COMMON_CUDA_HIP_BASE_BF16_ALIAS_HPP_
7+
8+
9+
#ifdef GKO_COMPILING_CUDA
10+
11+
12+
#include <cuda_bf16.h>
13+
14+
15+
namespace gko {
16+
17+
18+
using vendor_bf16 = __nv_bfloat16;
19+
20+
21+
}
22+
23+
24+
#elif defined(GKO_COMPILING_HIP)
25+
26+
27+
#if HIP_VERSION >= 60200000
28+
// HIP has __hip_bfloat16 after ROCM 5.6.0 but enough implementation for us
29+
// (conversion and operation overload) after ROCM 6.2.0 which provides more
30+
// native operations support.
31+
#include <hip/hip_bf16.h>
32+
33+
namespace gko {
34+
35+
36+
using vendor_bf16 = __hip_bfloat16;
37+
38+
39+
}
40+
41+
42+
#else
43+
44+
45+
// HIP has hip_bfloat16 but only the type with the operation fallback to the
46+
// single precision
47+
#include <hip/hip_bfloat16.h>
48+
49+
50+
namespace gko {
51+
52+
53+
using vendor_bf16 = hip_bfloat16;
54+
55+
56+
}
57+
58+
59+
#endif
60+
#endif
61+
#endif // GKO_COMMON_CUDA_HIP_BASE_BF16_ALIAS_HPP_

common/cuda_hip/base/math.hpp

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,18 @@
1414
#ifdef GKO_COMPILING_CUDA
1515

1616

17-
#include <cuda_bf16.h>
1817
#include <cuda_fp16.h>
1918

20-
using vendor_bf16 = __nv_bfloat16;
21-
2219

2320
#elif defined(GKO_COMPILING_HIP)
2421

2522

26-
#if HIP_VERSION >= 60200000
27-
// HIP has __hip_bfloat16 after ROCM 5.6.0 but enough implementation for us
28-
// (conversion and operation overload) after ROCM 6.2.0 which provides more
29-
// native operations support.
30-
#include <hip/hip_bf16.h>
31-
using vendor_bf16 = __hip_bfloat16;
32-
#else
33-
// HIP has hip_bfloat16 but only the type with the operation fallback to the
34-
// single precision
35-
#include <hip/hip_bfloat16.h>
36-
using vendor_bf16 = hip_bfloat16;
37-
#endif
38-
39-
4023
#include <hip/hip_fp16.h>
4124

4225

4326
#endif
4427

45-
28+
#include "common/cuda_hip/base/bf16_alias.hpp"
4629
#include "common/cuda_hip/base/thrust_macro.hpp"
4730

4831

@@ -242,16 +225,16 @@ GKO_ATTRIBUTES GKO_INLINE __half abs<__half>(const complex<__half>& z)
242225
#if GINKGO_ENABLE_BFLOAT16
243226

244227
template <>
245-
GKO_ATTRIBUTES GKO_INLINE complex<vendor_bf16> sqrt<vendor_bf16>(
246-
const complex<vendor_bf16>& a)
228+
GKO_ATTRIBUTES GKO_INLINE complex<gko::vendor_bf16> sqrt<gko::vendor_bf16>(
229+
const complex<gko::vendor_bf16>& a)
247230
{
248231
return sqrt(static_cast<complex<float>>(a));
249232
}
250233

251234

252235
template <>
253-
GKO_ATTRIBUTES GKO_INLINE vendor_bf16
254-
abs<vendor_bf16>(const complex<vendor_bf16>& z)
236+
GKO_ATTRIBUTES GKO_INLINE gko::vendor_bf16 abs<gko::vendor_bf16>(
237+
const complex<gko::vendor_bf16>& z)
255238
{
256239
return abs(static_cast<complex<float>>(z));
257240
}

common/cuda_hip/base/types.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#ifndef GKO_COMMON_CUDA_HIP_BASE_TYPES_HPP_
66
#define GKO_COMMON_CUDA_HIP_BASE_TYPES_HPP_
77

8+
#include "common/cuda_hip/base/bf16_alias.hpp"
89
#include "common/cuda_hip/base/math.hpp"
910
#if defined(GKO_COMPILING_CUDA)
1011
#include "cuda/base/types.hpp"
@@ -34,13 +35,13 @@ THRUST_HALF_FRIEND_OPERATOR(/, /=)
3435
#undef THRUST_HALF_FRIEND_OPERATOR
3536

3637

37-
#define THRUST_BF16_FRIEND_OPERATOR(_op, _opeq) \
38-
GKO_ATTRIBUTES GKO_INLINE GKO_THRUST_QUALIFIER::complex<vendor_bf16> \
39-
operator _op(const GKO_THRUST_QUALIFIER::complex<vendor_bf16> lhs, \
40-
const GKO_THRUST_QUALIFIER::complex<vendor_bf16> rhs) \
41-
{ \
42-
return GKO_THRUST_QUALIFIER::complex<float>{ \
43-
lhs} _op GKO_THRUST_QUALIFIER::complex<float>(rhs); \
38+
#define THRUST_BF16_FRIEND_OPERATOR(_op, _opeq) \
39+
GKO_ATTRIBUTES GKO_INLINE GKO_THRUST_QUALIFIER::complex<gko::vendor_bf16> \
40+
operator _op(const GKO_THRUST_QUALIFIER::complex<gko::vendor_bf16> lhs, \
41+
const GKO_THRUST_QUALIFIER::complex<gko::vendor_bf16> rhs) \
42+
{ \
43+
return GKO_THRUST_QUALIFIER::complex<float>{ \
44+
lhs} _op GKO_THRUST_QUALIFIER::complex<float>(rhs); \
4445
}
4546

4647
THRUST_BF16_FRIEND_OPERATOR(+, +=)

core/test/base/bfloat16.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,6 @@ TEST(FloatToBFloat16, TruncatesLargeNumberRoundToEven)
134134
}
135135

136136

137-
// TEST(FloatToBFloat16, Convert)
138-
// {
139-
// float rho = 86.25;
140-
// float beta = 1110;
141-
// auto float_res = rho/beta;
142-
// gko::bfloat16 rho_h = rho;
143-
// gko::bfloat16 beta_h = beta;
144-
// auto bfloat16_res = rho_h/beta_h;
145-
// std::cout << float_res << std::endl;
146-
// std::cout << float(bfloat16_res) << std::endl;
147-
148-
// std::complex<gko::bfloat16> cpx{100.0, 0.0};
149-
// std::cout << float(gko::squared_norm(cpx)) << std::endl;
150-
// }
151-
152-
153137
TEST(Bfloat16ToFloat, ConvertsOne)
154138
{
155139
float x = create_from_bits<bfloat16>("0" "01111111" "0000000");

dpcpp/base/bf16_alias.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#ifndef GKO_DPCPP_BASE_BF16_ALIAS_HPP_
6+
#define GKO_DPCPP_BASE_BF16_ALIAS_HPP_
7+
8+
#include <sycl/ext/oneapi/bfloat16.hpp>
9+
10+
namespace gko {
11+
12+
13+
using vendor_bf16 = sycl::ext::oneapi::bfloat16;
14+
15+
16+
}
17+
18+
19+
#endif // GKO_DPCPP_BASE_BF16_ALIAS_HPP_

dpcpp/base/complex.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
#include <complex>
99

10-
#include <sycl/ext/oneapi/bfloat16.hpp>
1110
#include <sycl/half_type.hpp>
1211

1312
#include <ginkgo/config.hpp>
1413

14+
#include "dpcpp/base/bf16_alias.hpp"
15+
1516

1617
namespace gko {
1718

@@ -204,9 +205,9 @@ class complex<sycl::half> {
204205

205206

206207
template <>
207-
class complex<sycl::ext::oneapi::bfloat16> {
208+
class complex<vendor_bf16> {
208209
public:
209-
using value_type = sycl::ext::oneapi::bfloat16;
210+
using value_type = vendor_bf16;
210211

211212
complex(const value_type& real = value_type(0.f),
212213
const value_type& imag = value_type(0.f))

dpcpp/base/math.hpp

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
#include <cmath>
1010

1111
#include <sycl/bit_cast.hpp>
12-
#include <sycl/ext/oneapi/bfloat16.hpp>
1312
#include <sycl/half_type.hpp>
1413

1514
#include <ginkgo/core/base/math.hpp>
1615

16+
#include "dpcpp/base/bf16_alias.hpp"
1717
#include "dpcpp/base/complex.hpp"
1818
#include "dpcpp/base/dpct.hpp"
1919

@@ -31,8 +31,8 @@ struct basic_float_traits<sycl::half> {
3131
};
3232

3333
template <>
34-
struct basic_float_traits<sycl::ext::oneapi::bfloat16> {
35-
using type = sycl::ext::oneapi::bfloat16;
34+
struct basic_float_traits<vendor_bf16> {
35+
using type = vendor_bf16;
3636
static constexpr int sign_bits = 1;
3737
static constexpr int significand_bits = 7;
3838
static constexpr int exponent_bits = 8;
@@ -44,8 +44,7 @@ template <>
4444
struct is_complex_or_scalar_impl<sycl::half> : public std::true_type {};
4545

4646
template <>
47-
struct is_complex_or_scalar_impl<sycl::ext::oneapi::bfloat16>
48-
: public std::true_type {};
47+
struct is_complex_or_scalar_impl<vendor_bf16> : public std::true_type {};
4948

5049
template <typename ValueType>
5150
struct complex_helper {
@@ -58,8 +57,8 @@ struct complex_helper<sycl::half> {
5857
};
5958

6059
template <>
61-
struct complex_helper<sycl::ext::oneapi::bfloat16> {
62-
using type = gko::complex<sycl::ext::oneapi::bfloat16>;
60+
struct complex_helper<vendor_bf16> {
61+
using type = gko::complex<vendor_bf16>;
6362
};
6463

6564

@@ -105,22 +104,22 @@ struct device_numeric_limits {
105104
// constructor. we use sycl::bit_cast (not guarenteed be constexpr) to create
106105
// the corresponding bfloat16
107106
template <>
108-
struct device_numeric_limits<sycl::ext::oneapi::bfloat16> {
107+
struct device_numeric_limits<vendor_bf16> {
109108
static GKO_ATTRIBUTES GKO_INLINE auto inf()
110109
{
111-
return sycl::bit_cast<sycl::ext::oneapi::bfloat16>(
110+
return sycl::bit_cast<vendor_bf16>(
112111
static_cast<unsigned short>(0b0'11111111'0000000u));
113112
}
114113

115114
static GKO_ATTRIBUTES GKO_INLINE auto max()
116115
{
117-
return sycl::bit_cast<sycl::ext::oneapi::bfloat16>(
116+
return sycl::bit_cast<vendor_bf16>(
118117
static_cast<unsigned short>(0b0'11111110'1111111u));
119118
}
120119

121120
static GKO_ATTRIBUTES GKO_INLINE auto min()
122121
{
123-
return sycl::bit_cast<sycl::ext::oneapi::bfloat16>(
122+
return sycl::bit_cast<vendor_bf16>(
124123
static_cast<unsigned short>(0b0'00000001'0000000u));
125124
}
126125
};
@@ -170,51 +169,45 @@ bool __dpct_inline__ is_finite(const gko::complex<sycl::half>& value)
170169
}
171170

172171

173-
bool __dpct_inline__ is_nan(const sycl::ext::oneapi::bfloat16& val)
172+
bool __dpct_inline__ is_nan(const vendor_bf16& val)
174173
{
175174
return std::isnan(static_cast<float>(val));
176175
}
177176

178-
bool __dpct_inline__
179-
is_nan(const gko::complex<sycl::ext::oneapi::bfloat16>& val)
177+
bool __dpct_inline__ is_nan(const gko::complex<vendor_bf16>& val)
180178
{
181179
return is_nan(val.real()) || is_nan(val.imag());
182180
}
183181

184182

185-
sycl::ext::oneapi::bfloat16 __dpct_inline__
186-
abs(const sycl::ext::oneapi::bfloat16& val)
183+
vendor_bf16 __dpct_inline__ abs(const vendor_bf16& val)
187184
{
188185
return abs(static_cast<float>(val));
189186
}
190187

191-
sycl::ext::oneapi::bfloat16 __dpct_inline__
192-
abs(const gko::complex<sycl::ext::oneapi::bfloat16>& val)
188+
vendor_bf16 __dpct_inline__ abs(const gko::complex<vendor_bf16>& val)
193189
{
194190
return abs(static_cast<std::complex<float>>(val));
195191
}
196192

197-
sycl::ext::oneapi::bfloat16 __dpct_inline__
198-
sqrt(const sycl::ext::oneapi::bfloat16& val)
193+
vendor_bf16 __dpct_inline__ sqrt(const vendor_bf16& val)
199194
{
200195
return sqrt(static_cast<float>(val));
201196
}
202197

203-
gko::complex<sycl::ext::oneapi::bfloat16> __dpct_inline__
204-
sqrt(const gko::complex<sycl::ext::oneapi::bfloat16>& val)
198+
gko::complex<vendor_bf16> __dpct_inline__
199+
sqrt(const gko::complex<vendor_bf16>& val)
205200
{
206201
return sqrt(static_cast<std::complex<float>>(val));
207202
}
208203

209204

210-
bool __dpct_inline__ is_finite(const sycl::ext::oneapi::bfloat16& value)
205+
bool __dpct_inline__ is_finite(const vendor_bf16& value)
211206
{
212-
return abs(value) <
213-
device_numeric_limits<sycl::ext::oneapi::bfloat16>::inf();
207+
return abs(value) < device_numeric_limits<vendor_bf16>::inf();
214208
}
215209

216-
bool __dpct_inline__
217-
is_finite(const gko::complex<sycl::ext::oneapi::bfloat16>& value)
210+
bool __dpct_inline__ is_finite(const gko::complex<vendor_bf16>& value)
218211
{
219212
return is_finite(value.real()) && is_finite(value.imag());
220213
}

dpcpp/base/types.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <ginkgo/core/base/matrix_data.hpp>
1616
#include <ginkgo/core/base/types.hpp>
1717

18+
#include "dpcpp/base/bf16_alias.hpp"
1819
#include "dpcpp/base/complex.hpp"
1920

2021

@@ -56,7 +57,7 @@ struct sycl_type_impl<half> {
5657

5758
template <>
5859
struct sycl_type_impl<bfloat16> {
59-
using type = sycl::ext::oneapi::bfloat16;
60+
using type = vendor_bf16;
6061
};
6162

6263
template <typename T>

0 commit comments

Comments
 (0)