@@ -852,27 +852,17 @@ MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e5m2x4>(const f8_e5m2x4& v) {
852852
853853// / f32x2 -> f8_e4m3x2.
854854// / HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32).
855- // / NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2).
856- // / NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
855+ // / NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2). On SM89+ this maps to a
856+ // / single hardware round-to-nearest-even instruction; on older arch it falls back to a
857+ // / software direct conversion.
857858template <>
858859MSCCLPP_DEVICE_INLINE f8_e4m3x2 to<f8_e4m3x2, f32x2>(const f32x2& v) {
859860#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
860861 uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32 (v.data [0 ], v.data [1 ], 0 , false );
861862 return bit_cast<f8_e4m3x2>(static_cast <__hip_fp8x2_storage_t >(packed));
862- #elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
863- __half2_raw h2;
864- h2.x = bit_cast<unsigned short >(__float2half_rn (v.data [0 ]));
865- h2.y = bit_cast<unsigned short >(__float2half_rn (v.data [1 ]));
866- __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2 (h2, __NV_SATFINITE, __NV_E4M3);
867- return bit_cast<f8_e4m3x2>(fp8x2);
868863#elif defined(MSCCLPP_DEVICE_CUDA)
869- __half_raw h0, h1;
870- h0.x = bit_cast<unsigned short >(__float2half_rn (v.data [0 ]));
871- h1.x = bit_cast<unsigned short >(__float2half_rn (v.data [1 ]));
872- f8_e4m3x2 result;
873- result.data [0 ] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8 (h0, __NV_SATFINITE, __NV_E4M3));
874- result.data [1 ] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8 (h1, __NV_SATFINITE, __NV_E4M3));
875- return result;
864+ __nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2 (make_float2 (v.data [0 ], v.data [1 ]), __NV_SATFINITE, __NV_E4M3);
865+ return bit_cast<f8_e4m3x2>(fp8x2);
876866#else
877867 f8_e4m3x2 result;
878868 result.data [0 ] = static_cast <__fp8_e4m3>(v.data [0 ]);
@@ -909,27 +899,17 @@ MSCCLPP_DEVICE_INLINE f8_e4m3x4 to<f8_e4m3x4, f32x4>(const f32x4& v) {
909899
910900// / f32x2 -> f8_e5m2x2.
911901// / HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32).
912- // / NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2).
913- // / NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
902+ // / NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2 with __NV_E5M2). On SM89+ this
903+ // / maps to a single hardware round-to-nearest-even instruction; on older arch it falls back to a
904+ // / software direct conversion.
914905template <>
915906MSCCLPP_DEVICE_INLINE f8_e5m2x2 to<f8_e5m2x2, f32x2>(const f32x2& v) {
916907#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
917908 uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32 (v.data [0 ], v.data [1 ], 0 , false );
918909 return bit_cast<f8_e5m2x2>(static_cast <__hip_fp8x2_storage_t >(packed));
919- #elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
920- __half2_raw h2;
921- h2.x = bit_cast<unsigned short >(__float2half_rn (v.data [0 ]));
922- h2.y = bit_cast<unsigned short >(__float2half_rn (v.data [1 ]));
923- __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2 (h2, __NV_SATFINITE, __NV_E5M2);
924- return bit_cast<f8_e5m2x2>(fp8x2);
925910#elif defined(MSCCLPP_DEVICE_CUDA)
926- __half_raw h0, h1;
927- h0.x = bit_cast<unsigned short >(__float2half_rn (v.data [0 ]));
928- h1.x = bit_cast<unsigned short >(__float2half_rn (v.data [1 ]));
929- f8_e5m2x2 result;
930- result.data [0 ] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8 (h0, __NV_SATFINITE, __NV_E5M2));
931- result.data [1 ] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8 (h1, __NV_SATFINITE, __NV_E5M2));
932- return result;
911+ __nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2 (make_float2 (v.data [0 ], v.data [1 ]), __NV_SATFINITE, __NV_E5M2);
912+ return bit_cast<f8_e5m2x2>(fp8x2);
933913#else
934914 f8_e5m2x2 result;
935915 result.data [0 ] = static_cast <__fp8_e5m2>(v.data [0 ]);
0 commit comments