|
1 | 1 | module PE ( |
2 | | - input wire clk, |
3 | | - input wire rst, |
4 | | - input wire clear, |
| 2 | + input wire clk, |
| 3 | + input wire rst, |
| 4 | + input wire clear, |
5 | 5 |
|
6 | | - input wire [7:0] a_in, // FP8 E4M3 (subnormals flushed) |
7 | | - input wire [7:0] b_in, |
| 6 | + input wire [7:0] a_in, // FP8 E4M3 |
| 7 | + input wire [7:0] b_in, // FP8 E4M3 |
8 | 8 |
|
9 | | - output reg [7:0] a_out, |
10 | | - output reg [7:0] b_out, |
11 | | - output reg [15:0] c_out // BF16 output after convert |
| 9 | + output reg [7:0] a_out, |
| 10 | + output reg [7:0] b_out, |
| 11 | + output reg [15:0] c_out // BF16 accumulator |
12 | 12 | ); |
13 | 13 |
|
14 | | - // -------------------------------------------------------------------- |
15 | | - // 1. Ultra-light FP8 Decode |
16 | | - // Subnormals → zero |
17 | | - // Mantissa = 1.xxx (4-bit int), Exponent = exp - 7 |
18 | | - // -------------------------------------------------------------------- |
| 14 | + // ================================================================ |
| 15 | + // 1. FP8 Decode → Integer mant, exp, sign |
| 16 | + // ================================================================ |
19 | 17 | wire sign_a = a_in[7]; |
20 | 18 | wire sign_b = b_in[7]; |
21 | 19 | wire sign_p = sign_a ^ sign_b; |
22 | 20 |
|
23 | 21 | wire [3:0] exp_a = a_in[6:3]; |
24 | 22 | wire [3:0] exp_b = b_in[6:3]; |
25 | 23 |
|
26 | | - wire [3:0] mant_a = (exp_a == 4'd0) ? 4'd0 : {1'b1, a_in[2:0]}; |
27 | | - wire [3:0] mant_b = (exp_b == 4'd0) ? 4'd0 : {1'b1, b_in[2:0]}; |
| 24 | + wire [3:0] mant_a = (exp_a == 0) ? 0 : {1'b1, a_in[2:0]}; // 8..15 |
| 25 | + wire [3:0] mant_b = (exp_b == 0) ? 0 : {1'b1, b_in[2:0]}; |
28 | 26 |
|
29 | | - wire signed [5:0] e_a = $signed({1'b0, exp_a}) - 6'd7; |
30 | | - wire signed [5:0] e_b = $signed({1'b0, exp_b}) - 6'd7; |
| 27 | + // 4×4 → 8-bit multiply |
| 28 | + wire [7:0] mant_prod_raw = mant_a * mant_b; |
31 | 29 |
|
32 | | - // -------------------------------------------------------------------- |
33 | | - // 2. INT multiply + shift exponent sum |
34 | | - // -------------------------------------------------------------------- |
35 | | - wire [7:0] mant_prod = mant_a * mant_b; // 8-bit |
| 30 | + // FP8 bias = 7 |
| 31 | + wire [9:0] exp_sum_raw = exp_a + exp_b - 7; |
36 | 32 |
|
37 | | - wire signed [6:0] shift_amt = e_a + e_b; |
| 33 | + wire prod_zero = (mant_prod_raw == 0) | (exp_a == 0) | (exp_b == 0); |
38 | 34 |
|
39 | | - // Shift in a bounded range (-7..+8 typical) |
40 | | - wire signed [23:0] prod_shifted = |
41 | | - (shift_amt >= 0) ? |
42 | | - ({{16{1'b0}}, mant_prod} << shift_amt) : |
43 | | - ({{16{1'b0}}, mant_prod} >> (-shift_amt)); |
| 35 | + // ================================================================ |
| 36 | + // 2. Normalize INT mantissa product → BF16 format |
| 37 | + // mant_prod_raw is 8–225 |
| 38 | + // => represent as 1.xxxxx BF16 mantissa (7 bits) |
| 39 | + // ================================================================ |
| 40 | + reg [7:0] mant_norm; |
| 41 | + reg [7:0] exp_norm; |
44 | 42 |
|
45 | | - wire signed [23:0] prod = sign_p ? -prod_shifted : prod_shifted; |
46 | | - |
47 | | - // -------------------------------------------------------------------- |
48 | | - // 3. INT accumulator |
49 | | - // -------------------------------------------------------------------- |
50 | | - reg signed [23:0] acc; |
| 43 | + always @(*) begin |
| 44 | + if (prod_zero) begin |
| 45 | + mant_norm = 0; |
| 46 | + exp_norm = 0; |
| 47 | + end |
| 48 | + else if (mant_prod_raw[7]) begin |
| 49 | + // 1xx.xxxxx range → shift down 1 bit |
| 50 | + // mantissa: top 7 bits |
| 51 | + mant_norm = mant_prod_raw[7:1]; |
| 52 | + exp_norm = exp_sum_raw + 127 + 1; |
| 53 | + end |
| 54 | + else begin |
| 55 | + // 0xx.xxxxx range → already normalized |
| 56 | + mant_norm = mant_prod_raw[6:0]; |
| 57 | + exp_norm = exp_sum_raw + 127; |
| 58 | + end |
| 59 | + end |
51 | 60 |
|
52 | | - // Convert INT24 → BF16 (approx, efficient) |
53 | | - function automatic [15:0] int24_to_bf16(input signed [23:0] x); |
54 | | - reg sign; |
55 | | - reg [23:0] mag; |
56 | | - reg [7:0] exponent; |
57 | | - reg [6:0] mant; |
| 61 | + // ================================================================ |
| 62 | + // 3. Reconstruct BF16 product |
| 63 | + // ================================================================ |
| 64 | + wire [15:0] bf16_prod = {sign_p, exp_norm, mant_norm}; |
| 65 | + |
| 66 | + // ================================================================ |
| 67 | + // 4. BF16 adder (your original one) |
| 68 | + // ================================================================ |
| 69 | + function automatic [15:0] bf16_add( |
| 70 | + input [15:0] a, |
| 71 | + input [15:0] b |
| 72 | + ); |
| 73 | + reg signa, signb, signr; |
| 74 | + reg [7:0] expa, expb, expr, ediff; |
| 75 | + reg [9:0] ma, mb, ms; |
58 | 76 | begin |
59 | | - if (x == 0) begin |
60 | | - int24_to_bf16 = 16'h0000; |
| 77 | + // unpack |
| 78 | + signa = a[15]; |
| 79 | + expa = a[14:7]; |
| 80 | + ma = (expa == 8'd0) ? 10'd0 : {1'b1, a[6:0], 1'b0}; |
| 81 | + |
| 82 | + signb = b[15]; |
| 83 | + expb = b[14:7]; |
| 84 | + mb = (expb == 8'd0) ? 10'd0 : {1'b1, b[6:0], 1'b0}; |
| 85 | + |
| 86 | + // exponent align |
| 87 | + if (expa > expb) begin |
| 88 | + ediff = expa - expb; |
| 89 | + expr = expa; |
| 90 | + mb = mb >> ediff; |
61 | 91 | end else begin |
62 | | - sign = x[23]; |
63 | | - mag = sign ? -x : x; |
64 | | - |
65 | | - // Normalize by finding highest bit |
66 | | - // Since ACC_WIDTH small, simple if-chain |
67 | | - if (mag[23]) begin exponent = 127+23; mant = mag[22:16]; end |
68 | | - else if (mag[22]) begin exponent = 127+22; mant = mag[21:15]; end |
69 | | - else if (mag[21]) begin exponent = 127+21; mant = mag[20:14]; end |
70 | | - else if (mag[20]) begin exponent = 127+20; mant = mag[19:13]; end |
71 | | - else if (mag[19]) begin exponent = 127+19; mant = mag[18:12]; end |
72 | | - else if (mag[18]) begin exponent = 127+18; mant = mag[17:11]; end |
73 | | - else if (mag[17]) begin exponent = 127+17; mant = mag[16:10]; end |
74 | | - else if (mag[16]) begin exponent = 127+16; mant = mag[15:9]; end |
75 | | - else if (mag[15]) begin exponent = 127+15; mant = mag[14:8]; end |
76 | | - else if (mag[14]) begin exponent = 127+14; mant = mag[13:7]; end |
77 | | - else if (mag[13]) begin exponent = 127+13; mant = mag[12:6]; end |
78 | | - else begin exponent = 0; mant = 0; end |
79 | | - |
80 | | - int24_to_bf16 = {sign, exponent, mant}; |
| 92 | + ediff = expb - expa; |
| 93 | + expr = expb; |
| 94 | + ma = ma >> ediff; |
| 95 | + end |
| 96 | + |
| 97 | + // add/sub mantissa |
| 98 | + if (signa == signb) begin |
| 99 | + ms = ma + mb; |
| 100 | + signr = signa; |
| 101 | + end else if (ma >= mb) begin |
| 102 | + ms = ma - mb; |
| 103 | + signr = signa; |
| 104 | + end else begin |
| 105 | + ms = mb - ma; |
| 106 | + signr = signb; |
| 107 | + end |
| 108 | + |
| 109 | + // normalize result |
| 110 | + if (ms[9]) begin |
| 111 | + // carry-out → shift right, add 1 to exponent |
| 112 | + bf16_add = {signr, expr + 8'd1, ms[9:3]}; |
| 113 | + end else if (ms[8]) begin |
| 114 | + bf16_add = {signr, expr, ms[8:2]}; |
| 115 | + end else begin |
| 116 | + bf16_add = 16'h0000; |
81 | 117 | end |
82 | 118 | end |
83 | 119 | endfunction |
84 | 120 |
|
85 | | - // -------------------------------------------------------------------- |
86 | | - // 4. Pipeline |
87 | | - // -------------------------------------------------------------------- |
| 121 | + // ================================================================ |
| 122 | + // 5. Pipeline + BF16 accumulation (unchanged) |
| 123 | + // ================================================================ |
88 | 124 | always @(posedge clk) begin |
89 | 125 | a_out <= a_in; |
90 | 126 | b_out <= b_in; |
91 | 127 |
|
92 | 128 | if (rst) |
93 | | - acc <= 24'd0; |
| 129 | + c_out <= 16'd0; |
94 | 130 | else if (clear) |
95 | | - acc <= prod; |
| 131 | + c_out <= bf16_prod; |
96 | 132 | else |
97 | | - acc <= acc + prod; |
98 | | - |
99 | | - // Convert accumulator to BF16 each cycle (or only at readout) |
100 | | - c_out <= int24_to_bf16(acc); |
| 133 | + c_out <= bf16_add(c_out, bf16_prod); |
101 | 134 | end |
102 | 135 |
|
103 | 136 | endmodule |
104 | | - |
|
0 commit comments