Skip to content

Commit a6722ae

Browse files
Max191claude
andauthored
[Codegen] Add pattern to hoist expand_shape out of forall (#23930)
Adds a pattern to collapse scf.forall destinations by hoisting expand_shape ops out to the forall's result. This enables further propagation by collapsing during the IGEMM transformation when we have pre-existing split-reduction loops. The inverse pattern for hoisting collapse_shape and expanding the forall is moved to Codegen/Common/Transforms.cpp because some utilities from that pattern are reused in this one. Signed-off-by: Max Dawkins <max.dawkins@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c490027 commit a6722ae

9 files changed

Lines changed: 820 additions & 305 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
147147
bubbleCollapseShapePatterns, context);
148148
populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
149149
populateFoldTensorReshapeIntoBufferPatterns(bubbleCollapseShapePatterns);
150+
populateSwapExtractWithExpandPattern(bubbleCollapseShapePatterns,
151+
bubbleUpExpansionControlFn);
152+
populateCollapseDestinationForallPatterns(bubbleCollapseShapePatterns);
150153
if (failed(applyPatternsGreedily(funcOp,
151154
std::move(bubbleCollapseShapePatterns)))) {
152155
return failure();

compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp

Lines changed: 3 additions & 305 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66

77
#include "iree/compiler/Codegen/Common/Transforms.h"
88
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
9-
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
109
#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h"
11-
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
12-
#include "iree/compiler/Codegen/Utils/Utils.h"
13-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
1411
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1512
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1613
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -24,305 +21,6 @@ namespace mlir::iree_compiler {
2421

2522
namespace {
2623

27-
/// Calculate the expanded shape of `dest` if it can be expanded with the inner
28-
/// expanded sizes of `sliceStaticSizes`. Returns failure if such expansion is
29-
/// not possible.
30-
static LogicalResult
31-
getExpandedShape(SmallVector<ReassociationIndices> reIndices,
32-
ArrayRef<int64_t> sliceStaticSizes, Value dest,
33-
SmallVectorImpl<int64_t> &expandedShape,
34-
SmallVectorImpl<int64_t> &totalInnerSizes) {
35-
auto destType = dyn_cast<ShapedType>(dest.getType());
36-
if (!destType) {
37-
return failure();
38-
}
39-
// TODO (nirvedhmeshram): Support rank reducing parallel_insert_slice.
40-
if (reIndices.size() != destType.getShape().size()) {
41-
return failure();
42-
}
43-
// Iterator to insert outer sizes.
44-
auto outerShapeIdx = 0;
45-
for (auto [reassociations, destSize] :
46-
llvm::zip_equal(reIndices, destType.getShape())) {
47-
// Dynamic destination dims that are not getting expanded are allowed.
48-
if (ShapedType::isDynamic(destSize) && reassociations.size() == 1) {
49-
expandedShape.insert(expandedShape.begin() + outerShapeIdx, destSize);
50-
outerShapeIdx++;
51-
totalInnerSizes.push_back(1);
52-
continue;
53-
}
54-
// Dynamic destination dims that are expanded are currently unsupported but
55-
// this support can be added if needed.
56-
if (ShapedType::isDynamic(destSize)) {
57-
return failure();
58-
}
59-
int64_t totalInnerSize = 1;
60-
for (int64_t reasociation : llvm::drop_begin(reassociations)) {
61-
int64_t expandedInnerSize = sliceStaticSizes[reasociation];
62-
// It is not safe to do this pattern if inner dimensions are dynamic.
63-
if (ShapedType::isDynamic(expandedInnerSize)) {
64-
return failure();
65-
}
66-
expandedShape.push_back(expandedInnerSize);
67-
totalInnerSize *= expandedInnerSize;
68-
}
69-
if (destSize % totalInnerSize != 0) {
70-
return failure();
71-
}
72-
totalInnerSizes.push_back(totalInnerSize);
73-
// insert the outer size in front of any inner sizes.
74-
expandedShape.insert(expandedShape.begin() + outerShapeIdx,
75-
destSize / totalInnerSize);
76-
// set up the iterator for the next uncollapsed dimension.
77-
outerShapeIdx = expandedShape.size();
78-
}
79-
return success();
80-
}
81-
82-
/// Check if the users of the expanded scf.forall destination can be updated to
83-
/// account for the expand. If not we bail out. There are two supported users
84-
/// which are extract_slice -> expand_shape with the same exact reassociation
85-
/// map as the collapse op to be hoisted out or the root parallel_insert_slice.
86-
static LogicalResult verifyAndCollectExpandableUsers(
87-
Value insertDest, SmallVector<ReassociationIndices> reIndices,
88-
tensor::ParallelInsertSliceOp parallelInsertOp,
89-
SmallVector<tensor::ExtractSliceOp> &expandableUsers) {
90-
for (Operation *user : insertDest.getUsers()) {
91-
if (user == parallelInsertOp) {
92-
continue;
93-
}
94-
auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
95-
if (!extractSliceOp) {
96-
return failure();
97-
}
98-
if (extractSliceOp.getMixedSizes() != parallelInsertOp.getMixedSizes()) {
99-
return failure();
100-
}
101-
if (extractSliceOp.getMixedOffsets() !=
102-
parallelInsertOp.getMixedOffsets()) {
103-
return failure();
104-
}
105-
for (Operation *user : extractSliceOp->getUsers()) {
106-
auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
107-
if (!expandShapeOp) {
108-
return failure();
109-
}
110-
SmallVector<ReassociationIndices> expandReIndices =
111-
expandShapeOp.getReassociationIndices();
112-
if (reIndices != expandReIndices) {
113-
return failure();
114-
}
115-
}
116-
expandableUsers.push_back(extractSliceOp);
117-
}
118-
return success();
119-
}
120-
121-
/// Utility to expand the pre-verified expandable users of the scf.forall
122-
/// output.
123-
static void
124-
expandVerifiedUsers(PatternRewriter &rewriter, Location loc, MLIRContext *ctx,
125-
SmallVector<tensor::ExtractSliceOp> expandableUsers,
126-
SmallVector<int64_t> totalInnerSizes,
127-
SmallVector<ReassociationIndices> reIndices,
128-
scf::ForallOp forallOp,
129-
tensor::ParallelInsertSliceOp parallelInsertOp) {
130-
// compute the offsets,sizes,strides in the expanded dimensions.
131-
auto computeExpandedAccess = [&](ArrayRef<OpFoldResult> mixedOffsets,
132-
ShapedType resultType)
133-
-> std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
134-
SmallVector<OpFoldResult>> {
135-
SmallVector<OpFoldResult> expandedOffsets;
136-
auto expandedOffsetsIdx = 0;
137-
138-
for (auto [index, offset] : llvm::enumerate(mixedOffsets)) {
139-
// Add zero offsets for the extra dimensions from reIndices.
140-
for (size_t i = 1, e = reIndices[index].size(); i < e; ++i) {
141-
expandedOffsets.push_back(getAsIndexOpFoldResult(ctx, 0));
142-
}
143-
Value offsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
144-
// Make sure we insert after offset.
145-
rewriter.setInsertionPointAfterValue(offsetVal);
146-
// Compute the outer dimension expression.
147-
AffineExpr s0, s1;
148-
bindSymbols(rewriter.getContext(), s0, s1);
149-
AffineExpr outerDimExpr = (s0).floorDiv(s1);
150-
// Insert computed offset using affine expression.
151-
expandedOffsets.insert(
152-
expandedOffsets.begin() + expandedOffsetsIdx,
153-
affine::makeComposedFoldedAffineApply(
154-
rewriter, loc, outerDimExpr,
155-
{offsetVal, rewriter.getIndexAttr(totalInnerSizes[index])}));
156-
157-
expandedOffsetsIdx = expandedOffsets.size();
158-
}
159-
SmallVector<OpFoldResult> expandedSizes =
160-
getAsIndexOpFoldResult(ctx, resultType.getShape());
161-
SmallVector<OpFoldResult> expandedStrides(resultType.getRank(),
162-
rewriter.getIndexAttr(1));
163-
return {expandedOffsets, expandedSizes, expandedStrides};
164-
};
165-
auto collapseShapeOp =
166-
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
167-
RankedTensorType resultType = collapseShapeOp.getSrcType();
168-
auto [expandedOffsets, expandedSizes, expandedStrides] =
169-
computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType);
170-
rewriter.setInsertionPoint(parallelInsertOp);
171-
rewriter.replaceOpWithNewOp<tensor::ParallelInsertSliceOp>(
172-
parallelInsertOp, collapseShapeOp.getSrc(), parallelInsertOp.getDest(),
173-
expandedOffsets, expandedSizes, expandedStrides);
174-
for (tensor::ExtractSliceOp extractSliceOp : expandableUsers) {
175-
rewriter.setInsertionPoint(extractSliceOp);
176-
auto newExtractSliceOp =
177-
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
178-
extractSliceOp, resultType, extractSliceOp.getSource(),
179-
expandedOffsets, expandedSizes, expandedStrides);
180-
for (Operation *user : newExtractSliceOp->getUsers()) {
181-
auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user);
182-
expandShapeOp->replaceAllUsesWith(newExtractSliceOp);
183-
}
184-
}
185-
return;
186-
}
187-
188-
/// This pattern expands destination of workgroup mapped scf.foralls by
189-
/// hoisting out collapse_shape op consumed by its parallel.insert_slice op.
190-
struct ExpandDestinationForallOp final
191-
: OpRewritePattern<tensor::ParallelInsertSliceOp> {
192-
using Base::Base;
193-
LogicalResult matchAndRewrite(tensor::ParallelInsertSliceOp parallelInsertOp,
194-
PatternRewriter &rewriter) const override {
195-
Location loc = parallelInsertOp.getLoc();
196-
MLIRContext *ctx = getContext();
197-
auto collapseOp =
198-
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
199-
// No collapse op to hoist out.
200-
if (!collapseOp) {
201-
return failure();
202-
}
203-
204-
// Ignore trivially foldable collapse ops.
205-
if (collapseOp.getSrcType().getRank() ==
206-
collapseOp.getResultType().getRank()) {
207-
return failure();
208-
}
209-
210-
// Get the destination to expand.
211-
Value insertDest = parallelInsertOp.getDest();
212-
213-
// Get the enclosing scf.forall op.
214-
OpResult tiedResult = parallelInsertOp.getTiedOpResult();
215-
int64_t tiedResultIdx = tiedResult.getResultNumber();
216-
217-
auto forallOp = dyn_cast<scf::ForallOp>(tiedResult.getOwner());
218-
if (!forallOp) {
219-
return failure();
220-
}
221-
222-
SmallVector<int64_t> expandedDestShape;
223-
SmallVector<int64_t> totalInnerSizes;
224-
// Get the shape of the outer expand which will be the new destination
225-
// of the scf.forall and the total size of inner dimensions per uncollapsed
226-
// dimension.
227-
SmallVector<ReassociationIndices> reIndices =
228-
collapseOp.getReassociationIndices();
229-
if (failed(getExpandedShape(reIndices, collapseOp.getSrcType().getShape(),
230-
insertDest, expandedDestShape,
231-
totalInnerSizes))) {
232-
return failure();
233-
}
234-
235-
// We only want this pattern if the forall op result is being written to a
236-
// full slice, or an expandable buffer. Otherwise the hoisted collapse op is
237-
// not foldable.
238-
for (Operation *foralluser : tiedResult.getUsers()) {
239-
auto storeOp =
240-
dyn_cast<IREE::TensorExt::DispatchTensorStoreOp>(foralluser);
241-
if (storeOp && isFullSlice(storeOp, storeOp.getTargetType(),
242-
storeOp.getTargetDims())) {
243-
continue;
244-
}
245-
auto storeToBufferOp =
246-
dyn_cast<IREE::Codegen::StoreToBufferOp>(foralluser);
247-
if (!storeToBufferOp) {
248-
return failure();
249-
}
250-
MemRefType bufferType = storeToBufferOp.getBuffer().getType();
251-
if (failed(memref::ExpandShapeOp::computeExpandedType(
252-
bufferType, expandedDestShape, reIndices))) {
253-
return failure();
254-
}
255-
}
256-
257-
// This allows us to assume that the extract/inserts in the loop are
258-
// disjoint and makes the application of this pattern safe.
259-
if (!forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>(
260-
forallOp)) {
261-
return failure();
262-
}
263-
264-
// Verify that the users of destination are valid to expand and collect all
265-
// such users.
266-
SmallVector<tensor::ExtractSliceOp> expandableUsers;
267-
if (failed(verifyAndCollectExpandableUsers(
268-
insertDest, collapseOp.getReassociationIndices(), parallelInsertOp,
269-
expandableUsers))) {
270-
return failure();
271-
}
272-
273-
// Expand the users of the destination.
274-
rewriter.setInsertionPointToStart(forallOp.getBody());
275-
expandVerifiedUsers(rewriter, loc, ctx, expandableUsers, totalInnerSizes,
276-
reIndices, forallOp, parallelInsertOp);
277-
rewriter.setInsertionPoint(forallOp);
278-
279-
// This pattern only supports forall ops with single
280-
// output.
281-
SmallVector<Value> forallOutputs(forallOp.getOutputs());
282-
// Create the expand -> new scf.forall -> collapse chain.
283-
auto expandedDestType =
284-
cast<RankedTensorType>(forallOutputs[tiedResultIdx].getType())
285-
.clone(expandedDestShape);
286-
auto expandedDest =
287-
tensor::ExpandShapeOp::create(rewriter, loc, expandedDestType,
288-
forallOutputs[tiedResultIdx], reIndices);
289-
290-
forallOutputs[tiedResultIdx] = expandedDest;
291-
292-
scf::ForallOp newForallOp = scf::ForallOp::create(
293-
rewriter, loc, forallOp.getMixedLowerBound(),
294-
forallOp.getMixedUpperBound(), forallOp.getMixedStep(), forallOutputs,
295-
forallOp.getMappingAttr());
296-
297-
auto collapsedResultOp = tensor::CollapseShapeOp::create(
298-
rewriter, loc,
299-
cast<ShapedType>(forallOp->getResult(tiedResultIdx).getType()),
300-
newForallOp->getResult(tiedResultIdx), reIndices);
301-
302-
// Merge the old scf.forall block which has the expanded users into the new
303-
// scf.forall which has the expanded destination.
304-
SmallVector<Value> argReplacements(newForallOp.getInductionVars());
305-
argReplacements.append(newForallOp.getRegionIterArgs().begin(),
306-
newForallOp.getRegionIterArgs().end());
307-
scf::InParallelOp parallelTerminator = newForallOp.getTerminator();
308-
parallelTerminator->erase();
309-
rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
310-
argReplacements);
311-
312-
// Replaces the uses of the old scf.forall with the new scf.forall
313-
for (int idx = 0; idx < forallOp->getNumResults(); ++idx) {
314-
if (idx == tiedResultIdx) {
315-
forallOp->getResult(idx).replaceAllUsesWith(
316-
collapsedResultOp->getResult(0));
317-
} else {
318-
forallOp->getResult(idx).replaceAllUsesWith(
319-
newForallOp->getResult(idx));
320-
}
321-
}
322-
return success();
323-
}
324-
};
325-
32624
/// This pattern hoists expand_shape & collapse_shape ops out of scf.for loops.
32725
struct ExpandDestinationForOp final : OpRewritePattern<scf::YieldOp> {
32826
using Base::Base;
@@ -551,9 +249,9 @@ void PropagateReshapesByExpansionPass::runOnOperation() {
551249
context);
552250
populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns);
553251
populateFoldTensorReshapeIntoBufferPatterns(bubbleExpandShapePatterns);
252+
populateExpandDestinationForallPatterns(bubbleExpandShapePatterns);
554253
bubbleExpandShapePatterns
555-
.add<ExpandDestinationForallOp, ExpandDestinationForOp,
556-
SwapInnerBitcastWithExtractSlice>(context);
254+
.add<ExpandDestinationForOp, SwapInnerBitcastWithExtractSlice>(context);
557255

