@@ -532,33 +532,60 @@ static MapStoreOp insertIdentityMapStore(RewriterBase &rewriter,
532532 return mapStoreOp;
533533}
534534
535- bool isSupportedSingleInputRelayoutOp (Operation *op) {
536- if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
537- tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
538- linalg::TransposeOp>(op)) {
535+ bool isSupportedSingleInputRelayoutOpForResult (Operation *op) {
536+ return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
537+ tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
538+ linalg::TransposeOp>(op);
539+ }
540+
541+ bool isSupportedSingleInputRelayoutOpForSource (Operation *op) {
542+ if (isSupportedSingleInputRelayoutOpForResult (op)) {
539543 return true ;
540544 }
541- auto genericOp = dyn_cast<linalg::GenericOp>(op);
542- return genericOp && linalg::isaBroadcastOpInterface (genericOp).has_value ();
545+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
546+ return linalgOp && linalg::isaBroadcastOpInterface (linalgOp).has_value ();
547+ }
548+
549+ // / Collects all relayout ops in the chain starting from `relayoutOp`
550+ // / (inclusive). The chain extends through result->user edges where the user is
551+ // / a supported relayout op.
552+ static void collectRelayoutChain (Operation *relayoutOp,
553+ SmallPtrSetImpl<Operation *> &chain) {
554+ if (!chain.insert (relayoutOp).second ) {
555+ return ;
556+ }
557+ Value result = relayoutOp->getResult (0 );
558+ for (Operation *user : result.getUsers ()) {
559+ if (isSupportedSingleInputRelayoutOpForSource (user) &&
560+ user->getOperand (0 ) == result) {
561+ collectRelayoutChain (user, chain);
562+ }
563+ }
543564}
544565
545566// / Returns true if the relayout op starts a "complex" chain.
546- // / A "complex" chain is :-
547- // / - a chain of relayout ops with length >= 2,
548- // / - or a chain of length 1 with one of the supported linalg relayout ops.
567+ // / A "complex" chain is one that is difficult for bufferization to handle:
568+ // / - chain length >= 2,
569+ // / - contains at least one reshape op (expand_shape or collapse_shape),
570+ // / - and contains at least one op that is not extract_slice.
549571static bool isComplexRelayoutChain (Operation *relayoutOp) {
550- assert (isSupportedSingleInputRelayoutOp (relayoutOp) &&
572+ assert (isSupportedSingleInputRelayoutOpForSource (relayoutOp) &&
551573 " expected a supported relayout op" );
552- Value result = relayoutOp->getResult (0 );
553- bool hasRelayoutUser = llvm::any_of (result.getUsers (), [](Operation *user) {
554- return isSupportedSingleInputRelayoutOp (user);
555- });
556- // Chain length >= 2 -> complex.
557- if (hasRelayoutUser) {
558- return true ;
574+ SmallPtrSet<Operation *, 4 > chain;
575+ collectRelayoutChain (relayoutOp, chain);
576+ if (chain.size () < 2 ) {
577+ return false ;
559578 }
560- // Chain length 1: complex only if the op is a linalg op.
561- return isa<linalg::LinalgOp>(relayoutOp);
579+ bool hasReshape = llvm::any_of (chain, [](Operation *op) {
580+ return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(op);
581+ });
582+ // Need at least one op that is not reshape and not extract_slice (e.g. copy,
583+ // transpose, broadcast, pad).
584+ bool hasOtherNonExtractSlice = llvm::any_of (chain, [](Operation *op) {
585+ return !isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
586+ tensor::ExtractSliceOp>(op);
587+ });
588+ return hasReshape && hasOtherNonExtractSlice;
562589}
563590
564591// / Collects direct relayout op users of `loadResult` that start a complex
@@ -567,7 +594,7 @@ static SmallPtrSet<Operation *, 4>
567594getComplexChainRelayoutUsers (Value loadResult) {
568595 SmallPtrSet<Operation *, 4 > complexUsers;
569596 for (Operation *user : loadResult.getUsers ()) {
570- if (isSupportedSingleInputRelayoutOp (user) &&
597+ if (isSupportedSingleInputRelayoutOpForSource (user) &&
571598 user->getOperand (0 ) == loadResult && isComplexRelayoutChain (user)) {
572599 complexUsers.insert (user);
573600 }
@@ -586,11 +613,11 @@ shouldDoReshapesByExpansion(IREE::Codegen::RelayoutCombinationScope scope) {
586613
587614// / Insert identity map_store ops after the given operation if it is a valid
588615// / leaf op of a relayout op chain. A relayout op chain is a sequence of
589- // / relayout ops (defined by `isSupportedSingleInputRelayoutOp `) for which the
590- // / only users of the ops in the chain are relayout ops, except for the leaves
591- // / of the chain. The leaves are simply relayout ops that have non relayout op
592- // / users. The `controlFn` is a callback on the leaf OpResult that provides
593- // / control over whether or not to insert a map_store op.
616+ // / relayout ops (defined by `isSupportedSingleInputRelayoutOpForResult `) for
617+ // / which the only users of the ops in the chain are relayout ops, except for
618+ // / the leaves of the chain. The leaves are simply relayout ops that have non
619+ // / relayout op users. The `controlFn` is a callback on the leaf OpResult that
620+ // / provides control over whether or not to insert a map_store op.
594621struct InsertMapStoreOpPattern : public RewritePattern {
595622 InsertMapStoreOpPattern (MLIRContext *context,
596623 CombineRelayoutOpsControlFnRef controlFn = nullptr ,
@@ -600,12 +627,13 @@ struct InsertMapStoreOpPattern : public RewritePattern {
600627
601628 LogicalResult matchAndRewrite (Operation *op,
602629 PatternRewriter &rewriter) const override {
603- if (!isSupportedSingleInputRelayoutOp (op)) {
630+ if (!isSupportedSingleInputRelayoutOpForResult (op)) {
604631 return failure ();
605632 }
606633 // Relayout ops with only relayout op users are not leaves.
607634 auto isDimOrSupportedRelayoutOp = [](Operation *op) {
608- return isSupportedSingleInputRelayoutOp (op) || isa<tensor::DimOp>(op);
635+ return isSupportedSingleInputRelayoutOpForResult (op) ||
636+ isa<tensor::DimOp>(op);
609637 };
610638 if (llvm::all_of (op->getUsers (), isDimOrSupportedRelayoutOp)) {
611639 return failure ();
@@ -790,7 +818,7 @@ getCombineRelayoutOpsControlFn(IREE::Codegen::RelayoutCombinationScope scope) {
790818 // it, so don't introduce map_store.
791819 llvm::SetVector<Operation *> slice;
792820 BackwardSliceOptions options;
793- options.filter = isSupportedSingleInputRelayoutOp ;
821+ options.filter = isSupportedSingleInputRelayoutOpForResult ;
794822 options.inclusive = true ;
795823 LogicalResult result =
796824 getBackwardSlice (parallelInsertOp.getSource (), &slice, options);
@@ -1021,19 +1049,20 @@ foldExtractSliceIntoMapLoad(RewriterBase &rewriter,
10211049 indexTransformBuilder);
10221050}
10231051
1024- // / Fold a consumer broadcast `linalg.generic` into a producer `map_load`.
1025- static FailureOr<MapLoadOp> foldBroadcastGenericIntoMapLoad (
1026- RewriterBase &rewriter, linalg::GenericOp genericOp, MapLoadOp mapLoadOp) {
1027- assert (genericOp.getDpsInputs ()[0 ] == mapLoadOp.getResult (0 ) &&
1028- " expected map_load to be the producer of genericOp input" );
1029- if (!linalg::isaBroadcastOpInterface (genericOp).has_value ()) {
1030- return rewriter.notifyMatchFailure (genericOp,
1031- " generic op is not a broadcast" );
1052+ // / Fold a consumer broadcast op (named or generic) into a producer `map_load`.
1053+ static FailureOr<MapLoadOp>
1054+ foldBroadcastIntoMapLoad (RewriterBase &rewriter, linalg::LinalgOp broadcastOp,
1055+ MapLoadOp mapLoadOp) {
1056+ assert (broadcastOp.getDpsInputs ()[0 ] == mapLoadOp.getResult (0 ) &&
1057+ " expected map_load to be the producer of broadcast input" );
1058+ if (!linalg::isaBroadcastOpInterface (broadcastOp).has_value ()) {
1059+ return rewriter.notifyMatchFailure (broadcastOp.getOperation (),
1060+ " op is not a broadcast" );
10321061 }
10331062
1034- AffineMap inputMap = genericOp .getIndexingMapsArray ()[0 ];
1063+ AffineMap inputMap = broadcastOp .getIndexingMapsArray ()[0 ];
10351064 return foldConsumerIntoMapLoadImpl (
1036- rewriter, genericOp , mapLoadOp,
1065+ rewriter, broadcastOp. getOperation () , mapLoadOp,
10371066 [inputMap](ArrayRef<BlockArgument> indices) -> SmallVector<Value> {
10381067 SmallVector<Value> sourceIndices;
10391068 sourceIndices.reserve (inputMap.getNumResults ());
@@ -1112,8 +1141,8 @@ FailureOr<MapLoadOp> foldIntoMapLoad(RewriterBase &rewriter, Operation *op,
11121141 .Case <tensor::PadOp>([&](tensor::PadOp padOp) {
11131142 return foldPadIntoMapLoad (rewriter, padOp, mapLoadOp);
11141143 })
1115- .Case <linalg::GenericOp >([&](linalg::GenericOp genericOp ) {
1116- return foldBroadcastGenericIntoMapLoad (rewriter, genericOp , mapLoadOp);
1144+ .Case <linalg::LinalgOp >([&](linalg::LinalgOp linalgOp ) {
1145+ return foldBroadcastIntoMapLoad (rewriter, linalgOp , mapLoadOp);
11171146 })
11181147 .Default ([](Operation *) { return failure (); });
11191148}
@@ -1128,7 +1157,7 @@ struct FoldConsumerRelayoutIntoMapLoadPattern
11281157 // Find a consumer relayout op.
11291158 Operation *consumerOp = nullptr ;
11301159 for (Operation *user : mapLoadOp->getUsers ()) {
1131- if (isSupportedSingleInputRelayoutOp (user)) {
1160+ if (isSupportedSingleInputRelayoutOpForSource (user)) {
11321161 consumerOp = user;
11331162 break ;
11341163 }
0 commit comments