Skip to content

Commit ba1c772

Browse files
add some data
1 parent 58d9651 commit ba1c772

File tree

4 files changed

+242
-1
lines changed

4 files changed

+242
-1
lines changed

docs/Block_Diagram.png

39.8 KB
Loading
79.1 KB
Loading

test/mac/test_mac.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from cocotb.clock import Clock
44
import random
55
import math
6+
import matplotlib.pyplot as plt
7+
import numpy as np
68
import struct
79
# ---------- helpers to emulate RTL bit-exact behavior ----------
810

@@ -110,3 +112,114 @@ async def test_pe_deviation(dut):
110112
f"[{i}] a={fa:.6f} b={fb:.6f} expected={expected:.6f} "
111113
f"bf16={got_float:.6f} abs_err={abs_err:.6e} rel_err={rel_err:.6e}"
112114
)
115+
116+
async def perform_multiplication(dut, fa: float, fb: float):
117+
expected = fa * fb
118+
119+
fp8_a = fp8_e4m3_encode(fa)
120+
fp8_b = fp8_e4m3_encode(fb)
121+
122+
await reset_accumulator(dut)
123+
124+
dut.a_in.value = fp8_a
125+
dut.b_in.value = fp8_b
126+
await RisingEdge(dut.clk) # one cycle if output is assigned combinationally
127+
await RisingEdge(dut.clk) # two cycles if output is registered inside PE
128+
129+
bf16_raw = int(dut.c_out.value) & 0xFFFF
130+
got_float = bf16_to_float(bf16_raw)
131+
132+
abs_err = abs(got_float - expected)
133+
# Using relative error or a variant of it (like ULP) is often better for floats,
134+
# but based on the previous test and request for 'error', we'll use Absolute Error for the heatmap.
135+
return abs_err
136+
137+
# Function to setup the clock and reset for the tests
138+
async def setup_dut(dut):
139+
cocotb.start_soon(Clock(dut.clk, 10, "ns").start())
140+
dut.rst.value = 1
141+
dut.clear.value = 1
142+
dut.a_in.value = 0
143+
dut.b_in.value = 0
144+
for _ in range(3):
145+
await RisingEdge(dut.clk)
146+
dut.rst.value = 0
147+
dut.clear.value = 0
148+
await RisingEdge(dut.clk)
149+
150+
@cocotb.test()
151+
async def test_pe_error_heatmap(dut):
152+
await setup_dut(dut)
153+
154+
# --- Plotting parameters ---
155+
MIN_VAL = -10.0
156+
MAX_VAL = 10.0
157+
NUM_STEPS = 100 # Increase for finer resolution, decrease for faster test
158+
159+
a_values = np.linspace(MIN_VAL, MAX_VAL, NUM_STEPS)
160+
b_values = np.linspace(MIN_VAL, MAX_VAL, NUM_STEPS)
161+
162+
epsilon_grid = 1e-6
163+
164+
# Initialize the error matrix
165+
error_matrix = np.zeros((NUM_STEPS, NUM_STEPS))
166+
167+
dut._log.info(f"Starting heatmap generation with {NUM_STEPS*NUM_STEPS} points...")
168+
169+
# --- Data Collection Loop ---
170+
for i in range(NUM_STEPS):
171+
fa = a_values[i]
172+
for j in range(NUM_STEPS):
173+
fb = b_values[j]
174+
175+
# The perform_multiplication function now handles the RTL interaction
176+
abs_err = await perform_multiplication(dut, fa, fb)
177+
178+
expected = fa * fb
179+
180+
if abs(expected) < epsilon_grid:
181+
# If the product is essentially zero, relative error is undefined/infinite.
182+
# We'll assign a max error value for plotting purposes.
183+
rel_err = 1.0 # Max expected relative error
184+
else:
185+
rel_err = abs_err / abs(expected)
186+
187+
error_matrix[i, j] = rel_err
188+
189+
dut._log.info("Data collection complete. Generating plot...")
190+
191+
# --- Matplotlib Plotting ---
192+
193+
# Create the X and Y meshgrid for the plot
194+
X, Y = np.meshgrid(a_values, b_values)
195+
196+
# Create the figure and axes
197+
plt.figure(figsize=(10, 8))
198+
199+
# Plot the heatmap (using imshow for 2D array)
200+
epsilon = 1e-15
201+
log_error_matrix = error_matrix + epsilon
202+
203+
# Plot the log10 of the absolute error
204+
plt.imshow(log_error_matrix, origin='lower', aspect='auto',
205+
extent=[MIN_VAL, MAX_VAL, MIN_VAL, MAX_VAL],
206+
cmap='viridis') # 'viridis' or 'inferno' are good choices
207+
208+
# Add a color bar to show the error scale
209+
cbar = plt.colorbar()
210+
cbar.set_label('Relative Error') #
211+
212+
# Add labels and title
213+
plt.xlabel('Multiplicand A (a_in)')
214+
plt.ylabel('Multiplicand B (b_in)')
215+
plt.title(f'FP8 E4M3 Multiplication Relative Error Heatmap ({NUM_STEPS}x{NUM_STEPS} Grid)')
216+
217+
# Save the plot
218+
plot_filename = "fp8_multiplication_error_heatmap.png"
219+
plt.savefig(plot_filename)
220+
dut._log.info(f"Error heatmap saved to: {plot_filename}")
221+
222+
# Optionally, you can assert that the max error is below a certain threshold
223+
max_err = np.max(error_matrix)
224+
dut._log.info(f"Maximum absolute error in grid: {max_err:.6e}")
225+
# assert max_err < 1e-2, f"Max absolute error {max_err} exceeds threshold"

