@@ -83,20 +83,25 @@ async def load_matrix(dut, matrix, hadamard=0, transpose=0, relu=0):
8383 dut .uio_in .value = (hadamard << 3 ) | (transpose << 1 ) | (relu << 2 ) | 1
8484 await RisingEdge (dut .clk )
8585
86- async def parallel_load_read (dut , A , B , transpose = 0 , relu = 0 ):
86+ async def parallel_load_read (dut , A , B , hadamard = 0 , transpose = 0 , relu = 0 ):
8787 results = []
88- dut . uio_in . value = ( transpose << 1 ) | ( relu << 2 ) | 1
89-
88+ count = 0
89+
9090 for inputs in [A , B ]:
9191 for i in range (2 ):
92+ count += 1
93+
94+ if count == 4 :
95+ dut .uio_in .value = (hadamard << 3 ) | (transpose << 1 ) | (relu << 2 ) | 1
96+
9297 idx0 = i * 2
9398 idx1 = i * 2 + 1
9499 # Feed either real data or dummy zeros
95100 dut .ui_in .value = fp8_e4m3_encode (inputs [idx0 ]) if inputs else 0
96101 await ClockCycles (dut .clk , 1 )
97102 high = dut .uo_out .value .integer
98103 dut ._log .info (f"Read high value = { high } " )
99-
104+
100105 dut .ui_in .value = fp8_e4m3_encode (inputs [idx1 ]) if inputs else 0
101106 await ClockCycles (dut .clk , 1 )
102107 low = dut .uo_out .value .integer
@@ -129,7 +134,7 @@ async def test_hadamard(dut):
129134 results = []
130135
131136 # Read test 1 matrices
132- results = await parallel_load_read (dut , [], [] )
137+ results = await parallel_load_read (dut , A , B )
133138
134139 print (results )
135140 print (expected )
0 commit comments