Skip to content

Commit c47f2c9

Browse files
committed
Use LivenessAnalysis for dead block arg elimination
1 parent 552306f commit c47f2c9

7 files changed

Lines changed: 54 additions & 137 deletions

File tree

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,9 @@ LogicalResult getConvertBackwardSlice(
189189
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
190190
nullptr);
191191

192-
// Populate pattern to remove dead cycles in ForOp.
193-
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);
192+
/// Run a dataflow analysis over \p top to identify block arguments to loops
193+
/// that are dead, and replace their usage with the corresponding init value.
194+
void runDeadBlockArgumentElimination(Operation *top);
194195

195196
// Convert an \param index to a multi-dim coordinate given \param shape and
196197
// \param order.

lib/Dialect/Gluon/Transforms/Canonicalize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct Canonicalize : public gluon::impl::GluonCanonicalizeBase<Canonicalize> {
2929
} // namespace
3030

3131
void Canonicalize::runOnOperation() {
32+
runDeadBlockArgumentElimination(getOperation());
3233
MLIRContext *ctx = &getContext();
3334
RewritePatternSet patterns(&getContext());
3435

@@ -48,7 +49,6 @@ void Canonicalize::runOnOperation() {
4849
for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect(
4950
cf::ControlFlowDialect::getDialectNamespace()))
5051
op.getCanonicalizationPatterns(patterns, ctx);
51-
populateForOpDeadArgumentElimination(patterns);
5252

5353
// Populate select Triton canonicalization patterns. The important patterns to
5454
// EXCLUDE are those that modify layouts, especially `ConvertLayoutOp`

lib/Dialect/Gluon/Transforms/SimplifyControlFlow.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct SimplifyControlFlow
2525
} // namespace
2626

2727
void SimplifyControlFlow::runOnOperation() {
28+
runDeadBlockArgumentElimination(getOperation());
2829
MLIRContext *ctx = &getContext();
2930
RewritePatternSet patterns(&getContext());
3031

@@ -39,7 +40,6 @@ void SimplifyControlFlow::runOnOperation() {
3940
for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect(
4041
cf::ControlFlowDialect::getDialectNamespace()))
4142
op.getCanonicalizationPatterns(patterns, ctx);
42-
populateForOpDeadArgumentElimination(patterns);
4343

4444
GreedyRewriteConfig config;
4545
// This is intended to run before AutoLayouts are resolved, in which case

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,10 +1685,12 @@ class TritonGPURemoveLayoutConversionsPass
16851685
m.dump();
16861686
});
16871687

1688-
// 4. Apply clean up patterns to remove remove dead convert and dead code
1688+
// 4. Run dataflow analysis to identify dead code before cleanup.
1689+
runDeadBlockArgumentElimination(m);
1690+
1691+
// 5. Apply clean up patterns to remove remove dead convert and dead code
16891692
// generated by the previous transformations.
16901693
RewritePatternSet cleanUpPatterns2(context);
1691-
populateForOpDeadArgumentElimination(cleanUpPatterns2);
16921694
scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context);
16931695
scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns2, context);
16941696
ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 16 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <fstream>
44

5+
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
56
#include "mlir/Analysis/SliceAnalysis.h"
67
#include "mlir/Dialect/SCF/IR/SCF.h"
78
#include "mlir/IR/Dominance.h"
@@ -1241,138 +1242,23 @@ Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op) {
12411242
return newOp;
12421243
}
12431244

