Skip to content

Commit b78b73c

Browse files
[Common] Update CombineSourceLayoutTransform pass to fold broadcast
-- This commit adds folding of linalg.generic broadcast op into MapGather op. -- It also improves the folding algo to make sure that an indentity MapGather (and consequentially a folding trigger) is inserted only if the relayout chain originating from LoadFromBufferOp is "complex". -- A "complex" relayout chain in this regard can be :- i. A chain of length >= 2. ii. Or, a chain of length == 1 but having a supported linalg relayout op. This is done to prevent creating MapGatherOp for simple primitives like tensor.extract_slice since it unneccesarily ends up creating an empty tensor for MapGather's destination -> which in turn leads to creation of big memref.alloca ops later in the pipeline causing stack size limit issue. Signed-off-by: Abhishek Varma <abhvarma@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 557fb9a commit b78b73c

3 files changed

Lines changed: 152 additions & 66 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ repos:
4141
exclude_types: ["jupyter"]
4242

4343
- repo: https://github.com/igorshubovych/markdownlint-cli
44-
rev: v0.47.0
44+
rev: v0.39.0
4545
hooks:
4646
- id: markdownlint
4747
name: Run markdownlint on .md files
48+
language_version: "system"
4849
args: ["--config", "docs/.markdownlint.yml"]
4950
files: "docs/website/.*.md"
5051
exclude: "mlir-dialects/!(index).md"

compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "llvm/Support/DebugLog.h"
1414
#include "mlir/Analysis/SliceAnalysis.h"
1515
#include "mlir/Dialect/Affine/IR/AffineOps.h"
16+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1617
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1718
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1819
#include "mlir/Dialect/UB/IR/UBOps.h"
@@ -532,9 +533,46 @@ static MapStoreOp insertIdentityMapStore(RewriterBase &rewriter,
532533
}
533534

534535
bool isSupportedSingleInputRelayoutOp(Operation *op) {
535-
return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
536-
tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
537-
linalg::TransposeOp>(op);
536+
if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
537+
tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
538+
linalg::TransposeOp>(op)) {
539+
return true;
540+
}
541+
auto genericOp = dyn_cast<linalg::GenericOp>(op);
542+
return genericOp && linalg::isaBroadcastOpInterface(genericOp).has_value();
543+
}
544+
545+
/// Returns true if the relayout op starts a "complex" chain.
546+
/// A "complex" chain is :-
547+
/// - a chain of relayout ops with length >= 2,
548+
/// - or a chain of length 1 with one of the supported linalg relayout ops.
549+
static bool isComplexRelayoutChain(Operation *relayoutOp) {
550+
assert(isSupportedSingleInputRelayoutOp(relayoutOp) &&
551+
"expected a supported relayout op");
552+
Value result = relayoutOp->getResult(0);
553+
bool hasRelayoutUser = llvm::any_of(result.getUsers(), [](Operation *user) {
554+
return isSupportedSingleInputRelayoutOp(user);
555+
});
556+
// Chain length >= 2 -> complex.
557+
if (hasRelayoutUser) {
558+
return true;
559+
}
560+
// Chain length 1: complex only if the op is a linalg op.
561+
return isa<linalg::LinalgOp>(relayoutOp);
562+
}
563+
564+
/// Collects direct relayout op users of `loadResult` that start a complex
565+
/// relayout chain.
566+
static SmallPtrSet<Operation *, 4>
567+
getComplexChainRelayoutUsers(Value loadResult) {
568+
SmallPtrSet<Operation *, 4> complexUsers;
569+
for (Operation *user : loadResult.getUsers()) {
570+
if (isSupportedSingleInputRelayoutOp(user) &&
571+
user->getOperand(0) == loadResult && isComplexRelayoutChain(user)) {
572+
complexUsers.insert(user);
573+
}
574+
}
575+
return complexUsers;
538576
}
539577

540578
// This is only desirable in the dispatch scope but not in the workgroup scope.
@@ -983,7 +1021,31 @@ foldExtractSliceIntoMapLoad(RewriterBase &rewriter,
9831021
indexTransformBuilder);
9841022
}
9851023

