@@ -123,7 +123,7 @@ MSCCLPP_DEVICE_INLINE __fp8_e4m3 operator+(const __fp8_e4m3& a, const __fp8_e4m3
123123 asm volatile (" v_pk_add_f32 %0, %1, %2"
124124 : " =v" (v)
125125 : " v" (__builtin_amdgcn_cvt_pk_f32_fp8 (a.__x , 0 )), " v" (__builtin_amdgcn_cvt_pk_f32_fp8 (b.__x , 0 )));
126- return __builtin_amdgcn_cvt_pk_fp8_f32 (v.x , v.x , ival, false );
126+ return static_cast < __hip_fp8_storage_t >( __builtin_amdgcn_cvt_pk_fp8_f32 (v.x , v.x , ival, false ) );
127127#elif defined(MSCCLPP_DEVICE_CUDA)
128128 // NVIDIA CUDA FP8 addition (CUDA 11.8+)
129129 __fp8_e4m3 result = __fp8_e4m3 (__hadd (__half (a), __half (b)));
@@ -142,8 +142,9 @@ MSCCLPP_DEVICE_INLINE __fp8x2_e4m3 operator+(const __fp8x2_e4m3& a, const __fp8x
142142 uint32_t ival = 0 ;
143143 asm volatile (" v_pk_add_f32 %0, %1, %2"
144144 : " =v" (v)
145- : " v" (__builtin_amdgcn_cvt_pk_f32_fp8 (a, 0 )), " v" (__builtin_amdgcn_cvt_pk_f32_fp8 (b, 0 )));
146- return __builtin_amdgcn_cvt_pk_fp8_f32 (v.x , v.y , ival, false );
145+ : " v" (__builtin_amdgcn_cvt_pk_f32_fp8 (a.__x , 0 )), " v" (__builtin_amdgcn_cvt_pk_f32_fp8 (b.__x , 0 )));
146+ return bit_cast<__fp8x2_e4m3>(
147+ static_cast <__hip_fp8x2_storage_t >(__builtin_amdgcn_cvt_pk_fp8_f32 (v.x , v.y , ival, false )));
147148#elif defined(MSCCLPP_DEVICE_CUDA)
148149 // CUDA: Convert to half2, add using optimized __hadd2, convert back
149150 return __fp8x2_e4m3 (__hadd2 (__half2 (a), __half2 (b)));
@@ -200,7 +201,7 @@ MSCCLPP_DEVICE_INLINE __fp8_e5m2 operator+(const __fp8_e5m2& a, const __fp8_e5m2
200201 asm volatile (" v_pk_add_f32 %0, %1, %2"
201202 : " =v" (v)
202203 : " v" (__builtin_amdgcn_cvt_pk_f32_bf8 (a.__x , 0 )), " v" (__builtin_amdgcn_cvt_pk_f32_bf8 (b.__x , 0 )));
203- return __builtin_amdgcn_cvt_pk_bf8_f32 (v.x , v.x , ival, false );
204+ return static_cast < __hip_fp8_storage_t >( __builtin_amdgcn_cvt_pk_bf8_f32 (v.x , v.x , ival, false ) );
204205#elif defined(MSCCLPP_DEVICE_CUDA)
205206 // NVIDIA CUDA FP8 addition
206207 __fp8_e5m2 result = __fp8_e5m2 (__hadd (__half (a), __half (b)));
@@ -226,8 +227,9 @@ MSCCLPP_DEVICE_INLINE __fp8x2_e5m2 operator+(const __fp8x2_e5m2& a, const __fp8x
226227 uint32_t ival = 0 ;
227228 asm volatile (" v_pk_add_f32 %0, %1, %2"
228229 : " =v" (v)
229- : " v" (__builtin_amdgcn_cvt_pk_f32_bf8 (a, 0 )), " v" (__builtin_amdgcn_cvt_pk_f32_bf8 (b, 0 )));
230- return __builtin_amdgcn_cvt_pk_bf8_f32 (v.x , v.y , ival, false );
230+ : " v" (__builtin_amdgcn_cvt_pk_f32_bf8 (a.__x , 0 )), " v" (__builtin_amdgcn_cvt_pk_f32_bf8 (b.__x , 0 )));
231+ return bit_cast<__fp8x2_e5m2>(
232+ static_cast <__hip_fp8x2_storage_t >(__builtin_amdgcn_cvt_pk_bf8_f32 (v.x , v.y , ival, false )));
231233#else
232234 // Fallback: element-wise using single-element operations
233235 union {
@@ -313,8 +315,8 @@ MSCCLPP_DEVICE_INLINE __fp8x2_e4m3 min(const __fp8x2_e4m3& a, const __fp8x2_e4m3
313315 } ua{}, ub{}, result{};
314316 ua.fp8x2 = a;
315317 ub.fp8x2 = b;
316- result.fp8 [0 ] = min (ua.fp8 [0 ], ub.fp8 [0 ]);
317- result.fp8 [1 ] = min (ua.fp8 [1 ], ub.fp8 [1 ]);
318+ result.fp8 [0 ] = mscclpp:: min (ua.fp8 [0 ], ub.fp8 [0 ]);
319+ result.fp8 [1 ] = mscclpp:: min (ua.fp8 [1 ], ub.fp8 [1 ]);
318320 return result.fp8x2 ;
319321}
320322
@@ -327,8 +329,8 @@ MSCCLPP_DEVICE_INLINE fp8_e4m3x4 min(const fp8_e4m3x4& a, const fp8_e4m3x4& b) {
327329 ua.vec4 = bit_cast<__fp8x4_e4m3>(a);
328330 ub.vec4 = bit_cast<__fp8x4_e4m3>(b);
329331
330- uresult.vec2 [0 ] = min (ua.vec2 [0 ], ub.vec2 [0 ]);
331- uresult.vec2 [1 ] = min (ua.vec2 [1 ], ub.vec2 [1 ]);
332+ uresult.vec2 [0 ] = mscclpp:: min (ua.vec2 [0 ], ub.vec2 [0 ]);
333+ uresult.vec2 [1 ] = mscclpp:: min (ua.vec2 [1 ], ub.vec2 [1 ]);
332334
333335 return bit_cast<fp8_e4m3x4>(uresult.vec4 ) ;
334336}
@@ -350,8 +352,8 @@ MSCCLPP_DEVICE_INLINE __fp8x2_e5m2 min(const __fp8x2_e5m2& a, const __fp8x2_e5m2
350352 } ua{}, ub{}, result{};
351353 ua.fp8x2 = a;
352354 ub.fp8x2 = b;
353- result.fp8 [0 ] = min (ua.fp8 [0 ], ub.fp8 [0 ]);
354- result.fp8 [1 ] = min (ua.fp8 [1 ], ub.fp8 [1 ]);
355+ result.fp8 [0 ] = mscclpp:: min (ua.fp8 [0 ], ub.fp8 [0 ]);
356+ result.fp8 [1 ] = mscclpp:: min (ua.fp8 [1 ], ub.fp8 [1 ]);
355357 return result.fp8x2 ;
356358}
357359
@@ -364,8 +366,8 @@ MSCCLPP_DEVICE_INLINE fp8_e5m2x4 min(const fp8_e5m2x4& a, const fp8_e5m2x4& b) {
364366 ua.vec4 = bit_cast<__fp8x4_e5m2>(a);
365367 ub.vec4 = bit_cast<__fp8x4_e5m2>(b);
366368
367- uresult.vec2 [0 ] = min (ua.vec2 [0 ], ub.vec2 [0 ]);
368- uresult.vec2 [1 ] = min (ua.vec2 [1 ], ub.vec2 [1 ]);
369+ uresult.vec2 [0 ] = mscclpp:: min (ua.vec2 [0 ], ub.vec2 [0 ]);
370+ uresult.vec2 [1 ] = mscclpp:: min (ua.vec2 [1 ], ub.vec2 [1 ]);
369371
370372 return bit_cast<fp8_e5m2x4>(uresult.vec4 );
371373}
@@ -377,7 +379,7 @@ MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) {
377379 if constexpr (OpType == SUM) {
378380 return a + b;
379381 } else if constexpr (OpType == MIN) {
380- return min (a, b);
382+ return mscclpp:: min (a, b);
381383 }
382384 static_assert (OpType == SUM || OpType == MIN, " Unsupported ReduceOp" );
383385}
0 commit comments