22from cocotb .clock import Clock
33from cocotb .triggers import RisingEdge , ClockCycles
44import numpy as np
5+ from cocotb .utils import get_sim_time
56import math
67import struct
78import itertools
@@ -97,14 +98,20 @@ async def read_output(dut, hadamard=0):
9798 results .append (float_val )
9899 return results
99100
100- async def parallel_load_read (dut , A , B , hadamard = 0 , transpose = 0 , relu = 0 ):
101+ async def parallel_load_read (dut , A , B , instr = ( 0 , 0 , 0 ), next_instr = ( 0 , 0 , 0 ) ):
101102 results = []
103+ hadamard , transpose , relu = instr
104+ next_hadamard , next_transpose , next_relu = next_instr
102105 dut .uio_in .value = (1 << 4 ) | (hadamard << 3 ) | (transpose << 1 ) | (relu << 2 ) | 1
103-
106+ cycle = 0
107+
104108 for inputs in [A , B ]:
105109 for i in range (2 ):
110+ cycle += 1
106111 idx0 = i * 2
107112 idx1 = i * 2 + 1
113+ if cycle == 3 :
114+ dut .uio_in .value = (1 << 4 ) | (next_hadamard << 3 ) | (next_transpose << 1 ) | (next_relu << 2 ) | 1
108115 # Feed either real data or dummy zeros
109116 dut .ui_in .value = fp8_e4m3_encode (inputs [idx0 ]) if inputs else 0
110117 await ClockCycles (dut .clk , 1 )
@@ -114,9 +121,6 @@ async def parallel_load_read(dut, A, B, hadamard=0, transpose=0, relu=0):
114121 await ClockCycles (dut .clk , 1 )
115122 low = dut .uo_out .value .integer
116123
117- misc = dut .uio_out .value .integer
118- dut ._log .info (f"Misc output: { misc } " )
119-
120124 combined = (high << 8 ) | low
121125 float_val = bf16_to_float (combined )
122126
@@ -183,8 +187,6 @@ async def test_gemm(dut):
183187 # Read test 1 matrices
184188 results = await parallel_load_read (dut , A , B )
185189
186- print (results )
187- print (expected )
188190 for i in range (4 ):
189191 rel_err = abs (results [i ] - expected [i ]) / abs (expected [i ])
190192 assert rel_err <= 0.12 , (
@@ -195,10 +197,11 @@ async def test_gemm(dut):
195197
196198 expected = get_expected_output (A , B )
197199
198- results = await parallel_load_read (dut , [], [])
200+ A = [5 , - 6 , 7 , 8 ] # row-major
201+ B = [8 , 9 , 6 , 8 ] # row-major: [B00, B01, B10, B11]
202+
203+ results = await parallel_load_read (dut , A , B , next_instr = (0 , 1 , 1 ))
199204
200- print (results )
201- print (expected )
202205 for i in range (4 ):
203206 rel_err = abs (results [i ] - expected [i ]) / abs (expected [i ])
204207 assert rel_err <= 0.12 , (
@@ -207,6 +210,14 @@ async def test_gemm(dut):
207210 )
208211 dut ._log .info ("Test 2 passed" )
209212
213+ expected = get_expected_output (A , B , transpose = True , relu = True )
214+ results = await parallel_load_read (dut , [], [], instr = (0 , 1 , 1 ))
215+
216+ for i in range (4 ):
217+ assert results [i ] == expected [i ], f"C[{ i // 2 } ][{ i % 2 } ] = { results [i ]} != expected { expected [i ]} "
218+
219+ dut ._log .info ("ReLU + Transpose test passed!" )
220+
210221async def load_stationary_weights (dut , weights ):
211222 """Load weights in stationary mode (stat_weights=1, load_weights=1)"""
212223 for i in range (4 ):
@@ -317,7 +328,7 @@ async def accumulate_matrix_output(dut, results_large, i, j, transpose=0, A_bloc
317328 else :
318329 input_stream = [0 ] * 8
319330
320- dut .uio_in .value = (transpose << 1 ) | 1 # load_en=1
331+ dut .uio_in .value = (1 << 4 ) | ( transpose << 1 ) | 1 # load_en=1
321332
322333 partial_outputs = []
323334
@@ -351,7 +362,6 @@ async def matmul(dut, A, B, transpose=False, relu=False):
351362 Accumulates partial results across k dimension for each (i,j) tile.
352363 Loads A and B in parallel with reading previous output.
353364 """
354- print ("REACHED CHIP KERNEL!!!!" )
355365 m , n = A .shape
356366 n_b , p = B .shape
357367 if (transpose ):
@@ -416,6 +426,163 @@ async def matmul(dut, A, B, transpose=False, relu=False):
416426
417427 # Apply ReLU if enabled
418428 if relu :
419- results_large = torch .maximum (results_large , 0 )
429+ results_large = torch .maximum (results_large , torch .tensor (0.0 ))
430+
431+ return results_large [:m , :n_b ] if transpose else results_large [:m , :p ]
432+
433+ async def matmul_faster (dut , A , B , transpose = False , relu = False ):
434+ """
435+ Accelerated matmul using A-stationary mode when legal (n <= 2),
436+ with pipelined B streaming + output readback.
437+ """
438+ import torch
439+
440+ m , n = A .shape
441+ n_b , p = B .shape
442+
443+ if transpose :
444+ assert n == p
445+ else :
446+ assert n == n_b
447+
448+ if n > 2 :
449+ return await matmul (dut , A , B , transpose = transpose , relu = relu )
450+
451+ m_p = ((m + 1 ) // 2 ) * 2
452+ p_p = ((p + 1 ) // 2 ) * 2
453+
454+ A_p = torch .zeros ((m_p , 2 ), dtype = torch .float32 )
455+ B_p = torch .zeros ((2 , p_p ), dtype = torch .float32 )
456+
457+ A_p [:m , :n ] = A
458+ B_p [:n , :p ] = B
459+
460+ C = torch .zeros ((m_p , p_p ), dtype = torch .float32 )
461+
462+ for i in range (0 , m_p , 2 ):
463+ A_block = A_p [i :i + 2 , :2 ].flatten ().tolist ()
464+ await load_stationary_weights (dut , A_block )
465+
466+ # ---- Prime pipeline with first B tile ----
467+ j0 = 0
468+ B_block = B_p [:2 , j0 :j0 + 2 ].flatten ().tolist ()
469+ await load_inputs_stationary (dut , B_block )
470+
471+ # ---- Pipelined loop ----
472+ for j in range (2 , p_p , 2 ):
473+ # Load next B while reading previous output
474+ B_next = B_p [:2 , j :j + 2 ].flatten ().tolist ()
475+
476+ results = await parallel_rw_stationary (dut , B_next )
477+
478+ C [i , j - 2 ] += results [0 ]
479+ C [i , j - 1 ] += results [1 ]
480+ C [i + 1 , j - 2 ] += results [2 ]
481+ C [i + 1 , j - 1 ] += results [3 ]
482+
483+ # ---- Drain final output ----
484+ results = await parallel_rw_stationary (dut , [0 , 0 , 0 , 0 ])
485+
486+ C [i , p_p - 2 ] += results [0 ]
487+ C [i , p_p - 1 ] += results [1 ]
488+ C [i + 1 , p_p - 2 ] += results [2 ]
489+ C [i + 1 , p_p - 1 ] += results [3 ]
490+
491+ if relu :
492+ C = torch .maximum (C , torch .tensor (0.0 ))
493+
494+ return C [:m , :p ]
495+
496+ @cocotb .test ()
497+ async def test_matmul_faster_matches_reference (dut ):
498+ import torch
499+ dut ._log .info ("Testing matmul_faster vs reference matmul" )
500+
501+ # Clock
502+ clock = Clock (dut .clk , 20 , units = "ns" )
503+ cocotb .start_soon (clock .start ())
504+
505+ # Reset
506+ await reset_dut (dut )
507+
508+ # Test configurations
509+ test_cases = [
510+ # (m, n, p, transpose, relu)
511+ (4 , 4 , 4 , False , False ),
512+ (5 , 3 , 6 , False , False ),
513+ (6 , 5 , 3 , False , True ),
514+ (4 , 6 , 5 , True , False ),
515+ (7 , 7 , 7 , True , True ),
516+ ]
517+
518+ torch .manual_seed (0 )
519+
520+ for idx , (m , n , p , transpose , relu ) in enumerate (test_cases ):
521+ dut ._log .info (
522+ f"Case { idx + 1 } : A={ m } x{ n } , B={ n } x{ p } , "
523+ f"transpose={ transpose } , relu={ relu } "
524+ )
525+
526+ # Generate random inputs (FP8-safe range)
527+ A = (torch .rand (m , n ) * 6.0 - 3.0 ).float ()
528+ if transpose :
529+ B = (torch .rand (p , n ) * 6.0 - 3.0 ).float ()
530+ else :
531+ B = (torch .rand (n , p ) * 6.0 - 3.0 ).float ()
532+
533+ t0 = get_sim_time (units = "ns" )
534+ ref_out = await matmul (
535+ dut ,
536+ A ,
537+ B ,
538+ transpose = transpose ,
539+ relu = relu ,
540+ )
541+ t1 = get_sim_time (units = "ns" )
542+ ref_time = t1 - t0
543+
544+ # Reset DUT to avoid state contamination
545+ await reset_dut (dut )
546+
547+ t2 = get_sim_time (units = "ns" )
548+ fast_out = await matmul_faster (
549+ dut ,
550+ A ,
551+ B ,
552+ transpose = transpose ,
553+ relu = relu ,
554+ )
555+ t3 = get_sim_time (units = "ns" )
556+ fast_time = t3 - t2
557+
558+ ref_out = ref_out .cpu ()
559+ fast_out = fast_out .cpu ()
560+
561+ assert ref_out .shape == fast_out .shape
562+
563+ for i in range (ref_out .shape [0 ]):
564+ for j in range (ref_out .shape [1 ]):
565+ ref_val = ref_out [i , j ].item ()
566+ fast_val = fast_out [i , j ].item ()
567+ assert ref_val == fast_val , (
568+ f"Mismatch at ({ i } ,{ j } ): ref={ ref_val } , fast={ fast_val } "
569+ )
570+
571+ speedup = ref_time / fast_time if fast_time > 0 else float ("inf" )
572+
573+ dut ._log .info (
574+ f"Case { idx + 1 } timing: "
575+ f"ref={ ref_time :.0f} ns, "
576+ f"fast={ fast_time :.0f} ns, "
577+ f"speedup={ speedup :.2f} ×"
578+ )
579+
580+ # Optional sanity check: fast path should not be slower
581+ if n <= 2 :
582+ assert speedup >= 1.0 , (
583+ f"Expected speedup for n<=2, got { speedup :.2f} ×"
584+ )
585+
586+ dut ._log .info (f"Case { idx + 1 } passed ✔" )
420587
421- return results_large [: m , : n_b ] if transpose else results_large [: m , : p ]
588+ dut . _log . info ( "All matmul_faster correctness + benchmark tests passed 🎉" )
0 commit comments