|
| 1 | +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. |
| 2 | +// SPDX-License-Identifier: MIT |
| 3 | + |
| 4 | +//////////////////////////////////////////////////////////////////////////////// |
| 5 | +// Transform Script for SwiGLU (AIE2): out = SiLU(gate) * up |
| 6 | +// |
| 7 | +// SwiGLU(gate, up) = gate * sigmoid(gate) * up |
| 8 | +// |
| 9 | +// The Linalg IR has the silu chain (extf, negf/subf, exp, addf, divf, mulf) |
| 10 | +// plus an additional mulf for the final * up. After fuse_elementwise_linalg, |
| 11 | +// this becomes a single generic with 2 bf16 inputs (gate, up) and 1 bf16 |
| 12 | +// output (out). |
| 13 | +// |
| 14 | +// AIE2 type mapping: |
| 15 | +// - math.exp: bf16 ONLY -> needs vector_type_cast |
| 16 | +// - arith.divf: f32 ONLY -> keep as f32 |
| 17 | +// - arith.subf/addf/mulf: bf16 ONLY -> needs vector_type_cast |
| 18 | +// |
| 19 | +// Strategy: fuse_elementwise_linalg -> 3-operand tiling (like axpy) -> |
| 20 | +// vectorize at 16 -> cast exp, subf, addf, mulf to bf16; divf stays f32. |
| 21 | +// |
| 22 | +// AIE2 requires extern_func.o for math.exp (no native bf16 exp intrinsic). |
| 23 | +//////////////////////////////////////////////////////////////////////////////// |
| 24 | + |
| 25 | +module attributes {transform.with_named_sequence} { |
| 26 | + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { |
| 27 | + |
| 28 | + //=================================================================== |
| 29 | + // PHASE 1: Initial Canonicalization |
| 30 | + //=================================================================== |
| 31 | + %func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 32 | + transform.apply_patterns to %func0 { |
| 33 | + transform.apply_patterns.linalg.tiling_canonicalization |
| 34 | + transform.apply_patterns.scf.for_loop_canonicalization |
| 35 | + transform.apply_patterns.canonicalization |
| 36 | + transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes |
| 37 | + } : !transform.any_op |
| 38 | + transform.apply_cse to %func0 : !transform.any_op |
| 39 | + |
| 40 | + //=================================================================== |
| 41 | + // PHASE 2: Fuse Elementwise Chain |
| 42 | + //=================================================================== |
| 43 | + %func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 44 | + %func1_fused = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op |
| 45 | + |
| 46 | + %func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 47 | + transform.apply_patterns to %func1a { |
| 48 | + transform.apply_patterns.canonicalization |
| 49 | + } : !transform.any_op |
| 50 | + transform.apply_cse to %func1a : !transform.any_op |
| 51 | + |
| 52 | + //=================================================================== |
| 53 | + // PHASE 3: Vec-Add-Style Tiling Pattern |
| 54 | + //=================================================================== |
| 55 | + %op = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 56 | + |
| 57 | + %op_flattened = transform.structured.flatten_elementwise %op |
| 58 | + : (!transform.any_op) -> !transform.any_op |
| 59 | + |
| 60 | + %op_res_shared, %new_op = transform.structured.bufferize_to_allocation %op_flattened |
| 61 | + {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op |
| 62 | + |
| 63 | + %op_1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 64 | + %tiled_op_1, %forall_op_1 = |
| 65 | + transform.structured.tile_using_forall %op_1 tile_sizes [256] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 66 | + |
| 67 | + //=================================================================== |
| 68 | + // PHASE 4: Canonicalization |
| 69 | + //=================================================================== |
| 70 | + %func_2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 71 | + transform.apply_patterns to %func_2 { |
| 72 | + transform.apply_patterns.linalg.tiling_canonicalization |
| 73 | + transform.apply_patterns.scf.for_loop_canonicalization |
| 74 | + transform.apply_patterns.canonicalization |
| 75 | + } : !transform.any_op |
| 76 | + transform.apply_cse to %func_2 : !transform.any_op |
| 77 | + |
| 78 | + //=================================================================== |
| 79 | + // PHASE 5: Pad and Promote to L1 (3 operands: gate, up, out) |
| 80 | + //=================================================================== |
| 81 | + %op_2 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 82 | + |
| 83 | + %padded_op, %pad_op, %__ = transform.structured.pad %op_2 { |
| 84 | + padding_values=[0.0 : bf16, 0.0 : bf16, 0.0 : bf16], |
| 85 | + padding_dimensions=[0, 1, 2], |
| 86 | + nofold_flags=[1, 1, 1], |
| 87 | + copy_back_op="linalg.copy" |
| 88 | + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) |
| 89 | + |
| 90 | + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op : (!transform.any_op) -> !transform.any_op |
| 91 | + |
| 92 | + %padded_gate = transform.get_producer_of_operand %padded_op[0] : (!transform.any_op) -> (!transform.any_op) |
| 93 | + %padded_gate_buffer, %padded_gate_new = transform.structured.bufferize_to_allocation %padded_gate |
| 94 | + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op |
| 95 | + |
| 96 | + %padded_up = transform.get_producer_of_operand %padded_op[1] : (!transform.any_op) -> (!transform.any_op) |
| 97 | + %padded_up_buffer, %padded_up_new = transform.structured.bufferize_to_allocation %padded_up |
| 98 | + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op |
| 99 | + |
| 100 | + %padded_out = transform.get_producer_of_operand %padded_op[2] : (!transform.any_op) -> (!transform.any_op) |
| 101 | + %padded_out_buffer, %padded_out_new = transform.structured.bufferize_to_allocation %padded_out |
| 102 | + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op |
| 103 | + |
| 104 | + //=================================================================== |
| 105 | + // PHASE 6: Canonicalization |
| 106 | + //=================================================================== |
| 107 | + %func_3 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 108 | + transform.apply_patterns to %func_3 { |
| 109 | + transform.apply_patterns.linalg.tiling_canonicalization |
| 110 | + transform.apply_patterns.scf.for_loop_canonicalization |
| 111 | + transform.apply_patterns.canonicalization |
| 112 | + } : !transform.any_op |
| 113 | + transform.apply_cse to %func_3 : !transform.any_op |
| 114 | + |
| 115 | + //=================================================================== |
| 116 | + // PHASE 7: Bufferization |
| 117 | + //=================================================================== |
| 118 | + %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 119 | + %func_bufferized = transform.bufferization.one_shot_bufferize %func_op : (!transform.any_op) -> !transform.any_op |
| 120 | + |
| 121 | + //=================================================================== |
| 122 | + // PHASE 8: Post-Bufferization Cleanup |
| 123 | + //=================================================================== |
| 124 | + %func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 125 | + transform.apply_patterns to %func6 { |
| 126 | + transform.apply_patterns.linalg.tiling_canonicalization |
| 127 | + transform.apply_patterns.scf.for_loop_canonicalization |
| 128 | + transform.apply_patterns.canonicalization |
| 129 | + } : !transform.any_op |
| 130 | + transform.apply_cse to %func6 : !transform.any_op |
| 131 | + transform.apply_patterns to %func6 { |
| 132 | + transform.apply_patterns.canonicalization |
| 133 | + } : !transform.any_op |
| 134 | + %linalg_copies = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 135 | + %memref_copies = transform.structured.linalg_copy_to_memref %linalg_copies : (!transform.any_op) -> !transform.any_op |
| 136 | + %func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op |
| 137 | + %func_op_updated_1 = transform.air.eliminate_cascade_memcpy %func_op_updated : (!transform.any_op) -> !transform.any_op |
| 138 | + |
| 139 | + //=================================================================== |
| 140 | + // PHASE 9: Vectorization Tiling (16-lane for bf16) |
| 141 | + //=================================================================== |
| 142 | + %linalg_generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 143 | + %inner_most_generics, %vec_loops:1 = |
| 144 | + transform.structured.tile_using_for %linalg_generics tile_sizes [16] |
| 145 | + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 146 | + |
| 147 | + //=================================================================== |
| 148 | + // PHASE 10: AIR Constructs Mapping + Type Casts |
| 149 | + //=================================================================== |
| 150 | + %forall_as_herd = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 151 | + %parallel = transform.loop.forall_to_parallel %forall_as_herd : (!transform.any_op) -> !transform.any_op |
| 152 | + %herd = transform.air.par_to_herd %parallel : (!transform.any_op) -> !transform.any_op |
| 153 | + |
| 154 | + // AIE2 needs extern_func.o for math.exp (no native bf16 exp intrinsic) |
| 155 | + %extern_func_param = transform.param.constant "extern_func.o" -> !transform.any_param |
| 156 | + transform.annotate %herd "link_with" = %extern_func_param : !transform.any_op, !transform.any_param |
| 157 | + |
| 158 | + %copies_in_herd = transform.structured.match ops{["memref.copy", "linalg.copy"]} in %herd : (!transform.any_op) -> !transform.any_op |
| 159 | + %dmas_from_copies = transform.air.copy_to_dma %copies_in_herd : (!transform.any_op) -> !transform.any_op |
| 160 | + |
| 161 | + %vectorized_herd = transform.air.herd_vectorize %herd : (!transform.any_op) -> !transform.any_op |
| 162 | + |
| 163 | + // math.exp -> bf16 |
| 164 | + %vector_exps = transform.structured.match ops{["math.exp"]} in %vectorized_herd : (!transform.any_op) -> !transform.any_op |
| 165 | + %exp_cast = transform.air.vector_type_cast %vector_exps {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op |
| 166 | + |
| 167 | + // arith.subf -> bf16 |
| 168 | + %vector_subs = transform.structured.match ops{["arith.subf"]} in %vectorized_herd : (!transform.any_op) -> !transform.any_op |
| 169 | + %sub_cast = transform.air.vector_type_cast %vector_subs {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op |
| 170 | + |
| 171 | + // arith.addf -> bf16 |
| 172 | + %vector_adds = transform.structured.match ops{["arith.addf"]} in %vectorized_herd : (!transform.any_op) -> !transform.any_op |
| 173 | + %add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op |
| 174 | + |
| 175 | + // arith.mulf -> bf16 |
| 176 | + %vector_muls = transform.structured.match ops{["arith.mulf"]} in %vectorized_herd : (!transform.any_op) -> !transform.any_op |
| 177 | + %mul_cast = transform.air.vector_type_cast %vector_muls {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op |
| 178 | + |
| 179 | + // arith.divf stays f32 |
| 180 | + |
| 181 | + transform.yield |
| 182 | + } |
| 183 | +} |
0 commit comments