Skip to content

Commit db970a6

Browse files
krzysz00claudeefric
authored
[Ccodegen] Add util.hoistable_conversion, use it to clean up accumulators (#23969)
Firstly, this commit introduces the util.hoistable_conversion operation. This operation lets you wrap a block of operations, giving them a tag and the tag of their inverse. This operation is removed by the new `eliminateHoistableConversions` set of pattern when either 1. A block of conversions flows directly into its inverse or 2. The cancellation is loop-carried: you convert a loop arg, then, right before yielding in the loop, you perform the inverse conversion. (in that case, the conversions are moved to the loop inits and the final results, respectively) 3. If none of these can happen, the operations get inlined. This infrastructure is introduced to allow us to pull the conversions that get placed around inner_tiled intrinsics out of loops, though it may have other future uses. For example, after this change, the interleaves and deinterleaves needed to make RDNA3's f16/bf16 intrinsics behave correctly are pulled out of the loop. Less performance-sensitively, the varous unrollings and shape_casts we introduce on parallel inner_tiled accumulators now get pulled out of their loops, simplifying types and making it clearer that we're dealing with independent values (and getting rid of a bunch of shuffling) which should make it clearer that we're generating independent tiles. These hoistable_conversion ops are currently meant to be introduced and removed within a single pipeline step, given their nature as opaque black boxes that you shouldn't pull operations into or out of. However, if a usecase is found for keeeping these around longer, we might want to leave them around in the future. This PR puts hoistable_conversion ops anywhere they make sense in the back part of inner_tiled lowering, since it was the motivating example, though it can be used in more places than that going forward. Note that we add an `iree-codegen-hoist-inner-tiled-acc-swizzles` pass to be run after vector distribution to fix up the transpose + shape_cast chains that vector distribution inserts, so that we can hoist the results of various unrollings and lowerings to intrinsics later in the pipeline. This is a hack, but doing otherwise seems to require getting deep into the guts of vector distribution or adding special cases for hoistable conversions. --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Eric Feng <55723758+efric@users.noreply.github.com>
1 parent 175fae3 commit db970a6

46 files changed

Lines changed: 1884 additions & 207 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ iree_compiler_cc_library(
114114
"ForallToFor.cpp",
115115
"FuseTensorPadWithConsumer.cpp",
116116
"GenericVectorization.cpp",
117+
"HoistInnerTiledAccReshapes.cpp",
117118
"HoistStaticallyBoundAllocations.cpp",
118119
"HoistUnrolledVectorExtractInsertSlice.cpp",
119120
"IREECodegenCanonicalizer.cpp",
@@ -227,6 +228,7 @@ iree_compiler_cc_library(
227228
"//compiler/src/iree/compiler/Dialect/TensorExt/Transforms",
228229
"//compiler/src/iree/compiler/Dialect/Util/Analysis",
229230
"//compiler/src/iree/compiler/Dialect/Util/IR",
231+
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
230232
"//compiler/src/iree/compiler/Utils",
231233
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
232234
"@llvm-project//llvm:Support",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ iree_cc_library(
107107
"ForallToFor.cpp"
108108
"FuseTensorPadWithConsumer.cpp"
109109
"GenericVectorization.cpp"
110+
"HoistInnerTiledAccReshapes.cpp"
110111
"HoistStaticallyBoundAllocations.cpp"
111112
"HoistUnrolledVectorExtractInsertSlice.cpp"
112113
"IREECodegenCanonicalizer.cpp"
@@ -264,6 +265,7 @@ iree_cc_library(
264265
iree::compiler::Dialect::TensorExt::Transforms
265266
iree::compiler::Dialect::Util::Analysis
266267
iree::compiler::Dialect::Util::IR
268+
iree::compiler::Dialect::Util::Transforms
267269
iree::compiler::Utils
268270
PUBLIC
269271
)
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
9+
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
10+
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
11+
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
12+
#include "mlir/Dialect/UB/IR/UBOps.h"
13+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
14+
#include "mlir/IR/IRMapping.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/Interfaces/LoopLikeInterface.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
19+
namespace mlir::iree_compiler {
20+
21+
#define GEN_PASS_DEF_HOISTINNERTILEDACCRESHAPESPASS
22+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
23+
24+
// Look for operations that reshape vectors to or from the form needed by
25+
// intrinsics, which are hard to hoist from loops up in vector distribute as
26+
// currently architected.
27+
static bool isReshapeOp(Operation *op) {
28+
return isa<vector::TransposeOp, vector::ShapeCastOp, vector::BroadcastOp>(op);
29+
}
30+
31+
static constexpr llvm::StringLiteral kAccReshapeTo = "acc_reshape_to_intrinsic";
32+
static constexpr llvm::StringLiteral kAccReshapeFrom =
33+
"acc_reshape_from_intrinsic";
34+
35+
namespace {
36+
37+
struct WrapAccReshapesPattern final
38+
: OpRewritePattern<IREE::Codegen::InnerTiledOp> {
39+
using Base::Base;
40+
41+
LogicalResult matchAndRewrite(IREE::Codegen::InnerTiledOp tiledOp,
42+
PatternRewriter &rewriter) const override {
43+
auto loopLike = dyn_cast<LoopLikeOpInterface>(tiledOp->getParentOp());
44+
if (!loopLike || loopLike.getRegionIterArgs().empty()) {
45+
return rewriter.notifyMatchFailure(tiledOp,
46+
"not inside a loop with iter_args");
47+
}
48+
49+
bool anyOutputMatched = false;
50+
for (size_t outputIdx = 0, numOutputs = tiledOp.getOutputs().size();
51+
outputIdx < numOutputs; ++outputIdx) {
52+
Value accOperand = tiledOp.getOutputs()[outputIdx];
53+
54+
SmallVector<Operation *> prefixOps;
55+
Value accRoot = accOperand;
56+
while (auto *defOp = accRoot.getDefiningOp()) {
57+
if (!defOp->hasOneUse() || !isReshapeOp(defOp)) {
58+
break;
59+
}
60+
prefixOps.push_back(defOp);
61+
accRoot = defOp->getOperand(0);
62+
}
63+
64+
SmallVector<Operation *> suffixOps;
65+
Value suffixEnd = tiledOp.getResult(outputIdx);
66+
while (suffixEnd.hasOneUse()) {
67+
Operation *user = *suffixEnd.getUsers().begin();
68+
if (!isReshapeOp(user)) {
69+
break;
70+
}
71+
suffixOps.push_back(user);
72+
suffixEnd = user->getResult(0);
73+
}
74+
75+
if (prefixOps.empty() || suffixOps.empty()) {
76+
continue;
77+
}
78+
if (!isa<BlockArgument>(accRoot)) {
79+
continue;
80+
}
81+
82+
rewriter.setInsertionPoint(prefixOps.back());
83+
// Wrap the prefix reshapes (iter_arg -> inner_tiled accumulator shape)
84+
// in a hoistable_conversion so the pair can be hoisted out of the loop.
85+
auto prefixHoist = IREE::Util::HoistableConversionOp::create(
86+
rewriter, tiledOp.getLoc(), /*tag=*/kAccReshapeTo,
87+
/*inverseTag=*/kAccReshapeFrom, accRoot,
88+
[&](OpBuilder &b, Location loc, ValueRange args) {
89+
Value v = args[0];
90+
for (auto *op : llvm::reverse(prefixOps)) {
91+
IRMapping mapping;
92+
mapping.map(op->getOperand(0), v);
93+
v = b.clone(*op, mapping)->getResult(0);
94+
}
95+
return SmallVector<Value>{v};
96+
});
97+
rewriter.replaceAllUsesWith(accOperand, prefixHoist.getResult(0));
98+
for (auto *op : prefixOps) {
99+
if (op->use_empty()) {
100+
rewriter.eraseOp(op);
101+
}
102+
}
103+
104+
Value suffixInput = tiledOp.getResult(outputIdx);
105+
rewriter.setInsertionPointAfter(suffixOps.back());
106+
// Wrap the suffix reshapes (inner_tiled result -> iter_arg shape)
107+
// as the inverse conversion.
108+
auto suffixHoist = IREE::Util::HoistableConversionOp::create(
109+
rewriter, tiledOp.getLoc(), /*tag=*/kAccReshapeFrom,
110+
/*inverseTag=*/kAccReshapeTo, TypeRange{suffixEnd.getType()},
111+
suffixInput, [&](OpBuilder &b, Location loc, ValueRange args) {
112+
Value v = args[0];
113+
for (auto *op : suffixOps) {
114+
IRMapping mapping;
115+
mapping.map(op->getOperand(0), v);
116+
v = b.clone(*op, mapping)->getResult(0);
117+
}
118+
return SmallVector<Value>{v};
119+
});
120+
rewriter.replaceAllUsesWith(suffixEnd, suffixHoist.getResult(0));
121+
for (auto *op : llvm::reverse(suffixOps)) {
122+
if (op->use_empty()) {
123+
rewriter.eraseOp(op);
124+
}
125+
}
126+
127+
anyOutputMatched = true;
128+
}
129+
130+
return success(anyOutputMatched);
131+
}
132+
};
133+
134+
struct HoistInnerTiledAccReshapesPass final
135+
: impl::HoistInnerTiledAccReshapesPassBase<HoistInnerTiledAccReshapesPass> {
136+
void runOnOperation() override {
137+
MLIRContext *context = &getContext();
138+
RewritePatternSet patterns(context);
139+
patterns.add<WrapAccReshapesPattern>(context);
140+
bool changed = false;
141+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
142+
GreedyRewriteConfig(), &changed))) {
143+
return signalPassFailure();
144+
}
145+
146+
if (changed) {
147+
if (failed(IREE::Util::eliminateHoistableConversions(getOperation()))) {
148+
return signalPassFailure();
149+
}
150+
}
151+
}
152+
};
153+
154+
} // namespace
155+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,23 @@ def OptimizeTensorInsertExtractSlicesPass
656656
];
657657
}
658658

