Skip to content

Commit af2cb4b

Browse files
antiagainstAlexAUT
andauthored
[AMD][gfx1250] Adjust (BLOCK) M/N in f16 gemm examples (#9421)
Co-authored-by: Alexander Weinrauch <Alexander.Weinrauch@amd.com>
1 parent 21a5f44 commit af2cb4b

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ def persistent_gemm_tdm_pipelined_lds_prefetch_kernel(a_ptr, b_ptr, c_ptr, #
193193

194194

195195
def _build_gemm_layouts(BLOCK_M, BLOCK_N, BLOCK_K, cga_layout_a, cga_layout_b, cga_layout_c, WARP_BASES, TRANSPOSE_B):
196+
"""
197+
Build all layouts for the GEMM kernel.
198+
"""
196199
# If TRANSPOSE_B we need to transpose each basis vector of the CGALayout for the
197200
# shared allocation because the permute will transpose the basis vectors before we
198201
# load them for wmmas.
@@ -388,6 +391,12 @@ def _run_runtime_gemm_tdm_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, TRAN
388391
if num_ctas > 1 and PERSISTENT:
389392
pytest.skip("Skip tests with multiple CTAs and persistent or prefetch")
390393

394+
# We scale the problem size and block dims by ctas_per_cga so each CTA works on BLOCK_M/BLOCK_N sized tile
395+
M *= ctas_per_cga[0]
396+
N *= ctas_per_cga[1]
397+
BLOCK_M *= ctas_per_cga[0]
398+
BLOCK_N *= ctas_per_cga[1]
399+
391400
torch.manual_seed(42)
392401

393402
a = torch.randn((M, K), dtype=torch.float16)

0 commit comments

Comments
 (0)