Skip to content

Commit 0fdae09

Browse files
erwei-xilinxclaude
andcommitted
Avoid O(N*body_size) deep-cloning herd bodies during shim loop unroll
For shim-level scf.for loops that contain air.SegmentOp / air.HerdOp, replace the standard loopUnrollFull (which deep-clones the entire body N times) with a manual unroller that uses lightweight herd cloning. The lightweight herd clone (cloneHerdOpLightweight) creates a new herd shell via OperationState and populates the body with ONLY channel ops, allocs, deallocations, wait_alls, and their transitive operand-defining ops. Heavy compute ops (matrix multiply, vector ops, etc.) are never cloned. The segment clone (cloneSegmentOpLightweight) preserves the full segment body (L3 channel ops needed by BD folding) but applies lightweight cloning recursively to any contained air.HerdOp. This brings flash attention 12x4 (tiles=2,2) from crashing (T002) or taking O(N*body_size) time back to ~50ms, matching the main-branch baseline. Also adds a lit test (func_lightweight_unroll) that verifies channel ops are preserved and compute ops are excluded from the lightweight herd copies. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 73b10c5 commit 0fdae09

2 files changed

Lines changed: 271 additions & 0 deletions

File tree

mlir/lib/Util/Dependency.cpp

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88

99
#include "air/Util/Dependency.h"
1010
#include "air/Util/Util.h"
11+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1112
#include "mlir/Dialect/SCF/Utils/Utils.h"
1213
#include "mlir/IR/Iterators.h"
14+
#include "mlir/IR/OperationSupport.h"
1315
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
#include "llvm/ADT/SmallPtrSet.h"
1417
#include "llvm/ADT/SmallSet.h"
1518

1619
#define DEBUG_TYPE "air-dependency-util"
@@ -965,6 +968,209 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) {
965968
}
966969
}
967970