659+
def HoistInnerTiledAccReshapesPass :
660+
InterfacePass<"iree-codegen-hoist-inner-tiled-acc-reshapes", "mlir::FunctionOpInterface"> {
661+
let summary = "Hoist vector reshapes surrounding inner_tiled ops out of loops";
662+
let description = [{
663+
This pass, mainly intended for use in vector distribution-based pipelines,
664+
searches for chains of vector reshapes (currently transpose, shape_cast, and
665+
broadcast, but it could be extended to other operations) that bracket an
666+
`iree_codegen.inner_tiled`'s accumulator(s) on the way from and to the
667+
argument of the reduction loop it's in, and adds in `util.hoistable_conversion`
668+
markers to move these operations out of the loop to enable further optimizations.
669+
}];
670+
let dependentDialects = [
671+
"::mlir::iree_compiler::IREE::Util::UtilDialect",
672+
"::mlir::ub::UBDialect",
673+
];
674+
}
675+
659676
def HoistUnrolledVectorExtractInsertSlicePass :
660677
InterfacePass<"iree-codegen-hoist-vector-extract-insert-slice", "mlir::FunctionOpInterface"> {
661678
let summary = "Hoist unrolled vector (extract, insert) pairs out of scf.for op";

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ iree_lit_test_suite(
6969
"generic_vectorization_masked_inferred.mlir",
7070
"generic_vectorization_unmasked.mlir",
7171
"generic_vectorization_using_transfer_gather.mlir",
72+
"hoist_inner_tiled_acc_reshapes.mlir",
7273
"hoist_statically_bound_allocations.mlir",
7374
"hoist_unrolled_vector_extract_insert_slice.mlir",
7475
"insert_batch_dim_for_batchless_conv.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ iree_lit_test_suite(
6464
"generic_vectorization_masked_inferred.mlir"
6565
"generic_vectorization_unmasked.mlir"
6666
"generic_vectorization_using_transfer_gather.mlir"
67+
"hoist_inner_tiled_acc_reshapes.mlir"
6768
"hoist_statically_bound_allocations.mlir"
6869
"hoist_unrolled_vector_extract_insert_slice.mlir"
6970
"insert_batch_dim_for_batchless_conv.mlir"
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-hoist-inner-tiled-acc-reshapes))" %s | FileCheck %s
2+
3+
#contraction_accesses = [
4+
affine_map<(i, j, k) -> (i, k)>,
5+
affine_map<(i, j, k) -> (k, j)>,
6+
affine_map<(i, j, k) -> (i, j)>
7+
]
8+
9+
// CHECK-LABEL: @hoist_shape_cast_chain
10+
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: vector<2x2x1x1x4x1xf32>
11+
// CHECK-DAG: %[[POISON:.+]] = ub.poison : vector<2x2x1x1x4x1xf32>
12+
// CHECK-DAG: %[[SC0:.+]] = vector.shape_cast %[[INIT]]
13+
// CHECK: %[[LOOP:.+]]:2 = scf.for {{.*}} iter_args(%[[DEADACC:.*]] = %[[POISON]], %[[ACC:.*]] = %[[SC0]])
14+
// CHECK: %[[OUT:.+]] = iree_codegen.inner_tiled {{.*}} outs(%[[ACC]])
15+
// CHECK: scf.yield %[[DEADACC]], %[[OUT]]
16+
// CHECK: vector.shape_cast %[[LOOP]]#1
17+
// CHECK-NOT: util.hoistable_conversion
18+
func.func @hoist_shape_cast_chain(
19+
%lhs: vector<2x2x4xf16>, %rhs: vector<2x2x4xf16>,
20+
%init: vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32> {
21+
%c0 = arith.constant 0 : index
22+
%c1 = arith.constant 1 : index
23+
%c10 = arith.constant 10 : index
24+
%result = scf.for %iv = %c0 to %c10 step %c1 iter_args(%acc = %init) -> vector<2x2x1x1x4x1xf32> {
25+
%inner_acc = vector.shape_cast %acc : vector<2x2x1x1x4x1xf32> to vector<2x2x4x1xf32>
26+
%mma = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%inner_acc) {
27+
indexing_maps = #contraction_accesses,
28+
iterator_types = [#linalg.iterator_type<parallel>,
29+
#linalg.iterator_type<parallel>,
30+
#linalg.iterator_type<reduction>],
31+
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
32+
semantics = #iree_gpu.mma_semantics<distributed = true, opaque = false>
33+
} : vector<2x2x4xf16>, vector<2x2x4xf16> into vector<2x2x4x1xf32>
34+
%back = vector.shape_cast %mma : vector<2x2x4x1xf32> to vector<2x2x1x1x4x1xf32>
35+
scf.yield %back : vector<2x2x1x1x4x1xf32>
36+
}
37+
return %result : vector<2x2x1x1x4x1xf32>
38+
}
39+
40+
// -----
41+
42+
#contraction_accesses2 = [
43+
affine_map<(i, j, k) -> (i, k)>,
44+
affine_map<(i, j, k) -> (k, j)>,
45+
affine_map<(i, j, k) -> (i, j)>
46+
]
47+
48+
// CHECK-LABEL: @no_reshape
49+
// CHECK-NOT: util.hoistable_conversion
50+
// CHECK-NOT: vector.shape_cast
51+
func.func @no_reshape(
52+
%lhs: vector<2x2x4xf16>, %rhs: vector<2x2x4xf16>,
53+
%init: vector<2x2x4x1xf32>) -> vector<2x2x4x1xf32> {
54+
%c0 = arith.constant 0 : index
55+
%c1 = arith.constant 1 : index
56+
%c10 = arith.constant 10 : index
57+
%result = scf.for %iv = %c0 to %c10 step %c1 iter_args(%acc = %init) -> vector<2x2x4x1xf32> {
58+
%mma = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%acc) {
59+
indexing_maps = #contraction_accesses2,
60+
iterator_types = [#linalg.iterator_type<parallel>,
61+
#linalg.iterator_type<parallel>,
62+
#linalg.iterator_type<reduction>],
63+
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
64+
semantics = #iree_gpu.mma_semantics<distributed = true, opaque = false>
65+
} : vector<2x2x4xf16>, vector<2x2x4xf16> into vector<2x2x4x1xf32>
66+
scf.yield %mma : vector<2x2x4x1xf32>
67+
}
68+
return %result : vector<2x2x4x1xf32>
69+
}

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ iree_compiler_cc_library(
9797
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
9898
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
9999
"//compiler/src/iree/compiler/Dialect/TensorExt/IR",
100+
"//compiler/src/iree/compiler/Dialect/Util/IR",
100101
"//compiler/src/iree/compiler/Utils",
101102
"@llvm-project//llvm:Support",
102103
"@llvm-project//mlir:AMDGPUDialect",

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ iree_cc_library(
7979
iree::compiler::Dialect::LinalgExt::IR
8080
iree::compiler::Dialect::LinalgExt::Utils
8181
iree::compiler::Dialect::TensorExt::IR
82+
iree::compiler::Dialect::Util::IR
8283
iree::compiler::Utils
8384
iree::compiler::bindings::c::headers
8485
PUBLIC

0 commit comments

Comments
 (0)