|
| 1 | +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. |
| 2 | +// SPDX-License-Identifier: MIT |
| 3 | + |
| 4 | +// Mean subtraction transform for AIE2. |
| 5 | +// y = x - mean(x, dim=-1) |
| 6 | +// |
| 7 | +// 2D kernel (BLOCK_M=2 x BLOCK_N=256) with 2D output [BLOCK_M, BLOCK_N]. |
| 8 | +// Follows the rms_norm reduction pattern exactly: tile [1], fuse_multi_op, |
| 9 | +// linalg_promote, herd + DMA, vectorize, type casts. |
| 10 | + |
| 11 | +module attributes {transform.with_named_sequence} { |
| 12 | + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { |
| 13 | + %func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 14 | + transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization |
| 15 | + transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op |
| 16 | + transform.apply_cse to %func0 : !transform.any_op |
| 17 | + %reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 18 | + %tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op |
| 19 | + %func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 20 | + transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op |
| 21 | + transform.apply_cse to %func1a : !transform.any_op |
| 22 | + %func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 23 | + %f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op |
| 24 | + %fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 25 | + transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op |
| 26 | + transform.apply_cse to %fa : !transform.any_op |
| 27 | + |
| 28 | + // After fusion: "out" generic [2, 256] (x - mean broadcast), reduce, fill |
| 29 | + // No "sq" generic (unlike rms_norm) -- the reduce directly sums x. |
| 30 | + %out = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 31 | + %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 32 | + %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 33 | + |
| 34 | + // L2 output alloc |
| 35 | + %ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op |
| 36 | + // Tile at [1] on row dim (same as rms_norm) |
| 37 | + %t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 38 | + // Fuse all into forall |
| 39 | + %f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 40 | + %f2, %fl2 = transform.structured.fuse_into_containing_op %fill into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 41 | + |
| 42 | + // L1 for fills only (destination-only) |
| 43 | + %fills3 = transform.structured.match ops{["linalg.fill"]} in %fl2 : (!transform.any_op) -> !transform.any_op |
| 44 | + %fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3 |
| 45 | + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op |
| 46 | + |
| 47 | + // Canonicalize + bufferize |
| 48 | + %f2c = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 49 | + transform.apply_patterns to %f2c { transform.apply_patterns.canonicalization } : !transform.any_op |
| 50 | + transform.apply_cse to %f2c : !transform.any_op |
| 51 | + %fop = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 52 | + %fb = transform.bufferization.one_shot_bufferize %fop : (!transform.any_op) -> !transform.any_op |
| 53 | + %f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 54 | + transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op |
| 55 | + transform.apply_cse to %f6 : !transform.any_op |
| 56 | + %lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 57 | + %mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op |
| 58 | + %fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op) |
| 59 | + %fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op) |
| 60 | + |
| 61 | + // L1 promote |
| 62 | + %forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 63 | + %gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op |
| 64 | + %reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op |
| 65 | + %all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op |
| 66 | + %promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op |
| 67 | + |
| 68 | + // Post-promote cleanup |
| 69 | + %f_pp = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 70 | + transform.apply_patterns to %f_pp { transform.apply_patterns.canonicalization } : !transform.any_op |
| 71 | + transform.apply_cse to %f_pp : !transform.any_op |
| 72 | + %f_pp2 = transform.air.remove_uninitialized_copy %f_pp : (!transform.any_op) -> (!transform.any_op) |
| 73 | + %f_pp3 = transform.air.eliminate_cascade_memcpy %f_pp2 : (!transform.any_op) -> (!transform.any_op) |
| 74 | + |
| 75 | + // Herd + DMA |
| 76 | + %fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 77 | + %pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op |
| 78 | + %h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op |
| 79 | + %lc2 = transform.structured.match ops{["linalg.copy"]} in %h : (!transform.any_op) -> !transform.any_op |
| 80 | + %mc2 = transform.structured.match ops{["memref.copy"]} in %h : (!transform.any_op) -> !transform.any_op |
| 81 | + %mc3 = transform.structured.linalg_copy_to_memref %lc2 : (!transform.any_op) -> !transform.any_op |
| 82 | + %ac = transform.merge_handles %mc2, %mc3 { deduplicate } : !transform.any_op |
| 83 | + %dm = transform.air.copy_to_dma %ac : (!transform.any_op) -> !transform.any_op |
| 84 | + |
| 85 | + // Vectorization (same as rms_norm) |
| 86 | + %h2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 87 | + %gens_h = transform.structured.match ops{["linalg.generic"]} in %h2 : (!transform.any_op) -> !transform.any_op |
| 88 | + %inner_g, %inner_gl:1 = transform.structured.tile_using_for %gens_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 89 | + %reds_h = transform.structured.match ops{["linalg.reduce"]} in %h2 : (!transform.any_op) -> !transform.any_op |
| 90 | + %inner_r, %inner_rl:1 = transform.structured.tile_using_for %reds_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 91 | + %fills_h = transform.structured.match ops{["linalg.fill"]} in %h2 : (!transform.any_op) -> !transform.any_op |
| 92 | + %fill_scl = transform.structured.convert_to_loops %fills_h : (!transform.any_op) -> !transform.any_op |
| 93 | + %vh = transform.air.herd_vectorize %h2 : (!transform.any_op) -> !transform.any_op |
| 94 | + |
| 95 | + // Lower vector reductions |
| 96 | + %func_final = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 97 | + transform.apply_patterns to %func_final { |
| 98 | + transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction" |
| 99 | + transform.apply_patterns.vector.lower_contraction |
| 100 | + transform.apply_patterns.vector.lower_transfer |
| 101 | + } : !transform.any_op |
| 102 | + transform.apply_cse to %func_final : !transform.any_op |
| 103 | + |
| 104 | + // AIE2 type casts: mulf/addf/subf bf16-only, divf f32-only |
| 105 | + %vh2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 106 | + %vector_adds = transform.structured.match ops{["arith.addf"]} in %vh2 : (!transform.any_op) -> !transform.any_op |
| 107 | + %add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op |
| 108 | + %vector_subs = transform.structured.match ops{["arith.subf"]} in %vh2 : (!transform.any_op) -> !transform.any_op |
| 109 | + %sub_cast = transform.air.vector_type_cast %vector_subs {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op |
| 110 | + %func_s1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 111 | + %func_s1_done = transform.air.convert_size1_vector_to_scalar %func_s1 : (!transform.any_op) -> !transform.any_op |
| 112 | + transform.apply_patterns to %func_s1_done { |
| 113 | + transform.apply_patterns.vector.cast_away_vector_leading_one_dim |
| 114 | + transform.apply_patterns.canonicalization |
| 115 | + } : !transform.any_op |
| 116 | + transform.apply_cse to %func_s1_done : !transform.any_op |
| 117 | + transform.yield |
| 118 | + } |
| 119 | +} |
0 commit comments