@@ -1006,6 +1006,134 @@ LogicalResult loopUnrollFullWithAsyncTokenPreserved(
10061006 return success ();
10071007}
10081008
1009+ // Lightweight unroll: clones empty shells for SegmentOp/HerdOp on iterations
1010+ // 1..N-1, keeping the full body only for iteration 0. This avoids the
1011+ // O(N * body_size) IR explosion from deep-cloning segment/herd bodies that
1012+ // BD folding never touches. The downstream air-to-std channel matching only
1013+ // needs one copy of the segment body (via ChannelCounterpartCache which
1014+ // stores the first get/put per channel name).
1015+ LogicalResult loopUnrollFullLightweight (scf::ForOp forOp) {
1016+ auto tripCount = air::getStaticScfForTripCountAsInt (forOp);
1017+ if (!tripCount) {
1018+ // Dynamic loop bound — fall back to standard unroll.
1019+ return loopUnrollFullWithAsyncTokenPreserved (forOp);
1020+ }
1021+
1022+ // Pre-processing: label yield-defining ops for async token fixup.
1023+ labelYieldDefiningOpsOfForLoop (forOp, " scf.for_result_id" );
1024+ Block *parentBlock = forOp->getBlock ();
1025+
1026+ // Trivial case: single iteration — just promote.
1027+ if (*tripCount == 1 ) {
1028+ IRRewriter rewriter (forOp.getContext ());
1029+ (void )forOp.promoteIfSingleIteration (rewriter);
1030+ preserveAsyncDependenciesAfterUnroll (*parentBlock);
1031+ return success ();
1032+ }
1033+
1034+ Block *body = forOp.getBody ();
1035+ OpBuilder builder (forOp.getContext ());
1036+ builder.setInsertionPoint (body->getTerminator ());
1037+
1038+ // Identify the last op before the yield (the "end" of the body to clone).
1039+ Block::iterator srcBlockEnd = std::prev (body->end (), 2 );
1040+
1041+ Value iv = forOp.getInductionVar ();
1042+ Value lb = forOp.getLowerBound ();
1043+ Value step = forOp.getStep ();
1044+ auto loc = forOp.getLoc ();
1045+
1046+ // Track yielded values for iter_arg remapping across iterations.
1047+ SmallVector<Value> lastYielded (
1048+ forOp.getBody ()->getTerminator ()->getOperands ());
1049+
1050+ // Clone iterations 1..N-1 into the body (iteration 0 is the original).
1051+ for (unsigned i = 1 ; i < *tripCount; i++) {
1052+ IRMapping operandMap;
1053+
1054+ // Map iter_args to the previous iteration's yielded values.
1055+ for (auto [iterArg, yielded] :
1056+ llvm::zip (forOp.getRegionIterArgs (), lastYielded))
1057+ operandMap.map (iterArg, yielded);
1058+
1059+ // Map IV to lb + i * step.
1060+ if (!iv.use_empty ()) {
1061+ Value iterConst = arith::ConstantIndexOp::create (builder, loc, i);
1062+ Value offset = arith::MulIOp::create (builder, loc, step, iterConst);
1063+ Value newIV = arith::AddIOp::create (builder, loc, lb, offset);
1064+ operandMap.map (iv, newIV);
1065+ }
1066+
1067+ // Clone each op in the body. After cloning, strip segment/herd bodies
1068+ // in the cloned copy (they are deep inside inner loops but BD folding
1069+ // never touches them — only L3 channel ops at the launch level matter).
1070+ for (auto it = body->begin (); it != std::next (srcBlockEnd); it++) {
1071+ Operation *clonedOp = builder.clone (*it, operandMap);
1072+
1073+ // Walk the cloned op and strip segment/herd bodies, but only those
1074+ // inside a dummyLaunch (from AIRLaunchToScfForPattern). Standalone
1075+ // herds at the function level must keep their bodies.
1076+ clonedOp->walk ([&](Operation *nested) {
1077+ if (!isa<air::SegmentOp, air::HerdOp>(nested))
1078+ return ;
1079+ auto parentLaunch = nested->getParentOfType <air::LaunchOp>();
1080+ if (!parentLaunch || !parentLaunch->hasAttr (" dummyLaunch" ))
1081+ return ;
1082+ for (Region ®ion : nested->getRegions ()) {
1083+ if (region.empty ())
1084+ continue ;
1085+ Block &blk = region.front ();
1086+ // Erase all ops except the terminator.
1087+ while (blk.getOperations ().size () > 1 ) {
1088+ Operation &op = *std::prev (blk.end (), 2 );
1089+ op.dropAllUses ();
1090+ op.erase ();
1091+ }
1092+ }
1093+ });
1094+ }
1095+
1096+ // Update yielded values for the next iteration.
1097+ auto *yield = forOp.getBody ()->getTerminator ();
1098+ for (unsigned j = 0 ; j < lastYielded.size (); j++)
1099+ lastYielded[j] = operandMap.lookupOrDefault (yield->getOperand (j));
1100+ }
1101+
1102+ // Replace the loop's results with the last iteration's yielded values,
1103+ // then inline the body into the parent block and erase the loop.
1104+ {
1105+ IRRewriter rewriter (forOp.getContext ());
1106+
1107+ // Replace IV uses in iteration 0 with lb (the first iteration's IV value).
1108+ if (!iv.use_empty ())
1109+ rewriter.replaceAllUsesWith (iv, lb);
1110+
1111+ // Replace iter_arg uses in iteration 0 with init values.
1112+ for (auto [iterArg, initVal] :
1113+ llvm::zip (forOp.getRegionIterArgs (), forOp.getInitArgs ()))
1114+ rewriter.replaceAllUsesWith (iterArg, initVal);
1115+
1116+ // Replace loop results with the last iteration's yielded values.
1117+ for (auto [loopResult, yielded] :
1118+ llvm::zip (forOp.getResults (), lastYielded))
1119+ rewriter.replaceAllUsesWith (loopResult, yielded);
1120+
1121+ // Move all body ops (except the yield) before the loop in the parent.
1122+ Block *parentBlock2 = forOp->getBlock ();
1123+ auto &bodyOps = body->getOperations ();
1124+ auto yieldIt = std::prev (bodyOps.end ()); // the yield terminator
1125+ parentBlock2->getOperations ().splice (Block::iterator (forOp), bodyOps,
1126+ bodyOps.begin (), yieldIt);
1127+
1128+ // Erase the now-empty loop.
1129+ rewriter.eraseOp (forOp);
1130+ }
1131+
1132+ // Post-processing: reconnect async token dependencies.
1133+ preserveAsyncDependenciesAfterUnroll (*parentBlock);
1134+ return success ();
1135+ }
1136+
10091137// Unrolls an `scf.for` loop by a given factor while preserving async token
10101138// dependencies.
10111139//
0 commit comments