Skip to content

Commit 57f4f5c

Browse files
committed
fix build issue
1 parent d828c44 commit 57f4f5c

2 files changed

Lines changed: 22 additions & 22 deletions

File tree

include/mscclpp/gpu_data_types.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,27 @@ using __bfloat16 = __hip_bfloat16;
1616
using __bfloat162 = __hip_bfloat162;
1717
#define __CUDA_BF16_TYPES_EXIST__
1818

19-
// AMD FP8 support - hip_fp8.h provides __hip_fp8_e4m3_fnuz and __hip_fp8_e5m2_fnuz
20-
// Only available on gfx942 and newer architectures (ROCm 6.0+)
19+
// AMD FP8 support - Use fnuz types for HIP 6.0 or when HIP_FP8_TYPE_FNUZ is enabled and HIP_FP8_TYPE_OCP is not enabled.
20+
// Otherwise, use the standard AMD FP8 types.
2121
#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >= 6)
2222
#include <hip/hip_fp8.h>
2323

2424
// Create aliases matching CUDA naming convention for cross-platform compatibility
25-
#if HIP_FP8_TYPE_FNUZ
25+
#if (HIP_VERSION_MAJOR == 6) || (HIP_VERSION_MAJOR > 6 && HIP_FP8_TYPE_FNUZ && !HIP_FP8_TYPE_OCP)
2626
using __fp8_e4m3 = __hip_fp8_e4m3_fnuz;
2727
using __fp8_e5m2 = __hip_fp8_e5m2_fnuz;
2828
using __fp8x2_e4m3 = __hip_fp8x2_e4m3_fnuz;
2929
using __fp8x2_e5m2 = __hip_fp8x2_e5m2_fnuz;
3030
using __fp8x4_e4m3 = __hip_fp8x4_e4m3_fnuz;
3131
using __fp8x4_e5m2 = __hip_fp8x4_e5m2_fnuz;
32-
#endif // HIP_FP8_TYPE_FNUZ
33-
34-
#if HIP_FP8_TYPE_OCP
32+
#else
3533
using __fp8_e4m3 = __hip_fp8_e4m3;
3634
using __fp8_e5m2 = __hip_fp8_e5m2;
3735
using __fp8x2_e4m3 = __hip_fp8x2_e4m3;
3836
using __fp8x2_e5m2 = __hip_fp8x2_e5m2;
3937
using __fp8x4_e4m3 = __hip_fp8x4_e4m3;
4038
using __fp8x4_e5m2 = __hip_fp8x4_e5m2;
41-
#endif // HIP_FP8_TYPE_OCP
39+
#endif
4240

4341
#define __FP8_TYPES_EXIST__
4442
#endif // HIP_VERSION_MAJOR >= 6

src/core/include/reduce_kernel.hpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ MSCCLPP_DEVICE_INLINE __fp8_e4m3 operator+(const __fp8_e4m3& a, const __fp8_e4m3
123123
asm volatile("v_pk_add_f32 %0, %1, %2"
124124
: "=v"(v)
125125
: "v"(__builtin_amdgcn_cvt_pk_f32_fp8(a.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(b.__x, 0)));
126-
return __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.x, ival, false);
126+
return static_cast<__hip_fp8_storage_t>(__builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.x, ival, false));
127127
#elif defined(MSCCLPP_DEVICE_CUDA)
128128
// NVIDIA CUDA FP8 addition (CUDA 11.8+)
129129
__fp8_e4m3 result = __fp8_e4m3(__hadd(__half(a), __half(b)));
@@ -142,8 +142,9 @@ MSCCLPP_DEVICE_INLINE __fp8x2_e4m3 operator+(const __fp8x2_e4m3& a, const __fp8x
142142
uint32_t ival = 0;
143143
asm volatile("v_pk_add_f32 %0, %1, %2"
144144
: "=v"(v)
145-
: "v"(__builtin_amdgcn_cvt_pk_f32_fp8(a, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(b, 0)));
146-
return __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, ival, false);
145+
: "v"(__builtin_amdgcn_cvt_pk_f32_fp8(a.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(b.__x, 0)));
146+
return bit_cast<__fp8x2_e4m3>(
147+
static_cast<__hip_fp8x2_storage_t>(__builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, ival, false)));
147148
#elif defined(MSCCLPP_DEVICE_CUDA)
148149
// CUDA: Convert to half2, add using optimized __hadd2, convert back
149150
return __fp8x2_e4m3(__hadd2(__half2(a), __half2(b)));
@@ -200,7 +201,7 @@ MSCCLPP_DEVICE_INLINE __fp8_e5m2 operator+(const __fp8_e5m2& a, const __fp8_e5m2
200201
asm volatile("v_pk_add_f32 %0, %1, %2"
201202
: "=v"(v)
202203
: "v"(__builtin_amdgcn_cvt_pk_f32_bf8(a.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(b.__x, 0)));
203-
return __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.x, ival, false);
204+
return static_cast<__hip_fp8_storage_t>(__builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.x, ival, false));
204205
#elif defined(MSCCLPP_DEVICE_CUDA)
205206
// NVIDIA CUDA FP8 addition
206207
__fp8_e5m2 result = __fp8_e5m2(__hadd(__half(a), __half(b)));
@@ -226,8 +227,9 @@ MSCCLPP_DEVICE_INLINE __fp8x2_e5m2 operator+(const __fp8x2_e5m2& a, const __fp8x
226227
uint32_t ival = 0;
227228
asm volatile("v_pk_add_f32 %0, %1, %2"
228229
: "=v"(v)
229-
: "v"(__builtin_amdgcn_cvt_pk_f32_bf8(a, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(b, 0)));
230-
return __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, ival, false);
230+
: "v"(__builtin_amdgcn_cvt_pk_f32_bf8(a.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(b.__x, 0)));
231+
return bit_cast<__fp8x2_e5m2>(
232+
static_cast<__hip_fp8x2_storage_t>(__builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, ival, false)));
231233
#else
232234
// Fallback: element-wise using single-element operations
233235
union {
@@ -313,8 +315,8 @@ MSCCLPP_DEVICE_INLINE __fp8x2_e4m3 min(const __fp8x2_e4m3& a, const __fp8x2_e4m3
313315
} ua{}, ub{}, result{};
314316
ua.fp8x2 = a;
315317
ub.fp8x2 = b;
316-
result.fp8[0] = min(ua.fp8[0], ub.fp8[0]);
317-
result.fp8[1] = min(ua.fp8[1], ub.fp8[1]);
318+
result.fp8[0] = mscclpp::min(ua.fp8[0], ub.fp8[0]);
319+
result.fp8[1] = mscclpp::min(ua.fp8[1], ub.fp8[1]);
318320
return result.fp8x2;
319321
}
320322