971+
// Returns the set of ops in herdBody that must be kept when performing a
972+
// lightweight herd clone: channel ops, allocs, deallocations, wait_alls, the
973+
// terminator, and any ops that transitively define operands consumed by those.
974+
static SmallPtrSet<Operation *, 16> collectHerdBodyOpsToKeep(Block &herdBody) {
975+
SmallPtrSet<Operation *, 16> toKeep;
976+
// Always keep the block terminator.
977+
if (!herdBody.empty())
978+
toKeep.insert(herdBody.getTerminator());
979+
// Seed: channel ops, allocs, deallocations, and wait_alls.
980+
for (Operation &op : herdBody.without_terminator()) {
981+
if (isa<air::ChannelInterface, memref::AllocOp, memref::DeallocOp,
982+
air::WaitAllOp>(op))
983+
toKeep.insert(&op);
984+
}
985+
// Expand: add ops whose results are consumed by any kept op (within block).
986+
bool changed = true;
987+
while (changed) {
988+
changed = false;
989+
for (Operation *op :
990+
SmallVector<Operation *>(toKeep.begin(), toKeep.end())) {
991+
for (Value operand : op->getOperands()) {
992+
if (Operation *defOp = operand.getDefiningOp()) {
993+
if (defOp->getBlock() == &herdBody && toKeep.insert(defOp).second)
994+
changed = true;
995+
}
996+
}
997+
}
998+
}
999+
return toKeep;
1000+
}
1001+
1002+
// Lightweight clone of air.HerdOp: creates a new herd shell via OperationState
1003+
// with the same operands (remapped through mapper), result types, and
1004+
// attributes, but populates the body with ONLY channel ops, allocs,
1005+
// deallocations, wait_alls, their transitive operand-defining ops, and the
1006+
// terminator. Heavy compute ops whose results do not feed any of the above are
1007+
// skipped entirely. Updates mapper with result mappings for the new herd.
1008+
static void cloneHerdOpLightweight(OpBuilder &builder, air::HerdOp herdOp,
1009+
IRMapping &mapper) {
1010+
Block &origBody = herdOp.getBody().front();
1011+
SmallPtrSet<Operation *, 16> toKeep = collectHerdBodyOpsToKeep(origBody);
1012+
1013+
// Map operands through the outer mapper.
1014+
SmallVector<Value> mappedOperands;
1015+
for (Value v : herdOp->getOperands())
1016+
mappedOperands.push_back(mapper.lookupOrDefault(v));
1017+
1018+
// Create the herd shell via OperationState (no body yet).
1019+
OperationState state(herdOp.getLoc(), air::HerdOp::getOperationName());
1020+
state.addOperands(mappedOperands);
1021+
state.addTypes(herdOp->getResultTypes());
1022+
state.addAttributes(herdOp->getAttrs());
1023+
state.addRegion(); // Placeholder for the body region.
1024+
Operation *newHerd = builder.create(state);
1025+
1026+
// Map original herd results -> new herd results.
1027+
for (auto [origRes, newRes] :
1028+
llvm::zip(herdOp->getResults(), newHerd->getResults()))
1029+
mapper.map(origRes, newRes);
1030+
1031+
// Create the body block with the same block-argument types as the original.
1032+
Block *newBody = new Block();
1033+
newHerd->getRegion(0).push_back(newBody);
1034+
IRMapping innerMapper;
1035+
for (BlockArgument origArg : origBody.getArguments()) {
1036+
BlockArgument newArg =
1037+
newBody->addArgument(origArg.getType(), origArg.getLoc());
1038+
innerMapper.map(origArg, newArg);
1039+
}
1040+
1041+
// Clone only the kept ops in original block order to preserve use-def.
1042+
OpBuilder bodyBuilder(newBody, newBody->end());
1043+
for (Operation &op : origBody) {
1044+
if (toKeep.contains(&op))
1045+
bodyBuilder.clone(op, innerMapper);
1046+
}
1047+
}
1048+
1049+
// Forward declaration for mutual recursion.
1050+
static void cloneBlockBodyLightweight(OpBuilder &builder, Block &srcBlock,
1051+
IRMapping &mapper);
1052+
1053+
// Lightweight clone of air.SegmentOp: clones the segment shell and fully
1054+
// clones all ops directly in the segment body (including L3 channel ops needed
1055+
// by BD folding), but substitutes lightweight copies for any contained
1056+
// air.HerdOp. Updates mapper with segment result mappings.
1057+
static void cloneSegmentOpLightweight(OpBuilder &builder, air::SegmentOp segOp,
1058+
IRMapping &mapper) {
1059+
// Clone segment shell (without regions); this maps operands and results.
1060+
Operation *newSeg = segOp->cloneWithoutRegions(mapper);
1061+
builder.insert(newSeg);
1062+
1063+
// Create and populate the segment body block.
1064+
Block &origBody = segOp.getBody().front();
1065+
Block *newBody = new Block();
1066+
newSeg->getRegion(0).push_back(newBody);
1067+
1068+
// Map block arguments (segment IDs, sizes, segment_operands).
1069+
IRMapping innerMapper;
1070+
for (BlockArgument origArg : origBody.getArguments()) {
1071+
BlockArgument newArg =
1072+
newBody->addArgument(origArg.getType(), origArg.getLoc());
1073+
innerMapper.map(origArg, newArg);
1074+
}
1075+
1076+
// Clone segment body, using lightweight cloning for any air.HerdOp.
1077+
OpBuilder bodyBuilder(newBody, newBody->end());
1078+
cloneBlockBodyLightweight(bodyBuilder, origBody, innerMapper);
1079+
}
1080+
1081+
// Clone all ops in srcBlock into the current builder insertion point.
1082+
// HerdOps are cloned lightweight (body stripped to channel ops + deps);
1083+
// SegmentOps are cloned with lightweight recursion into their bodies;
1084+
// all other ops are cloned fully. mapper is updated with result mappings.
1085+
static void cloneBlockBodyLightweight(OpBuilder &builder, Block &srcBlock,
1086+
IRMapping &mapper) {
1087+
for (Operation &op : srcBlock) {
1088+
if (auto herdOp = dyn_cast<air::HerdOp>(&op)) {
1089+
cloneHerdOpLightweight(builder, herdOp, mapper);
1090+
} else if (auto segOp = dyn_cast<air::SegmentOp>(&op)) {
1091+
cloneSegmentOpLightweight(builder, segOp, mapper);
1092+
} else {
1093+
builder.clone(op, mapper);
1094+
}
1095+
}
1096+
}
1097+
1098+
// Manually unroll a shim-level scf.for that contains air.SegmentOp or
1099+
// air.HerdOp, using lightweight cloning for herd bodies to avoid
1100+
// O(N * body_size) IR explosion. Inserts N copies of the loop body before the
1101+
// forOp, replaces forOp results with the final iteration's yields, and erases
1102+
// the forOp. Returns failure if bounds are not statically known.
1103+
static LogicalResult manuallyUnrollForOpLightweight(scf::ForOp forOp) {
1104+
auto maybeCount = air::getStaticScfForTripCountAsInt(forOp);
1105+
if (!maybeCount) {
1106+
forOp->emitOpError("lightweight unroll failed: dynamic trip count");
1107+
return failure();
1108+
}
1109+
unsigned tripCount = *maybeCount;
1110+
1111+
if (tripCount == 0) {
1112+
// Loop body never executes: replace results with init args and erase.
1113+
for (auto [result, initArg] :
1114+
llvm::zip(forOp.getResults(), forOp.getInitArgs()))
1115+
result.replaceAllUsesWith(initArg);
1116+
forOp.erase();
1117+
return success();
1118+
}
1119+
1120+
auto lbConst = getConstantIntValue(forOp.getLowerBound());
1121+
auto stepConst = getConstantIntValue(forOp.getStep());
1122+
if (!lbConst || !stepConst) {
1123+
forOp->emitOpError("lightweight unroll failed: non-constant bounds");
1124+
return failure();
1125+
}
1126+
1127+
// Insert new ops immediately before the forOp.
1128+
OpBuilder builder(forOp);
1129+
auto loc = forOp.getLoc();
1130+
Block &loopBody = forOp.getRegion().front();
1131+
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
1132+
1133+
// prevYields starts as the loop's init args.
1134+
SmallVector<Value> prevYields(forOp.getInitArgs().begin(),
1135+
forOp.getInitArgs().end());
1136+
1137+
for (unsigned i = 0; i < tripCount; ++i) {
1138+
IRMapping mapper;
1139+
// Map the induction variable to a constant for this iteration.
1140+
int64_t ivVal = *lbConst + static_cast<int64_t>(i) * *stepConst;
1141+
Value ivConst = arith::ConstantIndexOp::create(builder, loc, ivVal);
1142+
mapper.map(forOp.getInductionVar(), ivConst);
1143+
1144+
// Map iter_args to the previous iteration's yielded values.
1145+
for (auto [iterArg, prevYield] :
1146+
llvm::zip(forOp.getRegionIterArgs(), prevYields))
1147+
mapper.map(iterArg, prevYield);
1148+
1149+
// Clone the loop body ops (excluding the yield terminator) with
1150+
// lightweight treatment for air.SegmentOp and air.HerdOp.
1151+
for (Operation &op : loopBody.without_terminator()) {
1152+
if (auto segOp = dyn_cast<air::SegmentOp>(&op)) {
1153+
cloneSegmentOpLightweight(builder, segOp, mapper);
1154+
} else if (auto herdOp = dyn_cast<air::HerdOp>(&op)) {
1155+
cloneHerdOpLightweight(builder, herdOp, mapper);
1156+
} else {
1157+
builder.clone(op, mapper);
1158+
}
1159+
}
1160+
1161+
// Collect this iteration's yielded values for the next iteration.
1162+
prevYields.clear();
1163+
for (Value yieldVal : yieldOp.getOperands())
1164+
prevYields.push_back(mapper.lookupOrDefault(yieldVal));
1165+
}
1166+
1167+
// Replace forOp results with the final iteration's yields and erase.
1168+
for (auto [result, finalYield] : llvm::zip(forOp.getResults(), prevYields))
1169+
result.replaceAllUsesWith(finalYield);
1170+
forOp.erase();
1171+
return success();
1172+
}
1173+
9681174
// Fully unrolls an `scf.for` loop while preserving async token dependencies.
9691175
//
9701176
// This function labels the operations that define the values yielded by
@@ -976,6 +1182,13 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) {
9761182
//
9771183
// If `annotateFn` is provided, it is passed to `loopUnrollByFactor` for result
9781184
// tagging.
1185+
//
1186+
// For shim-level loops (loops whose bodies contain air.SegmentOp or directly
1187+
// contain air.HerdOp), lightweight unrolling is used: herd bodies are NOT
1188+
// deep-cloned N times. Instead, only channel ops, allocs, and their transitive
1189+
// operand-defining ops are cloned into each iteration's herd copy. This avoids
1190+
// O(N * body_size) IR explosion while preserving all channel ops needed by
1191+
// downstream passes (BD folding, air-to-std channel matching).
9791192
LogicalResult loopUnrollFullWithAsyncTokenPreserved(
9801193
scf::ForOp forOp,
9811194
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
@@ -984,6 +1197,20 @@ LogicalResult loopUnrollFullWithAsyncTokenPreserved(
9841197

9851198
Block *parentBlock = forOp->getBlock();
9861199

1200+
// For shim-level loops containing segments or herds, use lightweight
1201+
// unrolling to avoid O(N * herd_body_size) deep-clone cost.
1202+
bool containsSegmentOrHerd = false;
1203+
forOp.walk([&](air::SegmentOp) { containsSegmentOrHerd = true; });
1204+
if (!containsSegmentOrHerd)
1205+
forOp.walk([&](air::HerdOp) { containsSegmentOrHerd = true; });
1206+
1207+
if (!annotateFn && containsSegmentOrHerd) {
1208+
if (failed(manuallyUnrollForOpLightweight(forOp)))
1209+
return failure();
1210+
preserveAsyncDependenciesAfterUnroll(*parentBlock);
1211+
return success();
1212+
}
1213+
9871214
// Fully unroll the loop
9881215
if (annotateFn) {
9891216
auto unroll_factor = air::getStaticScfForTripCountAsInt(forOp);

mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1" | FileCheck %s
99
// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=2,2" | FileCheck %s --check-prefix=NPUTILED
1010
// RUN: air-opt %s -air-opt-shim-dma-bds="device=xcvc1902" | FileCheck %s --check-prefix=AIE1
11+
// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=2,2" | FileCheck %s --check-prefix=LIGHTWEIGHT
1112

1213
// Optimize logical air.channel.put/get op into efficient shim dma block descriptor (BD).
1314

@@ -954,4 +955,47 @@ module {
954955
}
955956
return
956957
}
958+
959+
// Lightweight unrolling: verify that after unrolling a shim-level scf.for
960+
// containing air.segment > air.herd, channel ops in herd bodies are
961+
// preserved while compute ops with dead results are NOT cloned into the
962+
// lightweight herd copies. The launch has trip count 2; after the pass
963+
// converts it to a scf.for and unrolls with shim-dma-tile-sizes=2,2, two
964+
// lightweight herd copies should appear, each missing the arith.muli.
965+
966+
// LIGHTWEIGHT-LABEL: func_lightweight_unroll
967+
// LIGHTWEIGHT: air.segment
968+
// LIGHTWEIGHT: air.channel.get
969+
// LIGHTWEIGHT-NOT: arith.muli
970+
971+
air.channel @ch_lite [2, 1]
972+
973+
func.func @func_lightweight_unroll(%arg0: memref<256xi32>) {
974+
%c2 = arith.constant 2 : index
975+
%0 = air.launch async (%arg1) in (%arg2=%c2) args(%arg3=%arg0) : memref<256xi32> {
976+
%c0 = arith.constant 0 : index
977+
%c1 = arith.constant 1 : index
978+
%c64 = arith.constant 64 : index
979+
%c128 = arith.constant 128 : index
980+
%1 = air.wait_all async
981+
%2 = air.segment async [%1] {
982+
%c0_0 = arith.constant 0 : index
983+
%c1_0 = arith.constant 1 : index
984+
%c64_0 = arith.constant 64 : index
985+
%3 = air.herd async tile (%tx, %ty) in (%sx=%c1_0, %sy=%c1_0) {
986+
%c0_1 = arith.constant 0 : index
987+
%c1_1 = arith.constant 1 : index
988+
%c64_1 = arith.constant 64 : index
989+
%l1buf = memref.alloc() : memref<64xi32, 2>
990+
%dead = arith.muli %tx, %ty : index
991+
%4 = air.channel.get async @ch_lite[%tx, %ty] (%l1buf[%c0_1] [%c64_1] [%c1_1]) : (memref<64xi32, 2>)
992+
memref.dealloc %l1buf : memref<64xi32, 2>
993+
air.herd_terminator
994+
}
995+
air.segment_terminator
996+
}
997+
%5 = air.channel.put async [%1] @ch_lite[%arg1, %c0] (%arg3[%arg1, %c0] [%c1, %c64] [%c128, %c1]) : (memref<256xi32>)
998+
}
999+
return
1000+
}
9571001
}

0 commit comments

Comments
 (0)