1244-
namespace {
1245-
1246-
/// Detect dead arguments in scf.for op by assuming all the values are dead and
1247-
/// propagate liveness property.
1248-
struct ForOpDeadArgElimination : public OpRewritePattern<scf::ForOp> {
1249-
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
1250-
1251-
LogicalResult matchAndRewrite(scf::ForOp forOp,
1252-
PatternRewriter &rewriter) const final {
1253-
Block &block = *forOp.getBody();
1254-
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
1255-
// Assume that nothing is live at the beginning and mark values as live
1256-
// based on uses.
1257-
DenseSet<Value> aliveValues;
1258-
SmallVector<Value> queue;
1259-
// Helper to mark values as live and add them to the queue of value to
1260-
// propagate if it is the first time we detect the value as live.
1261-
auto markLive = [&](Value val) {
1262-
if (!forOp->isAncestor(val.getParentRegion()->getParentOp()))
1263-
return;
1264-
if (aliveValues.insert(val).second)
1265-
queue.push_back(val);
1266-
};
1267-
// Mark all yield operands as live if the associated forOp result has any
1268-
// use.
1269-
for (auto result : llvm::enumerate(forOp.getResults())) {
1270-
if (!result.value().use_empty())
1271-
markLive(yieldOp.getOperand(result.index()));
1272-
}
1273-
if (aliveValues.size() == forOp.getNumResults())
1274-
return failure();
1275-
// Operations with side-effects are always live. Mark all theirs operands as
1276-
// live.
1277-
block.walk([&](Operation *op) {
1278-
if (!isa<scf::YieldOp, scf::ForOp>(op) && !wouldOpBeTriviallyDead(op)) {
1279-
for (Value operand : op->getOperands())
1280-
markLive(operand);
1281-
}
1282-
});
1283-
// Propagate live property until reaching a fixed point.
1284-
while (!queue.empty()) {
1285-
Value value = queue.pop_back_val();
1286-
if (auto nestedFor = value.getDefiningOp<scf::ForOp>()) {
1287-
auto result = mlir::cast<OpResult>(value);
1288-
OpOperand &forOperand = *nestedFor.getTiedLoopInit(result);
1289-
markLive(forOperand.get());
1290-
auto nestedYieldOp =
1291-
cast<scf::YieldOp>(nestedFor.getBody()->getTerminator());
1292-
Value nestedYieldOperand =
1293-
nestedYieldOp.getOperand(result.getResultNumber());
1294-
markLive(nestedYieldOperand);
1295-
continue;
1296-
}
1297-
if (auto nestedIf = value.getDefiningOp<scf::IfOp>()) {
1298-
auto result = mlir::cast<OpResult>(value);
1299-
// mark condition as live.
1300-
markLive(nestedIf.getCondition());
1301-
for (scf::YieldOp nestedYieldOp :
1302-
{nestedIf.thenYield(), nestedIf.elseYield()}) {
1303-
Value nestedYieldOperand =
1304-
nestedYieldOp.getOperand(result.getResultNumber());
1305-
markLive(nestedYieldOperand);
1306-
}
1307-
continue;
1308-
}
1309-
if (Operation *def = value.getDefiningOp()) {
1310-
// TODO: support while ops.
1311-
if (isa<scf::WhileOp>(def))
1312-
return failure();
1313-
for (Value operand : def->getOperands())
1314-
markLive(operand);
1315-
continue;
1316-
}
1317-
// If an argument block is live then the associated yield operand and
1318-
// forOp operand are live.
1319-
auto arg = mlir::cast<BlockArgument>(value);
1320-
if (auto forOwner = dyn_cast<scf::ForOp>(arg.getOwner()->getParentOp())) {
1321-
if (arg.getArgNumber() < forOwner.getNumInductionVars())
1322-
continue;
1323-
unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars();
1324-
Value yieldOperand =
1325-
forOwner.getBody()->getTerminator()->getOperand(iterIdx);
1326-
markLive(yieldOperand);
1327-
markLive(forOwner.getInitArgs()[iterIdx]);
1245+
void runDeadBlockArgumentElimination(Operation *top) {
1246+
// The op we are running on must not have any results, because the liveness
1247+
// analysis will not consider their users.
1248+
assert(top->hasTrait<OpTrait::ZeroResults>() && "op cannot have results");
1249+
dataflow::RunLivenessAnalysis la{top};
1250+
1251+
// Remove users of dead block arguments in loops by replacing them with their
1252+
// init values.
1253+
top->walk([&](Operation *op) {
1254+
if (auto loopLike = dyn_cast<LoopLikeOpInterface>(op)) {
1255+
for (auto [idx, arg] : llvm::enumerate(loopLike.getRegionIterArgs())) {
1256+
const auto *liveness = la.getLiveness(arg);
1257+
if (liveness && !liveness->isLive)
1258+
arg.replaceAllUsesWith(loopLike.getInits()[idx]);
13281259
}
13291260
}
1330-
SmallVector<unsigned> deadArg;
1331-
for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) {
1332-
if (aliveValues.contains(yieldOperand.value()))
1333-
continue;
1334-
if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1))
1335-
continue;
1336-
1337-
// The yield operand might live outside the loop, e.g.
1338-
// %init = ...
1339-
// %x = ...
1340-
// %y = for iter_args(%unused = %init) {
1341-
// yield %x
1342-
// }
1343-
//
1344-
// In this case, the loop returns %x if it runs 1 or more times, and
1345-
// otherwise it returns %init. We cowardly refuse to remove this operand
1346-
// from the yield. (We could, but we'd need to prove that the loop runs 0
1347-
// or >=1 times.)
1348-
//
1349-
// As a special case, if it doesn't matter whether the loop runs 0 or >=1
1350-
// times (because the loop returns the same value in both cases) then we
1351-
// can still mark the operand as dead. This occurs in the above example
1352-
// when %init is the same as %x.
1353-
if (!forOp->isAncestor(
1354-
yieldOperand.value().getParentRegion()->getParentOp()) &&
1355-
yieldOperand.value() != forOp.getInitArgs()[yieldOperand.index()])
1356-
continue;
1357-
1358-
deadArg.push_back(yieldOperand.index());
1359-
}
1360-
bool changed = false;
1361-
// For simplicity we just replace users of the block arg with init value and
1362-
// leave the operations and argument removal to dead code elimination.
1363-
for (unsigned deadArgIdx : deadArg) {
1364-
BlockArgument arg = block.getArgument(deadArgIdx + 1);
1365-
changed |= !arg.use_empty();
1366-
rewriter.replaceAllUsesWith(arg, forOp.getTiedLoopInit(arg)->get());
1367-
}
1368-
return success(changed);
1369-
}
1370-
};
1371-
1372-
} // namespace
1373-
1374-
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) {
1375-
patterns.add<ForOpDeadArgElimination>(patterns.getContext());
1261+
});
13761262
}
13771263