test/tpu/test.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,4 +201,132 @@ async def test_gemm(dut):
201201
f"C[{i//2}][{i%2}] = {results[i]} "
202202
f"!= expected {expected[i]} (relative error {rel_err:.4f})"
203203
)
204-
dut._log.info("Test 2 passed")
204+
dut._log.info("Test 2 passed")
205+
206+
def get_expected_large_matmul(A, B, transpose=0, relu=0):
207+
if transpose:
208+
B = B.T
209+
210+
result = A @ B
211+
212+
if relu:
213+
result = np.maximum(result, 0)
214+
215+
return result
216+
217+
def check_expected(A, B, result, transpose=0, relu=0):
218+
"""
219+
Check DUT results against expected matrix multiplication, for big matrices
220+
"""
221+
expected = get_expected_large_matmul(A, B, transpose, relu)
222+
np.testing.assert_array_equal(result, expected, err_msg="Matrix multiplication result does not match expected")
223+
224+
async def accumulate_matrix_output(dut, results_large, i, j, transpose=0, A_block=None, B_block=None):
225+
"""
226+
Serially loads A_block and B_block (1 value per cycle),
227+
and reads interleaved output (1 byte per cycle: high, low, high, low, ...).
228+
Accumulates output into results_large[i:i+2, j:j+2].
229+
"""
230+
# Full interleaved stream of 8 input values: A0-A3, then B0-B3
231+
input_stream = (A_block + B_block) if (A_block and B_block) else [0]*8
232+
233+
dut.uio_in.value = (transpose << 1) | 1 # load_en=1
234+
235+
partial_outputs = []
236+
237+
for idx in range(8):
238+
dut.ui_in.value = input_stream[idx]
239+
await ClockCycles(dut.clk, 1)
240+
val = dut.uo_out.value.integer
241+
partial_outputs.append(val)
242+
243+
# Now decode high/low bytes
244+
combined_outputs = []
245+
for ii in range(0, 8, 2):
246+
high = partial_outputs[ii]
247+
low = partial_outputs[ii + 1]
248+
val = (high << 8) | low
249+
if val >= 0x8000:
250+
val -= 0x10000
251+
combined_outputs.append(val)
252+
253+
results_large[i, j ] += combined_outputs[0] # C00
254+
results_large[i, j+1] += combined_outputs[1] # C01
255+
results_large[i+1, j ] += combined_outputs[2] # C10
256+
results_large[i+1, j+1] += combined_outputs[3] # C11
257+
258+
return combined_outputs
259+
260+
async def matmul(dut, A, B, transpose=False, relu=False):
261+
"""
262+
Fully pipelined systolic matrix multiplication using 2x2 blocks.
263+
Accumulates partial results across k dimension for each (i,j) tile.
264+
Loads A and B in parallel with reading previous output.
265+
"""
266+
m, n = A.shape
267+
n_b, p = B.shape
268+
if (transpose):
269+
assert n == p, "Reminder: you are computing A*B^T"
270+
else:
271+
assert n == n_b, "Matrix dimension mismatch"
272+
273+
# Pad dimensions to multiples of 2
274+
m_p = ((m + 1) // 2) * 2
275+
n_p = ((n + 1) // 2) * 2
276+
n_bp = ((n_b + 1) // 2) * 2
277+
p_p = ((p + 1) // 2) * 2
278+
279+
A_padded = np.zeros((m_p, n_p), dtype=int)
280+
B_padded = np.zeros((n_bp, p_p), dtype=int)
281+
282+
A_padded[:m, :n] = A
283+
B_padded[:n_b, :p] = B
284+
results_large = np.zeros((m_p, n_bp), dtype=int) if transpose else np.zeros((m_p, p_p), dtype=int)
285+
286+
# Generate tile coordinates (i, j, k)
287+
if transpose:
288+
# Order: j, i, k for transpose case
289+
tile_coords = [
290+
(i, j, k)
291+
for i in range(0, m_p, 2)
292+
for j in range(0, n_bp, 2)
293+
for k in range(0, p_p, 2)
294+
]
295+
else:
296+
# Original order: i, j, k
297+
tile_coords = [
298+
(i, j, k)
299+
for i in range(0, m_p, 2)
300+
for j in range(0, p_p, 2)
301+
for k in range(0, n_p, 2)
302+
]
303+
304+
# Step 1: Load first tile only (no output yet)
305+
i0, j0, k0 = tile_coords[0]
306+
A_block = A_padded[i0:i0+2, k0:k0+2].flatten().tolist()
307+
B_block = B_padded[k0:k0+2, j0:j0+2].flatten().tolist()
308+
309+
await load_matrix(dut, A_block, transpose=0, relu=relu)
310+
await load_matrix(dut, B_block, transpose=transpose, relu=relu)
311+
312+
# Step 2: Pipelined main loop
313+
for coord in tile_coords[1:]:
314+
i1, j1, k1 = coord
315+
A_next = A_padded[i1:i1+2, k1:k1+2].flatten().tolist()
316+
B_next = B_padded[j1:j1+2, k1:k1+2].flatten().tolist() if transpose else B_padded[k1:k1+2, j1:j1+2].flatten().tolist()
317+
# Read output from previous tile while loading next
318+
await accumulate_matrix_output(dut, results_large, i0, j0, transpose, A_next, B_next)
319+
320+
# Slide to next
321+
i0, j0, k0 = i1, j1, k1
322+
A_block = A_next
323+
B_block = B_next
324+
325+
# Final tile read (no further input)
326+
await accumulate_matrix_output(dut, results_large, i0, j0, transpose)
327+
328+
# Apply ReLU if enabled
329+
if relu:
330+
results_large = np.maximum(results_large, 0)
331+
332+
return results_large[:m, :n_b] if transpose else results_large[:m, :p]

0 commit comments

Comments
 (0)