986-
/// Fold a consumer `padOp` into a producer `mapLoadOp`.
1024+
/// Fold a consumer broadcast `linalg.generic` into a producer `map_load`.
1025+
static FailureOr<MapLoadOp> foldBroadcastGenericIntoMapLoad(
1026+
RewriterBase &rewriter, linalg::GenericOp genericOp, MapLoadOp mapLoadOp) {
1027+
assert(genericOp.getDpsInputs()[0] == mapLoadOp.getResult(0) &&
1028+
"expected map_load to be the producer of genericOp input");
1029+
if (!linalg::isaBroadcastOpInterface(genericOp).has_value()) {
1030+
return rewriter.notifyMatchFailure(genericOp,
1031+
"generic op is not a broadcast");
1032+
}
1033+
1034+
AffineMap inputMap = genericOp.getIndexingMapsArray()[0];
1035+
return foldConsumerIntoMapLoadImpl(
1036+
rewriter, genericOp, mapLoadOp,
1037+
[inputMap](ArrayRef<BlockArgument> indices) -> SmallVector<Value> {
1038+
SmallVector<Value> sourceIndices;
1039+
sourceIndices.reserve(inputMap.getNumResults());
1040+
for (AffineExpr expr : inputMap.getResults()) {
1041+
unsigned pos = cast<AffineDimExpr>(expr).getPosition();
1042+
sourceIndices.push_back(indices[pos]);
1043+
}
1044+
return sourceIndices;
1045+
});
1046+
}
1047+
1048+
/// Fold a consumer `padOp` into a producer `mapGatherOp`.
9871049
/// Index transformation: source_idx = new_idx - low_pad
9881050
/// Fill value is set to the pad value.
9891051
static FailureOr<MapLoadOp> foldPadIntoMapLoad(RewriterBase &rewriter,
@@ -1050,6 +1112,9 @@ FailureOr<MapLoadOp> foldIntoMapLoad(RewriterBase &rewriter, Operation *op,
10501112
.Case<tensor::PadOp>([&](tensor::PadOp padOp) {
10511113
return foldPadIntoMapLoad(rewriter, padOp, mapLoadOp);
10521114
})
1115+
.Case<linalg::GenericOp>([&](linalg::GenericOp genericOp) {
1116+
return foldBroadcastGenericIntoMapLoad(rewriter, genericOp, mapLoadOp);
1117+
})
10531118
.Default([](Operation *) { return failure(); });
10541119
}
10551120

@@ -1078,12 +1143,18 @@ struct FoldConsumerRelayoutIntoMapLoadPattern
10781143
}
10791144
};
10801145

