Skip to content

Commit cbfe3c3

Browse files
committed
WIP
1 parent 04c995f commit cbfe3c3

4 files changed

Lines changed: 25 additions & 24 deletions

File tree

include/mscclpp/gpu_data_types.hpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ using __bfloat162 = __nv_bfloat162;
7171

7272
/// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15.
7373
/// Format (MSB first): [sign:1][exponent:4][mantissa:3]
74-
/// No infinities, no NaN. Encode saturates to ±1.75 (0x7e/0xfe).
74+
/// No infinities, no NaN. Encode saturates to ±1.875 (0x7f/0xff).
7575
/// Adapted from the Triton compiler's fp8e4b15 format.
7676
struct alignas(1) __fp8_e4m3b15 {
7777
uint8_t __x;
@@ -103,7 +103,7 @@ struct alignas(1) __fp8_e4m3b15 {
103103
/// then convert fp16 → float32.
104104
static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) {
105105
// Branch-free decode: fp8 → fp16 → fp32, no special-case handling.
106-
// Encode saturates to ±1.75, so 0x7f/0xff are never produced.
106+
// Every byte maps to a finite value; encode saturates at ±1.875, so 0x7f/0xff decode to ±1.875.
107107
// Refer:
108108
// https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34
109109
uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16
@@ -132,10 +132,9 @@ struct alignas(1) __fp8_e4m3b15 {
132132
} cvt = {h_val};
133133
uint16_t fp16_bits = cvt.u;
134134

135-
// Clamp abs to max encodable value: 1.75 → fp16 = 0x3F00.
136-
// Matches Triton: encode saturates, 0x7f/0xff are never produced.
135+
// Clamp abs to max encodable value: 1.875 → fp16 = 0x3F80 (largest byte 0x7f/0xff).
137136
uint16_t abs_fp16 = fp16_bits & 0x7FFFu;
138-
if (abs_fp16 > 0x3F00u) abs_fp16 = 0x3F00u;
137+
if (abs_fp16 > 0x3F80u) abs_fp16 = 0x3F80u;
139138

140139
// Reconstruct with sign.
141140
uint16_t sign16 = fp16_bits & 0x8000u;
@@ -1083,11 +1082,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
10831082
#if defined(MSCCLPP_DEVICE_CUDA)
10841083
uint32_t in0;
10851084
asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast<const uint32_t*>(&v)));
1086-
// Clamp abs to max encodable e4m3b15 (0x3F00 = 1.75 in fp16).
1085+
// Clamp abs to max encodable e4m3b15 (0x3F80 = 1.875 in fp16).
10871086
uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16;
10881087
uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu;
1089-
alo = alo < 0x3F00u ? alo : 0x3F00u;
1090-
ahi = ahi < 0x3F00u ? ahi : 0x3F00u;
1088+
alo = alo < 0x3F80u ? alo : 0x3F80u;
1089+
ahi = ahi < 0x3F80u ? ahi : 0x3F80u;
10911090
uint32_t a0 = alo | (ahi << 16);
10921091
a0 = a0 * 2u + 0x00800080u;
10931092
uint32_t b0 = a0 | (in0 & 0x80008000u);
@@ -1098,7 +1097,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
10981097
uint32_t in0 = v.words[0];
10991098
uint32_t abs0 = in0 & 0x7fff7fffu;
11001099
uint32_t a0;
1101-
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u));
1100+
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F803F80u));
11021101
a0 = a0 * 2u + 0x00800080u;
11031102
uint32_t b0 = a0 | (in0 & 0x80008000u);
11041103
uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u));
@@ -1121,8 +1120,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
11211120
asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1]));
11221121
uint32_t abs0 = in0 & 0x7fff7fffu;
11231122
uint32_t abs1 = in1 & 0x7fff7fffu;
1124-
uint32_t a0 = __vminu2(abs0, 0x3F003F00u);
1125-
uint32_t a1 = __vminu2(abs1, 0x3F003F00u);
1123+
uint32_t a0 = __vminu2(abs0, 0x3F803F80u);
1124+
uint32_t a1 = __vminu2(abs1, 0x3F803F80u);
11261125
a0 = a0 * 2u + 0x00800080u;
11271126
a1 = a1 * 2u + 0x00800080u;
11281127
uint32_t b0, b1;
@@ -1135,8 +1134,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
11351134
uint32_t in0 = v.words[0], in1 = v.words[1];
11361135
uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu;
11371136
uint32_t a0, a1;
1138-
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u));
1139-
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F003F00u));
1137+
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F803F80u));
1138+
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F803F80u));
11401139
a0 = a0 * 2u + 0x00800080u;
11411140
a1 = a1 * 2u + 0x00800080u;
11421141
uint32_t b0 = a0 | (in0 & 0x80008000u);

python/mscclpp_benchmark/correctness.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,11 @@ def _fp8_max_abs_value(fp8_format: str) -> float:
332332

333333