558256
if (failed(applyPatternsGreedily(getOperation(),
559257
std::move(bubbleExpandShapePatterns)))) {

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ void transform_dialect::ApplyBubbleExpandPatternsOp::populatePatterns(
204204
patterns, [](OpOperand *) { return true; });
205205
}
206206

207+
void transform_dialect::ApplyCollapseForallDestPatternsOp::populatePatterns(
208+
RewritePatternSet &patterns) {
209+
populateCollapseDestinationForallPatterns(patterns);
210+
}
211+
207212
void transform_dialect::ApplyBubblePackUnpackPatternsOp::populatePatterns(
208213
RewritePatternSet &patterns) {
209214
linalg::populateDataLayoutPropagationPatterns(

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,19 @@ def ApplyBubbleExpandPatternsOp : Op<Transform_Dialect,
4545
let assemblyFormat = "attr-dict";
4646
}
4747

48+
def ApplyCollapseForallDestPatternsOp : Op<Transform_Dialect,
49+
"apply_patterns.iree.collapse_forall_dest",
50+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,
51+
ReportTrackingListenerFailuresOpTrait]> {
52+
let description = [{
53+
Populate patterns to collapse the destination of scf.forall ops by hoisting
54+
expand_shape ops out of the parallel_insert_slice.
55+
}];
56+
57+
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
58+
let assemblyFormat = "attr-dict";
59+
}
60+
4861
def ApplyBubblePackUnpackPatternsOp : Op<Transform_Dialect,
4962
"apply_patterns.iree.bubble_pack_unpack",
5063
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>,

0 commit comments

Comments
 (0)