1081-
// Insert identity map_load op after the root and replace uses.
1082-
static MapLoadOp insertIdentityMapLoad(RewriterBase &rewriter, OpResult root) {
1146+
// Insert identity map_gather op after the root and replace only uses whose
1147+
// owner is in `complexChainUsers` (i.e. uses that are part of a complex
1148+
// relayout chain). Other uses keep using the load/root directly.
1149+
static MapLoadOp
1150+
insertIdentityMapLoad(RewriterBase &rewriter, OpResult root,
1151+
const SmallPtrSetImpl<Operation *> &complexChainUsers) {
10831152
Location loc = root.getLoc();
10841153
SetVector<OpOperand *> originalUses;
10851154
for (OpOperand &use : root.getUses()) {
1086-
originalUses.insert(&use);
1155+
if (complexChainUsers.contains(use.getOwner())) {
1156+
originalUses.insert(&use);
1157+
}
10871158
}
10881159
OpBuilder::InsertionGuard g(rewriter);
10891160
rewriter.setInsertionPointAfterValue(root);
@@ -1109,12 +1180,11 @@ struct InsertMapLoadOpPattern
11091180

11101181
LogicalResult matchAndRewrite(IREE::Codegen::LoadFromBufferOp loadOp,
11111182
PatternRewriter &rewriter) const override {
1112-
// Check if the load has at least one relayout op user.
1113-
bool hasRelayoutUser =
1114-
llvm::any_of(loadOp->getUsers(), [](Operation *user) {
1115-
return isSupportedSingleInputRelayoutOp(user);
1116-
});
1117-
if (!hasRelayoutUser) {
1183+
Value loadResult = loadOp.getResult();
1184+
SmallPtrSet<Operation *, 4> complexChainUsers =
1185+
getComplexChainRelayoutUsers(loadResult);
1186+
// Only introduce map_gather when there is at least one complex chain.
1187+
if (complexChainUsers.empty()) {
11181188
return failure();
11191189
}
11201190
// Check that the load doesn't already have a map_load user (avoid
@@ -1125,7 +1195,8 @@ struct InsertMapLoadOpPattern
11251195
if (hasMapLoadUser) {
11261196
return failure();
11271197
}
1128-
(void)insertIdentityMapLoad(rewriter, loadOp->getResult(0));
1198+
(void)insertIdentityMapLoad(rewriter, cast<OpResult>(loadResult),
1199+
complexChainUsers);
11291200
return success();
11301201
}
11311202
};

compiler/src/iree/compiler/Codegen/Common/test/combine_source_layout_transformation.mlir

Lines changed: 65 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,9 @@ func.func @fold_expand_shape(%buffer : memref<8x16xf32>) -> tensor<2x4x16xf32> {
2828
}
2929
// CHECK-LABEL: @fold_expand_shape
3030
// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
31-
// CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
32-
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2x4x16xf32>
33-
// CHECK-NOT: tensor.expand_shape
34-
// CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
35-
// CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
36-
// CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
37-
// CHECK: %[[LINEARIZE:.+]] = affine.linearize_index
38-
// CHECK-SAME: [%[[IDX0]], %[[IDX1]]] by (2, 4)
39-
// CHECK: iree_linalg_ext.yield %[[LINEARIZE]], %[[IDX2]],
40-
// CHECK: } : tensor<8x16xf32> into tensor<2x4x16xf32> -> tensor<2x4x16xf32>
31+
// CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
32+
// CHECK: tensor.expand_shape
33+
// CHECK-NOT: iree_linalg_ext.map_load
4134

4235
// -----
4336

@@ -48,15 +41,9 @@ func.func @fold_collapse_shape(%buffer : memref<2x4x16xf32>) -> tensor<8x16xf32>
4841
}
4942
// CHECK-LABEL: @fold_collapse_shape
5043
// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
51-
// CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
52-
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<8x16xf32>
53-
// CHECK-NOT: tensor.collapse_shape
54-
// CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
55-
// CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
56-
// CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
57-
// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IDX0]] into (2, 4)
58-
// CHECK: iree_linalg_ext.yield %[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[IDX1]],
59-
// CHECK: } : tensor<2x4x16xf32> into tensor<8x16xf32> -> tensor<8x16xf32>
44+
// CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
45+
// CHECK: tensor.collapse_shape
46+
// CHECK-NOT: iree_linalg_ext.map_load
6047

6148
// -----
6249

@@ -67,16 +54,9 @@ func.func @fold_extract_slice(%buffer : memref<64xf32>) -> tensor<16xf32> {
6754
}
6855
// CHECK-LABEL: @fold_extract_slice
6956
// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
70-
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
7157
// CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
72-
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<16xf32>
73-
// CHECK-NOT: tensor.extract_slice
74-
// CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
75-
// CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
76-
// CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index):
77-
// CHECK: %[[NEW_IDX:.+]] = arith.addi %[[IDX0]], %[[C8]] overflow<nsw>
78-
// CHECK: iree_linalg_ext.yield %[[NEW_IDX]],
79-
// CHECK: } : tensor<64xf32> into tensor<16xf32> -> tensor<16xf32>
58+
// CHECK: tensor.extract_slice %[[SOURCE]][8] [16] [1]
59+
// CHECK-NOT: iree_linalg_ext.map_load
8060

8161
// -----
8262

@@ -104,7 +84,6 @@ func.func @fold_copy_transpose(%buffer : memref<4x16xf32>) -> tensor<16x4xf32> {
10484

10585
// -----
10686

107-
// Low padding is [0, 0, 0], so indices are passed through unchanged due to subi with 0.
10887
func.func @fold_pad_with_zero_low_padding_offsets(%buffer : memref<1x50x64xf32>) -> tensor<1x64x64xf32> {
10988
%cst = arith.constant 0.000000e+00 : f32
11089
%source = iree_codegen.load_from_buffer %buffer : memref<1x50x64xf32> -> tensor<1x50x64xf32>
@@ -116,15 +95,9 @@ func.func @fold_pad_with_zero_low_padding_offsets(%buffer : memref<1x50x64xf32>)
11695
}
11796
// CHECK-LABEL: @fold_pad_with_zero_low_padding_offsets
11897
// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
119-
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
120-
// CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
121-
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<1x64x64xf32>
122-
// CHECK-NOT: tensor.pad
123-
// CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
124-
// CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
125-
// CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index):
126-
// CHECK: iree_linalg_ext.yield %[[IDX0]], %[[IDX1]], %[[IDX2]], %[[CST]] :
127-
// CHECK: } : tensor<1x50x64xf32> into tensor<1x64x64xf32> -> tensor<1x64x64xf32>
98+
// CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
99+
// CHECK: tensor.pad
100+
// CHECK-NOT: iree_linalg_ext.map_load
128101

129102
// -----
130103

@@ -139,19 +112,9 @@ func.func @fold_pad_with_non_zero_low_padding_offsets(%buffer : memref<8x16xf32>
139112
}
140113
// CHECK-LABEL: @fold_pad_with_non_zero_low_padding_offsets
141114
// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
142-
// CHECK-DAG: %[[CST:.+]] = arith.constant 1.000000e+00 : f32
143-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
144-
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
145-
// CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
146-
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<10x20xf32>
147-
// CHECK-NOT: tensor.pad
148-
// CHECK: %[[MAP_GATHER:.+]] = iree_linalg_ext.map_load
149-
// CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
150-
// CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
151-
// CHECK: %[[NEW_IDX0:.+]] = arith.subi %[[IDX0]], %[[C1]] overflow<nsw> : index
152-
// CHECK: %[[NEW_IDX1:.+]] = arith.subi %[[IDX1]], %[[C2]] overflow<nsw> : index
153-
// CHECK: iree_linalg_ext.yield %[[NEW_IDX0]], %[[NEW_IDX1]], %[[CST]] :
154-
// CHECK: } : tensor<8x16xf32> into tensor<10x20xf32> -> tensor<10x20xf32>
115+
// CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
116+
// CHECK: tensor.pad
117+
// CHECK-NOT: iree_linalg_ext.map_load
155118

156119
// -----
157120

@@ -182,3 +145,54 @@ func.func @nested_pads_different_values(%buffer : memref<8x16xf32>) -> tensor<14
182145
// Second pad is NOT folded because the map_load already has a padding value.
183146
// CHECK: tensor.pad
184147
// CHECK: tensor.yield %[[CST1]] : f32
148+
149+
// -----
150+
151+
func.func @fold_broadcast_generic(%buffer : memref<2x3xf32>) -> tensor<2x3x4x5xf32> {
152+
%source = iree_codegen.load_from_buffer %buffer : memref<2x3xf32> -> tensor<2x3xf32>
153+
%init = tensor.empty() : tensor<2x3x4x5xf32>
154+
%broadcast = linalg.generic {
155+
indexing_maps = [
156+
affine_map<(d0, d1, d2, d3) -> (d0, d1)>,
157+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
158+
],
159+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
160+
} ins(%source : tensor<2x3xf32>) outs(%init : tensor<2x3x4x5xf32>) {
161+
^bb0(%in: f32, %out: f32):
162+
linalg.yield %in : f32
163+
} -> tensor<2x3x4x5xf32>
164+
return %broadcast : tensor<2x3x4x5xf32>
165+
}
166+
// CHECK-LABEL: @fold_broadcast_generic
167+
// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
168+
// CHECK: %[[SOURCE:.+]] = iree_codegen.load_from_buffer %[[BUFFER]]
169+
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2x3x4x5xf32>
170+
// CHECK-NOT: linalg.generic
171+
// CHECK: iree_linalg_ext.map_load
172+
// CHECK-SAME: %[[SOURCE]] into %[[DEST]] {
173+
// CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[IDX3:.+]]: index):
174+
// Broadcast: output (d0,d1,d2,d3) reads from source at (d0,d1)
175+
// CHECK: iree_linalg_ext.yield %[[IDX0]], %[[IDX1]],
176+
// CHECK: } : tensor<2x3xf32> into tensor<2x3x4x5xf32> -> tensor<2x3x4x5xf32>
177+
178+
// -----
179+
180+
func.func @complex_relayout_chain(%buffer : memref<8x16xf32>) -> tensor<16x8xf32> {
181+
%source = iree_codegen.load_from_buffer %buffer : memref<8x16xf32> -> tensor<8x16xf32>
182+
%expanded = tensor.expand_shape %source [[0, 1], [2]] output_shape [2, 4, 16] : tensor<8x16xf32> into tensor<2x4x16xf32>
183+
%collapsed = tensor.collapse_shape %expanded [[0, 1], [2]] : tensor<2x4x16xf32> into tensor<8x16xf32>
184+
%init = tensor.empty() : tensor<16x8xf32>
185+
%transposed = linalg.transpose ins(%collapsed : tensor<8x16xf32>) outs(%init : tensor<16x8xf32>) permutation = [1, 0]
186+
return %transposed : tensor<16x8xf32>
187+
}
188+
// CHECK-LABEL: @complex_relayout_chain
189+
// CHECK-SAME: %[[BUFFER:[a-zA-Z0-9_]+]]
190+
// CHECK: iree_codegen.load_from_buffer %[[BUFFER]]
191+
// CHECK: tensor.empty() : tensor<16x8xf32>
192+
// CHECK-NOT: tensor.expand_shape
193+
// CHECK-NOT: tensor.collapse_shape
194+
// CHECK-NOT: linalg.transpose
195+
// CHECK: iree_linalg_ext.map_load {{.*}} into {{.*}} {
196+
// CHECK-NEXT: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
197+
// CHECK-NEXT: iree_linalg_ext.yield %[[IDX1]], %[[IDX0]], {{.*}} : index, index, f32
198+
// CHECK: } : tensor<8x16xf32> into tensor<16x8xf32> -> tensor<16x8xf32>

0 commit comments

Comments
 (0)