@@ -201,4 +201,132 @@ async def test_gemm(dut):
201201 f"C[{ i // 2 } ][{ i % 2 } ] = { results [i ]} "
202202 f"!= expected { expected [i ]} (relative error { rel_err :.4f} )"
203203 )
204- dut ._log .info ("Test 2 passed" )
204+ dut ._log .info ("Test 2 passed" )
205+
206+ def get_expected_large_matmul (A , B , transpose = 0 , relu = 0 ):
207+ if transpose :
208+ B = B .T
209+
210+ result = A @ B
211+
212+ if relu :
213+ result = np .maximum (result , 0 )
214+
215+ return result
216+
217+ def check_expected (A , B , result , transpose = 0 , relu = 0 ):
218+ """
219+ Check DUT results against expected matrix multiplication, for big matrices
220+ """
221+ expected = get_expected_large_matmul (A , B , transpose , relu )
222+ np .testing .assert_array_equal (result , expected , err_msg = "Matrix multiplication result does not match expected" )
223+
224+ async def accumulate_matrix_output (dut , results_large , i , j , transpose = 0 , A_block = None , B_block = None ):
225+ """
226+ Serially loads A_block and B_block (1 value per cycle),
227+ and reads interleaved output (1 byte per cycle: high, low, high, low, ...).
228+ Accumulates output into results_large[i:i+2, j:j+2].
229+ """
230+ # Full interleaved stream of 8 input values: A0-A3, then B0-B3
231+ input_stream = (A_block + B_block ) if (A_block and B_block ) else [0 ]* 8
232+
233+ dut .uio_in .value = (transpose << 1 ) | 1 # load_en=1
234+
235+ partial_outputs = []
236+
237+ for idx in range (8 ):
238+ dut .ui_in .value = input_stream [idx ]
239+ await ClockCycles (dut .clk , 1 )
240+ val = dut .uo_out .value .integer
241+ partial_outputs .append (val )
242+
243+ # Now decode high/low bytes
244+ combined_outputs = []
245+ for ii in range (0 , 8 , 2 ):
246+ high = partial_outputs [ii ]
247+ low = partial_outputs [ii + 1 ]
248+ val = (high << 8 ) | low
249+ if val >= 0x8000 :
250+ val -= 0x10000
251+ combined_outputs .append (val )
252+
253+ results_large [i , j ] += combined_outputs [0 ] # C00
254+ results_large [i , j + 1 ] += combined_outputs [1 ] # C01
255+ results_large [i + 1 , j ] += combined_outputs [2 ] # C10
256+ results_large [i + 1 , j + 1 ] += combined_outputs [3 ] # C11
257+
258+ return combined_outputs
259+
260+ async def matmul (dut , A , B , transpose = False , relu = False ):
261+ """
262+ Fully pipelined systolic matrix multiplication using 2x2 blocks.
263+ Accumulates partial results across k dimension for each (i,j) tile.
264+ Loads A and B in parallel with reading previous output.
265+ """
266+ m , n = A .shape
267+ n_b , p = B .shape
268+ if (transpose ):
269+ assert n == p , "Reminder: you are computing A*B^T"
270+ else :
271+ assert n == n_b , "Matrix dimension mismatch"
272+
273+ # Pad dimensions to multiples of 2
274+ m_p = ((m + 1 ) // 2 ) * 2
275+ n_p = ((n + 1 ) // 2 ) * 2
276+ n_bp = ((n_b + 1 ) // 2 ) * 2
277+ p_p = ((p + 1 ) // 2 ) * 2
278+
279+ A_padded = np .zeros ((m_p , n_p ), dtype = int )
280+ B_padded = np .zeros ((n_bp , p_p ), dtype = int )
281+
282+ A_padded [:m , :n ] = A
283+ B_padded [:n_b , :p ] = B
284+ results_large = np .zeros ((m_p , n_bp ), dtype = int ) if transpose else np .zeros ((m_p , p_p ), dtype = int )
285+
286+ # Generate tile coordinates (i, j, k)
287+ if transpose :
288+ # Order: j, i, k for transpose case
289+ tile_coords = [
290+ (i , j , k )
291+ for i in range (0 , m_p , 2 )
292+ for j in range (0 , n_bp , 2 )
293+ for k in range (0 , p_p , 2 )
294+ ]
295+ else :
296+ # Original order: i, j, k
297+ tile_coords = [
298+ (i , j , k )
299+ for i in range (0 , m_p , 2 )
300+ for j in range (0 , p_p , 2 )
301+ for k in range (0 , n_p , 2 )
302+ ]
303+
304+ # Step 1: Load first tile only (no output yet)
305+ i0 , j0 , k0 = tile_coords [0 ]
306+ A_block = A_padded [i0 :i0 + 2 , k0 :k0 + 2 ].flatten ().tolist ()
307+ B_block = B_padded [k0 :k0 + 2 , j0 :j0 + 2 ].flatten ().tolist ()
308+
309+ await load_matrix (dut , A_block , transpose = 0 , relu = relu )
310+ await load_matrix (dut , B_block , transpose = transpose , relu = relu )
311+
312+ # Step 2: Pipelined main loop
313+ for coord in tile_coords [1 :]:
314+ i1 , j1 , k1 = coord
315+ A_next = A_padded [i1 :i1 + 2 , k1 :k1 + 2 ].flatten ().tolist ()
316+ B_next = B_padded [j1 :j1 + 2 , k1 :k1 + 2 ].flatten ().tolist () if transpose else B_padded [k1 :k1 + 2 , j1 :j1 + 2 ].flatten ().tolist ()
317+ # Read output from previous tile while loading next
318+ await accumulate_matrix_output (dut , results_large , i0 , j0 , transpose , A_next , B_next )
319+
320+ # Slide to next
321+ i0 , j0 , k0 = i1 , j1 , k1
322+ A_block = A_next
323+ B_block = B_next
324+
325+ # Final tile read (no further input)
326+ await accumulate_matrix_output (dut , results_large , i0 , j0 , transpose )
327+
328+ # Apply ReLU if enabled
329+ if relu :
330+ results_large = np .maximum (results_large , 0 )
331+
332+ return results_large [:m , :n_b ] if transpose else results_large [:m , :p ]
0 commit comments