334334
def _encode_e4m3b15_values(values):
335+
# Mirrors the device e4m3b15 encode (gpu_data_types.hpp): clamp the fp16 intermediate
336+
# to 0x3F80 (+/-1.875) so the max encodable byte is 0x7F/0xFF.
335337
fp16_bits = values.astype(cp.float16).view(cp.uint16)
336338
abs_fp16 = fp16_bits & cp.uint16(0x7FFF)
337-
abs_fp16 = cp.minimum(abs_fp16, cp.uint16(0x3F00)).astype(cp.uint32)
339+
abs_fp16 = cp.minimum(abs_fp16, cp.uint16(0x3F80)).astype(cp.uint32)
338340
sign16 = (fp16_bits & cp.uint16(0x8000)).astype(cp.uint32)
339341
adjusted = abs_fp16 * cp.uint32(2) + cp.uint32(0x0080)
340342
return (((sign16 | adjusted) >> cp.uint32(8)) & cp.uint32(0xFF)).astype(cp.uint8)

python/test/test_fp8_accum.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def float_to_e4m3fnuz(f32_array, chunk_size=65536):
167167

168168

169169
# ---------------------------------------------------------------------------
170-
# FP8 E4M3B15 helpers (bias=15, encode saturates to ±1.75, no NaN)
170+
# FP8 E4M3B15 helpers (bias=15, float source saturates to ±1.875, no NaN)
171171
# Matches Triton's fp8e4b15: all 256 bit patterns are finite.
172172
# ---------------------------------------------------------------------------
173173

@@ -193,7 +193,7 @@ def float_to_e4m3b15(f32_array, chunk_size=65536):
193193
"""Encode a cupy float32 array to uint8 E4M3B15 bit patterns.
194194
195195
Same lookup-table approach as float_to_e4m3fn.
196-
Saturates to ±1.75 (0x7e/0xfe), matching Triton's fp8e4b15.
196+
Saturates to ±1.875 (0x7f/0xff), matching the device float32 → e4m3b15 path.
197197
"""
198198
# Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F)
199199
all_bytes = cp.arange(128, dtype=cp.uint8)
@@ -203,7 +203,7 @@ def float_to_e4m3b15(f32_array, chunk_size=65536):
203203
values = f32_array.astype(cp.float32)
204204
signs = cp.signbit(values).astype(cp.uint8)
205205
absval = cp.abs(values)
206-
absval = cp.clip(absval, cp.float32(0.0), cp.float32(1.75))
206+
absval = cp.clip(absval, cp.float32(0.0), cp.float32(1.875))
207207

208208
result = cp.zeros(absval.shape, dtype=cp.uint8)
209209
n = absval.size
@@ -442,8 +442,8 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
442442
bits_r = cp.asarray(rng_r.randint(0, 256, (size,)).astype(np.uint8))
443443
ref_f32 += e4m3b15_to_float(bits_r)
444444

445-
# Clamp reference to e4m3b15 representable range
446-
ref_f32 = cp.clip(ref_f32, -1.75, 1.75)
445+
# Clamp reference to e4m3b15 representable range (float source saturates at ±1.875)
446+
ref_f32 = cp.clip(ref_f32, -1.875, 1.875)
447447

448448
# Compute errors
449449
abs_err = cp.abs(result_f32 - ref_f32)

test/unit/gpu_data_types_tests.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ TEST(GpuDataTypesTest, E4m3b15TypeConvert) {
9797
const float maxFloat = std::numeric_limits<float>::max();
9898

9999
// Each input value maps to the byte at the same index in expectedEncoded. The fp8_e4m3b15 format has no
100-
// NaN/Inf encoding, so NaN, Inf, and overflow inputs saturate to +/-1.75.
100+
// NaN/Inf encoding, so NaN, Inf, and overflow inputs saturate to +/-1.875 (max byte 0x7f/0xff).
101101
const auto input = makeArray<float>(0.0f, -0.0f, // +/-0
102102
0x1.0p-19f, -0x1.0p-19f, // +/-2^-19: underflows to signed 0
103103
0x1.0p-18f, -0x1.0p-18f, // +/-2^-18: rounds to min subnormal
@@ -119,10 +119,10 @@ TEST(GpuDataTypesTest, E4m3b15TypeConvert) {
119119
0x68, 0xe8, // Boundary rounds to +/-0.25
120120
0x69, 0xe9, // Boundary rounds to +/-0.28125
121121
0x6f, 0xef, // Boundary rounds to +/-0.46875
122-
0x7e, 0xfe, // Max signed finite
123-
0x7e, 0xfe, // Overflow saturation
124-
0x7e, 0xfe, // Inf saturation
125-
0x7e, 0xfe); // NaN / large negative saturation
122+
0x7e, 0xfe, // Max finite at fp16 grid (1.75)
123+
0x7f, 0xff, // Overflow saturation (1.875)
124+
0x7f, 0xff, // Inf saturation (1.875)
125+
0x7f, 0xff); // NaN / large negative saturation (1.875)
126126

127127
// Raw bytes to decode, with expectedDecoded giving the exact float value at the same index.
128128
const auto raw = makeArray<uint8_t>(0x00, 0x80, // +/-0

0 commit comments

Comments
 (0)