Skip to content

Commit 31e978b

Browse files
authored
Merge pull request #31 from amd/add-matmul-transform-generator
Add matmul transform generator and i8 matmul examples
2 parents c3c665a + 2753f54 commit 31e978b

10 files changed

Lines changed: 1338 additions & 299 deletions

File tree

examples/matmul/matmul.py renamed to examples/matmul_bf16_m64_n64_k64/matmul_bf16_m64_n64_k64.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def bench_matmul(M, N, K, provider):
8787

8888
if __name__ == "__main__":
8989
benchmark.select_npu_backend()
90-
for M in [2**i for i in range(8, 14, 2)]: # change to "in range(9, 14, 2)" if BLOCK_SIZE_M=512
90+
for M in [
91+
2**i for i in range(8, 14, 2)
92+
]: # change to "in range(9, 14, 2)" if BLOCK_SIZE_M=512
9193
for N in [2**i for i in range(8, 14, 2)]:
9294
for K in [2**i for i in range(8, 14, 2)]:
9395
bench_matmul(M, N, K, "test")

examples/matmul/transform_aie2.mlir renamed to examples/matmul_bf16_m64_n64_k64/transform_aie2.mlir

Lines changed: 66 additions & 135 deletions
Large diffs are not rendered by default.

examples/matmul/transform_aie2p.mlir renamed to examples/matmul_bf16_m64_n64_k64/transform_aie2p.mlir

Lines changed: 65 additions & 134 deletions
Large diffs are not rendered by default.

examples/padded_matmul/padded_matmul.py renamed to examples/matmul_f32_m64_n32_k16_padded_atransposed/matmul_f32_m64_n32_k16_padded_atransposed.py

File renamed without changes.

examples/padded_matmul/transform_aie2p.mlir renamed to examples/matmul_f32_m64_n32_k16_padded_atransposed/transform_aie2p.mlir

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,33 @@
1-
// Transform Script for F32 Matmul with BF16 Emulation
1+
// Auto-generated by matmul_transform.py — do not edit manually.
2+
// Parameters: l1_m=64, l1_n=32, l2_k=16, pack=[8,8,8], accum=f32, contract_in=bf16
23
//
3-
// Starting IR: Full-K matmul (no K-loop), all f32, generated from asm_src params.
4-
// - func @matmul_padding_kernel(memref<*xf32>*3, i32*6)
5-
// - linalg.matmul(64xK @ Kx32 → 64x32), f32 accumulation
6-
// - A in K×M layout (strides [1, M_alloc]), B in K×N (strides [N_alloc, 1])
7-
//
8-
// Follows test 53's transform pattern: tile copies, pack [8,8,8], tile K,
9-
// tile forall for multi-core, vectorize, hoist.
10-
//
11-
// Target: 4×8 AIE core array (Strix/NPU2), BF16 emulation
12-
// Tile sizes: M=64, N=32, K_L2=16, pack [8,8,8]
4+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
5+
// SPDX-License-Identifier: MIT
136

