Skip to content

Commit 966c427

Browse files
accurate torch+basic stat weight matmul
1 parent 869613c commit 966c427

File tree

6 files changed

+218
-33
lines changed

6 files changed

+218
-33
lines changed

src/tpu.v

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ module tt_um_tpu (
128128
uo_out = tail_hold[15:8];
129129
uio_out = tail_hold[7:0];
130130
end
131+
default: begin
132+
uo_out = 8'b0;
133+
uo_out = 8'b0;
134+
end
131135
endcase
132136
end
133137
end

test/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ endif
4343
# Allow sharing configuration between design and testbench via `include`:
4444
COMPILE_ARGS += -I$(SRC_DIR)
4545

46-
.PHONY: all test-mac test-top
46+
.PHONY: all test-mac test-top test-nn
4747

4848
all:
49-
$(MAKE) test-mac test-top
49+
$(MAKE) test-mac test-top test-nn
5050

5151
test-mac:
5252
$(MAKE) clean

test/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
pytest==8.3.4
22
cocotb==1.9.2
33
numpy==2.1.3
4-
matplotlib
4+
matplotlib
5+
torch
6+
torchao

test/tpu/test_tpu.py

Lines changed: 181 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from cocotb.clock import Clock
33
from cocotb.triggers import RisingEdge, ClockCycles
44
import numpy as np
5+
from cocotb.utils import get_sim_time
56
import math
67
import struct
78
import itertools
@@ -97,14 +98,20 @@ async def read_output(dut, hadamard=0):
9798
results.append(float_val)
9899
return results
99100

100-
async def parallel_load_read(dut, A, B, hadamard=0, transpose=0, relu=0):
101+
async def parallel_load_read(dut, A, B, instr=(0, 0, 0), next_instr=(0, 0, 0)):
101102
results = []
103+
hadamard, transpose, relu = instr
104+
next_hadamard, next_transpose, next_relu = next_instr
102105
dut.uio_in.value = (1 << 4) | (hadamard << 3) | (transpose << 1) | (relu << 2) | 1
103-
106+
cycle = 0
107+
104108
for inputs in [A, B]:
105109
for i in range(2):
110+
cycle += 1
106111
idx0 = i * 2
107112
idx1 = i * 2 + 1
113+
if cycle == 3:
114+
dut.uio_in.value = (1 << 4) | (next_hadamard << 3) | (next_transpose << 1) | (next_relu << 2) | 1
108115
# Feed either real data or dummy zeros
109116
dut.ui_in.value = fp8_e4m3_encode(inputs[idx0]) if inputs else 0
110117
await ClockCycles(dut.clk, 1)
@@ -114,9 +121,6 @@ async def parallel_load_read(dut, A, B, hadamard=0, transpose=0, relu=0):
114121
await ClockCycles(dut.clk, 1)
115122
low = dut.uo_out.value.integer
116123

117-
misc = dut.uio_out.value.integer
118-
dut._log.info(f"Misc output: {misc}")
119-
120124
combined = (high << 8) | low
121125
float_val = bf16_to_float(combined)
122126

@@ -183,8 +187,6 @@ async def test_gemm(dut):
183187
# Read test 1 matrices
184188
results = await parallel_load_read(dut, A, B)
185189

186-
print(results)
187-
print(expected)
188190
for i in range(4):
189191
rel_err = abs(results[i] - expected[i]) / abs(expected[i])
190192
assert rel_err <= 0.12, (
@@ -195,10 +197,11 @@ async def test_gemm(dut):
195197

196198
expected = get_expected_output(A, B)
197199

198-
results = await parallel_load_read(dut, [], [])
200+
A = [5, -6, 7, 8] # row-major
201+
B = [8, 9, 6, 8] # row-major: [B00, B01, B10, B11]
202+
203+
results = await parallel_load_read(dut, A, B, next_instr=(0, 1, 1))
199204

200-
print(results)
201-
print(expected)
202205
for i in range(4):
203206
rel_err = abs(results[i] - expected[i]) / abs(expected[i])
204207
assert rel_err <= 0.12, (
@@ -207,6 +210,14 @@ async def test_gemm(dut):
207210
)
208211
dut._log.info("Test 2 passed")
209212

213+
expected = get_expected_output(A, B, transpose=True, relu=True)
214+
results = await parallel_load_read(dut, [], [], instr=(0, 1, 1))
215+
216+
for i in range(4):
217+
assert results[i] == expected[i], f"C[{i//2}][{i%2}] = {results[i]} != expected {expected[i]}"
218+
219+
dut._log.info("ReLU + Transpose test passed!")
220+
210221
async def load_stationary_weights(dut, weights):
211222
"""Load weights in stationary mode (stat_weights=1, load_weights=1)"""
212223
for i in range(4):
@@ -317,7 +328,7 @@ async def accumulate_matrix_output(dut, results_large, i, j, transpose=0, A_bloc
317328
else:
318329
input_stream = [0] * 8
319330

320-
dut.uio_in.value = (transpose << 1) | 1 # load_en=1
331+
dut.uio_in.value = (1 << 4) | (transpose << 1) | 1 # load_en=1
321332

322333
partial_outputs = []
323334

@@ -351,7 +362,6 @@ async def matmul(dut, A, B, transpose=False, relu=False):
351362
Accumulates partial results across k dimension for each (i,j) tile.
352363
Loads A and B in parallel with reading previous output.
353364
"""
354-
print("REACHED CHIP KERNEL!!!!")
355365
m, n = A.shape
356366
n_b, p = B.shape
357367
if (transpose):
@@ -416,6 +426,163 @@ async def matmul(dut, A, B, transpose=False, relu=False):
416426

417427
# Apply ReLU if enabled
418428
if relu:
419-
results_large = torch.maximum(results_large, 0)
429+
results_large = torch.maximum(results_large, torch.tensor(0.0))
430+
431+
return results_large[:m, :n_b] if transpose else results_large[:m, :p]
432+
433+
async def matmul_faster(dut, A, B, transpose=False, relu=False):
434+
"""
435+
Accelerated matmul using A-stationary mode when legal (n <= 2),
436+
with pipelined B streaming + output readback.
437+
"""
438+
import torch
439+
440+
m, n = A.shape
441+
n_b, p = B.shape
442+
443+
if transpose:
444+
assert n == p
445+
else:
446+
assert n == n_b
447+
448+
if n > 2:
449+
return await matmul(dut, A, B, transpose=transpose, relu=relu)
450+
451+
m_p = ((m + 1) // 2) * 2
452+
p_p = ((p + 1) // 2) * 2
453+
454+
A_p = torch.zeros((m_p, 2), dtype=torch.float32)
455+
B_p = torch.zeros((2, p_p), dtype=torch.float32)
456+
457+
A_p[:m, :n] = A
458+
B_p[:n, :p] = B
459+
460+
C = torch.zeros((m_p, p_p), dtype=torch.float32)
461+
462+
for i in range(0, m_p, 2):
463+
A_block = A_p[i:i+2, :2].flatten().tolist()
464+
await load_stationary_weights(dut, A_block)
465+
466+
# ---- Prime pipeline with first B tile ----
467+
j0 = 0
468+
B_block = B_p[:2, j0:j0+2].flatten().tolist()
469+
await load_inputs_stationary(dut, B_block)
470+
471+
# ---- Pipelined loop ----
472+
for j in range(2, p_p, 2):
473+
# Load next B while reading previous output
474+
B_next = B_p[:2, j:j+2].flatten().tolist()
475+
476+
results = await parallel_rw_stationary(dut, B_next)
477+
478+
C[i, j-2] += results[0]
479+
C[i, j-1] += results[1]
480+
C[i+1, j-2] += results[2]
481+
C[i+1, j-1] += results[3]
482+
483+
# ---- Drain final output ----
484+
results = await parallel_rw_stationary(dut, [0, 0, 0, 0])
485+
486+
C[i, p_p-2] += results[0]
487+
C[i, p_p-1] += results[1]
488+
C[i+1, p_p-2] += results[2]
489+
C[i+1, p_p-1] += results[3]
490+
491+
if relu:
492+
C = torch.maximum(C, torch.tensor(0.0))
493+
494+
return C[:m, :p]
495+
496+
@cocotb.test()
497+
async def test_matmul_faster_matches_reference(dut):
498+
import torch
499+
dut._log.info("Testing matmul_faster vs reference matmul")
500+
501+
# Clock
502+
clock = Clock(dut.clk, 20, units="ns")
503+
cocotb.start_soon(clock.start())
504+
505+
# Reset
506+
await reset_dut(dut)
507+
508+
# Test configurations
509+
test_cases = [
510+
# (m, n, p, transpose, relu)
511+
(4, 4, 4, False, False),
512+
(5, 3, 6, False, False),
513+
(6, 5, 3, False, True),
514+
(4, 6, 5, True, False),
515+
(7, 7, 7, True, True),
516+
]
517+
518+
torch.manual_seed(0)
519+
520+
for idx, (m, n, p, transpose, relu) in enumerate(test_cases):
521+
dut._log.info(
522+
f"Case {idx+1}: A={m}x{n}, B={n}x{p}, "
523+
f"transpose={transpose}, relu={relu}"
524+
)
525+
526+
# Generate random inputs (FP8-safe range)
527+
A = (torch.rand(m, n) * 6.0 - 3.0).float()
528+
if transpose:
529+
B = (torch.rand(p, n) * 6.0 - 3.0).float()
530+
else:
531+
B = (torch.rand(n, p) * 6.0 - 3.0).float()
532+
533+
t0 = get_sim_time(units="ns")
534+
ref_out = await matmul(
535+
dut,
536+
A,
537+
B,
538+
transpose=transpose,
539+
relu=relu,
540+
)
541+
t1 = get_sim_time(units="ns")
542+
ref_time = t1 - t0
543+
544+
# Reset DUT to avoid state contamination
545+
await reset_dut(dut)
546+
547+
t2 = get_sim_time(units="ns")
548+
fast_out = await matmul_faster(
549+
dut,
550+
A,
551+
B,
552+
transpose=transpose,
553+
relu=relu,
554+
)
555+
t3 = get_sim_time(units="ns")
556+
fast_time = t3 - t2
557+
558+
ref_out = ref_out.cpu()
559+
fast_out = fast_out.cpu()
560+
561+
assert ref_out.shape == fast_out.shape
562+
563+
for i in range(ref_out.shape[0]):
564+
for j in range(ref_out.shape[1]):
565+
ref_val = ref_out[i, j].item()
566+
fast_val = fast_out[i, j].item()
567+
assert ref_val == fast_val, (
568+
f"Mismatch at ({i},{j}): ref={ref_val}, fast={fast_val}"
569+
)
570+
571+
speedup = ref_time / fast_time if fast_time > 0 else float("inf")
572+
573+
dut._log.info(
574+
f"Case {idx+1} timing: "
575+
f"ref={ref_time:.0f} ns, "
576+
f"fast={fast_time:.0f} ns, "
577+
f"speedup={speedup:.2f}×"
578+
)
579+
580+
# Optional sanity check: fast path should not be slower
581+
if n <= 2:
582+
assert speedup >= 1.0, (
583+
f"Expected speedup for n<=2, got {speedup:.2f}×"
584+
)
585+
586+
dut._log.info(f"Case {idx+1} passed ✔")
420587

421-
return results_large[:m, :n_b] if transpose else results_large[:m, :p]
588+
dut._log.info("All matmul_faster correctness + benchmark tests passed 🎉")

0 commit comments

Comments
 (0)