1- // Transform Script for F32 Matmul with BF16 Emulation
1+ // Auto-generated by matmul_transform.py — do not edit manually.
2+ // Parameters: l1_m=64, l1_n=32, l2_k=16, pack=[8,8,8], accum=f32, contract_in=bf16
23//
3- // Starting IR: Full-K matmul (no K-loop), all f32, generated from asm_src params.
4- // - func @matmul_padding_kernel(memref<*xf32>*3, i32*6)
5- // - linalg.matmul(64xK @ Kx32 → 64x32), f32 accumulation
6- // - A in K×M layout (strides [1, M_alloc]), B in K×N (strides [N_alloc, 1])
7- //
8- // Follows test 53's transform pattern: tile copies, pack [8,8,8], tile K,
9- // tile forall for multi-core, vectorize, hoist.
10- //
11- // Target: 4×8 AIE core array (Strix/NPU2), BF16 emulation
12- // Tile sizes: M=64, N=32, K_L2=16, pack [8,8,8]
4+ // Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
5+ // SPDX-License-Identifier: MIT
136
147module attributes {transform.with_named_sequence } {
158 transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
169
1710 //==========================================================================
18- // PHASE 1: TILE L3→L2 MEMORY COPIES
11+ // PHASE 1: TILE L3->L2 MEMORY COPIES
12+ // Tile memref copies for streaming data from DDR (L3) to MemTile (L2).
1913 //==========================================================================
2014
2115 %func10 = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
2216 %func10_updated = transform.air.convert_memref_copy_to_linalg_copy %func10 : (!transform.any_op ) -> !transform.any_op
2317 %copies = transform.structured.match ops {[" linalg.copy" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
2418 %copy1 , %copy2 = transform.split_handle %copies : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
25- // Tile A copy: 64×K → 64×16 tiles (K_L2_TILE=16)
2619 %tiled_copy1 , %tile_copy_loop1 =
2720 transform.structured.tile_using_for %copy1 tile_sizes [0 , 16 ]
2821 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
2922 transform.annotate %tile_copy_loop1 " copy_a_loop" : !transform.any_op
30- // Tile B copy: K×32 → 16×32 tiles
3123 %tiled_copy2 , %tile_copy_loop2 =
3224 transform.structured.tile_using_for %copy2 tile_sizes [16 ]
3325 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
3426 transform.annotate %tile_copy_loop2 " copy_b_loop" : !transform.any_op
3527
3628 //==========================================================================
3729 // PHASE 2: PROMOTE OUTPUT TO L2
38- // No truncf fusion needed (output is f32) .
30+ // Allocate output buffer (C) in L2 for accumulation .
3931 //==========================================================================
4032
4133 %result_l2 = transform.structured.match ops {[" linalg.fill" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
@@ -44,43 +36,47 @@ module attributes {transform.with_named_sequence} {
4436
4537 //==========================================================================
4638 // PHASE 3: PACK MATMUL FOR VECTORIZED COMPUTATION
47- // Pack sizes [8, 8, 8] for M, N, K dimensions .
39+ // Pack [8, 8, 8], transpose A/B/C, promote C pack to L1 .
4840 //==========================================================================
4941
5042 %matmul_to_pack = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
5143 %packed = transform.structured.pack %matmul_to_pack packed_sizes = [8 , 8 , 8 ]
5244 : (!transform.any_op ) -> (!transform.any_op )
5345
46+ // Transpose A: outer_perm [1,0]
5447 %pack_producer_a = transform.get_producer_of_operand %packed [0 ]
5548 : (!transform.any_op ) -> (!transform.any_op )
5649 %packed_a , %pack_a , %empty_unpack_a =
5750 transform.structured.pack_transpose %pack_producer_a with_compute_op (%packed )
5851 outer_perm = [1 , 0 ] : (!transform.any_op , !transform.any_op )
5952 -> (!transform.any_op , !transform.any_op , !transform.any_op )
6053
54+ // Transpose B: outer_perm [1,0] + inner_perm [1,0]
6155 %pack_producer_b = transform.get_producer_of_operand %packed_a [1 ]
6256 : (!transform.any_op ) -> (!transform.any_op )
6357 %packed_b , %pack_b , %empty_unpack_b =
6458 transform.structured.pack_transpose %pack_producer_b with_compute_op (%packed_a )
6559 outer_perm = [1 , 0 ] inner_perm = [1 , 0 ] : (!transform.any_op , !transform.any_op )
6660 -> (!transform.any_op , !transform.any_op , !transform.any_op )
6761
62+ // Transpose C: outer_perm [1,0]
6863 %unpack = transform.get_consumers_of_result %packed_b [0 ]
6964 : (!transform.any_op ) -> (!transform.any_op )
7065 %packed_c , %pack_c , %unpack_c =
7166 transform.structured.pack_transpose %unpack with_compute_op (%packed_b )
7267 outer_perm = [1 , 0 ] : (!transform.any_op , !transform.any_op )
7368 -> (!transform.any_op , !transform.any_op , !transform.any_op )
7469
70+ // Promote C pack to L1
7571 %output_l1_pack_op_source_buffer , %output_l1_pack_op_new = transform.structured.bufferize_to_allocation %pack_c
7672 {memory_space = 2 , bufferize_destination_only , memcpy_op = " linalg.copy" , emit_dealloc } : !transform.any_op
7773
78- // Annotate the packed matmul so we can find it after K-tiling
74+ // Annotate for robust matching after K-tiling
7975 transform.annotate %packed_c " packed_matmul" : !transform.any_op
8076
8177 //==========================================================================
8278 // PHASE 4: TILE K REDUCTION AND FUSE PACK OPERATIONS
83- // K/8 packed K- dim. Tile by 2 (= 16 raw K elements = K_L2_TILE ).
79+ // Tile packed K dim by 2 (= 16 raw K elements).
8480 //==========================================================================
8581
8682 %tiled_reduction , %outer_for_loop =
@@ -93,9 +89,7 @@ module attributes {transform.with_named_sequence} {
9389
9490 //==========================================================================
9591 // PHASE 5: TILE FOR MULTI-CORE PARALLELISM
96- // Packed C dims after pack [8,8,8] + outer_perm [1,0]:
97- // [N/8, M/8, K/8] = [16, 32, K/8] → tile [8, 4, 0] → forall(2, 8)
98- // par_to_herd maps to herd(8, 2) → collapse to 4×4
92+ // Tile [8, 4, 0] for herd distribution.
9993 //==========================================================================
10094
10195 %matmul_1 = transform.structured.match ops {[" linalg.generic" ]} attributes {packed_matmul } in %arg1 : (!transform.any_op ) -> !transform.any_op
@@ -119,15 +113,13 @@ module attributes {transform.with_named_sequence} {
119113 // PHASE 6: PROMOTE INPUTS TO L1 AND TILE PROLOGUE/EPILOGUE
120114 //==========================================================================
121115
116+ // Promote A and B to L1
122117 %buffer_a , %new_a = transform.structured.bufferize_to_allocation %fused_lhs_l1_pack2
123118 {memory_space = 2 , bufferize_destination_only , emit_dealloc } : !transform.any_op
124119 %buffer_b , %new_b = transform.structured.bufferize_to_allocation %fused_rhs_l1_pack2
125120 {memory_space = 2 , bufferize_destination_only , emit_dealloc } : !transform.any_op
126121
127- // Prologue: fill → generalize → interchange → tile_using_forall
128- // After packing, fill is on packed 4D tensor [N/8, M/8, 8, 8] = [16, 32, 8, 8].
129- // Interchange [1,0,2,3] swaps N/M dims → [32, 16, 8, 8].
130- // Tile [8, 4] → forall(4, 4) matching herd.
122+ // Prologue: fill -> generalize -> interchange -> tile for herd
131123 %fill_op = transform.structured.match ops {[" linalg.fill" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
132124 %generic_fill_op = transform.structured.generalize %fill_op
133125 : (!transform.any_op ) -> !transform.any_op
@@ -140,7 +132,7 @@ module attributes {transform.with_named_sequence} {
140132 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
141133 transform.annotate %prologue_forall " prologue_forall" : !transform.any_op
142134
143- // Epilogue: unpack → tile_using_forall [64, 32] for 4×4 herd
135+ // Epilogue: unpack -> tile for L2 write-back
144136 %unpack_op = transform.structured.match ops {[" linalg.unpack" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
145137 %epilogue_tiled_unpack , %epilogue_forall =
146138 transform.structured.tile_using_forall %unpack_op tile_sizes [64 , 32 ]
@@ -195,8 +187,6 @@ module attributes {transform.with_named_sequence} {
195187
196188 %generic1 = transform.structured.match ops {[" linalg.generic" ]} attributes {init_fill } in %arg1 : (!transform.any_op ) -> !transform.any_op
197189 %generic2 = transform.structured.match ops {[" linalg.generic" ]} attributes {matmul_compute } in %arg1 : (!transform.any_op ) -> !transform.any_op
198- // Per-core packed matmul: [4, 8, K/8, 8, 8, 8].
199- // Tile for vectorization: [2, 2, 1, 0, 0, 0] then unroll.
200190 %inner_most_generics , %vec_loops:3 =
201191 transform.structured.tile_using_for %generic2 tile_sizes [2 , 2 , 1 , 0 , 0 , 0 ]
202192 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op )
@@ -252,9 +242,12 @@ module attributes {transform.with_named_sequence} {
252242 %scf_fors_1 = transform.structured.match ops {[" scf.for" ]} in %herd2_1 : (!transform.any_op ) -> !transform.any_op
253243 %innermost_for , %outer_fors = transform.split_handle %scf_fors_1 {overflow_result = 1 } : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
254244
255- // Cast vector.contract input types: inputs 0,1 to bf16, accumulator 2 and output to f32
245+ // Cast accumulator (input[2]) and output[0] to f32
256246 %vector_contracts = transform.structured.match ops {[" vector.contract" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
257247 %result11 = transform.air.vector_type_cast %vector_contracts {target_element_type = f32 , input_indices = [2 ], output_indices = [0 ]} : (!transform.any_op ) -> !transform.any_op
248+
249+ // Cast vector.contract inputs 0,1 to bf16
250+ // (matches hardware MAC unit native input type)
258251 %vector_contracts_2 = transform.structured.match ops {[" vector.contract" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
259252 %result11b = transform.air.vector_type_cast %vector_contracts_2 {target_element_type = bf16 , input_indices = [0 , 1 ], output_indices = []} : (!transform.any_op ) -> !transform.any_op
260253
0 commit comments