99#include < cmath>
1010
1111#include < sycl/bit_cast.hpp>
12- #include < sycl/ext/oneapi/bfloat16.hpp>
1312#include < sycl/half_type.hpp>
1413
1514#include < ginkgo/core/base/math.hpp>
1615
16+ #include " dpcpp/base/bf16_alias.hpp"
1717#include " dpcpp/base/complex.hpp"
1818#include " dpcpp/base/dpct.hpp"
1919
@@ -31,8 +31,8 @@ struct basic_float_traits<sycl::half> {
3131};
3232
3333template <>
34- struct basic_float_traits <sycl::ext::oneapi::bfloat16 > {
35- using type = sycl::ext::oneapi::bfloat16 ;
34+ struct basic_float_traits <vendor_bf16 > {
35+ using type = vendor_bf16 ;
3636 static constexpr int sign_bits = 1 ;
3737 static constexpr int significand_bits = 7 ;
3838 static constexpr int exponent_bits = 8 ;
@@ -44,8 +44,7 @@ template <>
4444struct is_complex_or_scalar_impl <sycl::half> : public std::true_type {};
4545
4646template <>
47- struct is_complex_or_scalar_impl <sycl::ext::oneapi::bfloat16>
48- : public std::true_type {};
47+ struct is_complex_or_scalar_impl <vendor_bf16> : public std::true_type {};
4948
5049template <typename ValueType>
5150struct complex_helper {
@@ -58,8 +57,8 @@ struct complex_helper<sycl::half> {
5857};
5958
6059template <>
61- struct complex_helper <sycl::ext::oneapi::bfloat16 > {
62- using type = gko::complex <sycl::ext::oneapi::bfloat16 >;
60+ struct complex_helper <vendor_bf16 > {
61+ using type = gko::complex <vendor_bf16 >;
6362};
6463
6564
@@ -105,22 +104,22 @@ struct device_numeric_limits {
105104// constructor. we use sycl::bit_cast (not guarenteed be constexpr) to create
106105// the corresponding bfloat16
107106template <>
108- struct device_numeric_limits <sycl::ext::oneapi::bfloat16 > {
107+ struct device_numeric_limits <vendor_bf16 > {
109108 static GKO_ATTRIBUTES GKO_INLINE auto inf ()
110109 {
111- return sycl::bit_cast<sycl::ext::oneapi::bfloat16 >(
110+ return sycl::bit_cast<vendor_bf16 >(
112111 static_cast <unsigned short >(0b0'11111111'0000000u ));
113112 }
114113
115114 static GKO_ATTRIBUTES GKO_INLINE auto max ()
116115 {
117- return sycl::bit_cast<sycl::ext::oneapi::bfloat16 >(
116+ return sycl::bit_cast<vendor_bf16 >(
118117 static_cast <unsigned short >(0b0'11111110'1111111u ));
119118 }
120119
121120 static GKO_ATTRIBUTES GKO_INLINE auto min ()
122121 {
123- return sycl::bit_cast<sycl::ext::oneapi::bfloat16 >(
122+ return sycl::bit_cast<vendor_bf16 >(
124123 static_cast <unsigned short >(0b0'00000001'0000000u ));
125124 }
126125};
@@ -170,51 +169,45 @@ bool __dpct_inline__ is_finite(const gko::complex<sycl::half>& value)
170169}
171170
172171
173- bool __dpct_inline__ is_nan (const sycl::ext::oneapi::bfloat16 & val)
172+ bool __dpct_inline__ is_nan (const vendor_bf16 & val)
174173{
175174 return std::isnan (static_cast <float >(val));
176175}
177176
178- bool __dpct_inline__
179- is_nan (const gko::complex <sycl::ext::oneapi::bfloat16>& val)
177+ bool __dpct_inline__ is_nan (const gko::complex <vendor_bf16>& val)
180178{
181179 return is_nan (val.real ()) || is_nan (val.imag ());
182180}
183181
184182
185- sycl::ext::oneapi::bfloat16 __dpct_inline__
186- abs (const sycl::ext::oneapi::bfloat16& val)
183+ vendor_bf16 __dpct_inline__ abs (const vendor_bf16& val)
187184{
188185 return abs (static_cast <float >(val));
189186}
190187
191- sycl::ext::oneapi::bfloat16 __dpct_inline__
192- abs (const gko::complex <sycl::ext::oneapi::bfloat16>& val)
188+ vendor_bf16 __dpct_inline__ abs (const gko::complex <vendor_bf16>& val)
193189{
194190 return abs (static_cast <std::complex <float >>(val));
195191}
196192
197- sycl::ext::oneapi::bfloat16 __dpct_inline__
198- sqrt (const sycl::ext::oneapi::bfloat16& val)
193+ vendor_bf16 __dpct_inline__ sqrt (const vendor_bf16& val)
199194{
200195 return sqrt (static_cast <float >(val));
201196}
202197
203- gko::complex <sycl::ext::oneapi::bfloat16 > __dpct_inline__
204- sqrt (const gko::complex <sycl::ext::oneapi::bfloat16 >& val)
198+ gko::complex <vendor_bf16 > __dpct_inline__
199+ sqrt (const gko::complex <vendor_bf16 >& val)
205200{
206201 return sqrt (static_cast <std::complex <float >>(val));
207202}
208203
209204
210- bool __dpct_inline__ is_finite (const sycl::ext::oneapi::bfloat16 & value)
205+ bool __dpct_inline__ is_finite (const vendor_bf16 & value)
211206{
212- return abs (value) <
213- device_numeric_limits<sycl::ext::oneapi::bfloat16>::inf ();
207+ return abs (value) < device_numeric_limits<vendor_bf16>::inf ();
214208}
215209
216- bool __dpct_inline__
217- is_finite (const gko::complex <sycl::ext::oneapi::bfloat16>& value)
210+ bool __dpct_inline__ is_finite (const gko::complex <vendor_bf16>& value)
218211{
219212 return is_finite (value.real ()) && is_finite (value.imag ());
220213}
0 commit comments