Skip to content

Commit 6f205c9

Browse files
Hanhan + Max's review comment
Signed-off-by: Abhishek Varma <abhvarma@amd.com>
1 parent cfd4d88 commit 6f205c9

5 files changed

Lines changed: 169 additions & 110 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -532,33 +532,63 @@ static MapStoreOp insertIdentityMapStore(RewriterBase &rewriter,
532532
return mapStoreOp;
533533
}
534534

535-
bool isSupportedSingleInputRelayoutOp(Operation *op) {
536-
if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
537-
tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
538-
linalg::TransposeOp>(op)) {
535+
bool isSupportedSingleInputRelayoutOpForResult(Operation *op) {
536+
return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
537+
tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
538+
linalg::TransposeOp>(op);
539+
}
540+
541+
bool isSupportedSingleInputRelayoutOpForSource(Operation *op) {
542+
if (isSupportedSingleInputRelayoutOpForResult(op)) {
543+
return true;
544+
}
545+
if (isa<linalg::BroadcastOp>(op)) {
539546
return true;
540547
}
541548
auto genericOp = dyn_cast<linalg::GenericOp>(op);
542549
return genericOp && linalg::isaBroadcastOpInterface(genericOp).has_value();
543550
}
544551

552+
/// Collects all relayout ops in the chain starting from `relayoutOp`
553+
/// (inclusive). The chain extends through result->user edges where the user is
554+
/// a supported relayout op.
555+
static void collectRelayoutChain(Operation *relayoutOp,
556+
SmallPtrSetImpl<Operation *> &chain) {
557+
if (!chain.insert(relayoutOp).second) {
558+
return;
559+
}
560+
Value result = relayoutOp->getResult(0);
561+
for (Operation *user : result.getUsers()) {
562+
if (isSupportedSingleInputRelayoutOpForSource(user) &&
563+
user->getOperand(0) == result) {
564+
collectRelayoutChain(user, chain);
565+
}
566+
}
567+
}
568+
545569
/// Returns true if the relayout op starts a "complex" chain.
546-
/// A "complex" chain is :-
547-
/// - a chain of relayout ops with length >= 2,
548-
/// - or a chain of length 1 with one of the supported linalg relayout ops.
570+
/// A "complex" chain is one that is difficult for bufferization to handle:
571+
/// - chain length >= 2,
572+
/// - contains at least one reshape op (expand_shape or collapse_shape),
573+
/// - and contains at least one op that is not extract_slice.
549574
static bool isComplexRelayoutChain(Operation *relayoutOp) {
550-
assert(isSupportedSingleInputRelayoutOp(relayoutOp) &&
575+
assert(isSupportedSingleInputRelayoutOpForSource(relayoutOp) &&
551576
"expected a supported relayout op");
552-
Value result = relayoutOp->getResult(0);
553-
bool hasRelayoutUser = llvm::any_of(result.getUsers(), [](Operation *user) {
554-
return isSupportedSingleInputRelayoutOp(user);
555-
});
556-
// Chain length >= 2 -> complex.
557-
if (hasRelayoutUser) {
558-
return true;
577+
SmallPtrSet<Operation *, 4> chain;
578+
collectRelayoutChain(relayoutOp, chain);
579+
if (chain.size() < 2) {
580+
return false;
559581
}
560-
// Chain length 1: complex only if the op is a linalg op.
561-
return isa<linalg::LinalgOp>(relayoutOp);
582+
bool hasReshape = llvm::any_of(chain, [](Operation *op) {
583+
return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(op);
584+
});
585+
// Need at least one op that is not reshape and not extract_slice (e.g. copy,
586+
// transpose, broadcast, pad).
587+
bool hasOtherNonExtractSlice = llvm::any_of(chain, [](Operation *op) {
588+
return !isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
589+
tensor::ExtractSliceOp>(op);
590+
});
591+
return hasReshape && hasOtherNonExtractSlice;
562592
}
563593

564594
/// Collects direct relayout op users of `loadResult` that start a complex
@@ -567,7 +597,7 @@ static SmallPtrSet<Operation *, 4>
567597
getComplexChainRelayoutUsers(Value loadResult) {
568598
SmallPtrSet<Operation *, 4> complexUsers;
569599
for (Operation *user : loadResult.getUsers()) {
570-
if (isSupportedSingleInputRelayoutOp(user) &&
600+
if (isSupportedSingleInputRelayoutOpForSource(user) &&
571601
user->getOperand(0) == loadResult && isComplexRelayoutChain(user)) {
572602
complexUsers.insert(user);
573603
}
@@ -586,11 +616,11 @@ shouldDoReshapesByExpansion(IREE::Codegen::RelayoutCombinationScope scope) {
586616

587617
/// Insert identity map_store ops after the given operation if it is a valid
588618
/// leaf op of a relayout op chain. A relayout op chain is a sequence of
589-
/// relayout ops (defined by `isSupportedSingleInputRelayoutOp`) for which the
590-
/// only users of the ops in the chain are relayout ops, except for the leaves
591-
/// of the chain. The leaves are simply relayout ops that have non relayout op
592-
/// users. The `controlFn` is a callback on the leaf OpResult that provides
593-
/// control over whether or not to insert a map_store op.
619+
/// relayout ops (defined by `isSupportedSingleInputRelayoutOpForResult`) for
620+
/// which the only users of the ops in the chain are relayout ops, except for
621+
/// the leaves of the chain. The leaves are simply relayout ops that have non
622+
/// relayout op users. The `controlFn` is a callback on the leaf OpResult that
623+
/// provides control over whether or not to insert a map_store op.
594624
struct InsertMapStoreOpPattern : public RewritePattern {
595625
InsertMapStoreOpPattern(MLIRContext *context,
596626
CombineRelayoutOpsControlFnRef controlFn = nullptr,
@@ -600,12 +630,13 @@ struct InsertMapStoreOpPattern : public RewritePattern {
600630

601631
LogicalResult matchAndRewrite(Operation *op,
602632
PatternRewriter &rewriter) const override {
603-
if (!isSupportedSingleInputRelayoutOp(op)) {
633+
if (!isSupportedSingleInputRelayoutOpForResult(op)) {
604634
return failure();
605635
}
606636
// Relayout ops with only relayout op users are not leaves.
607637
auto isDimOrSupportedRelayoutOp = [](Operation *op) {
608-
return isSupportedSingleInputRelayoutOp(op) || isa<tensor::DimOp>(op);
638+
return isSupportedSingleInputRelayoutOpForResult(op) ||
639+
isa<tensor::DimOp>(op);
609640
};
610641
if (llvm::all_of(op->getUsers(), isDimOrSupportedRelayoutOp)) {
611642
return failure();
@@ -790,7 +821,7 @@ getCombineRelayoutOpsControlFn(IREE::Codegen::RelayoutCombinationScope scope) {
790821
// it, so don't introduce map_store.
791822
llvm::SetVector<Operation *> slice;
792823
BackwardSliceOptions options;
793-
options.filter = isSupportedSingleInputRelayoutOp;
824+
options.filter = isSupportedSingleInputRelayoutOpForResult;
794825
options.inclusive = true;
795826
LogicalResult result =
796827
getBackwardSlice(parallelInsertOp.getSource(), &slice, options);
@@ -1021,19 +1052,20 @@ foldExtractSliceIntoMapLoad(RewriterBase &rewriter,
10211052
indexTransformBuilder);
10221053
}
10231054

1024-
/// Fold a consumer broadcast `linalg.generic` into a producer `map_load`.
1025-
static FailureOr<MapLoadOp> foldBroadcastGenericIntoMapLoad(
1026-
RewriterBase &rewriter, linalg::GenericOp genericOp, MapLoadOp mapLoadOp) {
1027-
assert(genericOp.getDpsInputs()[0] == mapLoadOp.getResult(0) &&
1028-
"expected map_load to be the producer of genericOp input");
1029-
if (!linalg::isaBroadcastOpInterface(genericOp).has_value()) {
1030-
return rewriter.notifyMatchFailure(genericOp,
1031-
"generic op is not a broadcast");
1055+
/// Fold a consumer broadcast op (named or generic) into a producer `map_load`.
1056+
static FailureOr<MapLoadOp>
1057+
foldBroadcastIntoMapLoad(RewriterBase &rewriter, linalg::LinalgOp broadcastOp,
1058+
MapLoadOp mapLoadOp) {
1059+
assert(broadcastOp.getDpsInputs()[0] == mapLoadOp.getResult(0) &&
1060+
"expected map_load to be the producer of broadcast input");
1061+
if (!linalg::isaBroadcastOpInterface(broadcastOp).has_value()) {
1062+
return rewriter.notifyMatchFailure(broadcastOp.getOperation(),
1063+
"op is not a broadcast");
10321064
}
10331065

1034-
AffineMap inputMap = genericOp.getIndexingMapsArray()[0];
1066+
AffineMap inputMap = broadcastOp.getIndexingMapsArray()[0];
10351067
return foldConsumerIntoMapLoadImpl(
1036-
rewriter, genericOp, mapLoadOp,
1068+
rewriter, broadcastOp.getOperation(), mapLoadOp,
10371069
[inputMap](ArrayRef<BlockArgument> indices) -> SmallVector<Value> {
10381070
SmallVector<Value> sourceIndices;
10391071
sourceIndices.reserve(inputMap.getNumResults());
@@ -1112,8 +1144,15 @@ FailureOr<MapLoadOp> foldIntoMapLoad(RewriterBase &rewriter, Operation *op,
11121144
.Case<tensor::PadOp>([&](tensor::PadOp padOp) {
11131145
return foldPadIntoMapLoad(rewriter, padOp, mapLoadOp);
11141146
})
1147+
.Case<linalg::BroadcastOp>([&](linalg::BroadcastOp broadcastOp) {
1148+
return foldBroadcastIntoMapLoad(
1149+
rewriter, cast<linalg::LinalgOp>(broadcastOp.getOperation()),
1150+
mapLoadOp);
1151+
})
11151152
.Case<linalg::GenericOp>([&](linalg::GenericOp genericOp) {
1116-
return foldBroadcastGenericIntoMapLoad(rewriter, genericOp, mapLoadOp);
1153+
return foldBroadcastIntoMapLoad(
1154+
rewriter, cast<linalg::LinalgOp>(genericOp.getOperation()),
1155+
mapLoadOp);
11171156
})
11181157
.Default([](Operation *) { return failure(); });
11191158
}
@@ -1128,7 +1167,7 @@ struct FoldConsumerRelayoutIntoMapLoadPattern
11281167
// Find a consumer relayout op.
11291168
Operation *consumerOp = nullptr;
11301169
for (Operation *user : mapLoadOp->getUsers()) {
1131-
if (isSupportedSingleInputRelayoutOp(user)) {
1170+
if (isSupportedSingleInputRelayoutOpForSource(user)) {
11321171
consumerOp = user;
11331172
break;
11341173
}

compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,13 @@ CombineRelayoutOpsControlFn
5555
getCombineRelayoutOpsControlFn(IREE::Codegen::RelayoutCombinationScope scope);
5656

5757
/// Returns true if the `op` type has a folding pattern into
58-
/// iree_linalg_ext.map_store or iree_linalg_ext.map_load.
59-
bool isSupportedSingleInputRelayoutOp(Operation *op);
58+
/// iree_linalg_ext.map_store (used by CombineResultLayoutTransformationPass).
59+
bool isSupportedSingleInputRelayoutOpForResult(Operation *op);
60+
61+
/// Returns true if the `op` type has a folding pattern into
62+
/// iree_linalg_ext.map_load (used by CombineSourceLayoutTransformationPass).
63+
/// Includes broadcast GenericOp in addition to the ops supported for Result.
64+
bool isSupportedSingleInputRelayoutOpForSource(Operation *op);
6065

6166
/// Fold the `op` into the `mapLoadOp` and return the resulting map_load,
6267
/// or failure if the transformation is not supported. The `op` should be a

compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineLayoutTransformation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ static bool gpuRelayoutCombinationControlFn(OpResult leaf) {
8383
}
8484
llvm::SetVector<Operation *> slice;
8585
BackwardSliceOptions options;
86-
options.filter = isSupportedSingleInputRelayoutOp;
86+
options.filter = isSupportedSingleInputRelayoutOpForResult;
8787
options.inclusive = true;
8888
LogicalResult result = getBackwardSlice(leaf, &slice, options);
8989
if (failed(result)) {

compiler/src/iree/compiler/Codegen/Common/Transforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ struct FoldRelayoutOpIntoMapStorePattern
128128
return failure();
129129
}
130130
// Folding tensor.pad is handled by a separate pattern.
131-
if (!isSupportedSingleInputRelayoutOp(op) || isa<tensor::PadOp>(op)) {
131+
if (!isSupportedSingleInputRelayoutOpForResult(op) ||
132+
isa<tensor::PadOp>(op)) {
132133
return failure();
133134
}
134135
if (failed(foldIntoMapStore(rewriter, op, mapStoreOp))) {

0 commit comments

Comments
 (0)