Skip to content

Commit 6416e9b

Browse files
dramatically simplify
1 parent 18ac448 commit 6416e9b

File tree

5 files changed

+79
-157
lines changed

5 files changed

+79
-157
lines changed

src/PE.v

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,29 +46,25 @@ module PE (
4646
// ----------------------- Accumulator (2's complement) -----------------------
4747
reg signed [17:0] acc;
4848

49-
// Compute next accumulator value (combinational)
50-
wire signed [17:0] abs_aligned = aligned_prod;
5149
wire signed [17:0] signed_prod =
5250
prod_sign ? -aligned_prod : aligned_prod;
5351

54-
// Accumulator input: on clear, load product; else add to accumulator
55-
wire signed [17:0] acc_in =
56-
clear ? signed_prod : (acc + signed_prod);
57-
5852
always @(posedge clk) begin
5953
a_out <= a_in;
6054
b_out <= b_in;
6155

6256
if (rst)
6357
acc <= 18'sd0;
58+
else if (clear)
59+
acc <= signed_prod;
6460
else
65-
acc <= acc_in;
61+
acc <= acc + signed_prod;
6662
end
6763

6864
// ----------------------- INT18 → BF16 (combinational) -----------------------
6965
wire [15:0] bf16_c;
7066
int18_to_bf16_lzd #(.FRAC_BITS(FRAC_BITS)) convert (
71-
.acc(acc_in),
67+
.acc(acc),
7268
.bf16(bf16_c)
7369
);
7470

src/control_unit.v

