Skip to content

Commit dfa8d73

Browse files
move PE logic + test but no run
1 parent cbcb171 commit dfa8d73

File tree

5 files changed

+230
-332
lines changed

5 files changed

+230
-332
lines changed

src/PE.v

Lines changed: 149 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,160 @@
11
module PE #(
22
parameter WIDTH = 8
33
)(
4-
input wire clk,
5-
input wire rst,
6-
input wire clear,
7-
input wire signed [WIDTH-1:0] a_in,
8-
input wire signed [WIDTH-1:0] b_in,
4+
input wire clk,
5+
input wire rst,
6+
input wire clear,
7+
input wire signed [WIDTH-1:0] a_in, // FP8 E4M3
8+
input wire signed [WIDTH-1:0] b_in, // FP8 E4M3
9+
output reg signed [WIDTH-1:0] a_out,
10+
output reg signed [WIDTH-1:0] b_out,
11+
output reg signed [15:0] c_out // BF16 accumulator
12+
);
913

10-
output reg signed [WIDTH-1:0] a_out,
11-
output reg signed [WIDTH-1:0] b_out,
14+
// =================================================================
15+
// 1. FP8 (E4M3) → BF16 correct conversion (bias 7 → 127 ⇒ +120)
16+
// =================================================================
17+
function automatic [15:0] fp8_to_bf16(input [7:0] fp8);
18+
reg sign;
19+
reg [3:0] exp8;
20+
reg [2:0] mant8;
21+
reg [7:0] exp16;
22+
reg [6:0] mant16;
23+
begin
24+
sign = fp8[7];
25+
exp8 = fp8[6:3];
26+
mant8 = fp8[2:0];
1227

