@@ -532,33 +532,63 @@ static MapStoreOp insertIdentityMapStore(RewriterBase &rewriter,
532532 return mapStoreOp;
533533}
534534
535- bool isSupportedSingleInputRelayoutOp (Operation *op) {
536- if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
537- tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
538- linalg::TransposeOp>(op)) {
535+ bool isSupportedSingleInputRelayoutOpForResult (Operation *op) {
536+ return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
537+ tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
538+ linalg::TransposeOp>(op);
539+ }
540+
541+ bool isSupportedSingleInputRelayoutOpForSource (Operation *op) {
542+ if (isSupportedSingleInputRelayoutOpForResult (op)) {
543+ return true ;
544+ }
545+ if (isa<linalg::BroadcastOp>(op)) {
539546 return true ;
540547 }
541548 auto genericOp = dyn_cast<linalg::GenericOp>(op);
542549 return genericOp && linalg::isaBroadcastOpInterface (genericOp).has_value ();
543550}
544551
552+ // / Collects all relayout ops in the chain starting from `relayoutOp`
553+ // / (inclusive). The chain extends through result->user edges where the user is
554+ // / a supported relayout op.
555+ static void collectRelayoutChain (Operation *relayoutOp,
556+ SmallPtrSetImpl<Operation *> &chain) {
557+ if (!chain.insert (relayoutOp).second ) {
558+ return ;
559+ }
560+ Value result = relayoutOp->getResult (0 );
561+ for (Operation *user : result.getUsers ()) {
562+ if (isSupportedSingleInputRelayoutOpForSource (user) &&
563+ user->getOperand (0 ) == result) {
564+ collectRelayoutChain (user, chain);
565+ }
566+ }
567+ }
568+
545569// / Returns true if the relayout op starts a "complex" chain.
546- // / A "complex" chain is :-
547- // / - a chain of relayout ops with length >= 2,
548- // / - or a chain of length 1 with one of the supported linalg relayout ops.
570+ // / A "complex" chain is one that is difficult for bufferization to handle:
571+ // / - chain length >= 2,
572+ // / - contains at least one reshape op (expand_shape or collapse_shape),
573+ // / - and contains at least one op that is not extract_slice.
549574static bool isComplexRelayoutChain (Operation *relayoutOp) {
550- assert (isSupportedSingleInputRelayoutOp (relayoutOp) &&
575+ assert (isSupportedSingleInputRelayoutOpForSource (relayoutOp) &&
551576 " expected a supported relayout op" );
552- Value result = relayoutOp->getResult (0 );
553- bool hasRelayoutUser = llvm::any_of (result.getUsers (), [](Operation *user) {
554- return isSupportedSingleInputRelayoutOp (user);
555- });
556- // Chain length >= 2 -> complex.
557- if (hasRelayoutUser) {
558- return true ;
577+ SmallPtrSet<Operation *, 4 > chain;
578+ collectRelayoutChain (relayoutOp, chain);
579+ if (chain.size () < 2 ) {
580+ return false ;
559581 }
560- // Chain length 1: complex only if the op is a linalg op.
561- return isa<linalg::LinalgOp>(relayoutOp);
582+ bool hasReshape = llvm::any_of (chain, [](Operation *op) {
583+ return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(op);
584+ });
585+ // Need at least one op that is not reshape and not extract_slice (e.g. copy,
586+ // transpose, broadcast, pad).
587+ bool hasOtherNonExtractSlice = llvm::any_of (chain, [](Operation *op) {
588+ return !isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
589+ tensor::ExtractSliceOp>(op);
590+ });
591+ return hasReshape && hasOtherNonExtractSlice;
562592}
563593
564594// / Collects direct relayout op users of `loadResult` that start a complex
@@ -567,7 +597,7 @@ static SmallPtrSet<Operation *, 4>
567597getComplexChainRelayoutUsers (Value loadResult) {
568598 SmallPtrSet<Operation *, 4 > complexUsers;
569599 for (Operation *user : loadResult.getUsers ()) {
570- if (isSupportedSingleInputRelayoutOp (user) &&
600+ if (isSupportedSingleInputRelayoutOpForSource (user) &&
571601 user->getOperand (0 ) == loadResult && isComplexRelayoutChain (user)) {
572602 complexUsers.insert (user);
573603 }
@@ -586,11 +616,11 @@ shouldDoReshapesByExpansion(IREE::Codegen::RelayoutCombinationScope scope) {
586616
587617// / Insert identity map_store ops after the given operation if it is a valid
588618// / leaf op of a relayout op chain. A relayout op chain is a sequence of
589- // / relayout ops (defined by `isSupportedSingleInputRelayoutOp `) for which the
590- // / only users of the ops in the chain are relayout ops, except for the leaves
591- // / of the chain. The leaves are simply relayout ops that have non relayout op
592- // / users. The `controlFn` is a callback on the leaf OpResult that provides
593- // / control over whether or not to insert a map_store op.
619+ // / relayout ops (defined by `isSupportedSingleInputRelayoutOpForResult `) for
620+ // / which the only users of the ops in the chain are relayout ops, except for
621+ // / the leaves of the chain. The leaves are simply relayout ops that have non
622+ // / relayout op users. The `controlFn` is a callback on the leaf OpResult that
623+ // / provides control over whether or not to insert a map_store op.
594624struct InsertMapStoreOpPattern : public RewritePattern {
595625 InsertMapStoreOpPattern (MLIRContext *context,
596626 CombineRelayoutOpsControlFnRef controlFn = nullptr ,
@@ -600,12 +630,13 @@ struct InsertMapStoreOpPattern : public RewritePattern {
600630
601631 LogicalResult matchAndRewrite (Operation *op,
602632 PatternRewriter &rewriter) const override {
603- if (!isSupportedSingleInputRelayoutOp (op)) {
633+ if (!isSupportedSingleInputRelayoutOpForResult (op)) {
604634 return failure ();
605635 }
606636 // Relayout ops with only relayout op users are not leaves.
607637 auto isDimOrSupportedRelayoutOp = [](Operation *op) {
608- return isSupportedSingleInputRelayoutOp (op) || isa<tensor::DimOp>(op);
638+ return isSupportedSingleInputRelayoutOpForResult (op) ||
639+ isa<tensor::DimOp>(op);
609640 };
610641 if (llvm::all_of (op->getUsers (), isDimOrSupportedRelayoutOp)) {
611642 return failure ();
@@ -790,7 +821,7 @@ getCombineRelayoutOpsControlFn(IREE::Codegen::RelayoutCombinationScope scope) {
790821 // it, so don't introduce map_store.
791822 llvm::SetVector<Operation *> slice;
792823 BackwardSliceOptions options;
793- options.filter = isSupportedSingleInputRelayoutOp ;
824+ options.filter = isSupportedSingleInputRelayoutOpForResult ;
794825 options.inclusive = true ;
795826 LogicalResult result =
796827 getBackwardSlice (parallelInsertOp.getSource (), &slice, options);
@@ -1021,19 +1052,20 @@ foldExtractSliceIntoMapLoad(RewriterBase &rewriter,
10211052 indexTransformBuilder);
10221053}
10231054
1024- // / Fold a consumer broadcast `linalg.generic` into a producer `map_load`.
1025- static FailureOr<MapLoadOp> foldBroadcastGenericIntoMapLoad (
1026- RewriterBase &rewriter, linalg::GenericOp genericOp, MapLoadOp mapLoadOp) {
1027- assert (genericOp.getDpsInputs ()[0 ] == mapLoadOp.getResult (0 ) &&
1028- " expected map_load to be the producer of genericOp input" );
1029- if (!linalg::isaBroadcastOpInterface (genericOp).has_value ()) {
1030- return rewriter.notifyMatchFailure (genericOp,
1031- " generic op is not a broadcast" );
1055+ // / Fold a consumer broadcast op (named or generic) into a producer `map_load`.
1056+ static FailureOr<MapLoadOp>
1057+ foldBroadcastIntoMapLoad (RewriterBase &rewriter, linalg::LinalgOp broadcastOp,
1058+ MapLoadOp mapLoadOp) {
1059+ assert (broadcastOp.getDpsInputs ()[0 ] == mapLoadOp.getResult (0 ) &&
1060+ " expected map_load to be the producer of broadcast input" );
1061+ if (!linalg::isaBroadcastOpInterface (broadcastOp).has_value ()) {
1062+ return rewriter.notifyMatchFailure (broadcastOp.getOperation (),
1063+ " op is not a broadcast" );
10321064 }
10331065
1034- AffineMap inputMap = genericOp .getIndexingMapsArray ()[0 ];
1066+ AffineMap inputMap = broadcastOp .getIndexingMapsArray ()[0 ];
10351067 return foldConsumerIntoMapLoadImpl (
1036- rewriter, genericOp , mapLoadOp,
1068+ rewriter, broadcastOp. getOperation () , mapLoadOp,
10371069 [inputMap](ArrayRef<BlockArgument> indices) -> SmallVector<Value> {
10381070 SmallVector<Value> sourceIndices;
10391071 sourceIndices.reserve (inputMap.getNumResults ());
@@ -1112,8 +1144,15 @@ FailureOr<MapLoadOp> foldIntoMapLoad(RewriterBase &rewriter, Operation *op,
11121144 .Case <tensor::PadOp>([&](tensor::PadOp padOp) {
11131145 return foldPadIntoMapLoad (rewriter, padOp, mapLoadOp);
11141146 })
1147+ .Case <linalg::BroadcastOp>([&](linalg::BroadcastOp broadcastOp) {
1148+ return foldBroadcastIntoMapLoad (
1149+ rewriter, cast<linalg::LinalgOp>(broadcastOp.getOperation ()),
1150+ mapLoadOp);
1151+ })
11151152 .Case <linalg::GenericOp>([&](linalg::GenericOp genericOp) {
1116- return foldBroadcastGenericIntoMapLoad (rewriter, genericOp, mapLoadOp);
1153+ return foldBroadcastIntoMapLoad (
1154+ rewriter, cast<linalg::LinalgOp>(genericOp.getOperation ()),
1155+ mapLoadOp);
11171156 })
11181157 .Default ([](Operation *) { return failure (); });
11191158}
@@ -1128,7 +1167,7 @@ struct FoldConsumerRelayoutIntoMapLoadPattern
11281167 // Find a consumer relayout op.
11291168 Operation *consumerOp = nullptr ;
11301169 for (Operation *user : mapLoadOp->getUsers ()) {
1131- if (isSupportedSingleInputRelayoutOp (user)) {
1170+ if (isSupportedSingleInputRelayoutOpForSource (user)) {
11321171 consumerOp = user;
11331172 break ;
11341173 }
0 commit comments