Skip to content

Commit f56b871

Browse files
erwei-xilinxclaude
andcommitted
Strip segment/herd bodies during shim DMA BD loop unrolling
Add loopUnrollFullLightweight() which strips air.segment and air.herd bodies from cloned iterations during loop unrolling in the shim DMA BD pass. Iteration 0 keeps full bodies (needed by air-to-std channel matching), while iterations 1..N get empty shells with just the terminator. This is safe because: - BD folding only operates on L3 channel ops, never segment/herd bodies - ChannelCounterpartCache walks the entire module and needs only the first get/put per channel name (iteration 0 provides this) - removeDeadDeviceComputeOps in airrt-to-npu strips these bodies anyway The NPU instruction binary is byte-for-byte identical to the original. Profiling on flash attention (12 heads, 1024 LQ/LK, NPU1): Total aircc: 6,400 ms -> 1,496 ms (4.3x faster) Verified on local NPU1 hardware: test/xrt/12_matmul_transform_1x4_bf16 - PASS test/xrt/17_gemm_8x16_transform_vec_4x4 - PASS test/xrt/25_batch_matmul_bf16 - PASS air.insts.bin: identical for all 3 tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 221ae39 commit f56b871

3 files changed

Lines changed: 135 additions & 2 deletions

File tree

mlir/include/air/Util/Dependency.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter,
7979
LogicalResult loopUnrollFullWithAsyncTokenPreserved(
8080
scf::ForOp forOp,
8181
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
82+
// Lightweight variant that clones empty shells for SegmentOp/HerdOp on
83+
// iterations 1..N, avoiding O(N*body_size) IR explosion from deep-cloning
84+
// hierarchy bodies that shim DMA BD folding never touches.
85+
LogicalResult loopUnrollFullLightweight(scf::ForOp forOp);
8286

8387
// Unrolls an `scf.for` loop by a given factor while preserving async token
8488
// dependencies.

mlir/lib/Transform/AIRDependencyScheduleOpt.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6373,9 +6373,10 @@ class AIROptimizeShimDMABDs
63736373
// Canonicalize IR to make loop bounds explicitly static.
63746374
applyCanonicalizationPatterns(ctx, func.getBody());
63756375

6376-
// Unroll outer scf.for loop nest.
6376+
// Unroll outer scf.for loop nest. Use lightweight clone to avoid
6377+
// deep-copying segment/herd bodies that BD folding never touches.
63776378
for (auto scfFor : forLoopsToUnroll) {
6378-
if (failed(air::loopUnrollFullWithAsyncTokenPreserved(scfFor)))
6379+
if (failed(air::loopUnrollFullLightweight(scfFor)))
63796380
signalPassFailure();
63806381
}
63816382
// Canonicalize IR to make loop bounds explicitly static.

mlir/lib/Util/Dependency.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,134 @@ LogicalResult loopUnrollFullWithAsyncTokenPreserved(
10061006
return success();
10071007
}
10081008

1009+
// Lightweight unroll: clones empty shells for SegmentOp/HerdOp on iterations
1010+
// 1..N-1, keeping the full body only for iteration 0. This avoids the
1011+
// O(N * body_size) IR explosion from deep-cloning segment/herd bodies that
1012+
// BD folding never touches. The downstream air-to-std channel matching only
1013+
// needs one copy of the segment body (via ChannelCounterpartCache which
1014+
// stores the first get/put per channel name).
1015+
LogicalResult loopUnrollFullLightweight(scf::ForOp forOp) {
1016+
auto tripCount = air::getStaticScfForTripCountAsInt(forOp);
1017+
if (!tripCount) {
1018+
// Dynamic loop bound — fall back to standard unroll.
1019+
return loopUnrollFullWithAsyncTokenPreserved(forOp);
1020+
}
1021+
1022+
// Pre-processing: label yield-defining ops for async token fixup.
1023+
labelYieldDefiningOpsOfForLoop(forOp, "scf.for_result_id");
1024+
Block *parentBlock = forOp->getBlock();
1025+
1026+
// Trivial case: single iteration — just promote.
1027+
if (*tripCount == 1) {
1028+
IRRewriter rewriter(forOp.getContext());
1029+
(void)forOp.promoteIfSingleIteration(rewriter);
1030+
preserveAsyncDependenciesAfterUnroll(*parentBlock);
1031+
return success();
1032+
}
1033+
1034+
Block *body = forOp.getBody();
1035+
OpBuilder builder(forOp.getContext());
1036+
builder.setInsertionPoint(body->getTerminator());
1037+
1038+
// Identify the last op before the yield (the "end" of the body to clone).
1039+
Block::iterator srcBlockEnd = std::prev(body->end(), 2);
1040+
1041+
Value iv = forOp.getInductionVar();
1042+
Value lb = forOp.getLowerBound();
1043+
Value step = forOp.getStep();
1044+
auto loc = forOp.getLoc();
1045+
1046+
// Track yielded values for iter_arg remapping across iterations.
1047+
SmallVector<Value> lastYielded(
1048+
forOp.getBody()->getTerminator()->getOperands());
1049+
1050+
// Clone iterations 1..N-1 into the body (iteration 0 is the original).
1051+
for (unsigned i = 1; i < *tripCount; i++) {
1052+
IRMapping operandMap;
1053+
1054+
// Map iter_args to the previous iteration's yielded values.
1055+
for (auto [iterArg, yielded] :
1056+
llvm::zip(forOp.getRegionIterArgs(), lastYielded))
1057+
operandMap.map(iterArg, yielded);
1058+
1059+
// Map IV to lb + i * step.
1060+
if (!iv.use_empty()) {
1061+
Value iterConst = arith::ConstantIndexOp::create(builder, loc, i);
1062+
Value offset = arith::MulIOp::create(builder, loc, step, iterConst);
1063+
Value newIV = arith::AddIOp::create(builder, loc, lb, offset);
1064+
operandMap.map(iv, newIV);
1065+
}
1066+
1067+
// Clone each op in the body. After cloning, strip segment/herd bodies
1068+
// in the cloned copy (they are deep inside inner loops but BD folding
1069+
// never touches them — only L3 channel ops at the launch level matter).
1070+
for (auto it = body->begin(); it != std::next(srcBlockEnd); it++) {
1071+
Operation *clonedOp = builder.clone(*it, operandMap);
1072+
1073+
// Walk the cloned op and strip segment/herd bodies, but only those
1074+
// inside a dummyLaunch (from AIRLaunchToScfForPattern). Standalone
1075+
// herds at the function level must keep their bodies.
1076+
clonedOp->walk([&](Operation *nested) {
1077+
if (!isa<air::SegmentOp, air::HerdOp>(nested))
1078+
return;
1079+
auto parentLaunch = nested->getParentOfType<air::LaunchOp>();
1080+
if (!parentLaunch || !parentLaunch->hasAttr("dummyLaunch"))
1081+
return;
1082+
for (Region &region : nested->getRegions()) {
1083+
if (region.empty())
1084+
continue;
1085+
Block &blk = region.front();
1086+
// Erase all ops except the terminator.
1087+
while (blk.getOperations().size() > 1) {
1088+
Operation &op = *std::prev(blk.end(), 2);
1089+
op.dropAllUses();
1090+
op.erase();
1091+
}
1092+
}
1093+
});
1094+
}
1095+
1096+
// Update yielded values for the next iteration.
1097+
auto *yield = forOp.getBody()->getTerminator();
1098+
for (unsigned j = 0; j < lastYielded.size(); j++)
1099+
lastYielded[j] = operandMap.lookupOrDefault(yield->getOperand(j));
1100+
}
1101+
1102+
// Replace the loop's results with the last iteration's yielded values,
1103+
// then inline the body into the parent block and erase the loop.
1104+
{
1105+
IRRewriter rewriter(forOp.getContext());
1106+
1107+
// Replace IV uses in iteration 0 with lb (the first iteration's IV value).
1108+
if (!iv.use_empty())
1109+
rewriter.replaceAllUsesWith(iv, lb);
1110+
1111+
// Replace iter_arg uses in iteration 0 with init values.
1112+
for (auto [iterArg, initVal] :
1113+
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitArgs()))
1114+
rewriter.replaceAllUsesWith(iterArg, initVal);
1115+
1116+
// Replace loop results with the last iteration's yielded values.
1117+
for (auto [loopResult, yielded] :
1118+
llvm::zip(forOp.getResults(), lastYielded))
1119+
rewriter.replaceAllUsesWith(loopResult, yielded);
1120+
1121+
// Move all body ops (except the yield) before the loop in the parent.
1122+
Block *parentBlock2 = forOp->getBlock();
1123+
auto &bodyOps = body->getOperations();
1124+
auto yieldIt = std::prev(bodyOps.end()); // the yield terminator
1125+
parentBlock2->getOperations().splice(Block::iterator(forOp), bodyOps,
1126+
bodyOps.begin(), yieldIt);
1127+
1128+
// Erase the now-empty loop.
1129+
rewriter.eraseOp(forOp);
1130+
}
1131+
1132+
// Post-processing: reconnect async token dependencies.
1133+
preserveAsyncDependenciesAfterUnroll(*parentBlock);
1134+
return success();
1135+
}
1136+
10091137
// Unrolls an `scf.for` loop by a given factor while preserving async token
10101138
// dependencies.
10111139
//

0 commit comments

Comments
 (0)