@@ -71,7 +71,7 @@ using __bfloat162 = __nv_bfloat162;
7171
7272// / Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15.
7373// / Format (MSB first): [sign:1][exponent:4][mantissa:3]
74- // / No infinities, no NaN. Encode saturates to ±1.75 (0x7e/0xfe ).
74+ // / No infinities, no NaN. Encode saturates to ±1.875 (0x7f/0xff ).
7575// / Adapted from the Triton compiler's fp8e4b15 format.
7676struct alignas (1 ) __fp8_e4m3b15 {
7777 uint8_t __x;
@@ -103,7 +103,7 @@ struct alignas(1) __fp8_e4m3b15 {
103103 // / then convert fp16 → float32.
104104 static MSCCLPP_HOST_DEVICE_INLINE float toFloat (uint8_t bits) {
105105 // Branch-free decode: fp8 → fp16 → fp32, no special-case handling.
106- // Encode saturates to ±1.75 , so 0x7f/0xff are never produced .
106+ // Every byte maps to a finite value; encode saturates at ±1.875 , so 0x7f/0xff decode to ±1.875 .
107107 // Refer:
108108 // https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34
109109 uint16_t h = (uint16_t )bits << 8 ; // place fp8 in upper byte of fp16
@@ -132,10 +132,9 @@ struct alignas(1) __fp8_e4m3b15 {
132132 } cvt = {h_val};
133133 uint16_t fp16_bits = cvt.u ;
134134
135- // Clamp abs to max encodable value: 1.75 → fp16 = 0x3F00.
136- // Matches Triton: encode saturates, 0x7f/0xff are never produced.
135+ // Clamp abs to max encodable value: 1.875 → fp16 = 0x3F80 (largest byte 0x7f/0xff).
137136 uint16_t abs_fp16 = fp16_bits & 0x7FFFu ;
138- if (abs_fp16 > 0x3F00u ) abs_fp16 = 0x3F00u ;
137+ if (abs_fp16 > 0x3F80u ) abs_fp16 = 0x3F80u ;
139138
140139 // Reconstruct with sign.
141140 uint16_t sign16 = fp16_bits & 0x8000u ;
@@ -852,27 +851,17 @@ MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e5m2x4>(const f8_e5m2x4& v) {
852851
853852// / f32x2 -> f8_e4m3x2.
854853// / HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32).
855- // / NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2).
856- // / NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
854+ // / NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2). On SM89+ this maps to a
855+ // / single hardware round-to-nearest-even instruction; on older arch it falls back to a
856+ // / software direct conversion.
857857template <>
858858MSCCLPP_DEVICE_INLINE f8_e4m3x2 to<f8_e4m3x2, f32x2>(const f32x2& v) {
859859#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
860860 uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32 (v.data [0 ], v.data [1 ], 0 , false );
861861 return bit_cast<f8_e4m3x2>(static_cast <__hip_fp8x2_storage_t >(packed));
862- #elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
863- __half2_raw h2;
864- h2.x = bit_cast<unsigned short >(__float2half_rn (v.data [0 ]));
865- h2.y = bit_cast<unsigned short >(__float2half_rn (v.data [1 ]));
866- __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2 (h2, __NV_SATFINITE, __NV_E4M3);
867- return bit_cast<f8_e4m3x2>(fp8x2);
868862#elif defined(MSCCLPP_DEVICE_CUDA)
869- __half_raw h0, h1;
870- h0.x = bit_cast<unsigned short >(__float2half_rn (v.data [0 ]));
871- h1.x = bit_cast<unsigned short >(__float2half_rn (v.data [1 ]));
872- f8_e4m3x2 result;
873- result.data [0 ] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8 (h0, __NV_SATFINITE, __NV_E4M3));
874- result.data [1 ] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8 (h1, __NV_SATFINITE, __NV_E4M3));
875- return result;
863+ __nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2 (make_float2 (v.data [0 ], v.data [1 ]), __NV_SATFINITE, __NV_E4M3);
864+ return bit_cast<f8_e4m3x2>(fp8x2);
876865#else
877866 f8_e4m3x2 result;
878867 result.data [0 ] = static_cast <__fp8_e4m3>(v.data [0 ]);
@@ -909,27 +898,17 @@ MSCCLPP_DEVICE_INLINE f8_e4m3x4 to<f8_e4m3x4, f32x4>(const f32x4& v) {
909898
910899// / f32x2 -> f8_e5m2x2.
911900// / HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32).
912- // / NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2).
913- // / NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
901+ // / NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2 with __NV_E5M2). On SM89+ this
902+ // / maps to a single hardware round-to-nearest-even instruction; on older arch it falls back to a
903+ // / software direct conversion.
914904template <>
915905MSCCLPP_DEVICE_INLINE f8_e5m2x2 to<f8_e5m2x2, f32x2>(const f32x2& v) {
916906#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
917907 uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32 (v.data [0 ], v.data [1 ], 0 , false );
918908 return bit_cast<f8_e5m2x2>(static_cast <__hip_fp8x2_storage_t >(packed));
919- #elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
920- __half2_raw h2;
921- h2.x = bit_cast<unsigned short >(__float2half_rn (v.data [0 ]));
922- h2.y = bit_cast<unsigned short >(__float2half_rn (v.data [1 ]));
923- __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2 (h2, __NV_SATFINITE, __NV_E5M2);
924- return bit_cast<f8_e5m2x2>(fp8x2);
925909#elif defined(MSCCLPP_DEVICE_CUDA)
926- __half_raw h0, h1;
927- h0.x = bit_cast<unsigned short >(__float2half_rn (v.data [0 ]));
928- h1.x = bit_cast<unsigned short >(__float2half_rn (v.data [1 ]));
929- f8_e5m2x2 result;
930- result.data [0 ] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8 (h0, __NV_SATFINITE, __NV_E5M2));
931- result.data [1 ] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8 (h1, __NV_SATFINITE, __NV_E5M2));
932- return result;
910+ __nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2 (make_float2 (v.data [0 ], v.data [1 ]), __NV_SATFINITE, __NV_E5M2);
911+ return bit_cast<f8_e5m2x2>(fp8x2);
933912#else
934913 f8_e5m2x2 result;
935914 result.data [0 ] = static_cast <__fp8_e5m2>(v.data [0 ]);
@@ -1103,11 +1082,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
11031082#if defined(MSCCLPP_DEVICE_CUDA)
11041083 uint32_t in0;
11051084 asm (" mov.b32 %0, %1;" : " =r" (in0) : " r" (*reinterpret_cast <const uint32_t *>(&v)));
1106- // Clamp abs to max encodable e4m3b15 (0x3F00 = 1.75 in fp16).
1085+ // Clamp abs to max encodable e4m3b15 (0x3F80 = 1.875 in fp16).
11071086 uint32_t lo = in0 & 0xFFFFu , hi = in0 >> 16 ;
11081087 uint32_t alo = lo & 0x7FFFu , ahi = hi & 0x7FFFu ;
1109- alo = alo < 0x3F00u ? alo : 0x3F00u ;
1110- ahi = ahi < 0x3F00u ? ahi : 0x3F00u ;
1088+ alo = alo < 0x3F80u ? alo : 0x3F80u ;
1089+ ahi = ahi < 0x3F80u ? ahi : 0x3F80u ;
11111090 uint32_t a0 = alo | (ahi << 16 );
11121091 a0 = a0 * 2u + 0x00800080u ;
11131092 uint32_t b0 = a0 | (in0 & 0x80008000u );
@@ -1118,7 +1097,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
11181097 uint32_t in0 = v.words [0 ];
11191098 uint32_t abs0 = in0 & 0x7fff7fffu ;
11201099 uint32_t a0;
1121- asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a0) : " v" (abs0), " v" (0x3F003F00u ));
1100+ asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a0) : " v" (abs0), " v" (0x3F803F80u ));
11221101 a0 = a0 * 2u + 0x00800080u ;
11231102 uint32_t b0 = a0 | (in0 & 0x80008000u );
11241103 uint16_t packed = (uint16_t )(((b0 >> 8 ) & 0xFFu ) | ((b0 >> 16 ) & 0xFF00u ));
@@ -1141,8 +1120,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
11411120 asm (" mov.b32 %0, %1;" : " =r" (in1) : " r" (v.words [1 ]));
11421121 uint32_t abs0 = in0 & 0x7fff7fffu ;
11431122 uint32_t abs1 = in1 & 0x7fff7fffu ;
1144- uint32_t a0 = __vminu2 (abs0, 0x3F003F00u );
1145- uint32_t a1 = __vminu2 (abs1, 0x3F003F00u );
1123+ uint32_t a0 = __vminu2 (abs0, 0x3F803F80u );
1124+ uint32_t a1 = __vminu2 (abs1, 0x3F803F80u );
11461125 a0 = a0 * 2u + 0x00800080u ;
11471126 a1 = a1 * 2u + 0x00800080u ;
11481127 uint32_t b0, b1;
@@ -1155,8 +1134,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
11551134 uint32_t in0 = v.words [0 ], in1 = v.words [1 ];
11561135 uint32_t abs0 = in0 & 0x7fff7fffu , abs1 = in1 & 0x7fff7fffu ;
11571136 uint32_t a0, a1;
1158- asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a0) : " v" (abs0), " v" (0x3F003F00u ));
1159- asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a1) : " v" (abs1), " v" (0x3F003F00u ));
1137+ asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a0) : " v" (abs0), " v" (0x3F803F80u ));
1138+ asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a1) : " v" (abs1), " v" (0x3F803F80u ));
11601139 a0 = a0 * 2u + 0x00800080u ;
11611140 a1 = a1 * 2u + 0x00800080u ;
11621141 uint32_t b0 = a0 | (in0 & 0x80008000u );
@@ -1268,8 +1247,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f32x4>(const f32x4& v) {
12681247 return to<f8_e4m3b15x4, f16x4>(h);
12691248#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
12701249 f16x4 h;
1271- h.words [0 ] = __builtin_bit_cast (uint32_t , __builtin_amdgcn_cvt_pkrtz (v.data [0 ], v.data [1 ]));
1272- h.words [1 ] = __builtin_bit_cast (uint32_t , __builtin_amdgcn_cvt_pkrtz (v.data [2 ], v.data [3 ]));
1250+ h.words [0 ] = __builtin_bit_cast (uint32_t , __floats2half2_rn (v.data [0 ], v.data [1 ]));
1251+ h.words [1 ] = __builtin_bit_cast (uint32_t , __floats2half2_rn (v.data [2 ], v.data [3 ]));
12731252 return to<f8_e4m3b15x4, f16x4>(h);
12741253#else
12751254 f8_e4m3b15x4 result;
0 commit comments