Skip to content

Commit e5498d7

Browse files
authored
[mlir][linalg][elementwise] Fold broadcast into new elementwise (llvm#167626)
Fold broadcast into new elementwise Op which has affine-map attached. Merging on behalf of @someoneinjd
1 parent 0b0f02d commit e5498d7

File tree

3 files changed

+248
-25
lines changed

3 files changed

+248
-25
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,15 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
169169
}
170170

171171
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
172-
let summary = "Fold transpose ops into elementwise";
172+
let summary = "Fold transpose and broadcast ops into elementwise";
173173
let dependentDialects = ["linalg::LinalgDialect"];
174174

175175
let description = [{
176-
Fold transpose ops that feed `linalg.elementwise` into the elementwise op
177-
by updating its indexing maps. `linalg.transpose` producers whose consumer
178-
indexing map is the identity are absorbed, turning the permutation into
179-
the elementwise map itself. Other operands remain untouched.
176+
Fold transpose or broadcast op that feeds a `linalg.elementwise` into the
177+
elementwise op. `linalg.transpose` and `linalg.broadcast` producers whose
178+
consumer indexing map is a projected permutation can be absorbed into the
179+
indexing map of the `linalg.elementwise` by composing the producer's map
180+
into the elementwise op's indexing map. Other operands remain untouched.
180181
}];
181182
}
182183

mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,27 @@ using namespace mlir::linalg;
2929
#define DEBUG_TYPE "linalg-fold-into-elementwise"
3030

