Skip to content

Commit b79312a

Browse files
Hanhan + Max's review comment
Signed-off-by: Abhishek Varma <abhvarma@amd.com>
1 parent cfd4d88 commit b79312a

5 files changed

Lines changed: 209 additions & 114 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp

Lines changed: 73 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -532,33 +532,60 @@ static MapStoreOp insertIdentityMapStore(RewriterBase &rewriter,
532532
return mapStoreOp;
533533
}
534534

535-
bool isSupportedSingleInputRelayoutOp(Operation *op) {
536-
if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
537-
tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
538-
linalg::TransposeOp>(op)) {
535+
bool isSupportedSingleInputRelayoutOpForResult(Operation *op) {
536+
return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
537+
tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
538+
linalg::TransposeOp>(op);
539+
}
540+
541+
bool isSupportedSingleInputRelayoutOpForSource(Operation *op) {
542+
if (isSupportedSingleInputRelayoutOpForResult(op)) {
539543
return true;
540544
}
541-
auto genericOp = dyn_cast<linalg::GenericOp>(op);
542-
return genericOp && linalg::isaBroadcastOpInterface(genericOp).has_value();
545+
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
546+
return linalgOp && linalg::isaBroadcastOpInterface(linalgOp).has_value();
547+
}
548+
549+
/// Collects all relayout ops in the chain starting from `relayoutOp`
550+
/// (inclusive). The chain extends through result->user edges where the user is
551+
/// a supported relayout op.
552+
static void collectRelayoutChain(Operation *relayoutOp,
553+
SmallPtrSetImpl<Operation *> &chain) {
554+
if (!chain.insert(relayoutOp).second) {
555+
return;
556+
}
557+
Value result = relayoutOp->getResult(0);
558+
for (Operation *user : result.getUsers()) {
559+
if (isSupportedSingleInputRelayoutOpForSource(user) &&
560+
user->getOperand(0) == result) {
561+
collectRelayoutChain(user, chain);
562+
}
563+
}
543564
}
544565

545566
/// Returns true if the relayout op starts a "complex" chain.
546-
/// A "complex" chain is :-
547-
/// - a chain of relayout ops with length >= 2,
548-
/// - or a chain of length 1 with one of the supported linalg relayout ops.
567+
/// A "complex" chain is one that is difficult for bufferization to handle:
568+
/// - chain length >= 2,
569+
/// - contains at least one reshape op (expand_shape or collapse_shape),
570+
/// - and contains at least one op that is not extract_slice.
549571
static bool isComplexRelayoutChain(Operation *relayoutOp) {
550-
assert(isSupportedSingleInputRelayoutOp(relayoutOp) &&
572+
assert(isSupportedSingleInputRelayoutOpForSource(relayoutOp) &&
551573
"expected a supported relayout op");
552-
Value result = relayoutOp->getResult(0);
553-
bool hasRelayoutUser = llvm::any_of(result.getUsers(), [](Operation *user) {
554-
return isSupportedSingleInputRelayoutOp(user);
555-
});
556-
// Chain length >= 2 -> complex.
557-
if (hasRelayoutUser) {
558-
return true;
574+
SmallPtrSet<Operation *, 4> chain;
575+
collectRelayoutChain(relayoutOp, chain);
576+
if (chain.size() < 2) {
577+
return false;
559578
}
560-
// Chain length 1: complex only if the op is a linalg op.
561-
return isa<linalg::LinalgOp>(relayoutOp);
579+
bool hasReshape = llvm::any_of(chain, [](Operation *op) {
580+
return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(op);
581+
});
582+
// Need at least one op that is not reshape and not extract_slice (e.g. copy,
583+
// transpose, broadcast, pad).
584+
bool hasOtherNonExtractSlice = llvm::any_of(chain, [](Operation *op) {
585+
return !isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
586+
tensor::ExtractSliceOp>(op);
587+
});
588+
return hasReshape && hasOtherNonExtractSlice;
562589
}
563590

564591
/// Collects direct relayout op users of `loadResult` that start a complex
@@ -567,7 +594,7 @@ static SmallPtrSet<Operation *, 4>
567594
getComplexChainRelayoutUsers(Value loadResult) {
568595
SmallPtrSet<Operation *, 4> complexUsers;
569596
for (Operation *user : loadResult.getUsers()) {
570-
if (isSupportedSingleInputRelayoutOp(user) &&
597+
if (isSupportedSingleInputRelayoutOpForSource(user) &&
571598
user->getOperand(0) == loadResult && isComplexRelayoutChain(user)) {
572599
complexUsers.insert(user);
573600
}
@@ -586,11 +613,11 @@ shouldDoReshapesByExpansion(IREE::Codegen::RelayoutCombinationScope scope) {
586613

587614
/// Insert identity map_store ops after the given operation if it is a valid
588615
/// leaf op of a relayout op chain. A relayout op chain is a sequence of
589-
/// relayout ops (defined by `isSupportedSingleInputRelayoutOp`) for which the
590-
/// only users of the ops in the chain are relayout ops, except for the leaves
591-
/// of the chain. The leaves are simply relayout ops that have non relayout op
592-
/// users. The `controlFn` is a callback on the leaf OpResult that provides
593-
/// control over whether or not to insert a map_store op.
616+
/// relayout ops (defined by `isSupportedSingleInputRelayoutOpForResult`) for
617+
/// which the only users of the ops in the chain are relayout ops, except for
618+
/// the leaves of the chain. The leaves are simply relayout ops that have non
619+
/// relayout op users. The `controlFn` is a callback on the leaf OpResult that
620+
/// provides control over whether or not to insert a map_store op.
594621
struct InsertMapStoreOpPattern : public RewritePattern {
595622
InsertMapStoreOpPattern(MLIRContext *context,
596623
CombineRelayoutOpsControlFnRef controlFn = nullptr,
@@ -600,12 +627,13 @@ struct InsertMapStoreOpPattern : public RewritePattern {
600627

601628
LogicalResult matchAndRewrite(Operation *op,
602629
PatternRewriter &rewriter) const override {
603-
if (!isSupportedSingleInputRelayoutOp(op)) {
630+
if (!isSupportedSingleInputRelayoutOpForResult(op)) {
604631
return failure();
605632
}
606633
// Relayout ops with only relayout op users are not leaves.
607634
auto isDimOrSupportedRelayoutOp = [](Operation *op) {
608-
return isSupportedSingleInputRelayoutOp(op) || isa<tensor::DimOp>(op);
635+
return isSupportedSingleInputRelayoutOpForResult(op) ||
636+
isa<tensor::DimOp>(op);
609637
};
610638
if (llvm::all_of(op->getUsers(), isDimOrSupportedRelayoutOp)) {
611639
return failure();
@@ -790,7 +818,7 @@ getCombineRelayoutOpsControlFn(IREE::Codegen::RelayoutCombinationScope scope) {
790818
// it, so don't introduce map_store.
791819
llvm::SetVector<Operation *> slice;
792820
BackwardSliceOptions options;
793-
options.filter = isSupportedSingleInputRelayoutOp;
821+
options.filter = isSupportedSingleInputRelayoutOpForResult;
794822
options.inclusive = true;
795823
LogicalResult result =
796824
getBackwardSlice(parallelInsertOp.getSource(), &slice, options);
@@ -1021,19 +1049,20 @@ foldExtractSliceIntoMapLoad(RewriterBase &rewriter,
10211049
indexTransformBuilder);
10221050
}
10231051

1024-
/// Fold a consumer broadcast `linalg.generic` into a producer `map_load`.
1025-
static FailureOr<MapLoadOp> foldBroadcastGenericIntoMapLoad(
1026-
RewriterBase &rewriter, linalg::GenericOp genericOp, MapLoadOp mapLoadOp) {
1027-
assert(genericOp.getDpsInputs()[0] == mapLoadOp.getResult(0) &&
1028-
"expected map_load to be the producer of genericOp input");
1029-
if (!linalg::isaBroadcastOpInterface(genericOp).has_value()) {
1030-
return rewriter.notifyMatchFailure(genericOp,
1031-
"generic op is not a broadcast");
1052+
/// Fold a consumer broadcast op (named or generic) into a producer `map_load`.
1053+
static FailureOr<MapLoadOp>
1054+
foldBroadcastIntoMapLoad(RewriterBase &rewriter, linalg::LinalgOp broadcastOp,
1055+
MapLoadOp mapLoadOp) {
1056+
if (!linalg::isaBroadcastOpInterface(broadcastOp).has_value()) {
1057+
return rewriter.notifyMatchFailure(broadcastOp.getOperation(),
1058+
"op is not a broadcast");
10321059
}
1060+
assert(broadcastOp.getDpsInputs()[0] == mapLoadOp.getResult(0) &&
1061+
"expected map_load to be the producer of broadcast input");
10331062

1034-
AffineMap inputMap = genericOp.getIndexingMapsArray()[0];
1063+
AffineMap inputMap = broadcastOp.getIndexingMapsArray()[0];
10351064
return foldConsumerIntoMapLoadImpl(
1036-
rewriter, genericOp, mapLoadOp,
1065+
rewriter, broadcastOp.getOperation(), mapLoadOp,
10371066
[inputMap](ArrayRef<BlockArgument> indices) -> SmallVector<Value> {
10381067
SmallVector<Value> sourceIndices;
10391068
sourceIndices.reserve(inputMap.getNumResults());
@@ -1112,8 +1141,8 @@ FailureOr<MapLoadOp> foldIntoMapLoad(RewriterBase &rewriter, Operation *op,
11121141
.Case<tensor::PadOp>([&](tensor::PadOp padOp) {
11131142
return foldPadIntoMapLoad(rewriter, padOp, mapLoadOp);
11141143
})
1115-
.Case<linalg::GenericOp>([&](linalg::GenericOp genericOp) {
1116-
return foldBroadcastGenericIntoMapLoad(rewriter, genericOp, mapLoadOp);
1144+
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
1145+
return foldBroadcastIntoMapLoad(rewriter, linalgOp, mapLoadOp);
11171146
})
11181147
.Default([](Operation *) { return failure(); });
11191148
}
@@ -1125,10 +1154,12 @@ struct FoldConsumerRelayoutIntoMapLoadPattern
11251154

11261155
LogicalResult matchAndRewrite(IREE::LinalgExt::MapLoadOp mapLoadOp,
11271156
PatternRewriter &rewriter) const override {
1128-
// Find a consumer relayout op.
1157+
// Find a consumer relayout op (one that uses map_load result as its input).
11291158
Operation *consumerOp = nullptr;
1159+
Value mapLoadResult = mapLoadOp.getResult(0);
11301160
for (Operation *user : mapLoadOp->getUsers()) {
1131-
if (isSupportedSingleInputRelayoutOp(user)) {
1161+
if (isSupportedSingleInputRelayoutOpForSource(user) &&
1162+
user->getOperand(0) == mapLoadResult) {
11321163
consumerOp = user;
11331164
break;
11341165
}

compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,13 @@ CombineRelayoutOpsControlFn
5555
getCombineRelayoutOpsControlFn(IREE::Codegen::RelayoutCombinationScope scope);
5656

5757
/// Returns true if the `op` type has a folding pattern into
58-
/// iree_linalg_ext.map_store or iree_linalg_ext.map_load.
59-
bool isSupportedSingleInputRelayoutOp(Operation *op);
58+
/// iree_linalg_ext.map_store (used by CombineResultLayoutTransformationPass).
59+
bool isSupportedSingleInputRelayoutOpForResult(Operation *op);
60+
61+
/// Returns true if the `op` type has a folding pattern into
62+
/// iree_linalg_ext.map_load (used by CombineSourceLayoutTransformationPass).
63+
/// Includes broadcast GenericOp in addition to the ops supported for Result.
64+
bool isSupportedSingleInputRelayoutOpForSource(Operation *op);
6065

6166
/// Fold the `op` into the `mapLoadOp` and return the resulting map_load,
6267
/// or failure if the transformation is not supported. The `op` should be a

compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ static bool gpuRelayoutCombinationControlFn(OpResult leaf) {
8383
}
8484
llvm::SetVector<Operation *> slice;
8585
BackwardSliceOptions options;
86-
options.filter = isSupportedSingleInputRelayoutOp;
86+
options.filter = isSupportedSingleInputRelayoutOpForResult;
8787
options.inclusive = true;
8888
LogicalResult result = getBackwardSlice(leaf, &slice, options);
8989
if (failed(result)) {

compiler/src/iree/compiler/Codegen/Common/Transforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ struct FoldRelayoutOpIntoMapStorePattern
128128
return failure();
129129
}
130130
// Folding tensor.pad is handled by a separate pattern.
131-
if (!isSupportedSingleInputRelayoutOp(op) || isa<tensor::PadOp>(op)) {
131+
if (!isSupportedSingleInputRelayoutOpForResult(op) ||
132+
isa<tensor::PadOp>(op)) {
132133
return failure();
133134
}
134135
if (failed(foldIntoMapStore(rewriter, op, mapStoreOp))) {

0 commit comments

Comments
 (0)