Skip to content

Commit fac9467

Browse files
committed
WIP
1 parent f830639 commit fac9467

1 file changed

Lines changed: 10 additions & 30 deletions

File tree

include/mscclpp/gpu_data_types.hpp

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -852,27 +852,17 @@ MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e5m2x4>(const f8_e5m2x4& v) {
852852

853853
/// f32x2 -> f8_e4m3x2.
854854
/// HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32).
855-
/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2).
856-
/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
855+
/// NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2). On SM89+ this maps to a
856+
/// single hardware round-to-nearest-even instruction; on older arch it falls back to a
857+
/// software direct conversion.
857858
template <>
858859
MSCCLPP_DEVICE_INLINE f8_e4m3x2 to<f8_e4m3x2, f32x2>(const f32x2& v) {
859860
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
860861
uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false);
861862
return bit_cast<f8_e4m3x2>(static_cast<__hip_fp8x2_storage_t>(packed));
862-
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
863-
__half2_raw h2;
864-
h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
865-
h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
866-
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3);
867-
return bit_cast<f8_e4m3x2>(fp8x2);
868863
#elif defined(MSCCLPP_DEVICE_CUDA)
869-
__half_raw h0, h1;
870-
h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
871-
h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
872-
f8_e4m3x2 result;
873-
result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3));
874-
result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3));
875-
return result;
864+
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2(make_float2(v.data[0], v.data[1]), __NV_SATFINITE, __NV_E4M3);
865+
return bit_cast<f8_e4m3x2>(fp8x2);
876866
#else
877867
f8_e4m3x2 result;
878868
result.data[0] = static_cast<__fp8_e4m3>(v.data[0]);
@@ -909,27 +899,17 @@ MSCCLPP_DEVICE_INLINE f8_e4m3x4 to<f8_e4m3x4, f32x4>(const f32x4& v) {
909899

910900
/// f32x2 -> f8_e5m2x2.
911901
/// HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32).
912-
/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2).
913-
/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
902+
/// NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2 with __NV_E5M2). On SM89+ this
903+
/// maps to a single hardware round-to-nearest-even instruction; on older arch it falls back to a
904+
/// software direct conversion.
914905
template <>
915906
MSCCLPP_DEVICE_INLINE f8_e5m2x2 to<f8_e5m2x2, f32x2>(const f32x2& v) {
916907
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
917908
uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false);
918909
return bit_cast<f8_e5m2x2>(static_cast<__hip_fp8x2_storage_t>(packed));
919-
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
920-
__half2_raw h2;
921-
h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
922-
h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
923-
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E5M2);
924-
return bit_cast<f8_e5m2x2>(fp8x2);
925910
#elif defined(MSCCLPP_DEVICE_CUDA)
926-
__half_raw h0, h1;
927-
h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
928-
h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
929-
f8_e5m2x2 result;
930-
result.data[0] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E5M2));
931-
result.data[1] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E5M2));
932-
return result;
911+
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2(make_float2(v.data[0], v.data[1]), __NV_SATFINITE, __NV_E5M2);
912+
return bit_cast<f8_e5m2x2>(fp8x2);
933913
#else
934914
f8_e5m2x2 result;
935915
result.data[0] = static_cast<__fp8_e5m2>(v.data[0]);

0 commit comments

Comments
 (0)