11// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
22// SPDX-License-Identifier: MIT
33
4- ////////////////////////////////////////////////////////////////////////////////
5- // Transform Script for Average Pooling (AIE2P )
4+ // Mean subtraction transform for AIE2P.
5+ // y = x - mean(x, dim=-1 )
66//
7- // avg_pool(x) = mean(x, dim=-1) per row
8- //
9- // 2D kernel [BLOCK_M, BLOCK_N] with reduction over columns.
10- // Uses the rms_norm reduction pattern with linalg_promote for L1 staging.
11- // Requires mlir-air >= 4bc5734 (fix for linalg_promote memref.cast #1399).
12- ////////////////////////////////////////////////////////////////////////////////
7+ // 2D kernel (BLOCK_M=2 x BLOCK_N=256) with 2D output [BLOCK_M, BLOCK_N].
8+ // Follows the rms_norm reduction pattern exactly: tile [1], fuse_multi_op,
9+ // linalg_promote, herd + DMA, vectorize, type casts.
1310
1411module attributes {transform.with_named_sequence } {
1512 transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
16-
17- // Phase 1: Canonicalization
1813 %func0 = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
19- transform.apply_patterns to %func0 {
20- transform.apply_patterns.canonicalization
21- transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes
22- } : !transform.any_op
14+ transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization
15+ transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op
2316 transform.apply_cse to %func0 : !transform.any_op
24-
25- // Phase 2: Transpose reduce + fuse elementwise
2617 %reduces = transform.structured.match ops {[" linalg.reduce" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
2718 %tr = transform.air.transpose_reduce %reduces : (!transform.any_op ) -> !transform.any_op
2819 %func1a = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
2920 transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op
3021 transform.apply_cse to %func1a : !transform.any_op
31-
3222 %func1 = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
3323 %f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op ) -> !transform.any_op
3424 %fa = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
3525 transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op
3626 transform.apply_cse to %fa : !transform.any_op
3727
38- // Phase 3: Match, tile, fuse
39- // After fusion: 1 generic (fused extf+divf+truncf), 1 reduce, 1 fill
40- %generic = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
28+ // After fusion: "out" generic [2, 256] (x - mean broadcast), reduce, fill
29+ // No "sq" generic (unlike rms_norm) -- the reduce directly sums x.
30+ %out = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
4131 %reduce = transform.structured.match ops {[" linalg.reduce" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
4232 %fill = transform.structured.match ops {[" linalg.fill" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
4333
4434 // L2 output alloc
45- %ob , %nb = transform.structured.bufferize_to_allocation %generic
46- {memory_space = 1 , bufferize_destination_only , emit_dealloc } : !transform.any_op
47- // Tile at [2] not [1]: single bf16 = 2 bytes, below 4-byte DMA alignment
48- %t , %fl = transform.structured.tile_using_forall %generic tile_sizes [2 ]
49- : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
50- // Fuse into forall
51- %f1 , %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl
52- : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
53- %f2 , %fl2 = transform.structured.fuse_into_containing_op %fill into %fl1
54- : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
55-
56- // Phase 4: Fill dest to L1
35+ %ob , %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1 , bufferize_destination_only , emit_dealloc } : !transform.any_op
36+ // Tile at [1] on row dim (same as rms_norm)
37+ %t , %fl = transform.structured.tile_using_forall %out tile_sizes [1 ] : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
38+ // Fuse all into forall
39+ %f1 , %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
40+ %f2 , %fl2 = transform.structured.fuse_into_containing_op %fill into %fl1 : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
41+
42+ // L1 for fills only (destination-only)
5743 %fills3 = transform.structured.match ops {[" linalg.fill" ]} in %fl2 : (!transform.any_op ) -> !transform.any_op
5844 %fill_buf , %fill_new = transform.structured.bufferize_to_allocation %fills3
5945 {memory_space = 2 , bufferize_destination_only , emit_dealloc } : !transform.any_op
6046
61- // Phase 5: Canonicalize + bufferize
47+ // Canonicalize + bufferize
6248 %f2c = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
6349 transform.apply_patterns to %f2c { transform.apply_patterns.canonicalization } : !transform.any_op
6450 transform.apply_cse to %f2c : !transform.any_op
@@ -72,14 +58,21 @@ module attributes {transform.with_named_sequence} {
7258 %fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op ) -> (!transform.any_op )
7359 %fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op ) -> (!transform.any_op )
7460
75- // Phase 6: L1 promote (linalg_promote with fix from mlir-air #1399)
61+ // L1 promote
7662 %forall_op = transform.structured.match ops {[" scf.forall" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
7763 %gens_f = transform.structured.match ops {[" linalg.generic" ]} in %forall_op : (!transform.any_op ) -> !transform.any_op
7864 %reds_f = transform.structured.match ops {[" linalg.reduce" ]} in %forall_op : (!transform.any_op ) -> !transform.any_op
7965 %all_linalg_f = transform.merge_handles %reds_f , %gens_f { deduplicate } : !transform.any_op
8066 %promoted = transform.air.linalg_promote %all_linalg_f {memory_space = " L1" } : (!transform.any_op ) -> !transform.any_op
8167
82- // Phase 7: Herd + DMA
68+ // Post-promote cleanup
69+ %f_pp = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
70+ transform.apply_patterns to %f_pp { transform.apply_patterns.canonicalization } : !transform.any_op
71+ transform.apply_cse to %f_pp : !transform.any_op
72+ %f_pp2 = transform.air.remove_uninitialized_copy %f_pp : (!transform.any_op ) -> (!transform.any_op )
73+ %f_pp3 = transform.air.eliminate_cascade_memcpy %f_pp2 : (!transform.any_op ) -> (!transform.any_op )
74+
75+ // Herd + DMA
8376 %fh = transform.structured.match ops {[" scf.forall" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
8477 %pa = transform.loop.forall_to_parallel %fh : (!transform.any_op ) -> !transform.any_op
8578 %h = transform.air.par_to_herd %pa : (!transform.any_op ) -> !transform.any_op
@@ -89,25 +82,17 @@ module attributes {transform.with_named_sequence} {
8982 %ac = transform.merge_handles %mc2 , %mc3 { deduplicate } : !transform.any_op
9083 %dm = transform.air.copy_to_dma %ac : (!transform.any_op ) -> !transform.any_op
9184
92- // Phase 8: Vectorization
85+ // Vectorization (same as rms_norm)
9386 %h2 = transform.structured.match ops {[" air.herd" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
94-
95- // Tile reduce at [0, 16] for vectorization
96- %reds_h = transform.structured.match ops {[" linalg.reduce" ]} in %h2 : (!transform.any_op ) -> !transform.any_op
97- %inner_r , %inner_rl:1 = transform.structured.tile_using_for %reds_h tile_sizes [0 , 16 ]
98- : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
99-
100- // Generic is scalar (divf per row) -- convert to loops
10187 %gens_h = transform.structured.match ops {[" linalg.generic" ]} in %h2 : (!transform.any_op ) -> !transform.any_op
102- %gen_scl = transform.structured.convert_to_loops %gens_h : (!transform.any_op ) -> !transform.any_op
103-
104- // Fill is scalar -- convert to loops
88+ %inner_g , %inner_gl:1 = transform.structured.tile_using_for %gens_h tile_sizes [ 0 , 16 ] : (!transform.any_op ) -> ( !transform.any_op , !transform.any_op )
89+ %reds_h = transform.structured.match ops {[ " linalg.reduce " ]} in %h2 : ( !transform.any_op ) -> !transform.any_op
90+ %inner_r , %inner_rl:1 = transform.structured.tile_using_for %reds_h tile_sizes [ 0 , 16 ] : ( !transform.any_op ) -> ( !transform.any_op , !transform.any_op )
10591 %fills_h = transform.structured.match ops {[" linalg.fill" ]} in %h2 : (!transform.any_op ) -> !transform.any_op
10692 %fill_scl = transform.structured.convert_to_loops %fills_h : (!transform.any_op ) -> !transform.any_op
107-
10893 %vh = transform.air.herd_vectorize %h2 : (!transform.any_op ) -> !transform.any_op
10994
110- // Phase 9: Lower reductions + type casts
95+ // Lower vector reductions
11196 %func_final = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
11297 transform.apply_patterns to %func_final {
11398 transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = " innerreduction"
@@ -116,19 +101,19 @@ module attributes {transform.with_named_sequence} {
116101 } : !transform.any_op
117102 transform.apply_cse to %func_final : !transform.any_op
118103
119- // addf -> bf16 (from reduction lowering)
104+ // AIE2P type casts: mulf/addf/subf bf16-only, divf f32-only
120105 %vh2 = transform.structured.match ops {[" air.herd" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
121106 %vector_adds = transform.structured.match ops {[" arith.addf" ]} in %vh2 : (!transform.any_op ) -> !transform.any_op
122107 %add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16 } : (!transform.any_op ) -> !transform.any_op
123-
108+ %vector_subs = transform.structured.match ops {[" arith.subf" ]} in %vh2 : (!transform.any_op ) -> !transform.any_op
109+ %sub_cast = transform.air.vector_type_cast %vector_subs {target_element_type = bf16 } : (!transform.any_op ) -> !transform.any_op
124110 %func_s1 = transform.structured.match ops {[" func.func" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
125111 %func_s1_done = transform.air.convert_size1_vector_to_scalar %func_s1 : (!transform.any_op ) -> !transform.any_op
126112 transform.apply_patterns to %func_s1_done {
127113 transform.apply_patterns.vector.cast_away_vector_leading_one_dim
128114 transform.apply_patterns.canonicalization
129115 } : !transform.any_op
130116 transform.apply_cse to %func_s1_done : !transform.any_op
131-
132117 transform.yield
133118 }
134119}
0 commit comments