@@ -16,8 +16,14 @@ func.func @transpose(%buffer : memref<2x4x16xf32>) -> tensor<4x16x2xf32> {
1616// CHECK-NOT: iree_linalg_ext.map_load
1717// FOLD-LABEL: @transpose
1818// FOLD-SAME: %[[BUFFER:.+]]:
19+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
20+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<4x16x2xf32>
1921// FOLD-NOT: linalg.transpose
20- // FOLD: iree_linalg_ext.map_load
22+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
23+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
24+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
25+ // FOLD: iree_linalg_ext.yield %[[IDX2]], %[[IDX0]], %[[IDX1]], {{.*}} : index, index, index, f32
26+ // FOLD: } : tensor<2x4x16xf32> into tensor<4x16x2xf32> -> tensor<4x16x2xf32>
2127
2228// -----
2329
@@ -33,8 +39,15 @@ func.func @expand_shape(%buffer : memref<8x16xf32>) -> tensor<2x4x16xf32> {
3339// CHECK-NOT: iree_linalg_ext.map_load
3440// FOLD-LABEL: @expand_shape
3541// FOLD-SAME: %[[BUFFER:.+]]:
42+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
43+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<2x4x16xf32>
3644// FOLD-NOT: tensor.expand_shape
37- // FOLD: iree_linalg_ext.map_load
45+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
46+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
47+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
48+ // FOLD: %[[LINEARIZE:.+]] = affine.linearize_index disjoint [%[[IDX0]], %[[IDX1]]] by (2, 4) : index
49+ // FOLD: iree_linalg_ext.yield %[[LINEARIZE]], %[[IDX2]], {{.*}} : index, index, f32
50+ // FOLD: } : tensor<8x16xf32> into tensor<2x4x16xf32> -> tensor<2x4x16xf32>
3851
3952// -----
4053
@@ -50,8 +63,15 @@ func.func @collapse_shape(%buffer : memref<2x4x16xf32>) -> tensor<8x16xf32> {
5063// CHECK-NOT: iree_linalg_ext.map_load
5164// FOLD-LABEL: @collapse_shape
5265// FOLD-SAME: %[[BUFFER:.+]]:
66+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
67+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<8x16xf32>
5368// FOLD-NOT: tensor.collapse_shape
54- // FOLD: iree_linalg_ext.map_load
69+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
70+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
71+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
72+ // FOLD: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IDX0]] into (2, 4) : index, index
73+ // FOLD: iree_linalg_ext.yield %[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[IDX1]], {{.*}} : index, index, index, f32
74+ // FOLD: } : tensor<2x4x16xf32> into tensor<8x16xf32> -> tensor<8x16xf32>
5575
5676// -----
5777
@@ -67,8 +87,16 @@ func.func @extract_slice(%buffer : memref<64xf32>) -> tensor<16xf32> {
6787// CHECK-NOT: iree_linalg_ext.map_load
6888// FOLD-LABEL: @extract_slice
6989// FOLD-SAME: %[[BUFFER:.+]]:
90+ // FOLD-DAG: %[[C8:.+]] = arith.constant 8 : index
91+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
92+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<16xf32>
7093// FOLD-NOT: tensor.extract_slice
71- // FOLD: iree_linalg_ext.map_load
94+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
95+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
96+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index):
97+ // FOLD: %[[NEW_IDX:.+]] = arith.addi %[[IDX0]], %[[C8]] overflow<nsw> : index
98+ // FOLD: iree_linalg_ext.yield %[[NEW_IDX]], {{.*}} : index, f32
99+ // FOLD: } : tensor<64xf32> into tensor<16xf32> -> tensor<16xf32>
72100
73101// -----
74102
@@ -88,9 +116,15 @@ func.func @copy_transpose(%buffer : memref<4x16xf32>) -> tensor<16x4xf32> {
88116// CHECK-NOT: iree_linalg_ext.map_load
89117// FOLD-LABEL: @copy_transpose
90118// FOLD-SAME: %[[BUFFER:.+]]:
119+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
120+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<16x4xf32>
91121// FOLD-NOT: linalg.copy
92122// FOLD-NOT: linalg.transpose
93- // FOLD: iree_linalg_ext.map_load
123+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
124+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
125+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
126+ // FOLD: iree_linalg_ext.yield %[[IDX1]], %[[IDX0]], {{.*}} : index, index, f32
127+ // FOLD: } : tensor<4x16xf32> into tensor<16x4xf32> -> tensor<16x4xf32>
94128
95129// -----
96130
@@ -110,8 +144,15 @@ func.func @pad_zero_low(%buffer : memref<1x50x64xf32>) -> tensor<1x64x64xf32> {
110144// CHECK-NOT: iree_linalg_ext.map_load
111145// FOLD-LABEL: @pad_zero_low
112146// FOLD-SAME: %[[BUFFER:.+]]:
147+ // FOLD-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
148+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
149+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<1x64x64xf32>
113150// FOLD-NOT: tensor.pad
114- // FOLD: iree_linalg_ext.map_load
151+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
152+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
153+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
154+ // FOLD: iree_linalg_ext.yield %[[IDX0]], %[[IDX1]], %[[IDX2]], %[[CST]] : index, index, index, f32
155+ // FOLD: } : tensor<1x50x64xf32> into tensor<1x64x64xf32> -> tensor<1x64x64xf32>
115156
116157// -----
117158
@@ -131,8 +172,19 @@ func.func @pad_non_zero_low(%buffer : memref<8x16xf32>) -> tensor<10x20xf32> {
131172// CHECK-NOT: iree_linalg_ext.map_load
132173// FOLD-LABEL: @pad_non_zero_low
133174// FOLD-SAME: %[[BUFFER:.+]]:
175+ // FOLD-DAG: %[[C1:.+]] = arith.constant 1 : index
176+ // FOLD-DAG: %[[C2:.+]] = arith.constant 2 : index
177+ // FOLD-DAG: %[[CST:.+]] = arith.constant 1.000000e+00 : f32
178+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
179+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<10x20xf32>
134180// FOLD-NOT: tensor.pad
135- // FOLD: iree_linalg_ext.map_load
181+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
182+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
183+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
184+ // FOLD: %[[NEW_IDX0:.+]] = arith.subi %[[IDX0]], %[[C1]] overflow<nsw> : index
185+ // FOLD: %[[NEW_IDX1:.+]] = arith.subi %[[IDX1]], %[[C2]] overflow<nsw> : index
186+ // FOLD: iree_linalg_ext.yield %[[NEW_IDX0]], %[[NEW_IDX1]], %[[CST]] : index, index, f32
187+ // FOLD: } : tensor<8x16xf32> into tensor<10x20xf32> -> tensor<10x20xf32>
136188
137189// -----
138190
@@ -163,9 +215,19 @@ func.func @nested_pads_different_values(%buffer : memref<8x16xf32>) -> tensor<2x
163215// CHECK: tensor.pad
164216// CHECK: tensor.yield %[[CST1]] : f32
165217// FOLD-LABEL: @nested_pads_different_values
218+ // FOLD-SAME: %[[BUFFER:.+]]:
166219// With test-combine-non-complex-chains, first pad folds; second pad still remains.
167- // FOLD: iree_linalg_ext.map_load
168- // FOLD: tensor.pad
220+ // FOLD-DAG: %[[CST0:.+]] = arith.constant 0.000000e+00 : f32
221+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
222+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<2x5x20xf32>
223+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
224+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
225+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
226+ // FOLD: {{.*}} = affine.linearize_index disjoint [%[[IDX0]], %[[IDX1]]] by (2, 5) : index
227+ // FOLD: {{.*}} = arith.subi {{.*}} overflow<nsw> : index
228+ // FOLD: iree_linalg_ext.yield {{.*}}, %[[CST0]] : index, index, f32
229+ // FOLD: } : tensor<8x16xf32> into tensor<2x5x20xf32> -> tensor<2x5x20xf32>
230+ // FOLD: tensor.pad %[[MAP_LOAD]]
169231
170232// -----
171233
@@ -191,8 +253,14 @@ func.func @broadcast_generic(%buffer : memref<2x3xf32>) -> tensor<2x3x4x5xf32> {
191253// CHECK-NOT: iree_linalg_ext.map_load
192254// FOLD-LABEL: @broadcast_generic
193255// FOLD-SAME: %[[BUFFER:.+]]:
256+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
257+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<2x3x4x5xf32>
194258// FOLD-NOT: linalg.generic
195- // FOLD: iree_linalg_ext.map_load
259+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
260+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
261+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[IDX3:.+]]: index):
262+ // FOLD: iree_linalg_ext.yield %[[IDX0]], %[[IDX1]], {{.*}} : index, index, f32
263+ // FOLD: } : tensor<2x3xf32> into tensor<2x3x4x5xf32> -> tensor<2x3x4x5xf32>
196264
197265// -----
198266
@@ -209,8 +277,14 @@ func.func @broadcast_named(%buffer : memref<2x3xf32>) -> tensor<2x3x4x5xf32> {
209277// CHECK-NOT: iree_linalg_ext.map_load
210278// FOLD-LABEL: @broadcast_named
211279// FOLD-SAME: %[[BUFFER:.+]]:
280+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
281+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<2x3x4x5xf32>
212282// FOLD-NOT: linalg.broadcast
213- // FOLD: iree_linalg_ext.map_load
283+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
284+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
285+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[IDX3:.+]]: index):
286+ // FOLD: iree_linalg_ext.yield %[[IDX0]], %[[IDX1]], {{.*}} : index, index, f32
287+ // FOLD: } : tensor<2x3xf32> into tensor<2x3x4x5xf32> -> tensor<2x3x4x5xf32>
214288
215289// -----
216290
@@ -234,7 +308,14 @@ func.func @complex_relayout_chain(%buffer : memref<8x16xf32>) -> tensor<16x8xf32
234308// CHECK: iree_linalg_ext.yield %[[IDX1]], %[[IDX0]], {{.*}} : index, index, f32
235309// CHECK: } : tensor<8x16xf32> into tensor<16x8xf32> -> tensor<16x8xf32>
236310// FOLD-LABEL: @complex_relayout_chain
237- // FOLD: iree_linalg_ext.map_load
311+ // FOLD-SAME: %[[BUFFER:.+]]:
312+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
313+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<16x8xf32>
314+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
315+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
316+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
317+ // FOLD: iree_linalg_ext.yield %[[IDX1]], %[[IDX0]], {{.*}} : index, index, f32
318+ // FOLD: } : tensor<8x16xf32> into tensor<16x8xf32> -> tensor<16x8xf32>
238319
239320// -----
240321
@@ -259,7 +340,15 @@ func.func @complex_chain_reshape_and_transpose(%buffer : memref<4x8xf32>) -> ten
259340// CHECK: iree_linalg_ext.yield %[[LINEAR]], %[[IDX0]], %0 : index, index, f32
260341// CHECK: } : tensor<4x8xf32> into tensor<8x2x2xf32> -> tensor<8x2x2xf32>
261342// FOLD-LABEL: @complex_chain_reshape_and_transpose
262- // FOLD: iree_linalg_ext.map_load
343+ // FOLD-SAME: %[[BUFFER:.+]]:
344+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
345+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<8x2x2xf32>
346+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
347+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
348+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
349+ // FOLD: %[[LINEAR:.+]] = affine.linearize_index disjoint [%[[IDX1]], %[[IDX2]]] by (2, 2) : index
350+ // FOLD: iree_linalg_ext.yield %[[LINEAR]], %[[IDX0]], {{.*}} : index, index, f32
351+ // FOLD: } : tensor<4x8xf32> into tensor<8x2x2xf32> -> tensor<8x2x2xf32>
263352
264353// -----
265354
@@ -305,5 +394,15 @@ func.func @fold_broadcast_pad_expand_shape(%buffer : memref<2x64xf32>, %batch :
305394// CHECK: } : tensor<2x64xf32> into tensor<1x4x16x4x2x16xf32> -> tensor<1x4x16x4x2x16xf32>
306395// CHECK: linalg.copy ins({{.*}}) outs(%[[MAP_LOAD]]
307396// FOLD-LABEL: @fold_broadcast_pad_expand_shape
308- // FOLD: iree_linalg_ext.map_load
397+ // FOLD-SAME: %[[BUFFER:.+]]: memref<2x64xf32>, %[[BATCH:.+]]: index
398+ // FOLD-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
399+ // FOLD: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
400+ // FOLD: %[[DEST:.+]] = tensor.empty() : tensor<1x4x16x4x2x16xf32>
401+ // FOLD: %[[MAP_LOAD:.+]] = iree_linalg_ext.map_load
402+ // FOLD-SAME: %[[SOURCE]] into %[[DEST]] {
403+ // FOLD-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[IDX3:.+]]: index, %[[IDX4:.+]]: index, %[[IDX5:.+]]: index):
404+ // FOLD: {{.*}} = affine.linearize_index disjoint [%[[IDX1]], %[[IDX2]], %[[IDX3]], %[[IDX4]], %[[IDX5]]] by (4, 16, 4, 2, 16) : index
405+ // FOLD: {{.*}}:3 = affine.delinearize_index {{.*}} into (64, 4, 32) : index, index, index
406+ // FOLD: iree_linalg_ext.yield %[[BATCH]], {{.*}}, %[[CST]] : index, index, f32
407+ // FOLD: } : tensor<2x64xf32> into tensor<1x4x16x4x2x16xf32> -> tensor<1x4x16x4x2x16xf32>
309408// FOLD: linalg.copy
0 commit comments