Skip to content

Commit 0ee2ec2

Browse files
authored
[AMD][gfx1250] Improve gluon f16 gemm kernel pipeline (#10057)
Improve f16 gemm gfx1250-gluon performance. Improves gemm_tdm_pipelined_single_warp_per_simd_schedule_kernel by moving tdm.load earlier; from the top of the loop (which hides 3/4th of a loop-iteration's worth of cycles) to right after the wait (which hides a full loop-iteration's worth of cycles). This only fixes the mentioned kernel; other kernels need independent benchmarking and improving.
1 parent 2796cea commit 0ee2ec2

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,14 @@ def gemm_tdm_pipelined_single_warp_per_simd_schedule_kernel(a_ptr, b_ptr, c_ptr,
342342

343343
loop_ub = ttgl.cdiv(K, BLOCK_K)
344344
epilogue_lb = loop_ub - (NUM_BUFFERS - 1)
345+
346+
pred = 0 - epilogue_lb
347+
pred = (pred >> 31) & 1
348+
producer = issue_loads(producer, a_desc, b_desc, 0, 0, a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS, TRANSPOSE_B,
349+
pred=pred)
350+
345351
ttgl.assume(loop_ub > 0)
346352
for i in range(0, loop_ub):
347-
pred = i - epilogue_lb
348-
pred = (pred >> 31) & 1
349353
# SubIteration0
350354
# LDS load SubIteration1
351355
a1, b1 = lds_subtile_load(consumer, SUBTILE_LEN, a_buffer, OPERAND_LAYOUT_A, b_buffer, OPERAND_LAYOUT_B,
@@ -354,11 +358,6 @@ def gemm_tdm_pipelined_single_warp_per_simd_schedule_kernel(a_ptr, b_ptr, c_ptr,
354358
accumulator = ttgl.amd.gfx1250.wmma(a0, b0, accumulator)
355359

356360
# SubIteration1
357-
# TDM load for next tile
358-
# If we are in epilogue, we have already issued our tile loads
359-
producer = issue_loads(producer, a_desc, b_desc, 0, 0, a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS, TRANSPOSE_B,
360-
pred=pred)
361-
362361
# We prefetch distance - 1 iterations ahead because producer is already incremented by 1
363362
issue_l2_prefetches(L2_PREFETCH_DISTANCE - 1, producer, a_desc, b_desc, 0, 0, BLOCK_K, TRANSPOSE_B)
364363

@@ -378,6 +377,12 @@ def gemm_tdm_pipelined_single_warp_per_simd_schedule_kernel(a_ptr, b_ptr, c_ptr,
378377
# SubIteration3
379378
consumer += 1
380379
ttgl.amd.gfx1250.tdm.async_wait((NUM_BUFFERS - 2) * 2)
380+
# TDM load for next tile
381+
# If we are in epilogue, we have already issued our tile loads
382+
pred = (i + 1) - epilogue_lb
383+
pred = (pred >> 31) & 1
384+
producer = issue_loads(producer, a_desc, b_desc, 0, 0, a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS, TRANSPOSE_B,
385+
pred=pred)
381386
# LDS load SubIteration0 for next tile
382387
a0, b0 = lds_subtile_load(consumer, 0, a_buffer, OPERAND_LAYOUT_A, b_buffer, OPERAND_LAYOUT_B, NUM_BUFFERS,
383388
TRANSPOSE_B, SUBTILE_LEN)

0 commit comments

Comments
 (0)