11module PE #(
22 parameter WIDTH = 8
33)(
4- input wire clk,
5- input wire rst,
6- input wire clear,
7- input wire signed [WIDTH- 1 :0 ] a_in,
8- input wire signed [WIDTH- 1 :0 ] b_in,
4+ input wire clk,
5+ input wire rst,
6+ input wire clear,
7+ input wire signed [WIDTH- 1 :0 ] a_in, // FP8 E4M3
8+ input wire signed [WIDTH- 1 :0 ] b_in, // FP8 E4M3
9+ output reg signed [WIDTH- 1 :0 ] a_out,
10+ output reg signed [WIDTH- 1 :0 ] b_out,
11+ output reg signed [15 :0 ] c_out // BF16 accumulator
12+ );
913
10- output reg signed [WIDTH- 1 :0 ] a_out,
11- output reg signed [WIDTH- 1 :0 ] b_out,
14+ // =================================================================
15+ // 1. FP8 (E4M3) → BF16 correct conversion (bias 7 → 127 ⇒ +120)
16+ // =================================================================
17+ function automatic [15 :0 ] fp8_to_bf16(input [7 :0 ] fp8);
18+ reg sign;
19+ reg [3 :0 ] exp8;
20+ reg [2 :0 ] mant8;
21+ reg [7 :0 ] exp16;
22+ reg [6 :0 ] mant16;
23+ begin
24+ sign = fp8[7 ];
25+ exp8 = fp8[6 :3 ];
26+ mant8 = fp8[2 :0 ];
1227
13- output reg signed [WIDTH* 2 - 1 :0 ] c_out
14- );
28+ if (exp8 == 4'd0 ) begin
29+ // zero or subnormal → BF16 subnormal or zero
30+ mant16 = {mant8, 4'b0 }; // shift into BF16 mantissa field
31+ exp16 = 8'd0 ;
32+ if (mant8 == 0 ) mant16 = 7'd0 ;
33+ end else if (exp8 == 4'd15 ) begin
34+ // Inf or NaN
35+ exp16 = 8'd255 ;
36+ mant16 = (mant8 == 0 ) ? 7'd0 : {1'b1 , mant8[2 :0 ], 3'b0 };
37+ end else begin
38+ // normal number
39+ exp16 = exp8 + 8'd120 ; // bias adjustment
40+ mant16 = {1'b1 , mant8} << 3 ; // implicit bit + shift left 3
41+ end
42+ fp8_to_bf16 = {sign, exp16, mant16};
43+ end
44+ endfunction
45+
46+ // =================================================================
47+ // 2. BF16 × BF16 → BF16 multiplier (no extra bits, no FP32)
48+ // =================================================================
49+ function automatic [15 :0 ] bf16_mul(input [15 :0 ] a, input [15 :0 ] b);
50+ reg sa, sb, sp;
51+ reg [7 :0 ] ea, eb;
52+ reg [7 :0 ] ep;
53+ reg [6 :0 ] ma, mb;
54+ reg [13 :0 ] mp;
55+ reg [7 :0 ] ep_final;
56+ reg [6 :0 ] mp_final;
57+ reg round_bit;
58+ begin
59+ sa = a[15 ]; sb = b[15 ];
60+ ea = a[14 :7 ]; ma = a[6 :0 ];
61+ eb = b[14 :7 ]; mb = b[6 :0 ];
62+
63+ sp = sa ^ sb;
64+
65+ // Special cases
66+ if (ea == 0 || eb == 0 ) begin
67+ bf16_mul = {sp, 15'd0 }; // zero
68+ end else if (ea == 255 || eb == 255 ) begin
69+ bf16_mul = {sp, 8'd255 , 7'd0 }; // Inf (quiet NaN ignored for area)
70+ end else begin
71+ // Normal path
72+ ep = ea + eb - 127 ; // exponent sum (biased)
73+ mp = {1'b1 , ma} * {1'b1 , mb}; // 8 × 8 → 14-bit mantissa product
74+
75+ // Normalize (at most 1-bit shift)
76+ if (mp[13 ]) begin
77+ mp_final = mp[12 :6 ]; // take bits 12:6 → 7-bit mantissa
78+ round_bit = mp[5 ];
79+ ep_final = ep + 1 ;
80+ end else begin
81+ mp_final = mp[11 :5 ];
82+ round_bit = mp[4 ];
83+ ep_final = ep;
84+ end
85+
86+ // Round-to-nearest-even (tie-to-even)
87+ if (round_bit && (mp_final[0 ] || | mp[3 :0 ]))
88+ mp_final = mp_final + 1'b1 ;
89+
90+ // Overflow → Inf
91+ if (ep_final == 8'hFF )
92+ bf16_mul = {sp, 8'd255 , 7'd0 };
93+ else
94+ bf16_mul = {sp, ep_final, mp_final};
95+ end
96+ end
97+ endfunction
98+
99+ // =================================================================
100+ // 3. BF16 + BF16 → BF16 adder (block floating-point style)
101+ // =================================================================
102+ function automatic [15 :0 ] bf16_add(input [15 :0 ] a, input [15 :0 ] b);
103+ reg sa, sb;
104+ reg [7 :0 ] ea, eb;
105+ reg [7 :0 ] e_max;
106+ reg [8 :0 ] mant_a, mant_b; // 1.7 format (implicit bit included)
107+ reg [8 :0 ] mant_sum;
108+ reg [7 :0 ] e_result;
109+ reg [6 :0 ] m_result;
110+ begin
111+ sa = a[15 ]; ea = a[14 :7 ]; mant_a = {1'b1 , a[6 :0 ], 1'b0 };
112+ sb = b[15 ]; eb = b[14 :7 ]; mant_b = {1'b1 , b[6 :0 ], 1'b0 };
113+
114+ if (ea > eb) begin
115+ e_max = ea;
116+ mant_b = mant_b >> (ea - eb);
117+ end else begin
118+ e_max = eb;
119+ mant_a = mant_a >> (eb - ea);
120+ end
121+
122+ if (sa == sb)
123+ mant_sum = mant_a + mant_b;
124+ else
125+ mant_sum = (sa ? mant_b - mant_a : mant_a - mant_b);
126+
127+ // Normalize
128+ if (mant_sum[8 ]) begin
129+ m_result = mant_sum[8 :2 ];
130+ e_result = e_max + 1 ;
131+ end else begin
132+ m_result = mant_sum[7 :1 ];
133+ e_result = e_max;
134+ end
135+
136+ // Simple overflow → Inf
137+ bf16_add = (e_result >= 255 ) ? {sa, 8'd255 , 7'd0 } : {sa, e_result, m_result};
138+ end
139+ endfunction
140+
141+ // =================================================================
142+ // Pipeline and accumulation
143+ // =================================================================
144+ wire [15 :0 ] a_bf16 = fp8_to_bf16(a_in);
145+ wire [15 :0 ] b_bf16 = fp8_to_bf16(b_in);
146+ wire [15 :0 ] product = bf16_mul(a_bf16, b_bf16);
15147
16148 always @(posedge clk) begin
17149 a_out <= a_in;
18150 b_out <= b_in;
19- if (rst) begin
20- c_out <= 0 ;
21- a_out <= 0 ;
22- b_out <= 0 ;
23- end else if (clear) begin
24- c_out <= a_in * b_in;
25- end else begin
26- c_out <= c_out + (a_in * b_in);
27- end
151+
152+ if (rst)
153+ c_out <= 16'd0 ;
154+ else if (clear)
155+ c_out <= product;
156+ else
157+ c_out <= bf16_add(c_out, product);
28158 end
29159
30- endmodule
160+ endmodule
0 commit comments