Skip to content

Commit 695daa9

Browse files
committed
some changes necessary for bf16/fp16 updates in rocm
1 parent 04c297b commit 695daa9

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

include/lbann/utils/impl/rocm.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "hipcub/block/block_reduce.hpp"
3232
#endif // HYDROGEN_HAVE_CUB
3333
#include <hip/hip_fp16.h>
34+
#include <hip/hip_bf16.h>
3435
#include <limits>
3536
#endif // __HIPCC__
3637

@@ -165,7 +166,7 @@ __device__ __forceinline__ T gpu_lib::block_reduce(T val)
165166
#define WRAP_UNARY_ROCM_HALF_MATH_FUNCTION(func) \
166167
__device__ __forceinline__ __half gpu_lib::func(__half const& x) \
167168
{ \
168-
return ::h##func(x); \
169+
return h##func(x); \
169170
}
170171

171172
// FIXME (trb): This is maybe not the best long-term solution, but it
@@ -190,7 +191,7 @@ WRAP_UNARY_ROCM_HALF_MATH_FUNCTION(exp)
190191
// implementation could be:
191192
__device__ __forceinline__ __half gpu_lib::expm1(__half const& x)
192193
{
193-
return ::__hsub(::hexp(x), ::__float2half(1.f));
194+
return __hsub(hexp(x), __float2half(1.f));
194195
}
195196

196197
WRAP_UNARY_ROCM_HALF_MATH_FUNCTION(log)
@@ -204,7 +205,7 @@ WRAP_UNARY_ROCM_HALF_MATH_FUNCTION(sin)
204205
// accurate than a native implementation.
205206
__device__ __forceinline__ __half gpu_lib::tan(__half const& x)
206207
{
207-
return ::__hdiv(::hsin(x), ::hcos(x));
208+
return __hdiv(hsin(x), hcos(x));
208209
}
209210

210211
WRAP_UNARY_ROCM_HALF_CAST_TO_FLOAT_MATH_FUNCTION(acos)
@@ -242,12 +243,12 @@ __device__ __forceinline__ bool gpu_lib::isfinite(__half const& x)
242243
// Binary math functions
243244
__device__ __forceinline__ __half gpu_lib::min(const __half& x, const __half& y)
244245
{
245-
return ::__hle(x, y) ? x : y;
246+
return __hle(x, y) ? x : y;
246247
}
247248

248249
__device__ __forceinline__ __half gpu_lib::max(const __half& x, const __half& y)
249250
{
250-
return ::__hle(x, y) ? y : x;
251+
return __hle(x, y) ? y : x;
251252
}
252253

253254
// Numeric limits

0 commit comments

Comments
 (0)