|
| 1 | +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. |
| 2 | +// SPDX-License-Identifier: MIT |
| 3 | + |
| 4 | +// Weighted RMS Norm transform for AIE2P. |
| 5 | +// y = x * rsqrt(mean(x^2) + eps) * w |
| 6 | +// |
| 7 | +// No fuse_multi_op_linalg (creates L2 refs). CSE merges duplicate X |
| 8 | +// extract_slices so linalg_promote shares one L1 X buffer across ops. |
| 9 | + |
| 10 | +module attributes {transform.with_named_sequence} { |
| 11 | + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { |
| 12 | + %func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 13 | + transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization |
| 14 | + transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op |
| 15 | + transform.apply_cse to %func0 : !transform.any_op |
| 16 | + %reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 17 | + %tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op |
| 18 | + %func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 19 | + transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op |
| 20 | + transform.apply_cse to %func1a : !transform.any_op |
| 21 | + %func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 22 | + %f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op |
| 23 | + %fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 24 | + transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op |
| 25 | + transform.apply_cse to %fa : !transform.any_op |
| 26 | + |
| 27 | + %ag = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 28 | + %sq, %out = transform.split_handle %ag : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 29 | + %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 30 | + %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 31 | + |
| 32 | + %ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op |
| 33 | + %t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 34 | + %f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 35 | + %f2, %fl2 = transform.structured.fuse_into_containing_op %sq into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 36 | + %f3, %fl3 = transform.structured.fuse_into_containing_op %fill into %fl2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) |
| 37 | + |
| 38 | + // Fill dest → L1 |
| 39 | + %fills3 = transform.structured.match ops{["linalg.fill"]} in %fl3 : (!transform.any_op) -> !transform.any_op |
| 40 | + %fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3 |
| 41 | + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op |
| 42 | + |
| 43 | + // CSE to merge duplicate X extract_slices from fuse_into_containing_op. |
| 44 | + // sq and out both slice from the same X tensor -- CSE merges them so |
| 45 | + // linalg_promote's promotedValueMap shares one L1 buffer. |
| 46 | + %func_cse = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 47 | + transform.apply_patterns to %func_cse { |
| 48 | + transform.apply_patterns.linalg.tiling_canonicalization |
| 49 | + transform.apply_patterns.scf.for_loop_canonicalization |
| 50 | + transform.apply_patterns.canonicalization |
| 51 | + } : !transform.any_op |
| 52 | + transform.apply_cse to %func_cse : !transform.any_op |
| 53 | + |
| 54 | + // Bufferize |
| 55 | + %fop = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 56 | + %fb = transform.bufferization.one_shot_bufferize %fop : (!transform.any_op) -> !transform.any_op |
| 57 | + %f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 58 | + transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op |
| 59 | + transform.apply_cse to %f6 : !transform.any_op |
| 60 | + %lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 61 | + %mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op |
| 62 | + %fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op) |
| 63 | + %fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op) |
| 64 | + |
| 65 | + // linalg_promote for ALL operands (X, weight, output) with consistent |
| 66 | + // memory space attributes. No tensor-domain promotion needed. |
| 67 | + %forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 68 | + %gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op |
| 69 | + %reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op |
| 70 | + %all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op |
| 71 | + %promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op |
| 72 | + |
| 73 | + // No post-promote canonicalization (it hoists allocs out of forall region) |
| 74 | + |
| 75 | + // Herd + DMA |
| 76 | + %fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 77 | + %pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op |
| 78 | + %h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op |
| 79 | + |
| 80 | + // Override herd-internal allocs to L1 |
| 81 | + %herd_ms = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 82 | + %herd_ms_done = transform.air.override_memref_memory_space %herd_ms {memory_space = 2 : i32} : (!transform.any_op) -> !transform.any_op |
| 83 | + |
| 84 | + %h_fresh = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 85 | + |
| 86 | + // Skip copy_to_dma here -- driver's step 3 runs air-copy-to-dma pass |
| 87 | + // which handles self-copy elimination (PR #1390). |
| 88 | + |
| 89 | + // No post-herd cleanup (canonicalization hoists allocs across regions) |
| 90 | + |
| 91 | + // STOP HERE to check domination |
| 92 | + transform.yield |
| 93 | + } |
| 94 | +} |
0 commit comments