88
99#include " air/Util/Dependency.h"
1010#include " air/Util/Util.h"
11+ #include " mlir/Dialect/MemRef/IR/MemRef.h"
1112#include " mlir/Dialect/SCF/Utils/Utils.h"
1213#include " mlir/IR/Iterators.h"
14+ #include " mlir/IR/OperationSupport.h"
1315#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
16+ #include " llvm/ADT/SmallPtrSet.h"
1417#include " llvm/ADT/SmallSet.h"
1518
1619#define DEBUG_TYPE " air-dependency-util"
@@ -965,6 +968,209 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) {
965968 }
966969}
967970
971+ // Returns the set of ops in herdBody that must be kept when performing a
972+ // lightweight herd clone: channel ops, allocs, deallocations, wait_alls, the
973+ // terminator, and any ops that transitively define operands consumed by those.
974+ static SmallPtrSet<Operation *, 16 > collectHerdBodyOpsToKeep (Block &herdBody) {
975+ SmallPtrSet<Operation *, 16 > toKeep;
976+ // Always keep the block terminator.
977+ if (!herdBody.empty ())
978+ toKeep.insert (herdBody.getTerminator ());
979+ // Seed: channel ops, allocs, deallocations, and wait_alls.
980+ for (Operation &op : herdBody.without_terminator ()) {
981+ if (isa<air::ChannelInterface, memref::AllocOp, memref::DeallocOp,
982+ air::WaitAllOp>(op))
983+ toKeep.insert (&op);
984+ }
985+ // Expand: add ops whose results are consumed by any kept op (within block).
986+ bool changed = true ;
987+ while (changed) {
988+ changed = false ;
989+ for (Operation *op :
990+ SmallVector<Operation *>(toKeep.begin (), toKeep.end ())) {
991+ for (Value operand : op->getOperands ()) {
992+ if (Operation *defOp = operand.getDefiningOp ()) {
993+ if (defOp->getBlock () == &herdBody && toKeep.insert (defOp).second )
994+ changed = true ;
995+ }
996+ }
997+ }
998+ }
999+ return toKeep;
1000+ }
1001+
1002+ // Lightweight clone of air.HerdOp: creates a new herd shell via OperationState
1003+ // with the same operands (remapped through mapper), result types, and
1004+ // attributes, but populates the body with ONLY channel ops, allocs,
1005+ // deallocations, wait_alls, their transitive operand-defining ops, and the
1006+ // terminator. Heavy compute ops whose results do not feed any of the above are
1007+ // skipped entirely. Updates mapper with result mappings for the new herd.
1008+ static void cloneHerdOpLightweight (OpBuilder &builder, air::HerdOp herdOp,
1009+ IRMapping &mapper) {
1010+ Block &origBody = herdOp.getBody ().front ();
1011+ SmallPtrSet<Operation *, 16 > toKeep = collectHerdBodyOpsToKeep (origBody);
1012+
1013+ // Map operands through the outer mapper.
1014+ SmallVector<Value> mappedOperands;
1015+ for (Value v : herdOp->getOperands ())
1016+ mappedOperands.push_back (mapper.lookupOrDefault (v));
1017+
1018+ // Create the herd shell via OperationState (no body yet).
1019+ OperationState state (herdOp.getLoc (), air::HerdOp::getOperationName ());
1020+ state.addOperands (mappedOperands);
1021+ state.addTypes (herdOp->getResultTypes ());
1022+ state.addAttributes (herdOp->getAttrs ());
1023+ state.addRegion (); // Placeholder for the body region.
1024+ Operation *newHerd = builder.create (state);
1025+
1026+ // Map original herd results -> new herd results.
1027+ for (auto [origRes, newRes] :
1028+ llvm::zip (herdOp->getResults (), newHerd->getResults ()))
1029+ mapper.map (origRes, newRes);
1030+
1031+ // Create the body block with the same block-argument types as the original.
1032+ Block *newBody = new Block ();
1033+ newHerd->getRegion (0 ).push_back (newBody);
1034+ IRMapping innerMapper;
1035+ for (BlockArgument origArg : origBody.getArguments ()) {
1036+ BlockArgument newArg =
1037+ newBody->addArgument (origArg.getType (), origArg.getLoc ());
1038+ innerMapper.map (origArg, newArg);
1039+ }
1040+
1041+ // Clone only the kept ops in original block order to preserve use-def.
1042+ OpBuilder bodyBuilder (newBody, newBody->end ());
1043+ for (Operation &op : origBody) {
1044+ if (toKeep.contains (&op))
1045+ bodyBuilder.clone (op, innerMapper);
1046+ }
1047+ }
1048+
1049+ // Forward declaration for mutual recursion.
1050+ static void cloneBlockBodyLightweight (OpBuilder &builder, Block &srcBlock,
1051+ IRMapping &mapper);
1052+
1053+ // Lightweight clone of air.SegmentOp: clones the segment shell and fully
1054+ // clones all ops directly in the segment body (including L3 channel ops needed
1055+ // by BD folding), but substitutes lightweight copies for any contained
1056+ // air.HerdOp. Updates mapper with segment result mappings.
1057+ static void cloneSegmentOpLightweight (OpBuilder &builder, air::SegmentOp segOp,
1058+ IRMapping &mapper) {
1059+ // Clone segment shell (without regions); this maps operands and results.
1060+ Operation *newSeg = segOp->cloneWithoutRegions (mapper);
1061+ builder.insert (newSeg);
1062+
1063+ // Create and populate the segment body block.
1064+ Block &origBody = segOp.getBody ().front ();
1065+ Block *newBody = new Block ();
1066+ newSeg->getRegion (0 ).push_back (newBody);
1067+
1068+ // Map block arguments (segment IDs, sizes, segment_operands).
1069+ IRMapping innerMapper;
1070+ for (BlockArgument origArg : origBody.getArguments ()) {
1071+ BlockArgument newArg =
1072+ newBody->addArgument (origArg.getType (), origArg.getLoc ());
1073+ innerMapper.map (origArg, newArg);
1074+ }
1075+
1076+ // Clone segment body, using lightweight cloning for any air.HerdOp.
1077+ OpBuilder bodyBuilder (newBody, newBody->end ());
1078+ cloneBlockBodyLightweight (bodyBuilder, origBody, innerMapper);
1079+ }
1080+
1081+ // Clone all ops in srcBlock into the current builder insertion point.
1082+ // HerdOps are cloned lightweight (body stripped to channel ops + deps);
1083+ // SegmentOps are cloned with lightweight recursion into their bodies;
1084+ // all other ops are cloned fully. mapper is updated with result mappings.
1085+ static void cloneBlockBodyLightweight (OpBuilder &builder, Block &srcBlock,
1086+ IRMapping &mapper) {
1087+ for (Operation &op : srcBlock) {
1088+ if (auto herdOp = dyn_cast<air::HerdOp>(&op)) {
1089+ cloneHerdOpLightweight (builder, herdOp, mapper);
1090+ } else if (auto segOp = dyn_cast<air::SegmentOp>(&op)) {
1091+ cloneSegmentOpLightweight (builder, segOp, mapper);
1092+ } else {
1093+ builder.clone (op, mapper);
1094+ }
1095+ }
1096+ }
1097+
1098+ // Manually unroll a shim-level scf.for that contains air.SegmentOp or
1099+ // air.HerdOp, using lightweight cloning for herd bodies to avoid
1100+ // O(N * body_size) IR explosion. Inserts N copies of the loop body before the
1101+ // forOp, replaces forOp results with the final iteration's yields, and erases
1102+ // the forOp. Returns failure if bounds are not statically known.
1103+ static LogicalResult manuallyUnrollForOpLightweight (scf::ForOp forOp) {
1104+ auto maybeCount = air::getStaticScfForTripCountAsInt (forOp);
1105+ if (!maybeCount) {
1106+ forOp->emitOpError (" lightweight unroll failed: dynamic trip count" );
1107+ return failure ();
1108+ }
1109+ unsigned tripCount = *maybeCount;
1110+
1111+ if (tripCount == 0 ) {
1112+ // Loop body never executes: replace results with init args and erase.
1113+ for (auto [result, initArg] :
1114+ llvm::zip (forOp.getResults (), forOp.getInitArgs ()))
1115+ result.replaceAllUsesWith (initArg);
1116+ forOp.erase ();
1117+ return success ();
1118+ }
1119+
1120+ auto lbConst = getConstantIntValue (forOp.getLowerBound ());
1121+ auto stepConst = getConstantIntValue (forOp.getStep ());
1122+ if (!lbConst || !stepConst) {
1123+ forOp->emitOpError (" lightweight unroll failed: non-constant bounds" );
1124+ return failure ();
1125+ }
1126+
1127+ // Insert new ops immediately before the forOp.
1128+ OpBuilder builder (forOp);
1129+ auto loc = forOp.getLoc ();
1130+ Block &loopBody = forOp.getRegion ().front ();
1131+ auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator ());
1132+
1133+ // prevYields starts as the loop's init args.
1134+ SmallVector<Value> prevYields (forOp.getInitArgs ().begin (),
1135+ forOp.getInitArgs ().end ());
1136+
1137+ for (unsigned i = 0 ; i < tripCount; ++i) {
1138+ IRMapping mapper;
1139+ // Map the induction variable to a constant for this iteration.
1140+ int64_t ivVal = *lbConst + static_cast <int64_t >(i) * *stepConst;
1141+ Value ivConst = arith::ConstantIndexOp::create (builder, loc, ivVal);
1142+ mapper.map (forOp.getInductionVar (), ivConst);
1143+
1144+ // Map iter_args to the previous iteration's yielded values.
1145+ for (auto [iterArg, prevYield] :
1146+ llvm::zip (forOp.getRegionIterArgs (), prevYields))
1147+ mapper.map (iterArg, prevYield);
1148+
1149+ // Clone the loop body ops (excluding the yield terminator) with
1150+ // lightweight treatment for air.SegmentOp and air.HerdOp.
1151+ for (Operation &op : loopBody.without_terminator ()) {
1152+ if (auto segOp = dyn_cast<air::SegmentOp>(&op)) {
1153+ cloneSegmentOpLightweight (builder, segOp, mapper);
1154+ } else if (auto herdOp = dyn_cast<air::HerdOp>(&op)) {
1155+ cloneHerdOpLightweight (builder, herdOp, mapper);
1156+ } else {
1157+ builder.clone (op, mapper);
1158+ }
1159+ }
1160+
1161+ // Collect this iteration's yielded values for the next iteration.
1162+ prevYields.clear ();
1163+ for (Value yieldVal : yieldOp.getOperands ())
1164+ prevYields.push_back (mapper.lookupOrDefault (yieldVal));
1165+ }
1166+
1167+ // Replace forOp results with the final iteration's yields and erase.
1168+ for (auto [result, finalYield] : llvm::zip (forOp.getResults (), prevYields))
1169+ result.replaceAllUsesWith (finalYield);
1170+ forOp.erase ();
1171+ return success ();
1172+ }
1173+
9681174// Fully unrolls an `scf.for` loop while preserving async token dependencies.
9691175//
9701176// This function labels the operations that define the values yielded by
@@ -976,6 +1182,13 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) {
9761182//
9771183// If `annotateFn` is provided, it is passed to `loopUnrollByFactor` for result
9781184// tagging.
1185+ //
1186+ // For shim-level loops (loops whose bodies contain air.SegmentOp or directly
1187+ // contain air.HerdOp), lightweight unrolling is used: herd bodies are NOT
1188+ // deep-cloned N times. Instead, only channel ops, allocs, and their transitive
1189+ // operand-defining ops are cloned into each iteration's herd copy. This avoids
1190+ // O(N * body_size) IR explosion while preserving all channel ops needed by
1191+ // downstream passes (BD folding, air-to-std channel matching).
9791192LogicalResult loopUnrollFullWithAsyncTokenPreserved (
9801193 scf::ForOp forOp,
9811194 function_ref<void (unsigned , Operation *, OpBuilder)> annotateFn) {
@@ -984,6 +1197,20 @@ LogicalResult loopUnrollFullWithAsyncTokenPreserved(
9841197
9851198 Block *parentBlock = forOp->getBlock ();
9861199
1200+ // For shim-level loops containing segments or herds, use lightweight
1201+ // unrolling to avoid O(N * herd_body_size) deep-clone cost.
1202+ bool containsSegmentOrHerd = false ;
1203+ forOp.walk ([&](air::SegmentOp) { containsSegmentOrHerd = true ; });
1204+ if (!containsSegmentOrHerd)
1205+ forOp.walk ([&](air::HerdOp) { containsSegmentOrHerd = true ; });
1206+
1207+ if (!annotateFn && containsSegmentOrHerd) {
1208+ if (failed (manuallyUnrollForOpLightweight (forOp)))
1209+ return failure ();
1210+ preserveAsyncDependenciesAfterUnroll (*parentBlock);
1211+ return success ();
1212+ }
1213+
9871214 // Fully unroll the loop
9881215 if (annotateFn) {
9891216 auto unroll_factor = air::getStaticScfForTripCountAsInt (forOp);
0 commit comments