Skip to content

Commit 4deec0a

Browse files
committed
prepare for fp16
1 parent 6c1c667 commit 4deec0a

2 files changed

Lines changed: 365 additions & 0 deletions

File tree

bench/xjb/test/f16_to_decimal.cpp

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#include <iostream>
2+
#include <fstream>
3+
#include <cstdint>
4+
#include <cassert>
5+
#include <algorithm>
6+
#include <vector>
7+
#include <string>
8+
9+
// -------------------- 128 位分数类 --------------------
10+
struct Fraction {
11+
__int128 num; // 分子
12+
__int128 den; // 分母 (恒正)
13+
14+
Fraction(__int128 n = 0, __int128 d = 1) : num(n), den(d) {
15+
if (den < 0) { num = -num; den = -den; }
16+
reduce();
17+
}
18+
19+
// 约分
20+
void reduce() {
21+
if (den == 0) throw std::runtime_error("Denominator zero");
22+
__int128 g = gcd(num < 0 ? -num : num, den);
23+
num /= g;
24+
den /= g;
25+
}
26+
27+
static __int128 gcd(__int128 a, __int128 b) {
28+
while (b != 0) { __int128 t = b; b = a % b; a = t; }
29+
return a < 0 ? -a : a;
30+
}
31+
32+
// 比较运算符
33+
bool operator<(const Fraction& rhs) const {
34+
return num * rhs.den < rhs.num * den;
35+
}
36+
bool operator<=(const Fraction& rhs) const {
37+
return num * rhs.den <= rhs.num * den;
38+
}
39+
bool operator==(const Fraction& rhs) const {
40+
return num == rhs.num && den == rhs.den;
41+
}
42+
bool operator!=(const Fraction& rhs) const { return !(*this == rhs); }
43+
bool operator>(const Fraction& rhs) const { return rhs < *this; }
44+
bool operator>=(const Fraction& rhs) const { return rhs <= *this; }
45+
46+
// 算术运算
47+
Fraction operator+(const Fraction& rhs) const {
48+
return Fraction(num * rhs.den + rhs.num * den, den * rhs.den);
49+
}
50+
Fraction operator-(const Fraction& rhs) const {
51+
return Fraction(num * rhs.den - rhs.num * den, den * rhs.den);
52+
}
53+
Fraction operator*(const Fraction& rhs) const {
54+
return Fraction(num * rhs.num, den * rhs.den);
55+
}
56+
Fraction operator/(const Fraction& rhs) const {
57+
return Fraction(num * rhs.den, den * rhs.num);
58+
}
59+
60+
// 取整 (floor)
61+
__int128 floor() const {
62+
if (num >= 0) return num / den;
63+
else return (num - den + 1) / den;
64+
}
65+
66+
// 小数部分
67+
Fraction frac() const {
68+
return *this - Fraction(floor(), 1);
69+
}
70+
71+
// 转换为 double (仅用于调试输出,算法中不使用)
72+
double toDouble() const {
73+
return (double)num / (double)den;
74+
}
75+
};
76+
77+
// 幂函数
78+
Fraction pow(int base, int exp) {
79+
if (exp == 0) return Fraction(1);
80+
if (exp < 0) return Fraction(1) / pow(base, -exp);
81+
__int128 result = 1;
82+
for (int i = 0; i < exp; ++i) result *= base;
83+
return Fraction(result);
84+
}
85+
86+
// -------------------- FP16 解析 --------------------
87+
struct FP16Components {
88+
uint16_t bits;
89+
__int128 c; // 有效数字 (整数)
90+
int q; // 指数
91+
bool is_regular; // f != 0 (包括次正规数)
92+
bool is_irregular; // f == 0 且 e != 0 (2的幂)
93+
};
94+
95+
FP16Components f16_to_components(uint16_t bits) {
96+
int exp = (bits >> 10) & 0x1F;
97+
int frac = bits & 0x3FF;
98+
99+
if (exp == 0 && frac == 0) throw std::invalid_argument("+0 excluded");
100+
if (exp == 31) throw std::invalid_argument("inf/NaN excluded");
101+
102+
FP16Components comp;
103+
comp.bits = bits;
104+
105+
if (exp == 0) { // 次正规数
106+
comp.c = frac;
107+
comp.q = -24;
108+
comp.is_regular = true;
109+
comp.is_irregular = false;
110+
} else { // 常规数
111+
comp.c = 1024 + frac;
112+
comp.q = exp - 25;
113+
comp.is_regular = (frac != 0);
114+
comp.is_irregular = (frac == 0);
115+
}
116+
return comp;
117+
}
118+
119+
// -------------------- 计算 k (完全整数逻辑) --------------------
120+
int compute_k(int q, bool is_regular) {
121+
int k;
122+
if(is_regular){
123+
k = (q * 1233) >> 12;
124+
}else{
125+
k = ((q * 1233) - 512) >> 12;
126+
}
127+
return k;
128+
}
129+
130+
// -------------------- 算法1:计算 (d, k) --------------------
131+
std::pair<__int128, int> f16_to_decimal(uint16_t bits) {
132+
auto comp = f16_to_components(bits);
133+
__int128 c = comp.c;
134+
int q = comp.q;
135+
bool is_regular = comp.is_regular;
136+
bool is_irregular = comp.is_irregular;
137+
138+
int k = compute_k(q, is_regular);
139+
140+
// v = c * 2^q
141+
Fraction v = Fraction(c) * pow(2, q);
142+
143+
// R = v * 10^{-k-1}
144+
Fraction R = v * pow(10, -k-1);
145+
__int128 m = R.floor();
146+
Fraction n = R.frac();
147+
148+
__int128 ten = 10 * m;
149+
150+
// 10n = 10 * (R - m) = 10*R - 10*m
151+
Fraction tenR = v * pow(10, -k); // v * 10^{-k}
152+
Fraction ten_n = tenR - Fraction(ten);
153+
__int128 floor_ten_n = ten_n.floor();
154+
Fraction delta = ten_n.frac();
155+
156+
__int128 one;
157+
158+
// Step 11-21: 根据 δ 确定 one 初值
159+
if (delta == Fraction(1, 2)) {
160+
if (floor_ten_n % 2 == 0)
161+
one = floor_ten_n;
162+
else
163+
one = floor_ten_n + 1;
164+
} else if (delta < Fraction(1, 2)) {
165+
one = floor_ten_n;
166+
} else {
167+
one = floor_ten_n + 1;
168+
}
169+
170+
// Step 22-28: irregular 的特殊处理
171+
if (is_irregular) {
172+
Fraction cond1_val = pow(2, q-2) * pow(10, -k);
173+
Fraction cond2_val = pow(2, q-2) * pow(10, -k-1);
174+
175+
if (delta > cond1_val) {
176+
one = floor_ten_n + 1;
177+
}
178+
if (cond2_val >= n) {
179+
one = 0;
180+
}
181+
} else {
182+
// Step 30-35: regular 情况下的最短表示检查
183+
Fraction A = pow(2, q-1) * pow(10, -k-1);
184+
185+
if (A > n || (A == n && c % 2 == 0)) {
186+
one = 0;
187+
} else if (A > (Fraction(1) - n) || (A == (Fraction(1) - n) && c % 2 == 0)) {
188+
one = 10;
189+
}
190+
}
191+
192+
__int128 d = ten + one;
193+
return {d, k};
194+
}
195+
196+
// -------------------- 主程序 --------------------
197+
int main() {
198+
std::ofstream out("f16_decimal_results.txt");
199+
if (!out) {
200+
std::cerr << "Cannot open output file." << std::endl;
201+
return 1;
202+
}
203+
204+
out << "# bits(hex) d k\n";
205+
206+
// 遍历所有正 FP16 数值 (排除 0x0000, 0x7C00..0x7FFF)
207+
for (uint32_t bits = 0x0001; bits <= 0x7BFF; ++bits) {
208+
try {
209+
auto [d, k] = f16_to_decimal(static_cast<uint16_t>(bits));
210+
out << "0x" << std::hex << std::uppercase << bits << std::dec
211+
<< " " << (int64_t)d << " " << k << "\n";
212+
} catch (const std::invalid_argument&) {
213+
// 跳过 +0, inf, NaN
214+
continue;
215+
} catch (const std::exception& e) {
216+
std::cerr << "Error processing 0x" << std::hex << bits << ": " << e.what() << std::endl;
217+
return 1;
218+
}
219+
}
220+
221+
out.close();
222+
std::cout << "Results written to f16_decimal_results.txt" << std::endl;
223+
return 0;
224+
}

