Skip to content

Commit 6dd43c4

Browse files
erwei-xilinxclaude
andcommitted
rms_norm: inner tiling at 16, no LLVM errors, blocked on L1 DMA staging
- Added inner vectorization tiling at 16-lane width (fixes 64xf32 G_FMUL) - Added fuse_multi_op_linalg sq→reduce (reduces L1 alloc count) - Added promote_tensor to 2 for reduce input - No more LLVM legalization errors (f32 mulf at 16-wide is legal) - Blocked: __air_herd_arg linker error from direct L2 subview access inside herd body. Need explicit L1 staging via DMA. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent deaf3e7 commit 6dd43c4

18 files changed

Lines changed: 1984 additions & 6 deletions
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
#loop_annotation = #llvm.loop_annotation<mustProgress = true>
2+
module {
3+
aie.device(npu2) @rms_norm_kernel_0 {
4+
%shim_noc_tile_0_0 = aie.tile(0, 0)
5+
%shim_noc_tile_1_0 = aie.tile(1, 0)
6+
%mem_tile_0_1 = aie.tile(0, 1)
7+
%mem_tile_1_1 = aie.tile(1, 1)
8+
%tile_0_2 = aie.tile(0, 2)
9+
%tile_0_3 = aie.tile(0, 3)
10+
%lock_1_1 = aie.lock(%mem_tile_1_1, 1) {init = 1 : i32}
11+
%lock_1_1_0 = aie.lock(%mem_tile_1_1, 0) {init = 0 : i32}
12+
%buf5 = aie.buffer(%mem_tile_0_1) {sym_name = "buf5"} : memref<2x64xbf16, 1 : i32>
13+
%buf4 = aie.buffer(%mem_tile_1_1) {sym_name = "buf4"} : memref<2x64xbf16, 1>
14+
%buf3 = aie.buffer(%tile_0_3) {sym_name = "buf3"} : memref<1xf32, 2>
15+
%buf2 = aie.buffer(%tile_0_3) {sym_name = "buf2"} : memref<1x64xbf16, 2>
16+
%buf1 = aie.buffer(%tile_0_2) {sym_name = "buf1"} : memref<1xf32, 2>
17+
%buf0 = aie.buffer(%tile_0_2) {sym_name = "buf0"} : memref<1x64xbf16, 2>
18+
memref.global "public" @__air_herd_arg_1 : memref<2x64xbf16, 1 : i32>
19+
%core_0_3 = aie.core(%tile_0_3) {
20+
%c64 = arith.constant 64 : index
21+
%cst = arith.constant 0.000000e+00 : f32
22+
%cst_1 = arith.constant 6.400000e+01 : f32
23+
%cst_2 = arith.constant 9.99999974E-6 : f32
24+
%c1 = arith.constant 1 : index
25+
%c0 = arith.constant 0 : index
26+
cf.br ^bb1
27+
^bb1: // 2 preds: ^bb0, ^bb2
28+
%0 = memref.get_global @__air_herd_arg_1 : memref<2x64xbf16, 1 : i32>
29+
cf.br ^bb2
30+
^bb2: // pred: ^bb1
31+
%subview = memref.subview %0[%c1, 0] [1, 64] [1, 1] : memref<2x64xbf16, 1 : i32> to memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
32+
memref.store %cst, %buf3[%c0] : memref<1xf32, 2>
33+
scf.for %arg0 = %c0 to %c64 step %c1 {
34+
%1 = memref.load %subview[%c0, %arg0] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
35+
%2 = memref.load %buf3[%c0] : memref<1xf32, 2>
36+
%3 = arith.extf %1 : bf16 to f32
37+
%4 = arith.mulf %3, %3 : f32
38+
%5 = arith.addf %4, %2 : f32
39+
memref.store %5, %buf3[%c0] : memref<1xf32, 2>
40+
} {loop_annotation = #loop_annotation}
41+
scf.for %arg0 = %c0 to %c64 step %c1 {
42+
%1 = memref.load %subview[%c0, %arg0] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
43+
%2 = memref.load %buf3[%c0] : memref<1xf32, 2>
44+
%3 = arith.divf %2, %cst_1 : f32
45+
%4 = arith.addf %3, %cst_2 : f32
46+
%5 = math.rsqrt %4 : f32
47+
%6 = arith.extf %1 : bf16 to f32
48+
%7 = arith.mulf %6, %5 : f32
49+
%8 = arith.truncf %7 : f32 to bf16
50+
memref.store %8, %buf2[%c0, %arg0] : memref<1x64xbf16, 2>
51+
} {loop_annotation = #loop_annotation}
52+
cf.br ^bb1
53+
}
54+
memref.global "public" @__air_herd_arg : memref<2x64xbf16, 1 : i32>
55+
%core_0_2 = aie.core(%tile_0_2) {
56+
%c64 = arith.constant 64 : index
57+
%cst = arith.constant 0.000000e+00 : f32
58+
%cst_1 = arith.constant 6.400000e+01 : f32
59+
%cst_2 = arith.constant 9.99999974E-6 : f32
60+
%c1 = arith.constant 1 : index
61+
%c0 = arith.constant 0 : index
62+
cf.br ^bb1
63+
^bb1: // 2 preds: ^bb0, ^bb2
64+
%0 = memref.get_global @__air_herd_arg : memref<2x64xbf16, 1 : i32>
65+
cf.br ^bb2
66+
^bb2: // pred: ^bb1
67+
%subview = memref.subview %0[%c0, 0] [1, 64] [1, 1] : memref<2x64xbf16, 1 : i32> to memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
68+
memref.store %cst, %buf1[%c0] : memref<1xf32, 2>
69+
scf.for %arg0 = %c0 to %c64 step %c1 {
70+
%1 = memref.load %subview[%c0, %arg0] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
71+
%2 = memref.load %buf1[%c0] : memref<1xf32, 2>
72+
%3 = arith.extf %1 : bf16 to f32
73+
%4 = arith.mulf %3, %3 : f32
74+
%5 = arith.addf %4, %2 : f32
75+
memref.store %5, %buf1[%c0] : memref<1xf32, 2>
76+
} {loop_annotation = #loop_annotation}
77+
scf.for %arg0 = %c0 to %c64 step %c1 {
78+
%1 = memref.load %subview[%c0, %arg0] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
79+
%2 = memref.load %buf1[%c0] : memref<1xf32, 2>
80+
%3 = arith.divf %2, %cst_1 : f32
81+
%4 = arith.addf %3, %cst_2 : f32
82+
%5 = math.rsqrt %4 : f32
83+
%6 = arith.extf %1 : bf16 to f32
84+
%7 = arith.mulf %6, %5 : f32
85+
%8 = arith.truncf %7 : f32 to bf16
86+
memref.store %8, %buf0[%c0, %arg0] : memref<1x64xbf16, 2>
87+
} {loop_annotation = #loop_annotation}
88+
cf.br ^bb1
89+
}
90+
air.channel @channel_0 []
91+
air.channel @channel_1 []
92+
aie.flow(%shim_noc_tile_0_0, DMA : 0, %mem_tile_0_1, DMA : 0)
93+
aie.flow(%mem_tile_1_1, DMA : 0, %shim_noc_tile_1_0, DMA : 0)
94+
%memtile_dma_1_1 = aie.memtile_dma(%mem_tile_1_1) {
95+
%0 = aie.dma_start(MM2S, 0, ^bb1, ^bb2)
96+
^bb1: // 2 preds: ^bb0, ^bb1
97+
aie.use_lock(%lock_1_1_0, AcquireGreaterEqual, 1)
98+
aie.dma_bd(%buf4 : memref<2x64xbf16, 1>, 0, 128) {task_id = 0 : i32}
99+
aie.use_lock(%lock_1_1, Release, 1)
100+
aie.next_bd ^bb1
101+
^bb2: // pred: ^bb0
102+
aie.end
103+
}
104+
aie.shim_dma_allocation @air_channel_1(%shim_noc_tile_1_0, S2MM, 0)
105+
aie.shim_dma_allocation @air_channel_0(%shim_noc_tile_0_0, MM2S, 0)
106+
} {dlti.dl_spec = #dlti.dl_spec<index = 32 : i64>}
107+
airrt.module_metadata{
108+
airrt.segment_metadata attributes {dma_allocations = [{channel = 2 : i64, col = 0 : i64, id = 3 : i64, location = 0 : i64, row = -1 : i64}], sym_name = "rms_norm_kernel_0"}{
109+
airrt.herd_metadata {dma_allocations = [], loc_x = 0 : i64, loc_y = 2 : i64, size_x = 1 : i64, size_y = 2 : i64, sym_name = "herd_0"}
110+
}
111+
}
112+
air.channel @channel_0 []
113+
air.channel @channel_1 []
114+
func.func @rms_norm_kernel(%arg0: memref<*xbf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xbf16> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
115+
%c1 = arith.constant 1 : index
116+
%c16 = arith.constant 16 : index
117+
%0 = air.launch async (%arg9, %arg10, %arg11) in (%arg12=%c16, %arg13=%c1, %arg14=%c1) args(%arg15=%arg0, %arg16=%arg1) : memref<*xbf16>, memref<*xbf16> attributes {id = 1 : i32} {
118+
%c0 = arith.constant 0 : index
119+
%c64 = arith.constant 64 : index
120+
%c2 = arith.constant 2 : index
121+
%c1_0 = arith.constant 1 : index
122+
%c128 = arith.constant 128 : index
123+
%1 = arith.muli %arg10, %c128 : index
124+
%2 = air.channel.put async @channel_0[] (%arg15[%c0, %1] [%c2, %c64] [%c64, %c1_0]) {id = 1 : i32, metadataArray = [{base = "air_channel_0", index = 0 : i32}]} : (memref<*xbf16>)
125+
%3 = air.channel.get async @channel_1[] (%arg16[%c0, %1] [%c2, %c64] [%c64, %c1_0]) {id = 2 : i32, metadataArray = [{base = "air_channel_1", index = 0 : i32}]} : (memref<*xbf16>)
126+
%4 = air.segment @rms_norm_kernel_0 async attributes {id = 2 : i32, x_loc = 0 : i64, x_size = 8 : i64, y_loc = 2 : i64, y_size = 6 : i64} {
127+
%c2_1 = arith.constant 2 : index
128+
%c1_2 = arith.constant 1 : index
129+
%async_token, %results = air.execute -> (memref<2x64xbf16, 1 : i32>) {
130+
%alloc = memref.alloc() : memref<2x64xbf16, 1 : i32>
131+
air.execute_terminator %alloc : memref<2x64xbf16, 1 : i32>
132+
}
133+
%5 = air.channel.get async [%async_token] @channel_0[] (%results[] [] []) {id = 3 : i32} : (memref<2x64xbf16, 1 : i32>)
134+
%async_token_3, %results_4 = air.execute -> (memref<2x64xbf16, 1>) {
135+
%alloc = memref.alloc() : memref<2x64xbf16, 1>
136+
air.execute_terminator %alloc : memref<2x64xbf16, 1>
137+
}
138+
%6 = air.herd @herd_0 async [%5] tile (%arg17, %arg18) in (%arg19=%c1_2, %arg20=%c2_1) args(%arg21=%results) : memref<2x64xbf16, 1 : i32> attributes {id = 3 : i32, x_loc = 0 : i64, y_loc = 2 : i64} {
139+
%cst = arith.constant 9.99999974E-6 : f32
140+
%cst_6 = arith.constant 6.400000e+01 : f32
141+
%cst_7 = arith.constant 0.000000e+00 : f32
142+
%c0_8 = arith.constant 0 : index
143+
%c64_9 = arith.constant 64 : index
144+
%c1_10 = arith.constant 1 : index
145+
%subview = memref.subview %arg21[%arg18, 0] [1, 64] [1, 1] : memref<2x64xbf16, 1 : i32> to memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
146+
%async_token_11, %results_12 = air.execute -> (memref<1xf32, 2>) {
147+
%alloc = memref.alloc() : memref<1xf32, 2>
148+
air.execute_terminator %alloc : memref<1xf32, 2>
149+
}
150+
%async_token_13 = air.execute [%async_token_11] {
151+
memref.store %cst_7, %results_12[%c0_8] : memref<1xf32, 2>
152+
}
153+
%async_token_14 = air.execute [%async_token_13] {
154+
scf.for %arg22 = %c0_8 to %c64_9 step %c1_10 {
155+
%8 = memref.load %subview[%c0_8, %arg22] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
156+
%9 = memref.load %results_12[%c0_8] : memref<1xf32, 2>
157+
%10 = arith.extf %8 : bf16 to f32
158+
%11 = arith.mulf %10, %10 : f32
159+
%12 = arith.addf %11, %9 : f32
160+
memref.store %12, %results_12[%c0_8] : memref<1xf32, 2>
161+
}
162+
}
163+
%async_token_15, %results_16 = air.execute -> (memref<1x64xbf16, 2>) {
164+
%alloc = memref.alloc() : memref<1x64xbf16, 2>
165+
air.execute_terminator %alloc : memref<1x64xbf16, 2>
166+
}
167+
%async_token_17 = air.execute [%async_token_15, %async_token_14] {
168+
scf.for %arg22 = %c0_8 to %c64_9 step %c1_10 {
169+
%8 = memref.load %subview[%c0_8, %arg22] : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
170+
%9 = memref.load %results_12[%c0_8] : memref<1xf32, 2>
171+
%10 = arith.divf %9, %cst_6 : f32
172+
%11 = arith.addf %10, %cst : f32
173+
%12 = math.rsqrt %11 : f32
174+
%13 = arith.extf %8 : bf16 to f32
175+
%14 = arith.mulf %13, %12 : f32
176+
%15 = arith.truncf %14 : f32 to bf16
177+
memref.store %15, %results_16[%c0_8, %arg22] : memref<1x64xbf16, 2>
178+
}
179+
}
180+
%async_token_18 = air.execute [%async_token_17] {
181+
memref.dealloc %results_12 : memref<1xf32, 2>
182+
}
183+
%async_token_19 = air.execute [%async_token_17] {
184+
memref.dealloc %results_16 : memref<1x64xbf16, 2>
185+
}
186+
}
187+
%7 = air.channel.put async [%async_token_3] @channel_1[] (%results_4[] [] []) {id = 4 : i32} : (memref<2x64xbf16, 1>)
188+
%async_token_5 = air.execute [%7] {
189+
memref.dealloc %results_4 : memref<2x64xbf16, 1>
190+
}
191+
air.wait_all [%6, %async_token_5] {air.segment_end}
192+
}
193+
}
194+
return
195+
}
196+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#map = affine_map<(d0, d1) -> (d0, d1)>
2+
#map1 = affine_map<(d0, d1) -> (d0)>
3+
module {
4+
func.func @rms_norm_kernel(%arg0: memref<*xbf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xbf16> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
5+
%c1 = arith.constant 1 : index
6+
%c16 = arith.constant 16 : index
7+
air.launch (%arg9, %arg10, %arg11) in (%arg12=%c16, %arg13=%c1, %arg14=%c1) args(%arg15=%arg0, %arg16=%arg1) : memref<*xbf16>, memref<*xbf16> {
8+
air.segment @rms_norm_kernel_0 args(%arg17=%arg10, %arg18=%arg15, %arg19=%arg16) : index, memref<*xbf16>, memref<*xbf16> {
9+
%c0 = arith.constant 0 : index
10+
%c64 = arith.constant 64 : index
11+
%c2 = arith.constant 2 : index
12+
%c1_0 = arith.constant 1 : index
13+
%c128 = arith.constant 128 : index
14+
%0 = arith.muli %arg17, %c128 : index
15+
%alloc = memref.alloc() : memref<2x64xbf16, 1 : i32>
16+
air.dma_memcpy_nd (%alloc[] [] [], %arg18[%c0, %0] [%c2, %c64] [%c64, %c1_0]) {id = 1 : i32} : (memref<2x64xbf16, 1 : i32>, memref<*xbf16>)
17+
%alloc_1 = memref.alloc() : memref<2x64xbf16, 1>
18+
air.herd @herd_0 tile (%arg20, %arg21) in (%arg22=%c2, %arg23=%c1_0) args(%arg24=%alloc, %arg25=%alloc_1) : memref<2x64xbf16, 1 : i32>, memref<2x64xbf16, 1> {
19+
%c64_2 = arith.constant 64 : index
20+
%c1_3 = arith.constant 1 : index
21+
%c0_4 = arith.constant 0 : index
22+
%cst = arith.constant 0.000000e+00 : f32
23+
%cst_5 = arith.constant 6.400000e+01 : f32
24+
%cst_6 = arith.constant 9.99999974E-6 : f32
25+
%subview = memref.subview %arg24[%arg20, 0] [1, 64] [1, 1] : memref<2x64xbf16, 1 : i32> to memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
26+
%alloc_7 = memref.alloc() : memref<1xf32, 2>
27+
memref.store %cst, %alloc_7[%c0_4] : memref<1xf32, 2>
28+
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%subview : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>) outs(%alloc_7 : memref<1xf32, 2>) {
29+
^bb0(%in: bf16, %out: f32):
30+
%1 = arith.extf %in : bf16 to f32
31+
%2 = arith.mulf %1, %1 : f32
32+
%3 = arith.addf %2, %out : f32
33+
linalg.yield %3 : f32
34+
}
35+
%alloc_8 = memref.alloc() : memref<1x64xbf16, 2>
36+
linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%subview, %alloc_7 : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>, memref<1xf32, 2>) outs(%alloc_8 : memref<1x64xbf16, 2>) {
37+
^bb0(%in: bf16, %in_9: f32, %out: bf16):
38+
%1 = arith.divf %in_9, %cst_5 : f32
39+
%2 = arith.addf %1, %cst_6 : f32
40+
%3 = math.rsqrt %2 : f32
41+
%4 = arith.extf %in : bf16 to f32
42+
%5 = arith.mulf %4, %3 : f32
43+
%6 = arith.truncf %5 : f32 to bf16
44+
linalg.yield %6 : bf16
45+
}
46+
memref.dealloc %alloc_7 : memref<1xf32, 2>
47+
memref.dealloc %alloc_8 : memref<1x64xbf16, 2>
48+
}
49+
air.dma_memcpy_nd (%arg19[%c0, %0] [%c2, %c64] [%c64, %c1_0], %alloc_1[] [] []) {id = 2 : i32} : (memref<*xbf16>, memref<2x64xbf16, 1>)
50+
memref.dealloc %alloc_1 : memref<2x64xbf16, 1>
51+
}
52+
}
53+
return
54+
}
55+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#map = affine_map<(d0, d1) -> (d0, d1)>
2+
#map1 = affine_map<(d0, d1) -> (d0)>
3+
module {
4+
func.func @rms_norm_kernel(%arg0: memref<*xbf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xbf16> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
5+
%c1 = arith.constant 1 : index
6+
%c16 = arith.constant 16 : index
7+
air.launch (%arg9, %arg10, %arg11) in (%arg12=%c16, %arg13=%c1, %arg14=%c1) args(%arg15=%arg0, %arg16=%arg1) : memref<*xbf16>, memref<*xbf16> {
8+
air.segment @rms_norm_kernel_0 args(%arg17=%arg10, %arg18=%arg15, %arg19=%arg16) : index, memref<*xbf16>, memref<*xbf16> {
9+
%c0 = arith.constant 0 : index
10+
%c64 = arith.constant 64 : index
11+
%c2 = arith.constant 2 : index
12+
%c1_0 = arith.constant 1 : index
13+
%c128 = arith.constant 128 : index
14+
%0 = arith.muli %arg17, %c128 : index
15+
%alloc = memref.alloc() : memref<2x64xbf16, 1 : i32>
16+
air.dma_memcpy_nd (%alloc[] [] [], %arg18[%c0, %0] [%c2, %c64] [%c64, %c1_0]) {id = 1 : i32} : (memref<2x64xbf16, 1 : i32>, memref<*xbf16>)
17+
%alloc_1 = memref.alloc() : memref<2x64xbf16, 1>
18+
air.herd @herd_0 tile (%arg20, %arg21) in (%arg22=%c2, %arg23=%c1_0) args(%arg24=%alloc, %arg25=%alloc_1) : memref<2x64xbf16, 1 : i32>, memref<2x64xbf16, 1> {
19+
%c64_2 = arith.constant 64 : index
20+
%c1_3 = arith.constant 1 : index
21+
%c0_4 = arith.constant 0 : index
22+
%cst = arith.constant 0.000000e+00 : f32
23+
%cst_5 = arith.constant 6.400000e+01 : f32
24+
%cst_6 = arith.constant 9.99999974E-6 : f32
25+
%subview = memref.subview %arg24[%arg20, 0] [1, 64] [1, 1] : memref<2x64xbf16, 1 : i32> to memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>
26+
%alloc_7 = memref.alloc() : memref<1xf32, 2>
27+
memref.store %cst, %alloc_7[%c0_4] : memref<1xf32, 2>
28+
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%subview : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>) outs(%alloc_7 : memref<1xf32, 2>) {
29+
^bb0(%in: bf16, %out: f32):
30+
%1 = arith.extf %in : bf16 to f32
31+
%2 = arith.mulf %1, %1 : f32
32+
%3 = arith.addf %2, %out : f32
33+
linalg.yield %3 : f32
34+
}
35+
%alloc_8 = memref.alloc() : memref<1x64xbf16, 2>
36+
linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%subview, %alloc_7 : memref<1x64xbf16, strided<[64, 1], offset: ?>, 1 : i32>, memref<1xf32, 2>) outs(%alloc_8 : memref<1x64xbf16, 2>) {
37+
^bb0(%in: bf16, %in_9: f32, %out: bf16):
38+
%1 = arith.divf %in_9, %cst_5 : f32
39+
%2 = arith.addf %1, %cst_6 : f32
40+
%3 = math.rsqrt %2 : f32
41+
%4 = arith.extf %in : bf16 to f32
42+
%5 = arith.mulf %4, %3 : f32
43+
%6 = arith.truncf %5 : f32 to bf16
44+
linalg.yield %6 : bf16
45+
}
46+
memref.dealloc %alloc_7 : memref<1xf32, 2>
47+
memref.dealloc %alloc_8 : memref<1x64xbf16, 2>
48+
}
49+
air.dma_memcpy_nd (%arg19[%c0, %0] [%c2, %c64] [%c64, %c1_0], %alloc_1[] [] []) {id = 2 : i32} : (memref<*xbf16>, memref<2x64xbf16, 1>)
50+
memref.dealloc %alloc_1 : memref<2x64xbf16, 1>
51+
}
52+
}
53+
return
54+
}
55+
}

0 commit comments

Comments
 (0)