147
module attributes {transform.with_named_sequence} {
158
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
169

1710
//==========================================================================
18-
// PHASE 1: TILE L3→L2 MEMORY COPIES
11+
// PHASE 1: TILE L3->L2 MEMORY COPIES
12+
// Tile memref copies for streaming data from DDR (L3) to MemTile (L2).
1913
//==========================================================================
2014

2115
%func10 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
2216
%func10_updated = transform.air.convert_memref_copy_to_linalg_copy %func10 : (!transform.any_op) -> !transform.any_op
2317
%copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
2418
%copy1, %copy2 = transform.split_handle %copies : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
25-
// Tile A copy: 64×K → 64×16 tiles (K_L2_TILE=16)
2619
%tiled_copy1, %tile_copy_loop1 =
2720
transform.structured.tile_using_for %copy1 tile_sizes [0, 16]
2821
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
2922
transform.annotate %tile_copy_loop1 "copy_a_loop" : !transform.any_op
30-
// Tile B copy: K×32 → 16×32 tiles
3123
%tiled_copy2, %tile_copy_loop2 =
3224
transform.structured.tile_using_for %copy2 tile_sizes [16]
3325
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
3426
transform.annotate %tile_copy_loop2 "copy_b_loop" : !transform.any_op
3527

3628
//==========================================================================
3729
// PHASE 2: PROMOTE OUTPUT TO L2
38-
// No truncf fusion needed (output is f32).
30+
// Allocate output buffer (C) in L2 for accumulation.
3931
//==========================================================================
4032

4133
%result_l2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -44,43 +36,47 @@ module attributes {transform.with_named_sequence} {
4436

4537
//==========================================================================
4638
// PHASE 3: PACK MATMUL FOR VECTORIZED COMPUTATION
47-
// Pack sizes [8, 8, 8] for M, N, K dimensions.
39+
// Pack [8, 8, 8], transpose A/B/C, promote C pack to L1.
4840
//==========================================================================
4941

5042
%matmul_to_pack = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
5143
%packed = transform.structured.pack %matmul_to_pack packed_sizes = [8, 8, 8]
5244
: (!transform.any_op) -> (!transform.any_op)
5345

46+
// Transpose A: outer_perm [1,0]
5447
%pack_producer_a = transform.get_producer_of_operand %packed[0]
5548
: (!transform.any_op) -> (!transform.any_op)
5649
%packed_a, %pack_a, %empty_unpack_a =
5750
transform.structured.pack_transpose %pack_producer_a with_compute_op(%packed)
5851
outer_perm = [1, 0] : (!transform.any_op, !transform.any_op)
5952
-> (!transform.any_op, !transform.any_op, !transform.any_op)
6053

54+
// Transpose B: outer_perm [1,0] + inner_perm [1,0]
6155
%pack_producer_b = transform.get_producer_of_operand %packed_a[1]
6256
: (!transform.any_op) -> (!transform.any_op)
6357
%packed_b, %pack_b, %empty_unpack_b =
6458
transform.structured.pack_transpose %pack_producer_b with_compute_op(%packed_a)
6559
outer_perm = [1, 0] inner_perm = [1, 0] : (!transform.any_op, !transform.any_op)
6660
-> (!transform.any_op, !transform.any_op, !transform.any_op)
6761

62+
// Transpose C: outer_perm [1,0]
6863
%unpack = transform.get_consumers_of_result %packed_b[0]
6964
: (!transform.any_op) -> (!transform.any_op)
7065
%packed_c, %pack_c, %unpack_c =
7166
transform.structured.pack_transpose %unpack with_compute_op(%packed_b)
7267
outer_perm = [1, 0] : (!transform.any_op, !transform.any_op)
7368
-> (!transform.any_op, !transform.any_op, !transform.any_op)
7469

70+
// Promote C pack to L1
7571
%output_l1_pack_op_source_buffer, %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %pack_c
7672
{memory_space = 2, bufferize_destination_only, memcpy_op = "linalg.copy", emit_dealloc} : !transform.any_op
7773

78-
// Annotate the packed matmul so we can find it after K-tiling
74+
// Annotate for robust matching after K-tiling
7975
transform.annotate %packed_c "packed_matmul" : !transform.any_op
8076

8177
//==========================================================================
8278
// PHASE 4: TILE K REDUCTION AND FUSE PACK OPERATIONS
83-
// K/8 packed K-dim. Tile by 2 (= 16 raw K elements = K_L2_TILE).
79+
// Tile packed K dim by 2 (= 16 raw K elements).
8480
//==========================================================================
8581

8682
%tiled_reduction, %outer_for_loop =
@@ -93,9 +89,7 @@ module attributes {transform.with_named_sequence} {
9389

9490
//==========================================================================
9591
// PHASE 5: TILE FOR MULTI-CORE PARALLELISM
96-
// Packed C dims after pack [8,8,8] + outer_perm [1,0]:
97-
// [N/8, M/8, K/8] = [16, 32, K/8] → tile [8, 4, 0] → forall(2, 8)
98-
// par_to_herd maps to herd(8, 2) → collapse to 4×4
92+
// Tile [8, 4, 0] for herd distribution.
9993
//==========================================================================
10094

10195
%matmul_1 = transform.structured.match ops{["linalg.generic"]} attributes{packed_matmul} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -119,15 +113,13 @@ module attributes {transform.with_named_sequence} {
119113
// PHASE 6: PROMOTE INPUTS TO L1 AND TILE PROLOGUE/EPILOGUE
120114
//==========================================================================
121115

116+
// Promote A and B to L1
122117
%buffer_a, %new_a = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack2
123118
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
124119
%buffer_b, %new_b = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack2
125120
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
126121

127-
// Prologue: fill → generalize → interchange → tile_using_forall
128-
// After packing, fill is on packed 4D tensor [N/8, M/8, 8, 8] = [16, 32, 8, 8].
129-
// Interchange [1,0,2,3] swaps N/M dims → [32, 16, 8, 8].
130-
// Tile [8, 4] → forall(4, 4) matching herd.
122+
// Prologue: fill -> generalize -> interchange -> tile for herd
131123
%fill_op = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
132124
%generic_fill_op = transform.structured.generalize %fill_op
133125
: (!transform.any_op) -> !transform.any_op
@@ -140,7 +132,7 @@ module attributes {transform.with_named_sequence} {
140132
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
141133
transform.annotate %prologue_forall "prologue_forall" : !transform.any_op
142134

143-
// Epilogue: unpack → tile_using_forall [64, 32] for 4×4 herd
135+
// Epilogue: unpack -> tile for L2 write-back
144136
%unpack_op = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
145137
%epilogue_tiled_unpack, %epilogue_forall =
146138
transform.structured.tile_using_forall %unpack_op tile_sizes [64, 32]
@@ -195,8 +187,6 @@ module attributes {transform.with_named_sequence} {
195187

196188
%generic1 = transform.structured.match ops{["linalg.generic"]} attributes{init_fill} in %arg1 : (!transform.any_op) -> !transform.any_op
197189
%generic2 = transform.structured.match ops{["linalg.generic"]} attributes{matmul_compute} in %arg1 : (!transform.any_op) -> !transform.any_op
198-
// Per-core packed matmul: [4, 8, K/8, 8, 8, 8].
199-
// Tile for vectorization: [2, 2, 1, 0, 0, 0] then unroll.
200190
%inner_most_generics, %vec_loops:3 =
201191
transform.structured.tile_using_for %generic2 tile_sizes [2, 2, 1, 0, 0, 0]
202192
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
@@ -252,9 +242,12 @@ module attributes {transform.with_named_sequence} {
252242
%scf_fors_1 = transform.structured.match ops{["scf.for"]} in %herd2_1 : (!transform.any_op) -> !transform.any_op
253243
%innermost_for, %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
254244

255-
// Cast vector.contract input types: inputs 0,1 to bf16, accumulator 2 and output to f32
245+
// Cast accumulator (input[2]) and output[0] to f32
256246
%vector_contracts = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
257247
%result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = f32, input_indices = [2], output_indices = [0]} : (!transform.any_op) -> !transform.any_op
248+
249+
// Cast vector.contract inputs 0,1 to bf16
250+
// (matches hardware MAC unit native input type)
258251
%vector_contracts_2 = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
259252
%result11b = transform.air.vector_type_cast %vector_contracts_2 {target_element_type = bf16, input_indices = [0, 1], output_indices = []} : (!transform.any_op) -> !transform.any_op
260253

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
4+
# INT8 matmul with l1_m=128, l1_n=64, l2_k=64.
5+
# L1 budget (with pingpong): 16K(2*A) + 8K(2*B) + 32K(C) = 56KB / 64KB (88%).
6+
# BLOCK_SIZE_M=1024, BLOCK_SIZE_N=256 to fit 8x4 herd with per-core 128x64.
7+
#
8+
# Transform script generated by:
9+
# python examples/matmul_transform.py --l1-m 128 --l1-n 64 --l2-k 64 \
10+
# --pack-sizes 8 8 8 --accum-type i32 --contract-input-type i16 \
11+
# -o examples/matmul_i8_m128_n64_k64/transform_aie2p.mlir
12+
13+
import torch
14+
import triton
15+
import triton.language as tl
16+
import sys, os
17+
18+
sys.path.append(os.path.abspath(".."))
19+
import benchmark
20+
21+
22+
@triton.jit
23+
def bare_matmul_i8(
24+
A,
25+
B,
26+
C,
27+
M: tl.constexpr,
28+
N: tl.constexpr,
29+
K: tl.constexpr,
30+
stride_am: tl.constexpr,
31+
stride_ak: tl.constexpr,
32+
stride_bk: tl.constexpr,
33+
stride_bn: tl.constexpr,
34+
stride_cm: tl.constexpr,
35+
stride_cn: tl.constexpr,
36+
BLOCK_SIZE_M: tl.constexpr,
37+
BLOCK_SIZE_N: tl.constexpr,
38+
BLOCK_SIZE_K: tl.constexpr,
39+
):
40+
pid_m = tl.program_id(0)
41+
pid_n = tl.program_id(1)
42+
43+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
44+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
45+
offs_k = tl.arange(0, BLOCK_SIZE_K)
46+
47+
a_block = tl.load(A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
48+
b_block = tl.load(B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
49+
50+
c_block = tl.dot(a_block, b_block)
51+
52+
tl.store(C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, c_block)
53+
54+
55+
def bench_matmul_i8(M, N, K, provider):
56+
device = "cpu"
57+
dtype_in = torch.int8
58+
dtype_out = torch.int32
59+
a = torch.randint(-8, 8, (M, K), device=device, dtype=dtype_in)
60+
b = torch.randint(-8, 8, (K, N), device=device, dtype=dtype_in)
61+
c = torch.empty((M, N), device=device, dtype=dtype_out)
62+
if provider == "torch" or provider == "test":
63+
c_ref = torch.matmul(a.to(dtype_out), b.to(dtype_out))
64+
if provider == "triton" or provider == "test":
65+
grid = lambda META: (
66+
triton.cdiv(M, META["BLOCK_SIZE_M"]),
67+
triton.cdiv(N, META["BLOCK_SIZE_N"]),
68+
)
69+
compiled_kernel = bare_matmul_i8[grid](
70+
a,
71+
b,
72+
c,
73+
M,
74+
N,
75+
K,
76+
a.stride(0),
77+
a.stride(1),
78+
b.stride(0),
79+
b.stride(1),
80+
c.stride(0),
81+
c.stride(1),
82+
BLOCK_SIZE_M=1024,
83+
BLOCK_SIZE_N=256,
84+
BLOCK_SIZE_K=K,
85+
)
86+
with open("tt.shared.mlir", "w") as f:
87+
f.write(str(compiled_kernel.asm["ttsharedir"]))
88+
if provider == "test":
89+
torch.testing.assert_close(c, c_ref, atol=0, rtol=0)
90+
91+
92+
if __name__ == "__main__":
93+
benchmark.select_npu_backend()
94+
for M in [1024, 2048, 4096]:
95+
for N in [1024, 2048]:
96+
for K in [256, 512, 1024]:
97+
bench_matmul_i8(M, N, K, "test")
98+
bench_matmul_i8(M, N, K, "test")

0 commit comments

Comments
 (0)