Skip to content

Commit 19dd34d

Browse files
restore 2x2
1 parent 9111d65 commit 19dd34d

File tree

3 files changed

+103
-72
lines changed

3 files changed

+103
-72
lines changed

info.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ project:
88
clock_hz: 50000000 # Clock frequency in Hz (or 0 if not applicable)
99

1010
# How many tiles your design occupies? A single tile is about 167x108 uM.
11-
tiles: "1x2" # Valid values: 1x1, 1x2, 2x2, 3x2, 4x2, 6x2 or 8x2
11+
tiles: "2x2" # Valid values: 1x1, 1x2, 2x2, 3x2, 4x2, 6x2 or 8x2
1212

1313
# Your top module name must start with "tt_um_". Make it unique by including your github username:
1414
top_module: "tt_um_tpu"

src/PE.v

Lines changed: 102 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,136 @@
11
module PE (
2-
input wire clk,
3-
input wire rst,
4-
input wire clear,
2+
input wire clk,
3+
input wire rst,
4+
input wire clear,
55

6-
input wire [7:0] a_in, // FP8 E4M3 (subnormals flushed)
7-
input wire [7:0] b_in,
6+
input wire [7:0] a_in, // FP8 E4M3
7+
input wire [7:0] b_in, // FP8 E4M3
88

9-
output reg [7:0] a_out,
10-
output reg [7:0] b_out,
11-
output reg [15:0] c_out // BF16 output after convert
9+
output reg [7:0] a_out,
10+
output reg [7:0] b_out,
11+
output reg [15:0] c_out // BF16 accumulator
1212
);
1313

14-
// --------------------------------------------------------------------
15-
// 1. Ultra-light FP8 Decode
16-
// Subnormals → zero
17-
// Mantissa = 1.xxx (4-bit int), Exponent = exp - 7
18-
// --------------------------------------------------------------------
14+
// ================================================================
15+
// 1. FP8 Decode → Integer mant, exp, sign
16+
// ================================================================
1917
wire sign_a = a_in[7];
2018
wire sign_b = b_in[7];
2119
wire sign_p = sign_a ^ sign_b;
2220

2321
wire [3:0] exp_a = a_in[6:3];
2422
wire [3:0] exp_b = b_in[6:3];
2523

