|
2 | 2 |
|
3 | 3 | #include <fstream> |
4 | 4 |
|
| 5 | +#include "mlir/Analysis/DataFlow/LivenessAnalysis.h" |
5 | 6 | #include "mlir/Analysis/SliceAnalysis.h" |
6 | 7 | #include "mlir/Dialect/SCF/IR/SCF.h" |
7 | 8 | #include "mlir/IR/Dominance.h" |
@@ -1241,138 +1242,23 @@ Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op) { |
1241 | 1242 | return newOp; |
1242 | 1243 | } |
1243 | 1244 |
|
1244 | | -namespace { |
1245 | | - |
1246 | | -/// Detect dead arguments in scf.for op by assuming all the values are dead and |
1247 | | -/// propagate liveness property. |
1248 | | -struct ForOpDeadArgElimination : public OpRewritePattern<scf::ForOp> { |
1249 | | - using OpRewritePattern<scf::ForOp>::OpRewritePattern; |
1250 | | - |
1251 | | - LogicalResult matchAndRewrite(scf::ForOp forOp, |
1252 | | - PatternRewriter &rewriter) const final { |
1253 | | - Block &block = *forOp.getBody(); |
1254 | | - auto yieldOp = cast<scf::YieldOp>(block.getTerminator()); |
1255 | | - // Assume that nothing is live at the beginning and mark values as live |
1256 | | - // based on uses. |
1257 | | - DenseSet<Value> aliveValues; |
1258 | | - SmallVector<Value> queue; |
1259 | | - // Helper to mark values as live and add them to the queue of value to |
1260 | | - // propagate if it is the first time we detect the value as live. |
1261 | | - auto markLive = [&](Value val) { |
1262 | | - if (!forOp->isAncestor(val.getParentRegion()->getParentOp())) |
1263 | | - return; |
1264 | | - if (aliveValues.insert(val).second) |
1265 | | - queue.push_back(val); |
1266 | | - }; |
1267 | | - // Mark all yield operands as live if the associated forOp result has any |
1268 | | - // use. |
1269 | | - for (auto result : llvm::enumerate(forOp.getResults())) { |
1270 | | - if (!result.value().use_empty()) |
1271 | | - markLive(yieldOp.getOperand(result.index())); |
1272 | | - } |
1273 | | - if (aliveValues.size() == forOp.getNumResults()) |
1274 | | - return failure(); |
1275 | | - // Operations with side-effects are always live. Mark all theirs operands as |
1276 | | - // live. |
1277 | | - block.walk([&](Operation *op) { |
1278 | | - if (!isa<scf::YieldOp, scf::ForOp>(op) && !wouldOpBeTriviallyDead(op)) { |
1279 | | - for (Value operand : op->getOperands()) |
1280 | | - markLive(operand); |
1281 | | - } |
1282 | | - }); |
1283 | | - // Propagate live property until reaching a fixed point. |
1284 | | - while (!queue.empty()) { |
1285 | | - Value value = queue.pop_back_val(); |
1286 | | - if (auto nestedFor = value.getDefiningOp<scf::ForOp>()) { |
1287 | | - auto result = mlir::cast<OpResult>(value); |
1288 | | - OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); |
1289 | | - markLive(forOperand.get()); |
1290 | | - auto nestedYieldOp = |
1291 | | - cast<scf::YieldOp>(nestedFor.getBody()->getTerminator()); |
1292 | | - Value nestedYieldOperand = |
1293 | | - nestedYieldOp.getOperand(result.getResultNumber()); |
1294 | | - markLive(nestedYieldOperand); |
1295 | | - continue; |
1296 | | - } |
1297 | | - if (auto nestedIf = value.getDefiningOp<scf::IfOp>()) { |
1298 | | - auto result = mlir::cast<OpResult>(value); |
1299 | | - // mark condition as live. |
1300 | | - markLive(nestedIf.getCondition()); |
1301 | | - for (scf::YieldOp nestedYieldOp : |
1302 | | - {nestedIf.thenYield(), nestedIf.elseYield()}) { |
1303 | | - Value nestedYieldOperand = |
1304 | | - nestedYieldOp.getOperand(result.getResultNumber()); |
1305 | | - markLive(nestedYieldOperand); |
1306 | | - } |
1307 | | - continue; |
1308 | | - } |
1309 | | - if (Operation *def = value.getDefiningOp()) { |
1310 | | - // TODO: support while ops. |
1311 | | - if (isa<scf::WhileOp>(def)) |
1312 | | - return failure(); |
1313 | | - for (Value operand : def->getOperands()) |
1314 | | - markLive(operand); |
1315 | | - continue; |
1316 | | - } |
1317 | | - // If an argument block is live then the associated yield operand and |
1318 | | - // forOp operand are live. |
1319 | | - auto arg = mlir::cast<BlockArgument>(value); |
1320 | | - if (auto forOwner = dyn_cast<scf::ForOp>(arg.getOwner()->getParentOp())) { |
1321 | | - if (arg.getArgNumber() < forOwner.getNumInductionVars()) |
1322 | | - continue; |
1323 | | - unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars(); |
1324 | | - Value yieldOperand = |
1325 | | - forOwner.getBody()->getTerminator()->getOperand(iterIdx); |
1326 | | - markLive(yieldOperand); |
1327 | | - markLive(forOwner.getInitArgs()[iterIdx]); |
| 1245 | +void runDeadBlockArgumentElimination(Operation *top) { |
| 1246 | + // The op we are running on must not have any results, because the liveness |
| 1247 | + // analysis will not consider their users. |
| 1248 | + assert(top->hasTrait<OpTrait::ZeroResults>() && "op cannot have results"); |
| 1249 | + dataflow::RunLivenessAnalysis la{top}; |
| 1250 | + |
| 1251 | + // Remove users of dead block arguments in loops by replacing them with their |
| 1252 | + // init values. |
| 1253 | + top->walk([&](Operation *op) { |
| 1254 | + if (auto loopLike = dyn_cast<LoopLikeOpInterface>(op)) { |
| 1255 | + for (auto [idx, arg] : llvm::enumerate(loopLike.getRegionIterArgs())) { |
| 1256 | + const auto *liveness = la.getLiveness(arg); |
| 1257 | + if (liveness && !liveness->isLive) |
| 1258 | + arg.replaceAllUsesWith(loopLike.getInits()[idx]); |
1328 | 1259 | } |
1329 | 1260 | } |
1330 | | - SmallVector<unsigned> deadArg; |
1331 | | - for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) { |
1332 | | - if (aliveValues.contains(yieldOperand.value())) |
1333 | | - continue; |
1334 | | - if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1)) |
1335 | | - continue; |
1336 | | - |
1337 | | - // The yield operand might live outside the loop, e.g. |
1338 | | - // %init = ... |
1339 | | - // %x = ... |
1340 | | - // %y = for iter_args(%unused = %init) { |
1341 | | - // yield %x |
1342 | | - // } |
1343 | | - // |
1344 | | - // In this case, the loop returns %x if it runs 1 or more times, and |
1345 | | - // otherwise it returns %init. We cowardly refuse to remove this operand |
1346 | | - // from the yield. (We could, but we'd need to prove that the loop runs 0 |
1347 | | - // or >=1 times.) |
1348 | | - // |
1349 | | - // As a special case, if it doesn't matter whether the loop runs 0 or >=1 |
1350 | | - // times (because the loop returns the same value in both cases) then we |
1351 | | - // can still mark the operand as dead. This occurs in the above example |
1352 | | - // when %init is the same as %x. |
1353 | | - if (!forOp->isAncestor( |
1354 | | - yieldOperand.value().getParentRegion()->getParentOp()) && |
1355 | | - yieldOperand.value() != forOp.getInitArgs()[yieldOperand.index()]) |
1356 | | - continue; |
1357 | | - |
1358 | | - deadArg.push_back(yieldOperand.index()); |
1359 | | - } |
1360 | | - bool changed = false; |
1361 | | - // For simplicity we just replace users of the block arg with init value and |
1362 | | - // leave the operations and argument removal to dead code elimination. |
1363 | | - for (unsigned deadArgIdx : deadArg) { |
1364 | | - BlockArgument arg = block.getArgument(deadArgIdx + 1); |
1365 | | - changed |= !arg.use_empty(); |
1366 | | - rewriter.replaceAllUsesWith(arg, forOp.getTiedLoopInit(arg)->get()); |
1367 | | - } |
1368 | | - return success(changed); |
1369 | | - } |
1370 | | -}; |
1371 | | - |
1372 | | -} // namespace |
1373 | | - |
1374 | | -void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { |
1375 | | - patterns.add<ForOpDeadArgElimination>(patterns.getContext()); |
| 1261 | + }); |
1376 | 1262 | } |
1377 | 1263 |
|
1378 | 1264 | ttg::LocalAllocOp findShmemAlloc(Value operand) { |
|
0 commit comments