@@ -3,40 +3,60 @@ module bf16_adder (
33 input logic [15 :0 ] b,
44 output logic [15 :0 ] sum
55);
6- logic a_sign = a[15 ];
7- logic b_sign = b[15 ];
8- logic [7 :0 ] a_exp = a[14 :7 ];
9- logic [7 :0 ] b_exp = b[14 :7 ];
10- logic [6 :0 ] a_frac = a[6 :0 ];
11- logic [6 :0 ] b_frac = b[6 :0 ];
12-
13- logic [8 :0 ] exp_diff;
14- logic [8 :0 ] big_exp, small_exp;
15- logic [7 :0 ] big_frac, small_frac;
16- logic swap;
17-
18- logic [9 :0 ] big_mant = {1'b1 , big_frac};
19- logic [9 :0 ] small_mant_aligned;
20- logic [10 :0 ] sum_mant;
21- logic sum_sign;
22-
23- always @(* ) begin
24- // Handle specials quickly
25- if (a_exp == '1 ) sum = a;
26- else if (b_exp == '1 ) sum = b;
27- else if (a_exp == 0 && a_frac == 0 ) sum = b;
28- else if (b_exp == 0 && b_frac == 0 ) sum = a;
29- else begin
30- swap = (a_exp < b_exp) || (a_exp == b_exp && a_frac < b_frac);
31- big_exp = swap ? b_exp : a_exp;
32- small_exp = swap ? a_exp : b_exp;
33- big_frac = swap ? b_frac : a_frac;
34- small_frac= swap ? a_frac : b_frac;
6+
7+ logic a_sign, b_sign;
8+ logic [7 :0 ] a_exp, b_exp;
9+ logic [6 :0 ] a_frac, b_frac;
10+
11+ // unpacked fields
12+ always @(* ) begin
13+ a_sign = a[15 ];
14+ b_sign = b[15 ];
15+ a_exp = a[14 :7 ];
16+ b_exp = b[14 :7 ];
17+ a_frac = a[6 :0 ];
18+ b_frac = b[6 :0 ];
19+ end
20+
21+ always @(* ) begin
22+ // default to avoid latches
23+ sum = a; // safe default
24+
25+ // special cases
26+ if (a_exp == 8'hFF ) begin
27+ sum = a; // inf or NaN
28+ end else if (b_exp == 8'hFF ) begin
29+ sum = b;
30+ end else if (a_exp == 0 && a_frac == 0 ) begin
31+ sum = b;
32+ end else if (b_exp == 0 && b_frac == 0 ) begin
33+ sum = a;
34+ end else begin
35+ // normal add
36+ logic swap;
37+ logic [7 :0 ] big_exp, small_exp;
38+ logic [6 :0 ] big_frac, small_frac;
39+ logic [8 :0 ] exp_diff;
40+ logic [9 :0 ] big_mant;
41+ logic [9 :0 ] small_mant_aligned;
42+ logic [10 :0 ] sum_mant;
43+ logic sum_sign;
44+
45+ swap = (a_exp < b_exp) || ((a_exp == b_exp) && (a_frac < b_frac));
46+ big_exp = swap ? b_exp : a_exp;
47+ small_exp = swap ? a_exp : b_exp;
48+ big_frac = swap ? b_frac : a_frac;
49+ small_frac = swap ? a_frac : b_frac;
3550
3651 exp_diff = big_exp - small_exp;
3752
38- small_mant_aligned = (exp_diff >= 10 ) ? 0 :
39- {2'b1 , small_frac, 1'b0 } >> exp_diff;
53+ big_mant = {1'b1 , big_frac}; // hidden bit
54+
55+ if (exp_diff >= 10 ) begin
56+ small_mant_aligned = 0 ;
57+ end else begin
58+ small_mant_aligned = {1'b1 , small_frac, 2'b00 } >> exp_diff; // +2 bits for guard+round
59+ end
4060
4161 if (a_sign == b_sign) begin
4262 sum_mant = {1'b0 , big_mant} + small_mant_aligned;
@@ -47,12 +67,13 @@ module bf16_adder (
4767 end
4868
4969 if (sum_mant[10 ]) begin
50- sum = {sum_sign, big_exp+ 1 , sum_mant[9 :3 ]};
70+ sum = {sum_sign, big_exp + 1'd1 , sum_mant[9 :3 ]};
5171 end else if (sum_mant[9 :0 ] == 0 ) begin
52- sum = 16'b0 ;
72+ sum = 16'h0000 ;
5373 end else begin
5474 sum = {sum_sign, big_exp, sum_mant[9 :3 ]};
5575 end
5676 end
5777 end
78+
5879endmodule
0 commit comments