13-
output reg signed [WIDTH*2-1:0] c_out
14-
);
28+
if (exp8 == 4'd0) begin
29+
// zero or subnormal → BF16 subnormal or zero
30+
mant16 = {mant8, 4'b0}; // shift into BF16 mantissa field
31+
exp16 = 8'd0;
32+
if (mant8 == 0) mant16 = 7'd0;
33+
end else if (exp8 == 4'd15) begin
34+
// Inf or NaN
35+
exp16 = 8'd255;
36+
mant16 = (mant8 == 0) ? 7'd0 : {1'b1, mant8[2:0], 3'b0};
37+
end else begin
38+
// normal number
39+
exp16 = exp8 + 8'd120; // bias adjustment
40+
mant16 = {1'b1, mant8} << 3; // implicit bit + shift left 3
41+
end
42+
fp8_to_bf16 = {sign, exp16, mant16};
43+
end
44+
endfunction
45+
46+
// =================================================================
47+
// 2. BF16 × BF16 → BF16 multiplier (no extra bits, no FP32)
48+
// =================================================================
49+
function automatic [15:0] bf16_mul(input [15:0] a, input [15:0] b);
50+
reg sa, sb, sp;
51+
reg [7:0] ea, eb;
52+
reg [7:0] ep;
53+
reg [6:0] ma, mb;
54+
reg [13:0] mp;
55+
reg [7:0] ep_final;
56+
reg [6:0] mp_final;
57+
reg round_bit;
58+
begin
59+
sa = a[15]; sb = b[15];
60+
ea = a[14:7]; ma = a[6:0];
61+
eb = b[14:7]; mb = b[6:0];
62+
63+
sp = sa ^ sb;
64+
65+
// Special cases
66+
if (ea == 0 || eb == 0) begin
67+
bf16_mul = {sp, 15'd0}; // zero
68+
end else if (ea == 255 || eb == 255) begin
69+
bf16_mul = {sp, 8'd255, 7'd0}; // Inf (quiet NaN ignored for area)
70+
end else begin
71+
// Normal path
72+
ep = ea + eb - 127; // exponent sum (biased)
73+
mp = {1'b1, ma} * {1'b1, mb}; // 8 × 8 → 14-bit mantissa product
74+
75+
// Normalize (at most 1-bit shift)
76+
if (mp[13]) begin
77+
mp_final = mp[12:6]; // take bits 12:6 → 7-bit mantissa
78+
round_bit = mp[5];
79+
ep_final = ep + 1;
80+
end else begin
81+
mp_final = mp[11:5];
82+
round_bit = mp[4];
83+
ep_final = ep;
84+
end
85+
86+
// Round-to-nearest-even (tie-to-even)
87+
if (round_bit && (mp_final[0] || |mp[3:0]))
88+
mp_final = mp_final + 1'b1;
89+
90+
// Overflow → Inf
91+
if (ep_final == 8'hFF)
92+
bf16_mul = {sp, 8'd255, 7'd0};
93+
else
94+
bf16_mul = {sp, ep_final, mp_final};
95+
end
96+
end
97+
endfunction
98+
99+
// =================================================================
100+
// 3. BF16 + BF16 → BF16 adder (block floating-point style)
101+
// =================================================================
102+
function automatic [15:0] bf16_add(input [15:0] a, input [15:0] b);
103+
reg sa, sb;
104+
reg [7:0] ea, eb;
105+
reg [7:0] e_max;
106+
reg [8:0] mant_a, mant_b; // 1.7 format (implicit bit included)
107+
reg [8:0] mant_sum;
108+
reg [7:0] e_result;
109+
reg [6:0] m_result;
110+
begin
111+
sa = a[15]; ea = a[14:7]; mant_a = {1'b1, a[6:0], 1'b0};
112+
sb = b[15]; eb = b[14:7]; mant_b = {1'b1, b[6:0], 1'b0};
113+
114+
if (ea > eb) begin
115+
e_max = ea;
116+
mant_b = mant_b >> (ea - eb);
117+
end else begin
118+
e_max = eb;
119+
mant_a = mant_a >> (eb - ea);
120+
end
121+
122+
if (sa == sb)
123+
mant_sum = mant_a + mant_b;
124+
else
125+
mant_sum = (sa ? mant_b - mant_a : mant_a - mant_b);
126+
127+
// Normalize
128+
if (mant_sum[8]) begin
129+
m_result = mant_sum[8:2];
130+
e_result = e_max + 1;
131+
end else begin
132+
m_result = mant_sum[7:1];
133+
e_result = e_max;
134+
end
135+
136+
// Simple overflow → Inf
137+
bf16_add = (e_result >= 255) ? {sa, 8'd255, 7'd0} : {sa, e_result, m_result};
138+
end
139+
endfunction
140+
141+
// =================================================================
142+
// Pipeline and accumulation
143+
// =================================================================
144+
wire [15:0] a_bf16 = fp8_to_bf16(a_in);
145+
wire [15:0] b_bf16 = fp8_to_bf16(b_in);
146+
wire [15:0] product = bf16_mul(a_bf16, b_bf16);
15147

16148
always @(posedge clk) begin
17149
a_out <= a_in;
18150
b_out <= b_in;
19-
if (rst) begin
20-
c_out <= 0;
21-
a_out <= 0;
22-
b_out <= 0;
23-
end else if (clear) begin
24-
c_out <= a_in * b_in;
25-
end else begin
26-
c_out <= c_out + (a_in * b_in);
27-
end
151+
152+
if (rst)
153+
c_out <= 16'd0;
154+
else if (clear)
155+
c_out <= product;
156+
else
157+
c_out <= bf16_add(c_out, product);
28158
end
29159

30-
endmodule
160+
endmodule

src/fp_PE.v

Lines changed: 0 additions & 161 deletions
This file was deleted.

test/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ FST ?= -fst # Use more efficient FST format
77
TOPLEVEL_LANG ?= verilog
88
SRC_DIR = $(PWD)/../src
99
PROJECT_SOURCES = tpu.v \
10-
systolic_array_2x2.v \
10+
systolic_array_2x2.v \
1111
control_unit.v \
1212
PE.v \
1313
mmu_feeder.v \

test/tb.v

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module tb ();
88

99
// Dump the signals to a VCD file. You can view it with gtkwave or surfer.
1010
initial begin
11-
$dumpfile("tb.vcd");
11+
$dumpfile("tb.fst");
1212
$dumpvars(0, tb);
1313
#1;
1414
end
@@ -22,9 +22,17 @@ module tb ();
2222
wire [7:0] uo_out;
2323
wire [7:0] uio_out;
2424
wire [7:0] uio_oe;
25+
`ifdef GL_TEST
26+
wire VPWR = 1'b1;
27+
wire VGND = 1'b0;
28+
`endif
2529

2630
// Replace tt_um_example with your module name:
2731
tt_um_tpu tpu_project (
32+
`ifdef GL_TEST
33+
.VPWR(VPWR),
34+
.VGND(VGND),
35+
`endif
2836
.ui_in (ui_in), // Dedicated inputs
2937
.uo_out (uo_out), // Dedicated outputs
3038
.uio_in (uio_in), // IOs: Input path

0 commit comments

Comments
 (0)