Skip to content

Commit 0a8a751

Browse files
Hanhan review comment : DPS vs non-DPS
Signed-off-by: Abhishek Varma <abhvarma@amd.com>
1 parent 37a49de commit 0a8a751

1 file changed

Lines changed: 38 additions & 14 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/UB/IR/UBOps.h"
2020
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2121
#include "mlir/IR/IRMapping.h"
22+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2223
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2324
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2425

@@ -546,6 +547,34 @@ bool isSupportedSingleInputRelayoutOpForSource(Operation *op) {
546547
return linalgOp && linalg::isaBroadcastOpInterface(linalgOp).has_value();
547548
}
548549

550+
/// Returns true if `user` is a supported relayout op that uses `result` as a
551+
/// valid source for extending the relayout chain.
552+
///
553+
/// We use different logic for DPS vs non-DPS ops so neither case needs to
554+
/// consider operand numbers:
555+
/// - DPS ops (e.g. linalg broadcast): Iterate getDpsInputOperands(). Only
556+
/// consider uses as an *input* operand.
557+
/// - Non-DPS ops (expand_shape, collapse_shape, transpose, etc.): Caller
558+
/// iterates result.getUsers(), so we know user consumes result. These ops
559+
/// have a single input.
560+
static bool isRelayoutChainExtension(Operation *user, Value result) {
561+
if (!isSupportedSingleInputRelayoutOpForSource(user)) {
562+
return false;
563+
}
564+
auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user);
565+
if (dpsOp) {
566+
for (OpOperand *input : dpsOp.getDpsInputOperands()) {
567+
if (input->get() == result) {
568+
return true;
569+
}
570+
}
571+
return false;
572+
}
573+
// Non-DPS: caller iterates result.getUsers(), so user consumes result.
574+
// Single-input relayout ops have only one place for it.
575+
return true;
576+
}
577+
549578
/// Collects all relayout ops in the chain starting from `relayoutOp`
550579
/// (inclusive). The chain extends through result->user edges where the user is
551580
/// a supported relayout op.
@@ -556,8 +585,7 @@ static void collectRelayoutChain(Operation *relayoutOp,
556585
}
557586
Value result = relayoutOp->getResult(0);
558587
for (Operation *user : result.getUsers()) {
559-
if (isSupportedSingleInputRelayoutOpForSource(user) &&
560-
user->getOperand(0) == result) {
588+
if (isRelayoutChainExtension(user, result)) {
561589
collectRelayoutChain(user, chain);
562590
}
563591
}
@@ -579,16 +607,14 @@ static bool isComplexRelayoutChain(Operation *relayoutOp) {
579607
if (chain.size() < 2) {
580608
return false;
581609
}
582-
bool hasReshape = llvm::any_of(chain, [](Operation *op) {
583-
return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(op);
584-
});
610+
bool hasReshape = llvm::any_of(
611+
chain, llvm::IsaPred<tensor::ExpandShapeOp, tensor::CollapseShapeOp>);
585612
// Need at least one op that is not reshape and not extract_slice (e.g. copy,
586613
// transpose, broadcast, pad).
587-
bool hasOtherNonExtractSlice = llvm::any_of(chain, [](Operation *op) {
588-
return !isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
589-
tensor::ExtractSliceOp>(op);
590-
});
591-
return hasReshape && hasOtherNonExtractSlice;
614+
bool allReshapeOrExtractSlice = llvm::all_of(
615+
chain, llvm::IsaPred<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
616+
tensor::ExtractSliceOp>);
617+
return hasReshape && !allReshapeOrExtractSlice;
592618
}
593619

594620
/// Collects direct relayout op users of `loadResult` that start complex
@@ -599,8 +625,7 @@ getComplexChainRelayoutUsers(Value loadResult,
599625
bool combineNonComplexChains = false) {
600626
SmallPtrSet<Operation *, 4> complexUsers;
601627
for (Operation *user : loadResult.getUsers()) {
602-
if (isSupportedSingleInputRelayoutOpForSource(user) &&
603-
user->getOperand(0) == loadResult &&
628+
if (isRelayoutChainExtension(user, loadResult) &&
604629
(combineNonComplexChains || isComplexRelayoutChain(user))) {
605630
complexUsers.insert(user);
606631
}
@@ -1164,8 +1189,7 @@ struct FoldConsumerRelayoutIntoMapLoadPattern
11641189
Operation *consumerOp = nullptr;
11651190
Value mapLoadResult = mapLoadOp.getResult(0);
11661191
for (Operation *user : mapLoadOp->getUsers()) {
1167-
if (isSupportedSingleInputRelayoutOpForSource(user) &&
1168-
user->getOperand(0) == mapLoadResult) {
1192+
if (isRelayoutChainExtension(user, mapLoadResult)) {
11691193
consumerOp = user;
11701194
break;
11711195
}

0 commit comments

Comments
 (0)