@@ -327,8 +329,8 @@ MSCCLPP_DEVICE_INLINE fp8_e4m3x4 min(const fp8_e4m3x4& a, const fp8_e4m3x4& b) {
327329
ua.vec4 = bit_cast<__fp8x4_e4m3>(a);
328330
ub.vec4 = bit_cast<__fp8x4_e4m3>(b);
329331

330-
uresult.vec2[0] = min(ua.vec2[0], ub.vec2[0]);
331-
uresult.vec2[1] = min(ua.vec2[1], ub.vec2[1]);
332+
uresult.vec2[0] = mscclpp::min(ua.vec2[0], ub.vec2[0]);
333+
uresult.vec2[1] = mscclpp::min(ua.vec2[1], ub.vec2[1]);
332334

333335
return bit_cast<fp8_e4m3x4>(uresult.vec4) ;
334336
}
@@ -350,8 +352,8 @@ MSCCLPP_DEVICE_INLINE __fp8x2_e5m2 min(const __fp8x2_e5m2& a, const __fp8x2_e5m2
350352
} ua{}, ub{}, result{};
351353
ua.fp8x2 = a;
352354
ub.fp8x2 = b;
353-
result.fp8[0] = min(ua.fp8[0], ub.fp8[0]);
354-
result.fp8[1] = min(ua.fp8[1], ub.fp8[1]);
355+
result.fp8[0] = mscclpp::min(ua.fp8[0], ub.fp8[0]);
356+
result.fp8[1] = mscclpp::min(ua.fp8[1], ub.fp8[1]);
355357
return result.fp8x2;
356358
}
357359

@@ -364,8 +366,8 @@ MSCCLPP_DEVICE_INLINE fp8_e5m2x4 min(const fp8_e5m2x4& a, const fp8_e5m2x4& b) {
364366
ua.vec4 = bit_cast<__fp8x4_e5m2>(a);
365367
ub.vec4 = bit_cast<__fp8x4_e5m2>(b);
366368

367-
uresult.vec2[0] = min(ua.vec2[0], ub.vec2[0]);
368-
uresult.vec2[1] = min(ua.vec2[1], ub.vec2[1]);
369+
uresult.vec2[0] = mscclpp::min(ua.vec2[0], ub.vec2[0]);
370+
uresult.vec2[1] = mscclpp::min(ua.vec2[1], ub.vec2[1]);
369371

370372
return bit_cast<fp8_e5m2x4>(uresult.vec4);
371373
}
@@ -377,7 +379,7 @@ MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) {
377379
if constexpr (OpType == SUM) {
378380
return a + b;
379381
} else if constexpr (OpType == MIN) {
380-
return min(a, b);
382+
return mscclpp::min(a, b);
381383
}
382384
static_assert(OpType == SUM || OpType == MIN, "Unsupported ReduceOp");
383385
}

0 commit comments

Comments
 (0)