@@ -12,7 +12,7 @@ func.func @fold_transpose(%buffer : memref<2x4x16xf32>) -> tensor<4x16x2xf32> {
1212// CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
1313// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<4x16x2xf32>
1414// CHECK-NOT: linalg.transpose
15- // CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
15+ // CHECK: iree_linalg_ext.map_load
1616// CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
1717// CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
1818// For perm=[1,2,0], inverse_perm=[2,0,1]: output[i,j,k] = input[k,i,j], so yield [idx2,idx0,idx1].
@@ -28,16 +28,9 @@ func.func @fold_expand_shape(%buffer : memref<8x16xf32>) -> tensor<2x4x16xf32> {
2828}
2929// CHECK-LABEL: @fold_expand_shape
3030// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
31- // CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
32- // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2x4x16xf32>
33- // CHECK-NOT: tensor.expand_shape
34- // CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
35- // CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
36- // CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
37- // CHECK: %[[LINEARIZE:.+]] = affine.linearize_index
38- // CHECK-SAME: [%[[IDX0]], %[[IDX1]]] by (2, 4)
39- // CHECK: iree_linalg_ext.yield %[[LINEARIZE]], %[[IDX2]],
40- // CHECK: } : tensor<8x16xf32> into tensor<2x4x16xf32> -> tensor<2x4x16xf32>
31+ // CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
32+ // CHECK: tensor.expand_shape
33+ // CHECK-NOT: iree_linalg_ext.map_load
4134
4235// -----
4336
@@ -48,15 +41,9 @@ func.func @fold_collapse_shape(%buffer : memref<2x4x16xf32>) -> tensor<8x16xf32>
4841}
4942// CHECK-LABEL: @fold_collapse_shape
5043// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
51- // CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
52- // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<8x16xf32>
53- // CHECK-NOT: tensor.collapse_shape
54- // CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
55- // CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
56- // CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
57- // CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IDX0]] into (2, 4)
58- // CHECK: iree_linalg_ext.yield %[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[IDX1]],
59- // CHECK: } : tensor<2x4x16xf32> into tensor<8x16xf32> -> tensor<8x16xf32>
44+ // CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
45+ // CHECK: tensor.collapse_shape
46+ // CHECK-NOT: iree_linalg_ext.map_load
6047
6148// -----
6249
@@ -67,16 +54,9 @@ func.func @fold_extract_slice(%buffer : memref<64xf32>) -> tensor<16xf32> {
6754}
6855// CHECK-LABEL: @fold_extract_slice
6956// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
70- // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
7157// CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
72- // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<16xf32>
73- // CHECK-NOT: tensor.extract_slice
74- // CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
75- // CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
76- // CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index):
77- // CHECK: %[[NEW_IDX:.+]] = arith.addi %[[IDX0]], %[[C8]] overflow<nsw>
78- // CHECK: iree_linalg_ext.yield %[[NEW_IDX]],
79- // CHECK: } : tensor<64xf32> into tensor<16xf32> -> tensor<16xf32>
58+ // CHECK: tensor.extract_slice %[[SOURCE]][8] [16] [1]
59+ // CHECK-NOT: iree_linalg_ext.map_load
8060
8161// -----
8262
@@ -96,15 +76,14 @@ func.func @fold_copy_transpose(%buffer : memref<4x16xf32>) -> tensor<16x4xf32> {
9676// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<16x4xf32>
9777// CHECK-NOT: linalg.copy
9878// CHECK-NOT: linalg.transpose
99- // CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
79+ // CHECK: iree_linalg_ext.map_load
10080// CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
10181// CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
10282// CHECK: iree_linalg_ext.yield %[[IDX1]], %[[IDX0]],
10383// CHECK: } : tensor<4x16xf32> into tensor<16x4xf32> -> tensor<16x4xf32>
10484
10585// -----
10686
107- // Low padding is [0, 0, 0], so indices are passed through unchanged due to subi with 0.
10887func.func @fold_pad_with_zero_low_padding_offsets (%buffer : memref <1 x50 x64 xf32 >) -> tensor <1 x64 x64 xf32 > {
10988 %cst = arith.constant 0.000000e+00 : f32
11089 %source = iree_codegen.load_from_buffer %buffer : memref <1 x50 x64 xf32 > -> tensor <1 x50 x64 xf32 >
@@ -116,15 +95,9 @@ func.func @fold_pad_with_zero_low_padding_offsets(%buffer : memref<1x50x64xf32>)
11695}
11796// CHECK-LABEL: @fold_pad_with_zero_low_padding_offsets
11897// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
119- // CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
120- // CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
121- // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<1x64x64xf32>
122- // CHECK-NOT: tensor.pad
123- // CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
124- // CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
125- // CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
126- // CHECK: iree_linalg_ext.yield %[[IDX0]], %[[IDX1]], %[[IDX2]], %[[CST]] :
127- // CHECK: } : tensor<1x50x64xf32> into tensor<1x64x64xf32> -> tensor<1x64x64xf32>
98+ // CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
99+ // CHECK: tensor.pad
100+ // CHECK-NOT: iree_linalg_ext.map_load
128101
129102// -----
130103
@@ -139,19 +112,9 @@ func.func @fold_pad_with_non_zero_low_padding_offsets(%buffer : memref<8x16xf32>
139112}
140113// CHECK-LABEL: @fold_pad_with_non_zero_low_padding_offsets
141114// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
142- // CHECK-DAG: %[[CST:.+]] = arith.constant 1.000000e+00 : f32
143- // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
144- // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
145- // CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
146- // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<10x20xf32>
147- // CHECK-NOT: tensor.pad
148- // CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
149- // CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
150- // CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
151- // CHECK: %[[NEW_IDX0:.+]] = arith.subi %[[IDX0]], %[[C1]] overflow<nsw> : index
152- // CHECK: %[[NEW_IDX1:.+]] = arith.subi %[[IDX1]], %[[C2]] overflow<nsw> : index
153- // CHECK: iree_linalg_ext.yield %[[NEW_IDX0]], %[[NEW_IDX1]], %[[CST]] :
154- // CHECK: } : tensor<8x16xf32> into tensor<10x20xf32> -> tensor<10x20xf32>
115+ // CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
116+ // CHECK: tensor.pad
117+ // CHECK-NOT: iree_linalg_ext.map_load
155118
156119// -----
157120
@@ -182,3 +145,54 @@ func.func @nested_pads_different_values(%buffer : memref<8x16xf32>) -> tensor<14
182145// Second pad is NOT folded because the map_load already has a padding value.
183146// CHECK: tensor.pad
184147// CHECK: tensor.yield %[[CST1]] : f32
148+
149+ // -----
150+
151+ func.func @fold_broadcast_generic (%buffer : memref <2 x3 xf32 >) -> tensor <2 x3 x4 x5 xf32 > {
152+ %source = iree_codegen.load_from_buffer %buffer : memref <2 x3 xf32 > -> tensor <2 x3 xf32 >
153+ %init = tensor.empty () : tensor <2 x3 x4 x5 xf32 >
154+ %broadcast = linalg.generic {
155+ indexing_maps = [
156+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 )>,
157+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>
158+ ],
159+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]
160+ } ins (%source : tensor <2 x3 xf32 >) outs (%init : tensor <2 x3 x4 x5 xf32 >) {
161+ ^bb0 (%in: f32 , %out: f32 ):
162+ linalg.yield %in : f32
163+ } -> tensor <2 x3 x4 x5 xf32 >
164+ return %broadcast : tensor <2 x3 x4 x5 xf32 >
165+ }
166+ // CHECK-LABEL: @fold_broadcast_generic
167+ // CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
168+ // CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
169+ // CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2x3x4x5xf32>
170+ // CHECK-NOT: linalg.generic
171+ // CHECK: iree_linalg_ext.map_load
172+ // CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
173+ // CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[IDX3:.+]]: index):
174+ // Broadcast: output (d0,d1,d2,d3) reads from source at (d0,d1)
175+ // CHECK: iree_linalg_ext.yield %[[IDX0]], %[[IDX1]],
176+ // CHECK: } : tensor<2x3xf32> into tensor<2x3x4x5xf32> -> tensor<2x3x4x5xf32>
177+
178+ // -----
179+
180+ func.func @complex_relayout_chain (%buffer : memref <8 x16 xf32 >) -> tensor <16 x8 xf32 > {
181+ %source = iree_codegen.load_from_buffer %buffer : memref <8 x16 xf32 > -> tensor <8 x16 xf32 >
182+ %expanded = tensor.expand_shape %source [[0 , 1 ], [2 ]] output_shape [2 , 4 , 16 ] : tensor <8 x16 xf32 > into tensor <2 x4 x16 xf32 >
183+ %collapsed = tensor.collapse_shape %expanded [[0 , 1 ], [2 ]] : tensor <2 x4 x16 xf32 > into tensor <8 x16 xf32 >
184+ %init = tensor.empty () : tensor <16 x8 xf32 >
185+ %transposed = linalg.transpose ins (%collapsed : tensor <8 x16 xf32 >) outs (%init : tensor <16 x8 xf32 >) permutation = [1 , 0 ]
186+ return %transposed : tensor <16 x8 xf32 >
187+ }
188+ // CHECK-LABEL: @complex_relayout_chain
189+ // CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
190+ // CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
191+ // CHECK: tensor.empty() : tensor<16x8xf32>
192+ // CHECK-NOT: tensor.expand_shape
193+ // CHECK-NOT: tensor.collapse_shape
194+ // CHECK-NOT: linalg.transpose
195+ // CHECK: iree_linalg_ext.map_load {{.*}} into {{.*}} {
196+ // CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
197+ // CHECK-NEXT: iree_linalg_ext.yield %[[IDX1]], %[[IDX0]], {{.*}} : index, index, f32
198+ // CHECK: } : tensor<8x16xf32> into tensor<16x8xf32> -> tensor<16x8xf32>
0 commit comments