@@ -275,8 +275,6 @@ async def test_stationary_weights(dut):
275275 [0.5 , 1.5 , 2.5 , 3.5 ],
276276 [- 1.0 , 2.0 , - 3.0 , 4.0 ],
277277 ]
278-
279- weights2 = [5.0 , - 1.0 , 0.0 , 2.0 ]
280278
281279 for idx , inputs in enumerate (test_inputs ):
282280 dut ._log .info (f"Testing with input matrix { idx + 1 } " )
@@ -309,8 +307,9 @@ async def test_stationary_weights(dut):
309307
310308 dut ._log .info (f"Input matrix { idx + 1 } passed" )
311309
312- results = await parallel_rw_stationary (dut , weights2 , load_weights = 1 )
310+ results = await parallel_rw_stationary (dut , [ 0 , 0 , 0 , 0 ] )
313311 expected = get_expected_output (weights , test_inputs [- 1 ])
312+
314313 for i in range (4 ):
315314 if expected [i ] == 0 :
316315 # For zero expected values, check absolute error
@@ -447,32 +446,13 @@ async def matmul(dut, A, B, transpose=False, relu=False):
447446
448447 return results_large [:m , :n_b ] if transpose else results_large [:m , :p ]
449448
450- async def parallel_rw_stationary_optimized (dut , inputs , load_weights = 0 , get_output = True ):
451- """Optimized version that can skip output reading when not needed"""
452- # stat_weights=1 (bit 5), enable=1 (bit 4)
453- dut .uio_in .value = (load_weights << 6 ) | (1 << 5 ) | (1 << 4 ) | 1
454- results = []
455-
456- for i in range (4 ):
457- dut .ui_in .value = fp8_e4m3_encode (inputs [i ])
458- await RisingEdge (dut .clk )
459- if get_output :
460- high = dut .uo_out .value .integer
461- low = dut .uio_out .value .integer
462- combined = (high << 8 ) | low
463- float_val = bf16_to_float (combined )
464- results .append (float_val )
465-
466- return results if get_output else []
467-
468449async def matmul_faster (dut , A , B , transpose = False , relu = False ):
469450 """
470- True dyadic A-stationary matmul with debugging .
451+ True dyadic A-stationary matmul.
471452 Reuses stationary A tiles across all j tiles.
472453 Legal and faster when n <= 2.
473454 """
474455 import torch
475- from cocotb .utils import get_sim_time
476456
477457 m , n = A .shape
478458 n_b , p = B .shape
@@ -498,36 +478,17 @@ async def matmul_faster(dut, A, B, transpose=False, relu=False):
498478
499479 C = torch .zeros ((m_p , p_p ), dtype = torch .float32 )
500480
501- # Debug counters
502- total_weight_loads = 0
503- total_tile_reuses = 0
504- setup_cycles = 0
505- compute_cycles = 0
506- teardown_cycles = 0
507-
508481 # ---- Dyadic schedule ----
509482 for i in range (0 , m_p , 2 ):
510- t_start = get_sim_time (units = "ns" )
511-
512483 # A(i,k) is stationary for entire j sweep
513484 A_block = A_p [i :i + 2 , :2 ].flatten ().tolist ()
514485 await load_stationary_weights (dut , A_block )
515- total_weight_loads += 1
516-
517- t_setup = get_sim_time (units = "ns" )
518- setup_cycles += (t_setup - t_start )
519-
520- # Count how many j tiles we'll process with this A tile
521- j_tiles = p_p // 2
522- total_tile_reuses += j_tiles
523486
524487 # ---- Prime pipeline with first B tile ----
525488 j0 = 0
526489 B_block = B_p [:2 , j0 :j0 + 2 ].flatten ().tolist ()
527490 await load_inputs_stationary (dut , B_block )
528491
529- t_compute_start = get_sim_time (units = "ns" )
530-
531492 # ---- Stream remaining B tiles ----
532493 for j in range (2 , p_p , 2 ):
533494 B_next = B_p [:2 , j :j + 2 ].flatten ().tolist ()
@@ -539,171 +500,18 @@ async def matmul_faster(dut, A, B, transpose=False, relu=False):
539500 C [i + 1 , j - 1 ] += results [3 ]
540501
541502 # ---- Drain final output ----
542- t_teardown_start = get_sim_time (units = "ns" )
543503 results = await parallel_rw_stationary (dut , [0 , 0 , 0 , 0 ])
544504
545505 C [i , p_p - 2 ] += results [0 ]
546506 C [i , p_p - 1 ] += results [1 ]
547507 C [i + 1 , p_p - 2 ] += results [2 ]
548508 C [i + 1 , p_p - 1 ] += results [3 ]
549509
550- t_end = get_sim_time (units = "ns" )
551- compute_cycles += (t_teardown_start - t_compute_start )
552- teardown_cycles += (t_end - t_teardown_start )
553-
554510 if relu :
555511 C = torch .maximum (C , torch .tensor (0.0 ))
556512
557- # Debug output
558- dut ._log .info (f"matmul_faster debug:" )
559- dut ._log .info (f" Weight loads: { total_weight_loads } " )
560- dut ._log .info (f" Tile reuses: { total_tile_reuses } " )
561- dut ._log .info (f" Setup cycles: { setup_cycles :.0f} ns" )
562- dut ._log .info (f" Compute cycles: { compute_cycles :.0f} ns" )
563- dut ._log .info (f" Teardown cycles: { teardown_cycles :.0f} ns" )
564- dut ._log .info (f" Reuse ratio: { total_tile_reuses / total_weight_loads :.2f} " )
565-
566513 return C [:m , :p ]
567514
568- async def matmul_faster_optimized (dut , A , B , transpose = False , relu = False ):
569- """
570- Optimized version that eliminates setup/teardown bottlenecks
571- """
572- import torch
573-
574- m , n = A .shape
575- n_b , p = B .shape
576-
577- if transpose :
578- assert n == p
579- else :
580- assert n == n_b
581-
582- # Fallback if we cannot exploit stationarity
583- if n > 2 :
584- return await matmul (dut , A , B , transpose = transpose , relu = relu )
585-
586- # ---- Padding ----
587- m_p = ((m + 1 ) // 2 ) * 2
588- p_p = ((p + 1 ) // 2 ) * 2
589-
590- A_p = torch .zeros ((m_p , 2 ), dtype = torch .float32 )
591- B_p = torch .zeros ((2 , p_p ), dtype = torch .float32 )
592-
593- A_p [:m , :n ] = A
594- B_p [:n , :p ] = B
595-
596- C = torch .zeros ((m_p , p_p ), dtype = torch .float32 )
597-
598- # ---- Optimized dyadic schedule ----
599- for i in range (0 , m_p , 2 ):
600- # A(i,k) is stationary for entire j sweep
601- A_block = A_p [i :i + 2 , :2 ].flatten ().tolist ()
602- await load_stationary_weights (dut , A_block )
603-
604- # Process all B tiles in one go - no separate priming/draining
605- for j in range (0 , p_p , 2 ):
606- B_block = B_p [:2 , j :j + 2 ].flatten ().tolist ()
607-
608- if j == 0 :
609- # First tile: load inputs without getting output
610- await load_inputs_stationary (dut , B_block )
611- else :
612- # Subsequent tiles: load next inputs while getting previous output
613- results = await parallel_rw_stationary (dut , B_block )
614-
615- # Store results from previous computation
616- prev_j = j - 2
617- C [i , prev_j ] += results [0 ]
618- C [i , prev_j + 1 ] += results [1 ]
619- C [i + 1 , prev_j ] += results [2 ]
620- C [i + 1 , prev_j + 1 ] += results [3 ]
621-
622- # Get final output
623- results = await parallel_rw_stationary (dut , [0 , 0 , 0 , 0 ])
624- final_j = p_p - 2
625- C [i , final_j ] += results [0 ]
626- C [i , final_j + 1 ] += results [1 ]
627- C [i + 1 , final_j ] += results [2 ]
628- C [i + 1 , final_j + 1 ] += results [3 ]
629-
630- if relu :
631- C = torch .maximum (C , torch .tensor (0.0 ))
632-
633- return C [:m , :p ]
634-
635- @cocotb .test ()
636- async def test_matmul_performance_analysis (dut ):
637- """Detailed performance analysis of different matmul approaches"""
638- import torch
639- dut ._log .info ("=== MATMUL PERFORMANCE ANALYSIS ===" )
640-
641- # Clock
642- clock = Clock (dut .clk , 20 , units = "ns" )
643- cocotb .start_soon (clock .start ())
644-
645- # Test small matrices where stationary should help
646- test_cases = [
647- (4 , 2 , 4 , "Small n=2 case" ),
648- (6 , 2 , 6 , "Medium n=2 case" ),
649- (8 , 2 , 8 , "Large n=2 case" ),
650- (4 , 4 , 4 , "n>2 fallback case" )
651- ]
652-
653- torch .manual_seed (42 )
654-
655- for m , n , p , desc in test_cases :
656- dut ._log .info (f"\n --- { desc } : A={ m } x{ n } , B={ n } x{ p } ---" )
657-
658- # Generate test matrices
659- A = (torch .rand (m , n ) * 4.0 - 2.0 ).float ()
660- B = (torch .rand (n , p ) * 4.0 - 2.0 ).float ()
661-
662- # Test 1: Reference matmul
663- await reset_dut (dut )
664- t0 = get_sim_time (units = "ns" )
665- ref_out = await matmul (dut , A , B )
666- t1 = get_sim_time (units = "ns" )
667- ref_time = t1 - t0
668-
669- # Test 2: Original matmul_faster (with debug)
670- await reset_dut (dut )
671- t2 = get_sim_time (units = "ns" )
672- fast_out = await matmul_faster (dut , A , B )
673- t3 = get_sim_time (units = "ns" )
674- fast_time = t3 - t2
675-
676- # Test 3: Optimized version
677- await reset_dut (dut )
678- t4 = get_sim_time (units = "ns" )
679- opt_out = await matmul_faster_optimized (dut , A , B )
680- t5 = get_sim_time (units = "ns" )
681- opt_time = t5 - t4
682-
683- # Verify correctness
684- assert torch .allclose (ref_out , fast_out , atol = 1e-6 ), "matmul_faster mismatch"
685- assert torch .allclose (ref_out , opt_out , atol = 1e-6 ), "matmul_faster_optimized mismatch"
686-
687- # Performance analysis
688- fast_speedup = ref_time / fast_time if fast_time > 0 else float ("inf" )
689- opt_speedup = ref_time / opt_time if opt_time > 0 else float ("inf" )
690- opt_vs_fast = fast_time / opt_time if opt_time > 0 else float ("inf" )
691-
692- dut ._log .info (f"Performance results:" )
693- dut ._log .info (f" Reference: { ref_time :.0f} ns" )
694- dut ._log .info (f" Fast: { fast_time :.0f} ns (speedup: { fast_speedup :.2f} x)" )
695- dut ._log .info (f" Optimized: { opt_time :.0f} ns (speedup: { opt_speedup :.2f} x)" )
696- dut ._log .info (f" Opt vs Fast: { opt_vs_fast :.2f} x improvement" )
697-
698- # Calculate theoretical speedup for n=2 cases
699- if n == 2 :
700- total_tiles = (m * p ) // 4 # Total 2x2 output tiles
701- weight_tiles = m // 2 # Unique A tiles
702- theoretical_reuse = total_tiles / weight_tiles
703- dut ._log .info (f" Theoretical tile reuse: { theoretical_reuse :.2f} x" )
704-
705- dut ._log .info ("\n === ANALYSIS COMPLETE ===" )
706-
707515@cocotb .test ()
708516async def test_matmul_faster_matches_reference (dut ):
709517 import torch
@@ -720,8 +528,11 @@ async def test_matmul_faster_matches_reference(dut):
720528 test_cases = [
721529 # (m, n, p, transpose, relu)
722530 (4 , 4 , 4 , False , False ),
531+ (5 , 3 , 6 , False , False ),
723532 (6 , 5 , 3 , False , True ),
724533 (4 , 6 , 5 , True , False ),
534+ (7 , 7 , 7 , True , True ),
535+ (10 , 10 , 10 , False , False ),
725536 (20 , 20 , 20 , False , False )
726537 ]
727538
0 commit comments