66
77#include " iree/compiler/Codegen/Common/Transforms.h"
88#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
9- #include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
109#include " iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h"
11- #include " iree/compiler/Codegen/Utils/GPUUtils.h"
12- #include " iree/compiler/Codegen/Utils/Utils.h"
13- #include " mlir/Dialect/Affine/IR/AffineOps.h"
10+ #include " iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
1411#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1512#include " mlir/Dialect/Tensor/IR/Tensor.h"
1613#include " mlir/Dialect/Utils/StaticValueUtils.h"
@@ -24,305 +21,6 @@ namespace mlir::iree_compiler {
2421
2522namespace {
2623
27- // / Calculate the expanded shape of `dest` if it can be expanded with the inner
28- // / expanded sizes of `sliceStaticSizes`. Returns failure if such expansion is
29- // / not possible.
30- static LogicalResult
31- getExpandedShape (SmallVector<ReassociationIndices> reIndices,
32- ArrayRef<int64_t > sliceStaticSizes, Value dest,
33- SmallVectorImpl<int64_t > &expandedShape,
34- SmallVectorImpl<int64_t > &totalInnerSizes) {
35- auto destType = dyn_cast<ShapedType>(dest.getType ());
36- if (!destType) {
37- return failure ();
38- }
39- // TODO (nirvedhmeshram): Support rank reducing parallel_insert_slice.
40- if (reIndices.size () != destType.getShape ().size ()) {
41- return failure ();
42- }
43- // Iterator to insert outer sizes.
44- auto outerShapeIdx = 0 ;
45- for (auto [reassociations, destSize] :
46- llvm::zip_equal (reIndices, destType.getShape ())) {
47- // Dynamic destination dims that are not getting expanded are allowed.
48- if (ShapedType::isDynamic (destSize) && reassociations.size () == 1 ) {
49- expandedShape.insert (expandedShape.begin () + outerShapeIdx, destSize);
50- outerShapeIdx++;
51- totalInnerSizes.push_back (1 );
52- continue ;
53- }
54- // Dynamic destination dims that are expanded are currently unsupported but
55- // this support can be added if needed.
56- if (ShapedType::isDynamic (destSize)) {
57- return failure ();
58- }
59- int64_t totalInnerSize = 1 ;
60- for (int64_t reasociation : llvm::drop_begin (reassociations)) {
61- int64_t expandedInnerSize = sliceStaticSizes[reasociation];
62- // It is not safe to do this pattern if inner dimensions are dynamic.
63- if (ShapedType::isDynamic (expandedInnerSize)) {
64- return failure ();
65- }
66- expandedShape.push_back (expandedInnerSize);
67- totalInnerSize *= expandedInnerSize;
68- }
69- if (destSize % totalInnerSize != 0 ) {
70- return failure ();
71- }
72- totalInnerSizes.push_back (totalInnerSize);
73- // insert the outer size in front of any inner sizes.
74- expandedShape.insert (expandedShape.begin () + outerShapeIdx,
75- destSize / totalInnerSize);
76- // set up the iterator for the next uncollapsed dimension.
77- outerShapeIdx = expandedShape.size ();
78- }
79- return success ();
80- }
81-
82- // / Check if the users of the expanded scf.forall destination can be updated to
83- // / account for the expand. If not we bail out. There are two supported users
84- // / which are extract_slice -> expand_shape with the same exact reassociation
85- // / map as the collapse op to be hoisted out or the root parallel_insert_slice.
86- static LogicalResult verifyAndCollectExpandableUsers (
87- Value insertDest, SmallVector<ReassociationIndices> reIndices,
88- tensor::ParallelInsertSliceOp parallelInsertOp,
89- SmallVector<tensor::ExtractSliceOp> &expandableUsers) {
90- for (Operation *user : insertDest.getUsers ()) {
91- if (user == parallelInsertOp) {
92- continue ;
93- }
94- auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
95- if (!extractSliceOp) {
96- return failure ();
97- }
98- if (extractSliceOp.getMixedSizes () != parallelInsertOp.getMixedSizes ()) {
99- return failure ();
100- }
101- if (extractSliceOp.getMixedOffsets () !=
102- parallelInsertOp.getMixedOffsets ()) {
103- return failure ();
104- }
105- for (Operation *user : extractSliceOp->getUsers ()) {
106- auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
107- if (!expandShapeOp) {
108- return failure ();
109- }
110- SmallVector<ReassociationIndices> expandReIndices =
111- expandShapeOp.getReassociationIndices ();
112- if (reIndices != expandReIndices) {
113- return failure ();
114- }
115- }
116- expandableUsers.push_back (extractSliceOp);
117- }
118- return success ();
119- }
120-
121- // / Utility to expand the pre-verified expandable users of the scf.forall
122- // / output.
123- static void
124- expandVerifiedUsers (PatternRewriter &rewriter, Location loc, MLIRContext *ctx,
125- SmallVector<tensor::ExtractSliceOp> expandableUsers,
126- SmallVector<int64_t > totalInnerSizes,
127- SmallVector<ReassociationIndices> reIndices,
128- scf::ForallOp forallOp,
129- tensor::ParallelInsertSliceOp parallelInsertOp) {
130- // compute the offsets,sizes,strides in the expanded dimensions.
131- auto computeExpandedAccess = [&](ArrayRef<OpFoldResult> mixedOffsets,
132- ShapedType resultType)
133- -> std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
134- SmallVector<OpFoldResult>> {
135- SmallVector<OpFoldResult> expandedOffsets;
136- auto expandedOffsetsIdx = 0 ;
137-
138- for (auto [index, offset] : llvm::enumerate (mixedOffsets)) {
139- // Add zero offsets for the extra dimensions from reIndices.
140- for (size_t i = 1 , e = reIndices[index].size (); i < e; ++i) {
141- expandedOffsets.push_back (getAsIndexOpFoldResult (ctx, 0 ));
142- }
143- Value offsetVal = getValueOrCreateConstantIndexOp (rewriter, loc, offset);
144- // Make sure we insert after offset.
145- rewriter.setInsertionPointAfterValue (offsetVal);
146- // Compute the outer dimension expression.
147- AffineExpr s0, s1;
148- bindSymbols (rewriter.getContext (), s0, s1);
149- AffineExpr outerDimExpr = (s0).floorDiv (s1);
150- // Insert computed offset using affine expression.
151- expandedOffsets.insert (
152- expandedOffsets.begin () + expandedOffsetsIdx,
153- affine::makeComposedFoldedAffineApply (
154- rewriter, loc, outerDimExpr,
155- {offsetVal, rewriter.getIndexAttr (totalInnerSizes[index])}));
156-
157- expandedOffsetsIdx = expandedOffsets.size ();
158- }
159- SmallVector<OpFoldResult> expandedSizes =
160- getAsIndexOpFoldResult (ctx, resultType.getShape ());
161- SmallVector<OpFoldResult> expandedStrides (resultType.getRank (),
162- rewriter.getIndexAttr (1 ));
163- return {expandedOffsets, expandedSizes, expandedStrides};
164- };
165- auto collapseShapeOp =
166- parallelInsertOp.getSource ().getDefiningOp <tensor::CollapseShapeOp>();
167- RankedTensorType resultType = collapseShapeOp.getSrcType ();
168- auto [expandedOffsets, expandedSizes, expandedStrides] =
169- computeExpandedAccess (parallelInsertOp.getMixedOffsets (), resultType);
170- rewriter.setInsertionPoint (parallelInsertOp);
171- rewriter.replaceOpWithNewOp <tensor::ParallelInsertSliceOp>(
172- parallelInsertOp, collapseShapeOp.getSrc (), parallelInsertOp.getDest (),
173- expandedOffsets, expandedSizes, expandedStrides);
174- for (tensor::ExtractSliceOp extractSliceOp : expandableUsers) {
175- rewriter.setInsertionPoint (extractSliceOp);
176- auto newExtractSliceOp =
177- rewriter.replaceOpWithNewOp <tensor::ExtractSliceOp>(
178- extractSliceOp, resultType, extractSliceOp.getSource (),
179- expandedOffsets, expandedSizes, expandedStrides);
180- for (Operation *user : newExtractSliceOp->getUsers ()) {
181- auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
182- expandShapeOp->replaceAllUsesWith (newExtractSliceOp);
183- }
184- }
185- return ;
186- }
187-
188- // / This pattern expands destination of workgroup mapped scf.foralls by
189- // / hoisting out collapse_shape op consumed by its parallel.insert_slice op.
190- struct ExpandDestinationForallOp final
191- : OpRewritePattern<tensor::ParallelInsertSliceOp> {
192- using Base::Base;
193- LogicalResult matchAndRewrite (tensor::ParallelInsertSliceOp parallelInsertOp,
194- PatternRewriter &rewriter) const override {
195- Location loc = parallelInsertOp.getLoc ();
196- MLIRContext *ctx = getContext ();
197- auto collapseOp =
198- parallelInsertOp.getSource ().getDefiningOp <tensor::CollapseShapeOp>();
199- // No collapse op to hoist out.
200- if (!collapseOp) {
201- return failure ();
202- }
203-
204- // Ignore trivially foldable collapse ops.
205- if (collapseOp.getSrcType ().getRank () ==
206- collapseOp.getResultType ().getRank ()) {
207- return failure ();
208- }
209-
210- // Get the destination to expand.
211- Value insertDest = parallelInsertOp.getDest ();
212-
213- // Get the enclosing scf.forall op.
214- OpResult tiedResult = parallelInsertOp.getTiedOpResult ();
215- int64_t tiedResultIdx = tiedResult.getResultNumber ();
216-
217- auto forallOp = dyn_cast<scf::ForallOp>(tiedResult.getOwner ());
218- if (!forallOp) {
219- return failure ();
220- }
221-
222- SmallVector<int64_t > expandedDestShape;
223- SmallVector<int64_t > totalInnerSizes;
224- // Get the shape of the outer expand which will be the new destination
225- // of the scf.forall and the total size of inner dimensions per uncollapsed
226- // dimension.
227- SmallVector<ReassociationIndices> reIndices =
228- collapseOp.getReassociationIndices ();
229- if (failed (getExpandedShape (reIndices, collapseOp.getSrcType ().getShape (),
230- insertDest, expandedDestShape,
231- totalInnerSizes))) {
232- return failure ();
233- }
234-
235- // We only want this pattern if the forall op result is being written to a
236- // full slice, or an expandable buffer. Otherwise the hoisted collapse op is
237- // not foldable.
238- for (Operation *foralluser : tiedResult.getUsers ()) {
239- auto storeOp =
240- dyn_cast<IREE::TensorExt::DispatchTensorStoreOp>(foralluser);
241- if (storeOp && isFullSlice (storeOp, storeOp.getTargetType (),
242- storeOp.getTargetDims ())) {
243- continue ;
244- }
245- auto storeToBufferOp =
246- dyn_cast<IREE::Codegen::StoreToBufferOp>(foralluser);
247- if (!storeToBufferOp) {
248- return failure ();
249- }
250- MemRefType bufferType = storeToBufferOp.getBuffer ().getType ();
251- if (failed (memref::ExpandShapeOp::computeExpandedType (
252- bufferType, expandedDestShape, reIndices))) {
253- return failure ();
254- }
255- }
256-
257- // This allows us to assume that the extract/inserts in the loop are
258- // disjoint and makes the application of this pattern safe.
259- if (!forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>(
260- forallOp)) {
261- return failure ();
262- }
263-
264- // Verify that the users of destination are valid to expand and collect all
265- // such users.
266- SmallVector<tensor::ExtractSliceOp> expandableUsers;
267- if (failed (verifyAndCollectExpandableUsers (
268- insertDest, collapseOp.getReassociationIndices (), parallelInsertOp,
269- expandableUsers))) {
270- return failure ();
271- }
272-
273- // Expand the users of the destination.
274- rewriter.setInsertionPointToStart (forallOp.getBody ());
275- expandVerifiedUsers (rewriter, loc, ctx, expandableUsers, totalInnerSizes,
276- reIndices, forallOp, parallelInsertOp);
277- rewriter.setInsertionPoint (forallOp);
278-
279- // This pattern only supports forall ops with single
280- // output.
281- SmallVector<Value> forallOutputs (forallOp.getOutputs ());
282- // Create the expand -> new scf.forall -> collapse chain.
283- auto expandedDestType =
284- cast<RankedTensorType>(forallOutputs[tiedResultIdx].getType ())
285- .clone (expandedDestShape);
286- auto expandedDest =
287- tensor::ExpandShapeOp::create (rewriter, loc, expandedDestType,
288- forallOutputs[tiedResultIdx], reIndices);
289-
290- forallOutputs[tiedResultIdx] = expandedDest;
291-
292- scf::ForallOp newForallOp = scf::ForallOp::create (
293- rewriter, loc, forallOp.getMixedLowerBound (),
294- forallOp.getMixedUpperBound (), forallOp.getMixedStep (), forallOutputs,
295- forallOp.getMappingAttr ());
296-
297- auto collapsedResultOp = tensor::CollapseShapeOp::create (
298- rewriter, loc,
299- cast<ShapedType>(forallOp->getResult (tiedResultIdx).getType ()),
300- newForallOp->getResult (tiedResultIdx), reIndices);
301-
302- // Merge the old scf.forall block which has the expanded users into the new
303- // scf.forall which has the expanded destination.
304- SmallVector<Value> argReplacements (newForallOp.getInductionVars ());
305- argReplacements.append (newForallOp.getRegionIterArgs ().begin (),
306- newForallOp.getRegionIterArgs ().end ());
307- scf::InParallelOp parallelTerminator = newForallOp.getTerminator ();
308- parallelTerminator->erase ();
309- rewriter.mergeBlocks (forallOp.getBody (), newForallOp.getBody (),
310- argReplacements);
311-
312- // Replaces the uses of the old scf.forall with the new scf.forall
313- for (int idx = 0 ; idx < forallOp->getNumResults (); ++idx) {
314- if (idx == tiedResultIdx) {
315- forallOp->getResult (idx).replaceAllUsesWith (
316- collapsedResultOp->getResult (0 ));
317- } else {
318- forallOp->getResult (idx).replaceAllUsesWith (
319- newForallOp->getResult (idx));
320- }
321- }
322- return success ();
323- }
324- };
325-
32624// / This pattern hoists expand_shape & collapse_shape ops out of scf.for loops.
32725struct ExpandDestinationForOp final : OpRewritePattern<scf::YieldOp> {
32826 using Base::Base;
@@ -551,9 +249,9 @@ void PropagateReshapesByExpansionPass::runOnOperation() {
551249 context);
552250 populateReshapeToInterfaceTensorPatterns (bubbleExpandShapePatterns);
553251 populateFoldTensorReshapeIntoBufferPatterns (bubbleExpandShapePatterns);
252+ populateExpandDestinationForallPatterns (bubbleExpandShapePatterns);
554253 bubbleExpandShapePatterns
555- .add <ExpandDestinationForallOp, ExpandDestinationForOp,
556- SwapInnerBitcastWithExtractSlice>(context);
254+ .add <ExpandDestinationForOp, SwapInnerBitcastWithExtractSlice>(context);
557255
558256 if (failed (applyPatternsGreedily (getOperation (),
559257 std::move (bubbleExpandShapePatterns)))) {
0 commit comments