Skip to content

Commit 486103f

Browse files
cleanup the rest
1 parent a90e9c0 commit 486103f

File tree

2 files changed

+12
-201
lines changed

2 files changed

+12
-201
lines changed

src/tpu.v

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ module tt_um_tpu (
8181
reg [15:0] hold2;
8282

8383
always @(posedge clk) begin
84-
if (mem_addr == 3'b101 || (load_weights && mem_addr == 3'b001)) begin
84+
if (mem_addr == 3'b101) begin
8585
hold2 <= outputs[2];
8686
tail_hold <= {8'b0, outputs[3][7:0]};
87-
end else if (mem_addr == 3'b110 || (load_weights && mem_addr == 3'b010)) begin
87+
end else if (mem_addr == 3'b110) begin
8888
if (stat_weights) begin
8989
tail_hold <= outputs[3];
9090
end else begin
@@ -114,19 +114,19 @@ module tt_um_tpu (
114114
end else begin
115115
uio_oe = 8'b11111111;
116116
case (mem_addr)
117-
3'b000, 3'b100: begin
117+
3'b100: begin
118118
uo_out = outputs[0][15:8];
119119
uio_out = outputs[0][7:0];
120120
end
121-
3'b001, 3'b101: begin
121+
3'b101: begin
122122
uo_out = outputs[1][15:8];
123123
uio_out = outputs[1][7:0];
124124
end
125-
3'b010, 3'b110: begin
125+
3'b110: begin
126126
uo_out = hold2[15:8];
127127
uio_out = hold2[7:0];
128128
end
129-
3'b011, 3'b111: begin
129+
3'b111: begin
130130
uo_out = tail_hold[15:8];
131131
uio_out = tail_hold[7:0];
132132
end

test/tpu/test_tpu.py

Lines changed: 6 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,6 @@ async def test_stationary_weights(dut):
275275
[0.5, 1.5, 2.5, 3.5],
276276
[-1.0, 2.0, -3.0, 4.0],
277277
]
278-
279-
weights2 = [5.0, -1.0, 0.0, 2.0]
280278

281279
for idx, inputs in enumerate(test_inputs):
282280
dut._log.info(f"Testing with input matrix {idx + 1}")
@@ -309,8 +307,9 @@ async def test_stationary_weights(dut):
309307

310308
dut._log.info(f"Input matrix {idx + 1} passed")
311309

312-
results = await parallel_rw_stationary(dut, weights2, load_weights=1)
310+
results = await parallel_rw_stationary(dut, [0, 0, 0, 0])
313311
expected = get_expected_output(weights, test_inputs[-1])
312+
314313
for i in range(4):
315314
if expected[i] == 0:
316315
# For zero expected values, check absolute error
@@ -447,32 +446,13 @@ async def matmul(dut, A, B, transpose=False, relu=False):
447446

448447
return results_large[:m, :n_b] if transpose else results_large[:m, :p]
449448

450-
async def parallel_rw_stationary_optimized(dut, inputs, load_weights=0, get_output=True):
451-
"""Optimized version that can skip output reading when not needed"""
452-
# stat_weights=1 (bit 5), enable=1 (bit 4)
453-
dut.uio_in.value = (load_weights << 6) | (1 << 5) | (1 << 4) | 1
454-
results = []
455-
456-
for i in range(4):
457-
dut.ui_in.value = fp8_e4m3_encode(inputs[i])
458-
await RisingEdge(dut.clk)
459-
if get_output:
460-
high = dut.uo_out.value.integer
461-
low = dut.uio_out.value.integer
462-
combined = (high << 8) | low
463-
float_val = bf16_to_float(combined)
464-
results.append(float_val)
465-
466-
return results if get_output else []
467-
468449
async def matmul_faster(dut, A, B, transpose=False, relu=False):
469450
"""
470-
True dyadic A-stationary matmul with debugging.
451+
True dyadic A-stationary matmul.
471452
Reuses stationary A tiles across all j tiles.
472453
Legal and faster when n <= 2.
473454
"""
474455
import torch
475-
from cocotb.utils import get_sim_time
476456

477457
m, n = A.shape
478458
n_b, p = B.shape
@@ -498,36 +478,17 @@ async def matmul_faster(dut, A, B, transpose=False, relu=False):
498478

499479
C = torch.zeros((m_p, p_p), dtype=torch.float32)
500480

501-
# Debug counters
502-
total_weight_loads = 0
503-
total_tile_reuses = 0
504-
setup_cycles = 0
505-
compute_cycles = 0
506-
teardown_cycles = 0
507-
508481
# ---- Dyadic schedule ----
509482
for i in range(0, m_p, 2):
510-
t_start = get_sim_time(units="ns")
511-
512483
# A(i,k) is stationary for entire j sweep
513484
A_block = A_p[i:i+2, :2].flatten().tolist()
514485
await load_stationary_weights(dut, A_block)
515-
total_weight_loads += 1
516-
517-
t_setup = get_sim_time(units="ns")
518-
setup_cycles += (t_setup - t_start)
519-
520-
# Count how many j tiles we'll process with this A tile
521-
j_tiles = p_p // 2
522-
total_tile_reuses += j_tiles
523486

524487
# ---- Prime pipeline with first B tile ----
525488
j0 = 0
526489
B_block = B_p[:2, j0:j0+2].flatten().tolist()
527490
await load_inputs_stationary(dut, B_block)
528491

529-
t_compute_start = get_sim_time(units="ns")
530-
531492
# ---- Stream remaining B tiles ----
532493
for j in range(2, p_p, 2):
533494
B_next = B_p[:2, j:j+2].flatten().tolist()
@@ -539,171 +500,18 @@ async def matmul_faster(dut, A, B, transpose=False, relu=False):
539500
C[i+1, j-1] += results[3]
540501

541502
# ---- Drain final output ----
542-
t_teardown_start = get_sim_time(units="ns")
543503
results = await parallel_rw_stationary(dut, [0, 0, 0, 0])
544504

545505
C[i, p_p-2] += results[0]
546506
C[i, p_p-1] += results[1]
547507
C[i+1, p_p-2] += results[2]
548508
C[i+1, p_p-1] += results[3]
549509

550-
t_end = get_sim_time(units="ns")
551-
compute_cycles += (t_teardown_start - t_compute_start)
552-
teardown_cycles += (t_end - t_teardown_start)
553-
554510
if relu:
555511
C = torch.maximum(C, torch.tensor(0.0))
556512

557-
# Debug output
558-
dut._log.info(f"matmul_faster debug:")
559-
dut._log.info(f" Weight loads: {total_weight_loads}")
560-
dut._log.info(f" Tile reuses: {total_tile_reuses}")
561-
dut._log.info(f" Setup cycles: {setup_cycles:.0f} ns")
562-
dut._log.info(f" Compute cycles: {compute_cycles:.0f} ns")
563-
dut._log.info(f" Teardown cycles: {teardown_cycles:.0f} ns")
564-
dut._log.info(f" Reuse ratio: {total_tile_reuses/total_weight_loads:.2f}")
565-
566513
return C[:m, :p]
567514

568-
async def matmul_faster_optimized(dut, A, B, transpose=False, relu=False):
569-
"""
570-
Optimized version that eliminates setup/teardown bottlenecks
571-
"""
572-
import torch
573-
574-
m, n = A.shape
575-
n_b, p = B.shape
576-
577-
if transpose:
578-
assert n == p
579-
else:
580-
assert n == n_b
581-
582-
# Fallback if we cannot exploit stationarity
583-
if n > 2:
584-
return await matmul(dut, A, B, transpose=transpose, relu=relu)
585-
586-
# ---- Padding ----
587-
m_p = ((m + 1) // 2) * 2
588-
p_p = ((p + 1) // 2) * 2
589-
590-
A_p = torch.zeros((m_p, 2), dtype=torch.float32)
591-
B_p = torch.zeros((2, p_p), dtype=torch.float32)
592-
593-
A_p[:m, :n] = A
594-
B_p[:n, :p] = B
595-
596-
C = torch.zeros((m_p, p_p), dtype=torch.float32)
597-
598-
# ---- Optimized dyadic schedule ----
599-
for i in range(0, m_p, 2):
600-
# A(i,k) is stationary for entire j sweep
601-
A_block = A_p[i:i+2, :2].flatten().tolist()
602-
await load_stationary_weights(dut, A_block)
603-
604-
# Process all B tiles in one go - no separate priming/draining
605-
for j in range(0, p_p, 2):
606-
B_block = B_p[:2, j:j+2].flatten().tolist()
607-
608-
if j == 0:
609-
# First tile: load inputs without getting output
610-
await load_inputs_stationary(dut, B_block)
611-
else:
612-
# Subsequent tiles: load next inputs while getting previous output
613-
results = await parallel_rw_stationary(dut, B_block)
614-
615-
# Store results from previous computation
616-
prev_j = j - 2
617-
C[i, prev_j] += results[0]
618-
C[i, prev_j+1] += results[1]
619-
C[i+1, prev_j] += results[2]
620-
C[i+1, prev_j+1] += results[3]
621-
622-
# Get final output
623-
results = await parallel_rw_stationary(dut, [0, 0, 0, 0])
624-
final_j = p_p - 2
625-
C[i, final_j] += results[0]
626-
C[i, final_j+1] += results[1]
627-
C[i+1, final_j] += results[2]
628-
C[i+1, final_j+1] += results[3]
629-
630-
if relu:
631-
C = torch.maximum(C, torch.tensor(0.0))
632-
633-
return C[:m, :p]
634-
635-
@cocotb.test()
636-
async def test_matmul_performance_analysis(dut):
637-
"""Detailed performance analysis of different matmul approaches"""
638-
import torch
639-
dut._log.info("=== MATMUL PERFORMANCE ANALYSIS ===")
640-
641-
# Clock
642-
clock = Clock(dut.clk, 20, units="ns")
643-
cocotb.start_soon(clock.start())
644-
645-
# Test small matrices where stationary should help
646-
test_cases = [
647-
(4, 2, 4, "Small n=2 case"),
648-
(6, 2, 6, "Medium n=2 case"),
649-
(8, 2, 8, "Large n=2 case"),
650-
(4, 4, 4, "n>2 fallback case")
651-
]
652-
653-
torch.manual_seed(42)
654-
655-
for m, n, p, desc in test_cases:
656-
dut._log.info(f"\n--- {desc}: A={m}x{n}, B={n}x{p} ---")
657-
658-
# Generate test matrices
659-
A = (torch.rand(m, n) * 4.0 - 2.0).float()
660-
B = (torch.rand(n, p) * 4.0 - 2.0).float()
661-
662-
# Test 1: Reference matmul
663-
await reset_dut(dut)
664-
t0 = get_sim_time(units="ns")
665-
ref_out = await matmul(dut, A, B)
666-
t1 = get_sim_time(units="ns")
667-
ref_time = t1 - t0
668-
669-
# Test 2: Original matmul_faster (with debug)
670-
await reset_dut(dut)
671-
t2 = get_sim_time(units="ns")
672-
fast_out = await matmul_faster(dut, A, B)
673-
t3 = get_sim_time(units="ns")
674-
fast_time = t3 - t2
675-
676-
# Test 3: Optimized version
677-
await reset_dut(dut)
678-
t4 = get_sim_time(units="ns")
679-
opt_out = await matmul_faster_optimized(dut, A, B)
680-
t5 = get_sim_time(units="ns")
681-
opt_time = t5 - t4
682-
683-
# Verify correctness
684-
assert torch.allclose(ref_out, fast_out, atol=1e-6), "matmul_faster mismatch"
685-
assert torch.allclose(ref_out, opt_out, atol=1e-6), "matmul_faster_optimized mismatch"
686-
687-
# Performance analysis
688-
fast_speedup = ref_time / fast_time if fast_time > 0 else float("inf")
689-
opt_speedup = ref_time / opt_time if opt_time > 0 else float("inf")
690-
opt_vs_fast = fast_time / opt_time if opt_time > 0 else float("inf")
691-
692-
dut._log.info(f"Performance results:")
693-
dut._log.info(f" Reference: {ref_time:.0f} ns")
694-
dut._log.info(f" Fast: {fast_time:.0f} ns (speedup: {fast_speedup:.2f}x)")
695-
dut._log.info(f" Optimized: {opt_time:.0f} ns (speedup: {opt_speedup:.2f}x)")
696-
dut._log.info(f" Opt vs Fast: {opt_vs_fast:.2f}x improvement")
697-
698-
# Calculate theoretical speedup for n=2 cases
699-
if n == 2:
700-
total_tiles = (m * p) // 4 # Total 2x2 output tiles
701-
weight_tiles = m // 2 # Unique A tiles
702-
theoretical_reuse = total_tiles / weight_tiles
703-
dut._log.info(f" Theoretical tile reuse: {theoretical_reuse:.2f}x")
704-
705-
dut._log.info("\n=== ANALYSIS COMPLETE ===")
706-
707515
@cocotb.test()
708516
async def test_matmul_faster_matches_reference(dut):
709517
import torch
@@ -720,8 +528,11 @@ async def test_matmul_faster_matches_reference(dut):
720528
test_cases = [
721529
# (m, n, p, transpose, relu)
722530
(4, 4, 4, False, False),
531+
(5, 3, 6, False, False),
723532
(6, 5, 3, False, True),
724533
(4, 6, 5, True, False),
534+
(7, 7, 7, True, True),
535+
(10, 10, 10, False, False),
725536
(20, 20, 20, False, False)
726537
]
727538

0 commit comments

Comments
 (0)