|
| 1 | +#loop_annotation = #llvm.loop_annotation<mustProgress = true> |
| 2 | +module { |
| 3 | + aie.device(npu2) @rms_norm_kernel_0 { |
| 4 | + %shim_noc_tile_0_0 = aie.tile(0, 0) |
| 5 | + %shim_noc_tile_1_0 = aie.tile(1, 0) |
| 6 | + %mem_tile_0_1 = aie.tile(0, 1) |
| 7 | + %mem_tile_1_1 = aie.tile(1, 1) |
| 8 | + %tile_0_2 = aie.tile(0, 2) |
| 9 | + %tile_0_3 = aie.tile(0, 3) |
| 10 | + %lock_1_1 = aie.lock(%mem_tile_1_1, 1) {init = 1 : i32} |
| 11 | + %lock_1_1_0 = aie.lock(%mem_tile_1_1, 0) {init = 0 : i32} |
| 12 | + %buf5 = aie.buffer(%mem_tile_0_1) {sym_name = "buf5"} : memref<2x64xbf16, 1 : i32> |
| 13 | + %buf4 = aie.buffer(%mem_tile_1_1) {sym_name = "buf4"} : memref<2x64xbf16, 1> |
| 14 | + %buf3 = aie.buffer(%tile_0_3) {sym_name = "buf3"} : memref<1xf32, 2> |
| 15 | + %buf2 = aie.buffer(%tile_0_3) {sym_name = "buf2"} : memref<1x64xbf16, 2> |
| 16 | + %buf1 = aie.buffer(%tile_0_2) {sym_name = "buf1"} : memref<1xf32, 2> |
| 17 | + %buf0 = aie.buffer(%tile_0_2) {sym_name = "buf0"} : memref<1x64xbf16, 2> |
| 18 | + memref.global "public" @__air_herd_arg_1 : memref<2x64xbf16, 1 : i32> |
| 19 | + %core_0_3 = aie.core(%tile_0_3) { |
| 20 | + %c64 = arith.constant 64 : index |
| 21 | + %cst = arith.constant 0.000000e+00 : f32 |
| 22 | + %cst_1 = arith.constant 6.400000e+01 : f32 |
| 23 | + %cst_2 = arith.constant 9.99999974E-6 : f32 |
| 24 | + %c1 = arith.constant 1 : index |
| 25 | + %c0 = arith.constant 0 : index |
| 26 | + cf.br ^bb1 |
| 27 | + ^bb1: // 2 preds: ^bb0, ^bb2 |
| 28 | + %0 = memref.get_global @__air_herd_arg_1 : memref<2x64xbf16, 1 : i32> |
| 29 | + cf.br ^bb2 |
| 30 | + ^bb2: // pred: ^bb1 |
| 31 | + %subview = memref.subview %0[%c1, 0] [1, 64] [1, 1] : memref<2x64xbf16, 1 : i32> to memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32> |
| 32 | + memref.store %cst, %buf3[%c0] : memref<1xf32, 2> |
| 33 | + scf.for %arg0 = %c0 to %c64 step %c1 { |
| 34 | + %1 = memref.load %subview[%c0, %arg0] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32> |
| 35 | + %2 = memref.load %buf3[%c0] : memref<1xf32, 2> |
| 36 | + %3 = arith.extf %1 : bf16 to f32 |
| 37 | + %4 = arith.mulf %3, %3 : f32 |
| 38 | + %5 = arith.addf %4, %2 : f32 |
| 39 | + memref.store %5, %buf3[%c0] : memref<1xf32, 2> |
| 40 | + } {loop_annotation = #loop_annotation} |
| 41 | + scf.for %arg0 = %c0 to %c64 step %c1 { |
| 42 | + %1 = memref.load %subview[%c0, %arg0] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32> |
| 43 | + %2 = memref.load %buf3[%c0] : memref<1xf32, 2> |
| 44 | + %3 = arith.divf %2, %cst_1 : f32 |
| 45 | + %4 = arith.addf %3, %cst_2 : f32 |
| 46 | + %5 = math.rsqrt %4 : f32 |
| 47 | + %6 = arith.extf %1 : bf16 to f32 |
| 48 | + %7 = arith.mulf %6, %5 : f32 |
| 49 | + %8 = arith.truncf %7 : f32 to bf16 |
| 50 | + memref.store %8, %buf2[%c0, %arg0] : memref<1x64xbf16, 2> |
| 51 | + } {loop_annotation = #loop_annotation} |
| 52 | + cf.br ^bb1 |
| 53 | + } |
| 54 | + memref.global "public" @__air_herd_arg : memref<2x64xbf16, 1 : i32> |
| 55 | + %core_0_2 = aie.core(%tile_0_2) { |
| 56 | + %c64 = arith.constant 64 : index |
| 57 | + %cst = arith.constant 0.000000e+00 : f32 |
| 58 | + %cst_1 = arith.constant 6.400000e+01 : f32 |
| 59 | + %cst_2 = arith.constant 9.99999974E-6 : f32 |
| 60 | + %c1 = arith.constant 1 : index |
| 61 | + %c0 = arith.constant 0 : index |
| 62 | + cf.br ^bb1 |
| 63 | + ^bb1: // 2 preds: ^bb0, ^bb2 |
| 64 | + %0 = memref.get_global @__air_herd_arg : memref<2x64xbf16, 1 : i32> |
| 65 | + cf.br ^bb2 |
| 66 | + ^bb2: // pred: ^bb1 |
| 67 | + %subview = memref.subview %0[%c0, 0] [1, 64] [1, 1] : memref<2x64xbf16, 1 : i32> to memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32> |
| 68 | + memref.store %cst, %buf1[%c0] : memref<1xf32, 2> |
| 69 | + scf.for %arg0 = %c0 to %c64 step %c1 { |
| 70 | + %1 = memref.load %subview[%c0, %arg0] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32> |
| 71 | + %2 = memref.load %buf1[%c0] : memref<1xf32, 2> |
| 72 | + %3 = arith.extf %1 : bf16 to f32 |
| 73 | + %4 = arith.mulf %3, %3 : f32 |
| 74 | + %5 = arith.addf %4, %2 : f32 |
| 75 | + memref.store %5, %buf1[%c0] : memref<1xf32, 2> |
| 76 | + } {loop_annotation = #loop_annotation} |
| 77 | + scf.for %arg0 = %c0 to %c64 step %c1 { |
| 78 | + %1 = memref.load %subview[%c0, %arg0] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32> |
| 79 | + %2 = memref.load %buf1[%c0] : memref<1xf32, 2> |
| 80 | + %3 = arith.divf %2, %cst_1 : f32 |
| 81 | + %4 = arith.addf %3, %cst_2 : f32 |
| 82 | + %5 = math.rsqrt %4 : f32 |
| 83 | + %6 = arith.extf %1 : bf16 to f32 |
| 84 | + %7 = arith.mulf %6, %5 : f32 |
| 85 | + %8 = arith.truncf %7 : f32 to bf16 |
| 86 | + memref.store %8, %buf0[%c0, %arg0] : memref<1x64xbf16, 2> |
| 87 | + } {loop_annotation = #loop_annotation} |
| 88 | + cf.br ^bb1 |
| 89 | + } |
| 90 | + air.channel @channel_0 [] |
| 91 | + air.channel @channel_1 [] |
| 92 | + aie.flow(%shim_noc_tile_0_0, DMA : 0, %mem_tile_0_1, DMA : 0) |
| 93 | + aie.flow(%mem_tile_1_1, DMA : 0, %shim_noc_tile_1_0, DMA : 0) |
| 94 | + %memtile_dma_1_1 = aie.memtile_dma(%mem_tile_1_1) { |
| 95 | + %0 = aie.dma_start(MM2S, 0, ^bb1, ^bb2) |
| 96 | + ^bb1: // 2 preds: ^bb0, ^bb1 |
| 97 | + aie.use_lock(%lock_1_1_0, AcquireGreaterEqual, 1) |
| 98 | + aie.dma_bd(%buf4 : memref<2x64xbf16, 1>, 0, 128) {task_id = 0 : i32} |
| 99 | + aie.use_lock(%lock_1_1, Release, 1) |
| 100 | + aie.next_bd ^bb1 |
| 101 | + ^bb2: // pred: ^bb0 |
| 102 | + aie.end |
| 103 | + } |
| 104 | + aie.shim_dma_allocation @air_channel_1(%shim_noc_tile_1_0, S2MM, 0) |
| 105 | + aie.shim_dma_allocation @air_channel_0(%shim_noc_tile_0_0, MM2S, 0) |
| 106 | + } {dlti.dl_spec = #dlti.dl_spec<index = 32 : i64>} |
| 107 | + airrt.module_metadata{ |
| 108 | + airrt.segment_metadata attributes {dma_allocations = [{channel = 2 : i64, col = 0 : i64, id = 3 : i64, location = 0 : i64, row = -1 : i64}], sym_name = "rms_norm_kernel_0"}{ |
| 109 | + airrt.herd_metadata {dma_allocations = [], loc_x = 0 : i64, loc_y = 2 : i64, size_x = 1 : i64, size_y = 2 : i64, sym_name = "herd_0"} |
| 110 | + } |
| 111 | + } |
| 112 | + air.channel @channel_0 [] |
| 113 | + air.channel @channel_1 [] |
| 114 | + func.func @rms_norm_kernel(%arg0: memref<*xbf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xbf16> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { |
| 115 | + %c1 = arith.constant 1 : index |
| 116 | + %c16 = arith.constant 16 : index |
| 117 | + %0 = air.launch async (%arg9, %arg10, %arg11) in (%arg12=%c16, %arg13=%c1, %arg14=%c1) args(%arg15=%arg0, %arg16=%arg1) : memref<*xbf16>, memref<*xbf16> attributes {id = 1 : i32} { |
| 118 | + %c0 = arith.constant 0 : index |
| 119 | + %c64 = arith.constant 64 : index |
| 120 | + %c2 = arith.constant 2 : index |
| 121 | + %c1_0 = arith.constant 1 : index |
| 122 | + %c128 = arith.constant 128 : index |
| 123 | + %1 = arith.muli %arg10, %c128 : index |
| 124 | + %2 = air.channel.put async @channel_0[] (%arg15[%c0, %1] [%c2, %c64] [%c64, %c1_0]) {id = 1 : i32, metadataArray = [{base = "air_channel_0", index = 0 : i32}]} : (memref<*xbf16>) |
| 125 | + %3 = air.channel.get async @channel_1[] (%arg16[%c0, %1] [%c2, %c64] [%c64, %c1_0]) {id = 2 : i32, metadataArray = [{base = "air_channel_1", index = 0 : i32}]} : (memref<*xbf16>) |
| 126 | + %4 = air.segment @rms_norm_kernel_0 async attributes {id = 2 : i32, x_loc = 0 : i64, x_size = 8 : i64, y_loc = 2 : i64, y_size = 6 : i64} { |
| 127 | + %c2_1 = arith.constant 2 : index |
| 128 | + %c1_2 = arith.constant 1 : index |
| 129 | + %async_token, %results = air.execute -> (memref<2x64xbf16, 1 : i32>) { |
| 130 | + %alloc = memref.alloc() : memref<2x64xbf16, 1 : i32> |
| 131 | + air.execute_terminator %alloc : memref<2x64xbf16, 1 : i32> |
| 132 | + } |
| 133 | + %5 = air.channel.get async [%async_token] @channel_0[] (%results[] [] []) {id = 3 : i32} : (memref<2x64xbf16, 1 : i32>) |
| 134 | + %async_token_3, %results_4 = air.execute -> (memref<2x64xbf16, 1>) { |
| 135 | + %alloc = memref.alloc() : memref<2x64xbf16, 1> |
| 136 | + air.execute_terminator %alloc : memref<2x64xbf16, 1> |
| 137 | + } |
| 138 | + %6 = air.herd @herd_0 async [%5] tile (%arg17, %arg18) in (%arg19=%c1_2, %arg20=%c2_1) args(%arg21=%results) : memref<2x64xbf16, 1 : i32> attributes {id = 3 : i32, x_loc = 0 : i64, y_loc = 2 : i64} { |
| 139 | + %cst = arith.constant 9.99999974E-6 : f32 |
| 140 | + %cst_6 = arith.constant 6.400000e+01 : f32 |
| 141 | + %cst_7 = arith.constant 0.000000e+00 : f32 |
| 142 | + %c0_8 = arith.constant 0 : index |
| 143 | + %c64_9 = arith.constant 64 : index |
| 144 | + %c1_10 = arith.constant 1 : index |
| 145 | + %subview = memref.subview %arg21[%arg18, 0] [1, 64] [1, 1] : memref<2x64xbf16, 1 : i32> to memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32> |
| 146 | + %async_token_11, %results_12 = air.execute -> (memref<1xf32, 2>) { |
| 147 | + %alloc = memref.alloc() : memref<1xf32, 2> |
| 148 | + air.execute_terminator %alloc : memref<1xf32, 2> |
| 149 | + } |
| 150 | + %async_token_13 = air.execute [%async_token_11] { |
| 151 | + memref.store %cst_7, %results_12[%c0_8] : memref<1xf32, 2> |
| 152 | + } |
| 153 | + %async_token_14 = air.execute [%async_token_13] { |
| 154 | + scf.for %arg22 = %c0_8 to %c64_9 step %c1_10 { |
| 155 | + %8 = memref.load %subview[%c0_8, %arg22] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32> |
| 156 | + %9 = memref.load %results_12[%c0_8] : memref<1xf32, 2> |
| 157 | + %10 = arith.extf %8 : bf16 to f32 |
| 158 | + %11 = arith.mulf %10, %10 : f32 |
| 159 | + %12 = arith.addf %11, %9 : f32 |
| 160 | + memref.store %12, %results_12[%c0_8] : memref<1xf32, 2> |
| 161 | + } |
| 162 | + } |
| 163 | + %async_token_15, %results_16 = air.execute -> (memref<1x64xbf16, 2>) { |
| 164 | + %alloc = memref.alloc() : memref<1x64xbf16, 2> |
| 165 | + air.execute_terminator %alloc : memref<1x64xbf16, 2> |
| 166 | + } |
| 167 | + %async_token_17 = air.execute [%async_token_15, %async_token_14] { |
| 168 | + scf.for %arg22 = %c0_8 to %c64_9 step %c1_10 { |
| 169 | + %8 = memref.load %subview[%c0_8, %arg22] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32> |
| 170 | + %9 = memref.load %results_12[%c0_8] : memref<1xf32, 2> |
| 171 | + %10 = arith.divf %9, %cst_6 : f32 |
| 172 | + %11 = arith.addf %10, %cst : f32 |
| 173 | + %12 = math.rsqrt %11 : f32 |
| 174 | + %13 = arith.extf %8 : bf16 to f32 |
| 175 | + %14 = arith.mulf %13, %12 : f32 |
| 176 | + %15 = arith.truncf %14 : f32 to bf16 |
| 177 | + memref.store %15, %results_16[%c0_8, %arg22] : memref<1x64xbf16, 2> |
| 178 | + } |
| 179 | + } |
| 180 | + %async_token_18 = air.execute [%async_token_17] { |
| 181 | + memref.dealloc %results_12 : memref<1xf32, 2> |
| 182 | + } |
| 183 | + %async_token_19 = air.execute [%async_token_17] { |
| 184 | + memref.dealloc %results_16 : memref<1x64xbf16, 2> |
| 185 | + } |
| 186 | + } |
| 187 | + %7 = air.channel.put async [%async_token_3] @channel_1[] (%results_4[] [] []) {id = 4 : i32} : (memref<2x64xbf16, 1>) |
| 188 | + %async_token_5 = air.execute [%7] { |
| 189 | + memref.dealloc %results_4 : memref<2x64xbf16, 1> |
| 190 | + } |
| 191 | + air.wait_all [%6, %async_token_5] {air.segment_end} |
| 192 | + } |
| 193 | + } |
| 194 | + return |
| 195 | + } |
| 196 | +} |
0 commit comments