bench/xjb/test/f16_to_decimal.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import sys
2+
from fractions import Fraction
3+
import math
4+
5+
# -------------------- 工具函数 --------------------
6+
7+
def f16_to_components(bits: int):
8+
"""
9+
解析 IEEE 754 binary16 正数 (sign=0),返回 (c, q, is_regular, is_irregular)。
10+
排除 +0, inf, NaN。
11+
"""
12+
exp = (bits >> 10) & 0x1F # 5 位指数
13+
frac = bits & 0x3FF # 10 位尾数
14+
15+
if exp == 0 and frac == 0:
16+
raise ValueError("+0 excluded")
17+
if exp == 31:
18+
raise ValueError("inf/NaN excluded")
19+
20+
if exp == 0: # subnormal
21+
c = frac
22+
q = -24
23+
is_regular = True # subnormal 属于 regular (因为 f≠0)
24+
is_irregular = False
25+
else: # normal
26+
c = 1024 + frac
27+
q = exp - 25
28+
if frac == 0:
29+
is_regular = False
30+
is_irregular = True
31+
else:
32+
is_regular = True
33+
is_irregular = False
34+
35+
return c, q, is_regular, is_irregular
36+
37+
38+
def compute_k(q: int, is_regular: bool) -> int:
39+
"""
40+
对于 regular: 10^k <= 2^q < 10^{k+1}
41+
对于 irregular: (4/3)*10^k <= 2^q < (4/3)*10^{k+1}
42+
"""
43+
if is_regular:
44+
k = math.floor(q * math.log10(2))
45+
else:
46+
k = math.floor(q * math.log10(2) - math.log10(4/3))
47+
return int(k) # 取整
48+
49+
def f16_to_decimal(bits: int):
50+
"""
51+
根据算法 1 计算 binary16 正数的 (d, k),满足 SW 原则。
52+
"""
53+
c, q, is_regular, is_irregular = f16_to_components(bits)
54+
55+
# Step 2-5: 计算 k
56+
k = compute_k(q, is_regular)
57+
58+
# v = c * 2^q
59+
v = c * (Fraction(2) ** q)
60+
61+
# Step 7: m = floor(v * 10^{-k-1})
62+
# R = v * 10^{-k-1}
63+
R = v * (Fraction(10) ** (-k - 1))
64+
m = int(R) # floor(R)
65+
n = R - m # 小数部分
66+
67+
# Step 9: ten = 10*m
68+
ten = 10 * m
69+
70+
# Step 10: δ = fractional part of 10n
71+
# 10n = 10*(R - m) = 10*R - 10*m
72+
tenR = v * (Fraction(10) ** (-k)) # = v * 10^{-k}
73+
ten_n = tenR - ten
74+
floor_ten_n = int(ten_n) # floor(10n)
75+
delta = ten_n - floor_ten_n
76+
77+
# Step 11-21: 根据 δ 确定 one 初值
78+
if delta == Fraction(1, 2):
79+
# 0.5 的情况:round to even
80+
if floor_ten_n % 2 == 0:
81+
one = floor_ten_n
82+
else:
83+
one = floor_ten_n + 1
84+
elif delta < Fraction(1, 2):
85+
one = floor_ten_n
86+
else:
87+
one = floor_ten_n + 1
88+
89+
# Step 22-28: irregular 的特殊处理
90+
if is_irregular:
91+
# 2^{q-2} * 10^{-k}
92+
cond1_val = (Fraction(2) ** (q - 2)) * (Fraction(10) ** (-k))
93+
# 2^{q-2} * 10^{-k-1}
94+
cond2_val = (Fraction(2) ** (q - 2)) * (Fraction(10) ** (-k - 1))
95+
96+
if delta > cond1_val:
97+
one = floor_ten_n + 1
98+
if cond2_val >= n:
99+
one = 0
100+
else:
101+
# Step 30-35: regular 情况下检查是否 one=0 或 one=10 更短
102+
# 边界值 A = 2^{q-1} * 10^{-k-1}
103+
A = (Fraction(2) ** (q - 1)) * (Fraction(10) ** (-k - 1))
104+
105+
# 条件1: one = 0
106+
if A > n or (A == n and c % 2 == 0):
107+
one = 0
108+
# 条件2: one = 10
109+
elif A > (1 - n) or (A == 1 - n and c % 2 == 0):
110+
one = 10
111+
112+
# Step 37: d = ten + one
113+
d = ten + one
114+
115+
return d, k
116+
117+
118+
# -------------------- 主程序 --------------------
119+
120+
def main():
121+
output_filename = "f16_decimal_results.txt"
122+
123+
with open(output_filename, "w", encoding="utf-8") as f:
124+
f.write("# bits(hex) d k\n")
125+
# 正数范围:0x0001 .. 0x7BFF (排除 +0, inf, NaN)
126+
for bits in range(0x0001, 0x7C00):
127+
try:
128+
d, k = f16_to_decimal(bits)
129+
f.write(f"0x{bits:04X} {d} {k}\n")
130+
except ValueError:
131+
# 正常情况不会触发,因为我们跳过了 +0, inf, NaN
132+
continue
133+
except Exception as e:
134+
print(f"Error processing 0x{bits:04X}: {e}", file=sys.stderr)
135+
raise
136+
137+
print(f"Results written to {output_filename}")
138+
139+
140+
if __name__ == "__main__":
141+
main()

0 commit comments

Comments
 (0)