Skip to content

Commit 1a939c9

Browse files
solidify floating point encoding
1 parent 7858929 commit 1a939c9

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

test/mac/test_mac.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def bf16_to_float(bf16: int) -> float:
1616

1717
def fp8_e4m3_encode(x: float) -> int:
1818
if math.isnan(x):
19-
return 0x7F # closest representation
19+
return 0x7F
2020
if math.isinf(x):
2121
return 0x7F if x > 0 else 0xFF
2222

@@ -29,18 +29,27 @@ def fp8_e4m3_encode(x: float) -> int:
2929
return sign << 7
3030

3131
exp = math.floor(math.log2(x))
32-
mant = x / (2 ** exp) - 1.0
33-
34-
# FP8 exponent bias = 7
32+
33+
# FP8 E4M3 bias = 7, min normal exp = -6
3534
exp_fp8 = exp + 7
3635

37-
# Handle underflow/overflow
36+
# Denormal handling
3837
if exp_fp8 <= 0:
39-
return sign << 7
38+
# Denormal: exp_fp8 = 0, effective exponent = -6
39+
# Value = 2^(-6) * (0.mantissa)
40+
# So: x = 2^(-6) * (mantissa_bits / 8)
41+
# => mantissa_bits = x * 2^6 * 8 = x * 512
42+
mant_fp8 = int(round(x * 512))
43+
if mant_fp8 == 0 or mant_fp8 >= 8:
44+
return sign << 7 # underflow to zero
45+
return (sign << 7) | mant_fp8
46+
47+
# Normal numbers
4048
if exp_fp8 >= 0xF:
41-
return (sign << 7) | 0x7F # max finite
49+
return (sign << 7) | 0x7F # overflow
4250

43-
mant_fp8 = int(round(mant * 8)) # 3 bits mantissa (2^3=8)
51+
mant = x / (2 ** exp) - 1.0
52+
mant_fp8 = int(round(mant * 8))
4453

4554
if mant_fp8 == 8: # rounding overflow
4655
mant_fp8 = 0
@@ -72,7 +81,7 @@ async def test_pe_deviation(dut):
7281
await RisingEdge(dut.clk)
7382

7483
NUM_TESTS = 100
75-
random.seed(2025)
84+
random.seed(42)
7685

7786
for i in range(NUM_TESTS):
7887
fa = random.uniform(-10.0, 10.0)

test/tpu/test.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def bf16_to_float(bf16: int) -> float:
1212

1313
def fp8_e4m3_encode(x: float) -> int:
1414
if math.isnan(x):
15-
return 0x7F # closest representation
15+
return 0x7F
1616
if math.isinf(x):
1717
return 0x7F if x > 0 else 0xFF
1818

@@ -25,18 +25,27 @@ def fp8_e4m3_encode(x: float) -> int:
2525
return sign << 7
2626

2727
exp = math.floor(math.log2(x))
28-
mant = x / (2 ** exp) - 1.0
29-
30-
# FP8 exponent bias = 7
28+
29+
# FP8 E4M3 bias = 7, min normal exp = -6
3130
exp_fp8 = exp + 7
3231

33-
# Handle underflow/overflow
32+
# Denormal handling
3433
if exp_fp8 <= 0:
35-
return sign << 7
34+
# Denormal: exp_fp8 = 0, effective exponent = -6
35+
# Value = 2^(-6) * (0.mantissa)
36+
# So: x = 2^(-6) * (mantissa_bits / 8)
37+
# => mantissa_bits = x * 2^6 * 8 = x * 512
38+
mant_fp8 = int(round(x * 512))
39+
if mant_fp8 == 0 or mant_fp8 >= 8:
40+
return sign << 7 # underflow to zero
41+
return (sign << 7) | mant_fp8
42+
43+
# Normal numbers
3644
if exp_fp8 >= 0xF:
37-
return (sign << 7) | 0x7F # max finite
45+
return (sign << 7) | 0x7F # overflow
3846

39-
mant_fp8 = int(round(mant * 8)) # 3 bits mantissa (2^3=8)
47+
mant = x / (2 ** exp) - 1.0
48+
mant_fp8 = int(round(mant * 8))
4049

4150
if mant_fp8 == 8: # rounding overflow
4251
mant_fp8 = 0

0 commit comments

Comments
 (0)