Lines changed: 52 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
/*
2-
* Monolithic Control Unit
3-
* Contains state machine control with data output selection
4-
* Data routing is handled directly between memory and systolic array
5-
* Simplified to 2-state machine: IDLE and ACTIVE
6-
*/
7-
81
`default_nettype none
92

103
module control_unit (
@@ -20,150 +13,88 @@ module control_unit (
2013

2114
// Systolic array control signals (lightweight!)
2215
output wire clear,
23-
output reg data_valid,
2416
output reg [1:0] a0_sel, a1_sel, b0_sel, b1_sel,
2517

2618
// Output interface
27-
output wire done,
2819
output reg [7:0] data_out
2920
);
3021

31-
// STATES - Simplified to just IDLE and ACTIVE
32-
localparam S_IDLE = 1'b0;
33-
localparam S_ACTIVE = 1'b1;
34-
35-
reg state, next_state;
3622
reg [2:0] mmu_cycle; // Counting Systolic Array Stages
3723

3824
// Done signal and clear signal
39-
assign done = data_valid && (mmu_cycle >= 3'b010);
4025
assign clear = (mmu_cycle == 3'b000);
4126

4227
// Buffer of output after clearing previous
4328
reg [7:0] tail_hold;
4429

45-
// Next state logic - very simple now!
46-
always @(*) begin
47-
next_state = state;
48-
49-
case (state)
50-
S_IDLE: begin
51-
if (load_en) begin
52-
next_state = S_ACTIVE;
53-
end
54-
end
55-
56-
S_ACTIVE: begin
57-
next_state = S_ACTIVE; // Stay active, cycles forever
58-
end
59-
60-
default: begin
61-
next_state = S_IDLE;
62-
end
63-
endcase
64-
end
65-
6630
// State machine and control signal generation
6731
always @(posedge clk) begin
6832
if (rst) begin
69-
state <= S_IDLE;
7033
mmu_cycle <= 0;
71-
data_valid <= 0;
7234
mem_addr <= 0;
7335
tail_hold <= 8'b0;
74-
a0_sel <= 2'b0;
75-
a1_sel <= 2'b0;
76-
b0_sel <= 2'b0;
77-
b1_sel <= 2'b0;
7836
end else begin
79-
state <= next_state;
80-
81-
case (state)
82-
S_IDLE: begin
83-
mem_addr <= 0;
84-
mmu_cycle <= 0;
85-
data_valid <= 0;
86-
a0_sel <= 2'b0;
87-
a1_sel <= 2'b0;
88-
b0_sel <= 2'b0;
89-
b1_sel <= 2'b0;
90-
91-
if (load_en) begin
92-
mem_addr <= mem_addr + 1;
93-
end
94-
end
95-
96-
S_ACTIVE: begin
97-
// Handle memory addressing
98-
if (load_en) begin
99-
mem_addr <= mem_addr + 1;
100-
data_valid <= 1;
101-
end
37+
// Handle memory addressing
38+
if (load_en) begin
39+
mem_addr <= mem_addr + 1;
40+
end else begin
41+
mem_addr <= 0;
42+
mmu_cycle <= 0;
43+
end
10244

103-
// The signal data_valid triggers systolic array computation, overlapping load & compute
104-
if (mem_addr == 3'b101) begin
105-
mmu_cycle <= 0; // systolic cycling begins at 5th load
106-
tail_hold <= c11[7:0];
107-
end else begin
108-
mmu_cycle <= mmu_cycle + 1;
109-
if (mem_addr == 3'b111) begin
110-
mem_addr <= 0;
111-
end
112-
end
113-
114-
// Generate mux selects based on mmu_cycle (same for all cycles)
115-
case (mmu_cycle)
116-
3'd0: begin
117-
a0_sel <= 2'd0; // weight0
118-
a1_sel <= 2'd2; // not used
119-
b0_sel <= 2'd0; // input0
120-
b1_sel <= 2'd2; // not used
121-
end
122-
3'd1: begin
123-
a0_sel <= 2'd1; // weight1
124-
a1_sel <= 2'd0; // weight2
125-
b0_sel <= 2'd1; // input1/input2 (transpose)
126-
b1_sel <= 2'd0; // input2/input1 (transpose)
127-
end
128-
3'd2: begin
129-
a0_sel <= 2'd2; // not used
130-
a1_sel <= 2'd1; // weight3
131-
b0_sel <= 2'd2; // not used
132-
b1_sel <= 2'd1; // input3
133-
end
134-
default: begin // by default turn everything off, i.e. set systolic inputs to 0
135-
a0_sel <= 2'd2;
136-
a1_sel <= 2'd2;
137-
b0_sel <= 2'd2;
138-
b1_sel <= 2'd2;
139-
end
140-
endcase
141-
end
142-
143-
default: begin
144-
mmu_cycle <= 0;
145-
data_valid <= 0;
45+
if (mem_addr == 3'b101) begin
46+
mmu_cycle <= 0; // systolic cycling begins at 5th load
47+
tail_hold <= c11[7:0];
48+
end else begin
49+
mmu_cycle <= mmu_cycle + 1;
50+
if (mem_addr == 3'b111) begin
14651
mem_addr <= 0;
14752
end
148-
endcase
53+
end
14954
end
15055
end
15156

15257
// Combinational logic for data_out
15358
always @(*) begin
15459
data_out = 8'b0;
155-
if (data_valid) begin
156-
case (mem_addr)
157-
3'b000: data_out = c00[15:8];
158-
3'b001: data_out = c00[7:0];
159-
3'b010: data_out = c01[15:8];
160-
3'b011: data_out = c01[7:0];
161-
3'b100: data_out = c10[15:8];
162-
3'b101: data_out = c10[7:0];
163-
3'b110: data_out = c11[15:8];
164-
3'b111: data_out = tail_hold;
165-
endcase
166-
end
60+
case (mem_addr)
61+
3'b000: data_out = c00[15:8];
62+
3'b001: data_out = c00[7:0];
63+
3'b010: data_out = c01[15:8];
64+
3'b011: data_out = c01[7:0];
65+
3'b100: data_out = c10[15:8];
66+
3'b101: data_out = c10[7:0];
67+
3'b110: data_out = c11[15:8];
68+
3'b111: data_out = tail_hold;
69+
endcase
70+
71+
// Generate mux selects based on mmu_cycle (same for all cycles)
72+
case (mmu_cycle)
73+
3'd0: begin
74+
a0_sel = 2'd0; // weight0
75+
a1_sel = 2'd2; // not used
76+
b0_sel = 2'd0; // input0
77+
b1_sel = 2'd2; // not used
78+
end
79+
3'd1: begin
80+
a0_sel = 2'd1; // weight1
81+
a1_sel = 2'd0; // weight2
82+
b0_sel = 2'd1; // input1/input2 (transpose)
83+
b1_sel = 2'd0; // input2/input1 (transpose)
84+
end
85+
3'd2: begin
86+
a0_sel = 2'd2; // not used
87+
a1_sel = 2'd1; // weight3
88+
b0_sel = 2'd2; // not used
89+
b1_sel = 2'd1; // input3
90+
end
91+
default: begin // by default turn everything off, i.e. set systolic inputs to 0
92+
a0_sel = 2'd2;
93+
a1_sel = 2'd2;
94+
b0_sel = 2'd2;
95+
b1_sel = 2'd2;
96+
end
97+
endcase
16798
end
16899

169100
endmodule

src/systolic_array_2x2.v

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ module systolic_array_2x2 #(
1111
input wire [7:0] input0, input1, input2, input3,
1212

1313
// Control signals from merged control unit (lightweight!)
14-
input wire data_valid,
1514
input wire [1:0] a0_sel, a1_sel, b0_sel, b1_sel,
1615
input wire transpose,
1716

@@ -24,21 +23,17 @@ module systolic_array_2x2 #(
2423
wire [WIDTH-1:0] b_wire [0:2][0:1];
2524
wire signed [WIDTH*2-1:0] c_array [0:1][0:1];
2625

27-
assign a_wire[0][0] = data_valid ?
28-
(a0_sel == 2'd0) ? weight0 :
29-
(a0_sel == 2'd1) ? weight1 : 8'b0 : 8'b0;
26+
assign a_wire[0][0] = (a0_sel == 2'd0) ? weight0 :
27+
(a0_sel == 2'd1) ? weight1 : 8'b0;
3028

31-
assign a_wire[1][0] = data_valid ?
32-
(a1_sel == 2'd0) ? weight2 :
33-
(a1_sel == 2'd1) ? weight3 : 8'b0 : 8'b0;
29+
assign a_wire[1][0] = (a1_sel == 2'd0) ? weight2 :
30+
(a1_sel == 2'd1) ? weight3 : 8'b0;
3431

35-
assign b_wire[0][0] = data_valid ?
36-
(b0_sel == 2'd0) ? input0 :
37-
(b0_sel == 2'd1) ? (transpose ? input1 : input2) : 8'b0 : 8'b0;
32+
assign b_wire[0][0] = (b0_sel == 2'd0) ? input0 :
33+
(b0_sel == 2'd1) ? (transpose ? input1 : input2) : 8'b0;
3834

39-
assign b_wire[0][1] = data_valid ?
40-
(b1_sel == 2'd0) ? (transpose ? input2 : input1) :
41-
(b1_sel == 2'd1) ? input3 : 8'b0 : 8'b0;
35+
assign b_wire[0][1] = (b1_sel == 2'd0) ? (transpose ? input2 : input1) :
36+
(b1_sel == 2'd1) ? input3 : 8'b0;
4237

4338
genvar i, j;
4439
generate

src/tpu.v

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ module tt_um_tpu (
1919
wire load_en = uio_in[0];
2020
wire transpose = uio_in[1];
2121
wire activation = uio_in[2];
22-
2322
wire [2:0] mem_addr; // 3-bit address for matrix and element selection
2423

2524
wire [7:0] weight0, weight1, weight2, weight3;
@@ -30,11 +29,8 @@ module tt_um_tpu (
3029

3130
// Control signals
3231
wire clear;
33-
wire data_valid;
3432
wire [1:0] a0_sel, a1_sel, b0_sel, b1_sel;
3533

36-
wire done;
37-
3834
// Module Instantiations
3935
memory mem (
4036
.clk(clk),
@@ -53,9 +49,7 @@ module tt_um_tpu (
5349
.c00(outputs[0]), .c01(outputs[1]), .c10(outputs[2]), .c11(outputs[3]),
5450
.mem_addr(mem_addr),
5551
.clear(clear),
56-
.data_valid(data_valid),
5752
.a0_sel(a0_sel), .a1_sel(a1_sel), .b0_sel(b0_sel), .b1_sel(b1_sel),
58-
.done(done),
5953
.data_out(out_data)
6054
);
6155

@@ -66,7 +60,6 @@ module tt_um_tpu (
6660
.activation(activation),
6761
.weight0(weight0), .weight1(weight1), .weight2(weight2), .weight3(weight3),
6862
.input0(input0), .input1(input1), .input2(input2), .input3(input3),
69-
.data_valid(data_valid),
7063
.a0_sel(a0_sel),
7164
.a1_sel(a1_sel),
7265
.b0_sel(b0_sel),
@@ -80,8 +73,8 @@ module tt_um_tpu (
8073

8174
assign uo_out = out_data;
8275

83-
assign uio_out = {done, 7'b0};
84-
assign uio_oe = 8'b10000000;
76+
assign uio_out = {8'b0};
77+
assign uio_oe = 8'b00000000;
8578

8679
wire _unused = &{ena, uio_in[7:3]};
8780

test/tpu/test.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,24 @@ async def test_project(dut):
134134

135135
print(results)
136136
print(expected)
137-
# for i in range(4):
138-
# assert results[i] == expected[i], f"C[{i//2}][{i%2}] = {results[i]} != expected {expected[i]}"
137+
for i in range(4):
138+
rel_err = abs(results[i] - expected[i]) / abs(expected[i])
139+
assert rel_err <= 0.12, (
140+
f"C[{i//2}][{i%2}] = {results[i]} "
141+
f"!= expected {expected[i]} (relative error {rel_err:.4f})"
142+
)
143+
dut._log.info("Test 1 passed")
139144

140145
expected = get_expected_matmul(A, B)
141146

142147
results = await parallel_load_read(dut, [], [])
143148

144149
print(results)
145150
print(expected)
146-
"""
147151
for i in range(4):
148-
assert results[i] == expected[i], f"C[{i//2}][{i%2}] = {results[i]} != expected {expected[i]}"
149-
"""
150-
dut._log.info("End of TEST")
152+
rel_err = abs(results[i] - expected[i]) / abs(expected[i])
153+
assert rel_err <= 0.12, (
154+
f"C[{i//2}][{i%2}] = {results[i]} "
155+
f"!= expected {expected[i]} (relative error {rel_err:.4f})"
156+
)
157+
dut._log.info("Test 2 passed")

0 commit comments

Comments
 (0)