Skip to content

Commit 8ab3778

Browse files
patch hadamard test
1 parent a7f092b commit 8ab3778

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

test/tpu/test.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,25 @@ async def load_matrix(dut, matrix, hadamard=0, transpose=0, relu=0):
8383
dut.uio_in.value = (hadamard << 3) | (transpose << 1) | (relu << 2) | 1
8484
await RisingEdge(dut.clk)
8585

86-
async def parallel_load_read(dut, A, B, transpose=0, relu=0):
86+
async def parallel_load_read(dut, A, B, hadamard=0, transpose=0, relu=0):
8787
results = []
88-
dut.uio_in.value = (transpose << 1) | (relu << 2) | 1
89-
88+
count = 0
89+
9090
for inputs in [A, B]:
9191
for i in range(2):
92+
count += 1
93+
94+
if count == 4:
95+
dut.uio_in.value = (hadamard << 3) | (transpose << 1) | (relu << 2) | 1
96+
9297
idx0 = i * 2
9398
idx1 = i * 2 + 1
9499
# Feed either real data or dummy zeros
95100
dut.ui_in.value = fp8_e4m3_encode(inputs[idx0]) if inputs else 0
96101
await ClockCycles(dut.clk, 1)
97102
high = dut.uo_out.value.integer
98103
dut._log.info(f"Read high value = {high}")
99-
104+
100105
dut.ui_in.value = fp8_e4m3_encode(inputs[idx1]) if inputs else 0
101106
await ClockCycles(dut.clk, 1)
102107
low = dut.uo_out.value.integer
@@ -129,7 +134,7 @@ async def test_hadamard(dut):
129134
results = []
130135

131136
# Read test 1 matrices
132-
results = await parallel_load_read(dut, [], [])
137+
results = await parallel_load_read(dut, A, B)
133138

134139
print(results)
135140
print(expected)

0 commit comments

Comments
 (0)