3131
namespace {
32-
struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
32+
template <typename ProducerOpTy>
33+
struct ElementwiseOpFolder {
34+
// Helper function to fold broadcast etc into elementwise op.
35+
// Producer in this context is `broadcast op` etc, consumer is elwise operand.
36+
static bool fold(OpOperand *elwiseOperand, AffineMap elwiseMap,
37+
SmallVector<Value> &newIns,
38+
SmallVector<AffineMap> &newMaps) {
39+
auto producerOp = elwiseOperand->get().getDefiningOp<ProducerOpTy>();
40+
if (!producerOp || !elwiseMap.isProjectedPermutation())
41+
return false;
42+
newIns.push_back(producerOp.getInput());
43+
// push in the new composed affine map
44+
newMaps.push_back(
45+
producerOp.getMatchingIndexingMap(producerOp.getDpsInputOperand(0))
46+
.compose(elwiseMap));
47+
return true;
48+
}
49+
};
50+
51+
template <typename... ProducerOps>
52+
struct FoldIntoElementwisePattern : public OpRewritePattern<ElementwiseOp> {
3353
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
3454

3555
LogicalResult matchAndRewrite(ElementwiseOp op,
@@ -38,20 +58,17 @@ struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
3858
SmallVector<Value> newIns;
3959
SmallVector<AffineMap> newMaps;
4060
for (OpOperand *operand : op.getDpsInputOperands()) {
41-
AffineMap map = op.getMatchingIndexingMap(operand);
42-
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
43-
44-
if (!map.isIdentity() || !transposeOp) {
61+
AffineMap consumerMap = op.getMatchingIndexingMap(operand);
62+
const bool folded = (ElementwiseOpFolder<ProducerOps>::fold(
63+
operand, consumerMap, newIns, newMaps) ||
64+
...);
65+
if (folded) {
66+
changed = true;
67+
} else {
4568
// push in original operand and its map.
4669
newIns.push_back(operand->get());
47-
newMaps.push_back(map);
48-
continue;
70+
newMaps.push_back(consumerMap);
4971
}
50-
newIns.push_back(transposeOp.getInput());
51-
// push in transposeOp's inverse permutation map.
52-
newMaps.push_back(transposeOp.getMatchingIndexingMap(
53-
transposeOp.getDpsInputOperand(0)));
54-
changed = true;
5572
}
5673
if (!changed)
5774
return failure();
@@ -83,5 +100,6 @@ struct LinalgFoldIntoElementwisePass
83100

84101
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
85102
RewritePatternSet &patterns) {
86-
patterns.add<FoldTransposePattern>(patterns.getContext());
103+
patterns.add<FoldIntoElementwisePattern<TransposeOp, BroadcastOp>>(
104+
patterns.getContext());
87105
}

mlir/test/Dialect/Linalg/elementwise/fold.mlir

Lines changed: 211 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
1010
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
1111
//
12-
func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
12+
func.func @unary_transpose(%A: tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
1313
%empty = tensor.empty() : tensor<8x16x32xf32>
14-
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
14+
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
1515
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
16-
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
16+
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
1717
return %result : tensor<8x16x32xf32>
1818
}
1919

@@ -28,16 +28,220 @@ func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->
2828
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
2929
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
3030
//
31-
func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
31+
func.func @binary_transposed(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
3232
%c0 = arith.constant 0 : index
3333
%c1 = arith.constant 1 : index
3434
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
3535
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
3636

3737
%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
38-
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
38+
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
3939
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
40-
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
41-
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
40+
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
41+
outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
4242
return %result : tensor<?x?xf32>
4343
}
44+
45+
// -----
46+
47+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
48+
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
49+
//
50+
// CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
51+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
52+
// CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
53+
// CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
54+
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
55+
//
56+
func.func @unary_broadcasted(%A: tensor<8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
57+
%empty = tensor.empty() : tensor<8x16x32xf32>
58+
%broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty : tensor<8x16x32xf32>) dimensions = [1]
59+
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
60+
ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
61+
return %result : tensor<8x16x32xf32>
62+
}
63+
64+
// -----
65+
66+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
67+
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
68+
//
69+
// CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
70+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
71+
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
72+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
73+
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
74+
//
75+
func.func @binary_broadcasted(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
76+
%c0 = arith.constant 0 : index
77+
%c1 = arith.constant 1 : index
78+
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
79+
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
80+
81+
%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
82+
%broadcasted_B = linalg.broadcast ins(%B : tensor<?xf32>) outs(%empty : tensor<?x?xf32>) dimensions = [1]
83+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
84+
ins(%A, %broadcasted_B : tensor<?x?xf32>, tensor<?x?xf32>)
85+
outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
86+
return %result : tensor<?x?xf32>
87+
}
88+
89+
// -----
90+
91+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
92+
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
93+
//
94+
// CHECK: func.func @fold_broadcast_after_transpose_fold(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32> {
95+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
96+
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
97+
// CHECK-SAME: ins(%[[A]] : tensor<16xf32>) outs(%[[B]] : tensor<16x32xf32>) -> tensor<16x32xf32>
98+
// CHECK-NEXT: return %[[RES]] : tensor<16x32xf32>
99+
//
100+
#identity = affine_map<(d0, d1) -> (d0, d1)>
101+
#transpose = affine_map<(d0, d1) -> (d1, d0)>
102+
103+
func.func @fold_broadcast_after_transpose_fold(%A: tensor<16xf32>, %B: tensor<16x32xf32>) -> tensor<16x32xf32> {
104+
%empty_b = tensor.empty() : tensor<32x16xf32>
105+
106+
%broadcasted_A = linalg.broadcast ins(%A : tensor<16xf32>) outs(%empty_b : tensor<32x16xf32>) dimensions = [0]
107+
108+
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
109+
indexing_maps = [#transpose, #identity]
110+
ins(%broadcasted_A : tensor<32x16xf32>) outs(%B : tensor<16x32xf32>) -> tensor<16x32xf32>
111+
return %result : tensor<16x32xf32>
112+
}
113+
114+
// -----
115+
116+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
117+
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
118+
//
119+
// CHECK: func.func @fold_transpose_after_broadcast_fold(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
120+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
121+
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
122+
// CHECK-SAME: ins(%[[A]] : tensor<32x16xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
123+
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
124+
//
125+
#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
126+
#broadcast = affine_map<(d0, d1, d2) -> (d1, d2)>
127+
128+
func.func @fold_transpose_after_broadcast_fold(%A: tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
129+
%empty_t = tensor.empty() : tensor<16x32xf32>
130+
%transposed_A = linalg.transpose ins(%A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]
131+
132+
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
133+
indexing_maps = [#broadcast, #identity]
134+
ins(%transposed_A : tensor<16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
135+
return %result : tensor<8x16x32xf32>
136+
}
137+
138+
// -----
139+
140+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
141+
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
142+
//
143+
// CHECK: func.func @fold_broadcast_after_transpose_fold_binary(%[[A:.+]]: tensor<?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
144+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
145+
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
146+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
147+
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
148+
//
149+
#identity = affine_map<(d0, d1) -> (d0, d1)>
150+
#transpose = affine_map<(d0, d1) -> (d1, d0)>
151+
152+
func.func @fold_broadcast_after_transpose_fold_binary(%A: tensor<?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
153+
%c0 = arith.constant 0 : index
154+
%c1 = arith.constant 1 : index
155+
%dim0 = tensor.dim %B, %c0 : tensor<?x?xf32>
156+
%dim1 = tensor.dim %B, %c1 : tensor<?x?xf32>
157+
158+
%empty_b = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
159+
%broadcasted_A = linalg.broadcast ins(%A : tensor<?xf32>) outs(%empty_b : tensor<?x?xf32>) dimensions = [0]
160+
161+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
162+
indexing_maps = [#transpose, #identity, #identity]
163+
ins(%broadcasted_A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
164+
165+
return %result : tensor<?x?xf32>
166+
}
167+
168+
// -----
169+
170+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
171+
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
172+
//
173+
// CHECK: func.func @fold_transpose_after_broadcast_fold_binary(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?x?xf32>, %[[C:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
174+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
175+
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
176+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?x?xf32>) outs(%[[C]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
177+
// CHECK-NEXT: return %[[RES]] : tensor<?x?x?xf32>
178+
//
179+
#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
180+
#broadcast = affine_map<(d0, d1, d2) -> (d1, d2)>
181+
182+
func.func @fold_transpose_after_broadcast_fold_binary(%A: tensor<?x?xf32>, %B: tensor<?x?x?xf32>, %C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
183+
%c0 = arith.constant 0 : index
184+
%c1 = arith.constant 1 : index
185+
%c2 = arith.constant 2 : index
186+
%dim0 = tensor.dim %B, %c0 : tensor<?x?x?xf32>
187+
%dim1 = tensor.dim %B, %c1 : tensor<?x?x?xf32>
188+
%dim2 = tensor.dim %B, %c2 : tensor<?x?x?xf32>
189+
190+
%empty_t = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
191+
%transposed_A = linalg.transpose ins(%A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]
192+
193+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
194+
indexing_maps = [#broadcast, #identity, #identity]
195+
ins(%transposed_A, %B : tensor<?x?xf32>, tensor<?x?x?xf32>) outs(%C : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
196+
return %result : tensor<?x?x?xf32>
197+
}
198+
199+
// -----
200+
201+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0) -> (d0)>
202+
// CHECK-DAG: #[[DIAGONAL:.+]] = affine_map<(d0) -> (d0, d0)>
203+
//
204+
// CHECK: func.func @fold_failed_diagonal_map(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16xf32>) -> tensor<16xf32> {
205+
// CHECK-NEXT: %[[EMPTY:.+]] = tensor.empty() : tensor<16x16xf32>
206+
// CHECK-NEXT: %[[BROADCASTED_B:.+]] = linalg.broadcast ins(%[[B]] : tensor<16xf32>) outs(%[[EMPTY]] : tensor<16x16xf32>) dimensions = [0]
207+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
208+
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[DIAGONAL]], #[[IDENTITY]]]
209+
// CHECK-SAME: ins(%[[A]], %[[BROADCASTED_B]] : tensor<16xf32>, tensor<16x16xf32>) outs(%[[C]] : tensor<16xf32>) -> tensor<16xf32>
210+
// CHECK-NEXT: return %[[RES]] : tensor<16xf32>
211+
//
212+
#identity = affine_map<(d0) -> (d0)>
213+
#diagonal = affine_map<(d0) -> (d0, d0)>
214+
215+
func.func @fold_failed_diagonal_map(%A: tensor<16xf32>, %B: tensor<16xf32>, %C: tensor<16xf32>) -> tensor<16xf32> {
216+
%empty = tensor.empty() : tensor<16x16xf32>
217+
%broadcasted_B = linalg.broadcast ins(%B : tensor<16xf32>) outs(%empty : tensor<16x16xf32>) dimensions = [0]
218+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
219+
indexing_maps = [#identity, #diagonal, #identity]
220+
ins(%A, %broadcasted_B : tensor<16xf32>, tensor<16x16xf32>) outs(%C : tensor<16xf32>) -> tensor<16xf32>
221+
return %result : tensor<16xf32>
222+
}
223+
224+
// -----
225+
226+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0) -> (d0)>
227+
// CHECK-DAG: #[[CONSTANT:.+]] = affine_map<(d0) -> (0, d0)>
228+
//
229+
// CHECK: func.func @fold_failed_constant_map(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>, %[[C:.+]]: tensor<16xf32>) -> tensor<16xf32> {
230+
// CHECK-NEXT: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16xf32>
231+
// CHECK-NEXT: %[[TRANSPOSED_B:.+]] = linalg.transpose ins(%[[B]] : tensor<16x32xf32>) outs(%[[EMPTY]] : tensor<32x16xf32>) permutation = [1, 0]
232+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
233+
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[CONSTANT]], #[[IDENTITY]]]
234+
// CHECK-SAME: ins(%[[A]], %[[TRANSPOSED_B]] : tensor<16xf32>, tensor<32x16xf32>) outs(%[[C]] : tensor<16xf32>) -> tensor<16xf32>
235+
// CHECK-NEXT: return %[[RES]] : tensor<16xf32>
236+
//
237+
#identity = affine_map<(d0) -> (d0)>
238+
#constant = affine_map<(d0) -> (0, d0)>
239+
240+
func.func @fold_failed_constant_map(%A: tensor<16xf32>, %B: tensor<16x32xf32>, %C: tensor<16xf32>) -> tensor<16xf32> {
241+
%empty = tensor.empty() : tensor<32x16xf32>
242+
%transposed_B = linalg.transpose ins(%B : tensor<16x32xf32>) outs(%empty : tensor<32x16xf32>) permutation = [1, 0]
243+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
244+
indexing_maps = [#identity, #constant, #identity]
245+
ins(%A, %transposed_B : tensor<16xf32>, tensor<32x16xf32>) outs(%C : tensor<16xf32>) -> tensor<16xf32>
246+
return %result : tensor<16xf32>
247+
}

0 commit comments

Comments
 (0)