1919#include " mlir/Dialect/UB/IR/UBOps.h"
2020#include " mlir/Dialect/Utils/StaticValueUtils.h"
2121#include " mlir/IR/IRMapping.h"
22+ #include " mlir/Interfaces/DestinationStyleOpInterface.h"
2223#include " mlir/Interfaces/ValueBoundsOpInterface.h"
2324#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2425
@@ -546,6 +547,34 @@ bool isSupportedSingleInputRelayoutOpForSource(Operation *op) {
546547 return linalgOp && linalg::isaBroadcastOpInterface (linalgOp).has_value ();
547548}
548549
550+ // / Returns true if `user` is a supported relayout op that uses `result` as a
551+ // / valid source for extending the relayout chain.
552+ // /
553+ // / We use different logic for DPS vs non-DPS ops so neither case needs to
554+ // / consider operand numbers:
555+ // / - DPS ops (e.g. linalg broadcast): Iterate getDpsInputOperands(). Only
556+ // / consider uses as an *input* operand.
557+ // / - Non-DPS ops (expand_shape, collapse_shape, transpose, etc.): Caller
558+ // / iterates result.getUsers(), so we know user consumes result. These ops
559+ // / have a single input.
560+ static bool isRelayoutChainExtension (Operation *user, Value result) {
561+ if (!isSupportedSingleInputRelayoutOpForSource (user)) {
562+ return false ;
563+ }
564+ auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user);
565+ if (dpsOp) {
566+ for (OpOperand *input : dpsOp.getDpsInputOperands ()) {
567+ if (input->get () == result) {
568+ return true ;
569+ }
570+ }
571+ return false ;
572+ }
573+ // Non-DPS: caller iterates result.getUsers(), so user consumes result.
574+ // Single-input relayout ops have only one place for it.
575+ return true ;
576+ }
577+
549578// / Collects all relayout ops in the chain starting from `relayoutOp`
550579// / (inclusive). The chain extends through result->user edges where the user is
551580// / a supported relayout op.
@@ -556,8 +585,7 @@ static void collectRelayoutChain(Operation *relayoutOp,
556585 }
557586 Value result = relayoutOp->getResult (0 );
558587 for (Operation *user : result.getUsers ()) {
559- if (isSupportedSingleInputRelayoutOpForSource (user) &&
560- user->getOperand (0 ) == result) {
588+ if (isRelayoutChainExtension (user, result)) {
561589 collectRelayoutChain (user, chain);
562590 }
563591 }
@@ -579,16 +607,14 @@ static bool isComplexRelayoutChain(Operation *relayoutOp) {
579607 if (chain.size () < 2 ) {
580608 return false ;
581609 }
582- bool hasReshape = llvm::any_of (chain, [](Operation *op) {
583- return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(op);
584- });
610+ bool hasReshape = llvm::any_of (
611+ chain, llvm::IsaPred<tensor::ExpandShapeOp, tensor::CollapseShapeOp>);
585612 // Need at least one op that is not reshape and not extract_slice (e.g. copy,
586613 // transpose, broadcast, pad).
587- bool hasOtherNonExtractSlice = llvm::any_of (chain, [](Operation *op) {
588- return !isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
589- tensor::ExtractSliceOp>(op);
590- });
591- return hasReshape && hasOtherNonExtractSlice;
614+ bool allReshapeOrExtractSlice = llvm::all_of (
615+ chain, llvm::IsaPred<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
616+ tensor::ExtractSliceOp>);
617+ return hasReshape && !allReshapeOrExtractSlice;
592618}
593619
594620// / Collects direct relayout op users of `loadResult` that start complex
@@ -599,8 +625,7 @@ getComplexChainRelayoutUsers(Value loadResult,
599625 bool combineNonComplexChains = false ) {
600626 SmallPtrSet<Operation *, 4 > complexUsers;
601627 for (Operation *user : loadResult.getUsers ()) {
602- if (isSupportedSingleInputRelayoutOpForSource (user) &&
603- user->getOperand (0 ) == loadResult &&
628+ if (isRelayoutChainExtension (user, loadResult) &&
604629 (combineNonComplexChains || isComplexRelayoutChain (user))) {
605630 complexUsers.insert (user);
606631 }
@@ -1164,8 +1189,7 @@ struct FoldConsumerRelayoutIntoMapLoadPattern
11641189 Operation *consumerOp = nullptr ;
11651190 Value mapLoadResult = mapLoadOp.getResult (0 );
11661191 for (Operation *user : mapLoadOp->getUsers ()) {
1167- if (isSupportedSingleInputRelayoutOpForSource (user) &&
1168- user->getOperand (0 ) == mapLoadResult) {
1192+ if (isRelayoutChainExtension (user, mapLoadResult)) {
11691193 consumerOp = user;
11701194 break ;
11711195 }
0 commit comments