@@ -71,7 +71,7 @@ using __bfloat162 = __nv_bfloat162;
7171
7272// / Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15.
7373// / Format (MSB first): [sign:1][exponent:4][mantissa:3]
74- // / No infinities, no NaN. Encode saturates to ±1.75 (0x7e/0xfe ).
74+ // / No infinities, no NaN. Encode saturates to ±1.875 (0x7f/0xff ).
7575// / Adapted from the Triton compiler's fp8e4b15 format.
7676struct alignas (1 ) __fp8_e4m3b15 {
7777 uint8_t __x;
@@ -103,7 +103,7 @@ struct alignas(1) __fp8_e4m3b15 {
103103 // / then convert fp16 → float32.
104104 static MSCCLPP_HOST_DEVICE_INLINE float toFloat (uint8_t bits) {
105105 // Branch-free decode: fp8 → fp16 → fp32, no special-case handling.
106- // Encode saturates to ±1.75 , so 0x7f/0xff are never produced .
106+ // Every byte maps to a finite value; encode saturates at ±1.875 , so 0x7f/0xff decode to ±1.875 .
107107 // Refer:
108108 // https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34
109109 uint16_t h = (uint16_t )bits << 8 ; // place fp8 in upper byte of fp16
@@ -132,10 +132,9 @@ struct alignas(1) __fp8_e4m3b15 {
132132 } cvt = {h_val};
133133 uint16_t fp16_bits = cvt.u ;
134134
135- // Clamp abs to max encodable value: 1.75 → fp16 = 0x3F00.
136- // Matches Triton: encode saturates, 0x7f/0xff are never produced.
135+ // Clamp abs to max encodable value: 1.875 → fp16 = 0x3F80 (largest byte 0x7f/0xff).
137136 uint16_t abs_fp16 = fp16_bits & 0x7FFFu ;
138- if (abs_fp16 > 0x3F00u ) abs_fp16 = 0x3F00u ;
137+ if (abs_fp16 > 0x3F80u ) abs_fp16 = 0x3F80u ;
139138
140139 // Reconstruct with sign.
141140 uint16_t sign16 = fp16_bits & 0x8000u ;
@@ -1083,11 +1082,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
10831082#if defined(MSCCLPP_DEVICE_CUDA)
10841083 uint32_t in0;
10851084 asm (" mov.b32 %0, %1;" : " =r" (in0) : " r" (*reinterpret_cast <const uint32_t *>(&v)));
1086- // Clamp abs to max encodable e4m3b15 (0x3F00 = 1.75 in fp16).
1085+ // Clamp abs to max encodable e4m3b15 (0x3F80 = 1.875 in fp16).
10871086 uint32_t lo = in0 & 0xFFFFu , hi = in0 >> 16 ;
10881087 uint32_t alo = lo & 0x7FFFu , ahi = hi & 0x7FFFu ;
1089- alo = alo < 0x3F00u ? alo : 0x3F00u ;
1090- ahi = ahi < 0x3F00u ? ahi : 0x3F00u ;
1088+ alo = alo < 0x3F80u ? alo : 0x3F80u ;
1089+ ahi = ahi < 0x3F80u ? ahi : 0x3F80u ;
10911090 uint32_t a0 = alo | (ahi << 16 );
10921091 a0 = a0 * 2u + 0x00800080u ;
10931092 uint32_t b0 = a0 | (in0 & 0x80008000u );
@@ -1098,7 +1097,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
10981097 uint32_t in0 = v.words [0 ];
10991098 uint32_t abs0 = in0 & 0x7fff7fffu ;
11001099 uint32_t a0;
1101- asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a0) : " v" (abs0), " v" (0x3F003F00u ));
1100+ asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a0) : " v" (abs0), " v" (0x3F803F80u ));
11021101 a0 = a0 * 2u + 0x00800080u ;
11031102 uint32_t b0 = a0 | (in0 & 0x80008000u );
11041103 uint16_t packed = (uint16_t )(((b0 >> 8 ) & 0xFFu ) | ((b0 >> 16 ) & 0xFF00u ));
@@ -1121,8 +1120,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
11211120 asm (" mov.b32 %0, %1;" : " =r" (in1) : " r" (v.words [1 ]));
11221121 uint32_t abs0 = in0 & 0x7fff7fffu ;
11231122 uint32_t abs1 = in1 & 0x7fff7fffu ;
1124- uint32_t a0 = __vminu2 (abs0, 0x3F003F00u );
1125- uint32_t a1 = __vminu2 (abs1, 0x3F003F00u );
1123+ uint32_t a0 = __vminu2 (abs0, 0x3F803F80u );
1124+ uint32_t a1 = __vminu2 (abs1, 0x3F803F80u );
11261125 a0 = a0 * 2u + 0x00800080u ;
11271126 a1 = a1 * 2u + 0x00800080u ;
11281127 uint32_t b0, b1;
@@ -1135,8 +1134,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
11351134 uint32_t in0 = v.words [0 ], in1 = v.words [1 ];
11361135 uint32_t abs0 = in0 & 0x7fff7fffu , abs1 = in1 & 0x7fff7fffu ;
11371136 uint32_t a0, a1;
1138- asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a0) : " v" (abs0), " v" (0x3F003F00u ));
1139- asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a1) : " v" (abs1), " v" (0x3F003F00u ));
1137+ asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a0) : " v" (abs0), " v" (0x3F803F80u ));
1138+ asm volatile (" v_pk_min_u16 %0, %1, %2" : " =v" (a1) : " v" (abs1), " v" (0x3F803F80u ));
11401139 a0 = a0 * 2u + 0x00800080u ;
11411140 a1 = a1 * 2u + 0x00800080u ;
11421141 uint32_t b0 = a0 | (in0 & 0x80008000u );
0 commit comments