13781264
ttg::LocalAllocOp findShmemAlloc(Value operand) {

test/TritonGPU/combine.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4056,3 +4056,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
40564056
tt.return %3 : tensor<32xf32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 0, parent = #blocked1}>}>>
40574057
}
40584058
}
4059+
4060+
// -----
4061+
4062+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4063+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
4064+
// CHECK-LABEL: @for_arg_used_in_nested_for_bound
4065+
tt.func public @for_arg_used_in_nested_for_bound(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
4066+
%c0_i32 = arith.constant 0 : i32
4067+
%c1_i32 = arith.constant 1 : i32
4068+
%c128_i32 = arith.constant 128 : i32
4069+
%c256_i32 = arith.constant 256 : i32
4070+
// CHECK: scf.for
4071+
// CHECK-SAME: iter_args(%[[OUTER_ARG0:.*]] = %arg0, %[[OUTER_ARG1:.*]] = %c0_i32)
4072+
%0:2 = scf.for %iv = %c0_i32 to %c128_i32 step %c1_i32 iter_args(%outer_arg0 = %arg0, %outer_arg1 = %c0_i32) -> (i32, i32) : i32 {
4073+
%1 = arith.remsi %outer_arg1, %arg2 : i32
4074+
%2 = arith.addi %1, %c1_i32 : i32
4075+
%3 = arith.muli %2, %c256_i32 : i32
4076+
// CHECK: scf.for %[[INNER_IV:.*]] = %c0_i32 to %3 step %c1_i32
4077+
%inner_result = scf.for %inner_iv = %c0_i32 to %3 step %c1_i32 iter_args(%inner_acc = %outer_arg0) -> (i32) : i32 {
4078+
%sum = arith.addi %inner_acc, %inner_iv : i32
4079+
scf.yield %sum : i32
4080+
}
4081+
%4 = arith.addi %outer_arg0, %c1_i32 : i32
4082+
scf.yield %inner_result, %4 : i32, i32
4083+
}
4084+
tt.return %0#0 : i32
4085+
}
4086+
}

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1299,8 +1299,8 @@ static bool doDeepCleanup(triton::FuncOp &funcOp,
12991299
});
13001300

13011301
// delete block arguments
1302+
runDeadBlockArgumentElimination(funcOp);
13021303
RewritePatternSet cleanUpPatterns(funcOp.getContext());
1303-
populateForOpDeadArgumentElimination(cleanUpPatterns);
13041304
scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns,
13051305
funcOp.getContext());
13061306
scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns,

0 commit comments

Comments
 (0)