Skip to content

Commit 400cd2a

Browse files
erwei-xilinxclaude
andcommitted
Address review comments: worklist transitive closure, single early-terminating walk
Two efficiency improvements identified in self-review: 1. collectHerdBodyOpsToKeep: replace fixed-point iteration (copies toKeep set each round) with a worklist algorithm. Each newly-added op is pushed once and processed once, avoiding redundant re-scanning of already-kept ops in subsequent rounds. 2. loopUnrollFullWithAsyncTokenPreserved: replace two sequential walks (one for SegmentOp, one for HerdOp) with a single interruptible walk that stops as soon as the first SegmentOp or HerdOp is found. Avoids walking the entire IR a second time in the common case where a segment is present. No functional change; build and all tests pass (same 2 pre-existing ROCDL failures unrelated to T003). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 0fdae09 commit 400cd2a

1 file changed

Lines changed: 15 additions & 13 deletions

File tree

mlir/lib/Util/Dependency.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -983,16 +983,14 @@ static SmallPtrSet<Operation *, 16> collectHerdBodyOpsToKeep(Block &herdBody) {
983983
toKeep.insert(&op);
984984
}
985985
// Expand: add ops whose results are consumed by any kept op (within block).
986-
bool changed = true;
987-
while (changed) {
988-
changed = false;
989-
for (Operation *op :
990-
SmallVector<Operation *>(toKeep.begin(), toKeep.end())) {
991-
for (Value operand : op->getOperands()) {
992-
if (Operation *defOp = operand.getDefiningOp()) {
993-
if (defOp->getBlock() == &herdBody && toKeep.insert(defOp).second)
994-
changed = true;
995-
}
986+
// Use a worklist to avoid redundant re-processing of already-kept ops.
987+
SmallVector<Operation *> worklist(toKeep.begin(), toKeep.end());
988+
while (!worklist.empty()) {
989+
Operation *op = worklist.pop_back_val();
990+
for (Value operand : op->getOperands()) {
991+
if (Operation *defOp = operand.getDefiningOp()) {
992+
if (defOp->getBlock() == &herdBody && toKeep.insert(defOp).second)
993+
worklist.push_back(defOp);
996994
}
997995
}
998996
}
@@ -1200,9 +1198,13 @@ LogicalResult loopUnrollFullWithAsyncTokenPreserved(
12001198
// For shim-level loops containing segments or herds, use lightweight
12011199
// unrolling to avoid O(N * herd_body_size) deep-clone cost.
12021200
bool containsSegmentOrHerd = false;
1203-
forOp.walk([&](air::SegmentOp) { containsSegmentOrHerd = true; });
1204-
if (!containsSegmentOrHerd)
1205-
forOp.walk([&](air::HerdOp) { containsSegmentOrHerd = true; });
1201+
forOp->walk([&](Operation *op) -> WalkResult {
1202+
if (isa<air::SegmentOp, air::HerdOp>(op)) {
1203+
containsSegmentOrHerd = true;
1204+
return WalkResult::interrupt();
1205+
}
1206+
return WalkResult::advance();
1207+
});
12061208

12071209
if (!annotateFn && containsSegmentOrHerd) {
12081210
if (failed(manuallyUnrollForOpLightweight(forOp)))

0 commit comments

Comments
 (0)