26-
wire [3:0] mant_a = (exp_a == 4'd0) ? 4'd0 : {1'b1, a_in[2:0]};
27-
wire [3:0] mant_b = (exp_b == 4'd0) ? 4'd0 : {1'b1, b_in[2:0]};
24+
wire [3:0] mant_a = (exp_a == 0) ? 0 : {1'b1, a_in[2:0]}; // 8..15
25+
wire [3:0] mant_b = (exp_b == 0) ? 0 : {1'b1, b_in[2:0]};
2826

29-
wire signed [5:0] e_a = $signed({1'b0, exp_a}) - 6'd7;
30-
wire signed [5:0] e_b = $signed({1'b0, exp_b}) - 6'd7;
27+
// 4×4 → 8-bit multiply
28+
wire [7:0] mant_prod_raw = mant_a * mant_b;
3129

32-
// --------------------------------------------------------------------
33-
// 2. INT multiply + shift exponent sum
34-
// --------------------------------------------------------------------
35-
wire [7:0] mant_prod = mant_a * mant_b; // 8-bit
30+
// FP8 bias = 7
31+
wire [9:0] exp_sum_raw = exp_a + exp_b - 7;
3632

37-
wire signed [6:0] shift_amt = e_a + e_b;
33+
wire prod_zero = (mant_prod_raw == 0) | (exp_a == 0) | (exp_b == 0);
3834

39-
// Shift in a bounded range (-7..+8 typical)
40-
wire signed [23:0] prod_shifted =
41-
(shift_amt >= 0) ?
42-
({{16{1'b0}}, mant_prod} << shift_amt) :
43-
({{16{1'b0}}, mant_prod} >> (-shift_amt));
35+
// ================================================================
36+
// 2. Normalize INT mantissa product → BF16 format
37+
// mant_prod_raw is 8–225
38+
// => represent as 1.xxxxx BF16 mantissa (7 bits)
39+
// ================================================================
40+
reg [7:0] mant_norm;
41+
reg [7:0] exp_norm;
4442

45-
wire signed [23:0] prod = sign_p ? -prod_shifted : prod_shifted;
46-
47-
// --------------------------------------------------------------------
48-
// 3. INT accumulator
49-
// --------------------------------------------------------------------
50-
reg signed [23:0] acc;
43+
always @(*) begin
44+
if (prod_zero) begin
45+
mant_norm = 0;
46+
exp_norm = 0;
47+
end
48+
else if (mant_prod_raw[7]) begin
49+
// 1xx.xxxxx range → shift down 1 bit
50+
// mantissa: top 7 bits
51+
mant_norm = mant_prod_raw[7:1];
52+
exp_norm = exp_sum_raw + 127 + 1;
53+
end
54+
else begin
55+
// 0xx.xxxxx range → already normalized
56+
mant_norm = mant_prod_raw[6:0];
57+
exp_norm = exp_sum_raw + 127;
58+
end
59+
end
5160

52-
// Convert INT24 → BF16 (approx, efficient)
53-
function automatic [15:0] int24_to_bf16(input signed [23:0] x);
54-
reg sign;
55-
reg [23:0] mag;
56-
reg [7:0] exponent;
57-
reg [6:0] mant;
61+
// ================================================================
62+
// 3. Reconstruct BF16 product
63+
// ================================================================
64+
wire [15:0] bf16_prod = {sign_p, exp_norm, mant_norm};
65+
66+
// ================================================================
67+
// 4. BF16 adder (your original one)
68+
// ================================================================
69+
function automatic [15:0] bf16_add(
70+
input [15:0] a,
71+
input [15:0] b
72+
);
73+
reg signa, signb, signr;
74+
reg [7:0] expa, expb, expr, ediff;
75+
reg [9:0] ma, mb, ms;
5876
begin
59-
if (x == 0) begin
60-
int24_to_bf16 = 16'h0000;
77+
// unpack
78+
signa = a[15];
79+
expa = a[14:7];
80+
ma = (expa == 8'd0) ? 10'd0 : {1'b1, a[6:0], 1'b0};
81+
82+
signb = b[15];
83+
expb = b[14:7];
84+
mb = (expb == 8'd0) ? 10'd0 : {1'b1, b[6:0], 1'b0};
85+
86+
// exponent align
87+
if (expa > expb) begin
88+
ediff = expa - expb;
89+
expr = expa;
90+
mb = mb >> ediff;
6191
end else begin
62-
sign = x[23];
63-
mag = sign ? -x : x;
64-
65-
// Normalize by finding highest bit
66-
// Since ACC_WIDTH small, simple if-chain
67-
if (mag[23]) begin exponent = 127+23; mant = mag[22:16]; end
68-
else if (mag[22]) begin exponent = 127+22; mant = mag[21:15]; end
69-
else if (mag[21]) begin exponent = 127+21; mant = mag[20:14]; end
70-
else if (mag[20]) begin exponent = 127+20; mant = mag[19:13]; end
71-
else if (mag[19]) begin exponent = 127+19; mant = mag[18:12]; end
72-
else if (mag[18]) begin exponent = 127+18; mant = mag[17:11]; end
73-
else if (mag[17]) begin exponent = 127+17; mant = mag[16:10]; end
74-
else if (mag[16]) begin exponent = 127+16; mant = mag[15:9]; end
75-
else if (mag[15]) begin exponent = 127+15; mant = mag[14:8]; end
76-
else if (mag[14]) begin exponent = 127+14; mant = mag[13:7]; end
77-
else if (mag[13]) begin exponent = 127+13; mant = mag[12:6]; end
78-
else begin exponent = 0; mant = 0; end
79-
80-
int24_to_bf16 = {sign, exponent, mant};
92+
ediff = expb - expa;
93+
expr = expb;
94+
ma = ma >> ediff;
95+
end
96+
97+
// add/sub mantissa
98+
if (signa == signb) begin
99+
ms = ma + mb;
100+
signr = signa;
101+
end else if (ma >= mb) begin
102+
ms = ma - mb;
103+
signr = signa;
104+
end else begin
105+
ms = mb - ma;
106+
signr = signb;
107+
end
108+
109+
// normalize result
110+
if (ms[9]) begin
111+
// carry-out → shift right, add 1 to exponent
112+
bf16_add = {signr, expr + 8'd1, ms[9:3]};
113+
end else if (ms[8]) begin
114+
bf16_add = {signr, expr, ms[8:2]};
115+
end else begin
116+
bf16_add = 16'h0000;
81117
end
82118
end
83119
endfunction
84120

85-
// --------------------------------------------------------------------
86-
// 4. Pipeline
87-
// --------------------------------------------------------------------
121+
// ================================================================
122+
// 5. Pipeline + BF16 accumulation (unchanged)
123+
// ================================================================
88124
always @(posedge clk) begin
89125
a_out <= a_in;
90126
b_out <= b_in;
91127

92128
if (rst)
93-
acc <= 24'd0;
129+
c_out <= 16'd0;
94130
else if (clear)
95-
acc <= prod;
131+
c_out <= bf16_prod;
96132
else
97-
acc <= acc + prod;
98-
99-
// Convert accumulator to BF16 each cycle (or only at readout)
100-
c_out <= int24_to_bf16(acc);
133+
c_out <= bf16_add(c_out, bf16_prod);
101134
end
102135

103136
endmodule
104-

src/tpu.v

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ module tt_um_tpu (
3535
wire [7:0] a_data0, b_data0, a_data1, b_data1;
3636

3737
wire done;
38-
wire [1:0] state;
3938

4039
// Module Instantiations
4140
memory mem (

0 commit comments

Comments
 (0)