diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index deab87556..a26fa449c 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -8,9 +8,12 @@ #include "air/Util/Dependency.h" #include "air/Util/Util.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Iterators.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #define DEBUG_TYPE "air-dependency-util" @@ -965,6 +968,207 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) { } } +// Returns the set of ops in herdBody that must be kept when performing a +// lightweight herd clone: channel ops, allocs, deallocations, wait_alls, the +// terminator, and any ops that transitively define operands consumed by those. +static SmallPtrSet collectHerdBodyOpsToKeep(Block &herdBody) { + SmallPtrSet toKeep; + // Always keep the block terminator. + if (!herdBody.empty()) + toKeep.insert(herdBody.getTerminator()); + // Seed: channel ops, allocs, deallocations, and wait_alls. + for (Operation &op : herdBody.without_terminator()) { + if (isa(op)) + toKeep.insert(&op); + } + // Expand: add ops whose results are consumed by any kept op (within block). + // Use a worklist to avoid redundant re-processing of already-kept ops. + SmallVector worklist(toKeep.begin(), toKeep.end()); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + for (Value operand : op->getOperands()) { + if (Operation *defOp = operand.getDefiningOp()) { + if (defOp->getBlock() == &herdBody && toKeep.insert(defOp).second) + worklist.push_back(defOp); + } + } + } + return toKeep; +} + +// Lightweight clone of air.HerdOp: creates a new herd shell via OperationState +// with the same operands (remapped through mapper), result types, and +// attributes, but populates the body with ONLY channel ops, allocs, +// deallocations, wait_alls, their transitive operand-defining ops, and the +// terminator. Heavy compute ops whose results do not feed any of the above are +// skipped entirely. Updates mapper with result mappings for the new herd. +static void cloneHerdOpLightweight(OpBuilder &builder, air::HerdOp herdOp, + IRMapping &mapper) { + Block &origBody = herdOp.getBody().front(); + SmallPtrSet toKeep = collectHerdBodyOpsToKeep(origBody); + + // Map operands through the outer mapper. + SmallVector mappedOperands; + for (Value v : herdOp->getOperands()) + mappedOperands.push_back(mapper.lookupOrDefault(v)); + + // Create the herd shell via OperationState (no body yet). + OperationState state(herdOp.getLoc(), air::HerdOp::getOperationName()); + state.addOperands(mappedOperands); + state.addTypes(herdOp->getResultTypes()); + state.addAttributes(herdOp->getAttrs()); + state.addRegion(); // Placeholder for the body region. + Operation *newHerd = builder.create(state); + + // Map original herd results -> new herd results. + for (auto [origRes, newRes] : + llvm::zip(herdOp->getResults(), newHerd->getResults())) + mapper.map(origRes, newRes); + + // Create the body block with the same block-argument types as the original. + Block *newBody = new Block(); + newHerd->getRegion(0).push_back(newBody); + IRMapping innerMapper; + for (BlockArgument origArg : origBody.getArguments()) { + BlockArgument newArg = + newBody->addArgument(origArg.getType(), origArg.getLoc()); + innerMapper.map(origArg, newArg); + } + + // Clone only the kept ops in original block order to preserve use-def. + OpBuilder bodyBuilder(newBody, newBody->end()); + for (Operation &op : origBody) { + if (toKeep.contains(&op)) + bodyBuilder.clone(op, innerMapper); + } +} + +// Forward declaration for mutual recursion. +static void cloneBlockBodyLightweight(OpBuilder &builder, Block &srcBlock, + IRMapping &mapper); + +// Lightweight clone of air.SegmentOp: clones the segment shell and fully +// clones all ops directly in the segment body (including L3 channel ops needed +// by BD folding), but substitutes lightweight copies for any contained +// air.HerdOp. Updates mapper with segment result mappings. +static void cloneSegmentOpLightweight(OpBuilder &builder, air::SegmentOp segOp, + IRMapping &mapper) { + // Clone segment shell (without regions); this maps operands and results. + Operation *newSeg = segOp->cloneWithoutRegions(mapper); + builder.insert(newSeg); + + // Create and populate the segment body block. + Block &origBody = segOp.getBody().front(); + Block *newBody = new Block(); + newSeg->getRegion(0).push_back(newBody); + + // Map block arguments (segment IDs, sizes, segment_operands). + IRMapping innerMapper; + for (BlockArgument origArg : origBody.getArguments()) { + BlockArgument newArg = + newBody->addArgument(origArg.getType(), origArg.getLoc()); + innerMapper.map(origArg, newArg); + } + + // Clone segment body, using lightweight cloning for any air.HerdOp. + OpBuilder bodyBuilder(newBody, newBody->end()); + cloneBlockBodyLightweight(bodyBuilder, origBody, innerMapper); +} + +// Clone all ops in srcBlock into the current builder insertion point. +// HerdOps are cloned lightweight (body stripped to channel ops + deps); +// SegmentOps are cloned with lightweight recursion into their bodies; +// all other ops are cloned fully. mapper is updated with result mappings. +static void cloneBlockBodyLightweight(OpBuilder &builder, Block &srcBlock, + IRMapping &mapper) { + for (Operation &op : srcBlock) { + if (auto herdOp = dyn_cast(&op)) { + cloneHerdOpLightweight(builder, herdOp, mapper); + } else if (auto segOp = dyn_cast(&op)) { + cloneSegmentOpLightweight(builder, segOp, mapper); + } else { + builder.clone(op, mapper); + } + } +} + +// Manually unroll a shim-level scf.for that contains air.SegmentOp or +// air.HerdOp, using lightweight cloning for herd bodies to avoid +// O(N * body_size) IR explosion. Inserts N copies of the loop body before the +// forOp, replaces forOp results with the final iteration's yields, and erases +// the forOp. Returns failure if bounds are not statically known. +static LogicalResult manuallyUnrollForOpLightweight(scf::ForOp forOp) { + auto maybeCount = air::getStaticScfForTripCountAsInt(forOp); + if (!maybeCount) { + forOp->emitOpError("lightweight unroll failed: dynamic trip count"); + return failure(); + } + unsigned tripCount = *maybeCount; + + if (tripCount == 0) { + // Loop body never executes: replace results with init args and erase. + for (auto [result, initArg] : + llvm::zip(forOp.getResults(), forOp.getInitArgs())) + result.replaceAllUsesWith(initArg); + forOp.erase(); + return success(); + } + + auto lbConst = getConstantIntValue(forOp.getLowerBound()); + auto stepConst = getConstantIntValue(forOp.getStep()); + if (!lbConst || !stepConst) { + forOp->emitOpError("lightweight unroll failed: non-constant bounds"); + return failure(); + } + + // Insert new ops immediately before the forOp. + OpBuilder builder(forOp); + auto loc = forOp.getLoc(); + Block &loopBody = forOp.getRegion().front(); + auto yieldOp = cast(loopBody.getTerminator()); + + // prevYields starts as the loop's init args. + SmallVector prevYields(forOp.getInitArgs().begin(), + forOp.getInitArgs().end()); + + for (unsigned i = 0; i < tripCount; ++i) { + IRMapping mapper; + // Map the induction variable to a constant for this iteration. + int64_t ivVal = *lbConst + static_cast(i) * *stepConst; + Value ivConst = arith::ConstantIndexOp::create(builder, loc, ivVal); + mapper.map(forOp.getInductionVar(), ivConst); + + // Map iter_args to the previous iteration's yielded values. + for (auto [iterArg, prevYield] : + llvm::zip(forOp.getRegionIterArgs(), prevYields)) + mapper.map(iterArg, prevYield); + + // Clone the loop body ops (excluding the yield terminator) with + // lightweight treatment for air.SegmentOp and air.HerdOp. + for (Operation &op : loopBody.without_terminator()) { + if (auto segOp = dyn_cast(&op)) { + cloneSegmentOpLightweight(builder, segOp, mapper); + } else if (auto herdOp = dyn_cast(&op)) { + cloneHerdOpLightweight(builder, herdOp, mapper); + } else { + builder.clone(op, mapper); + } + } + + // Collect this iteration's yielded values for the next iteration. + prevYields.clear(); + for (Value yieldVal : yieldOp.getOperands()) + prevYields.push_back(mapper.lookupOrDefault(yieldVal)); + } + + // Replace forOp results with the final iteration's yields and erase. + for (auto [result, finalYield] : llvm::zip(forOp.getResults(), prevYields)) + result.replaceAllUsesWith(finalYield); + forOp.erase(); + return success(); +} + // Fully unrolls an `scf.for` loop while preserving async token dependencies. // // This function labels the operations that define the values yielded by @@ -976,6 +1180,13 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) { // // If `annotateFn` is provided, it is passed to `loopUnrollByFactor` for result // tagging. +// +// For shim-level loops (loops whose bodies contain air.SegmentOp or directly +// contain air.HerdOp), lightweight unrolling is used: herd bodies are NOT +// deep-cloned N times. Instead, only channel ops, allocs, and their transitive +// operand-defining ops are cloned into each iteration's herd copy. This avoids +// O(N * body_size) IR explosion while preserving all channel ops needed by +// downstream passes (BD folding, air-to-std channel matching). LogicalResult loopUnrollFullWithAsyncTokenPreserved( scf::ForOp forOp, function_ref annotateFn) { @@ -984,6 +1195,24 @@ LogicalResult loopUnrollFullWithAsyncTokenPreserved( Block *parentBlock = forOp->getBlock(); + // For shim-level loops containing segments or herds, use lightweight + // unrolling to avoid O(N * herd_body_size) deep-clone cost. + bool containsSegmentOrHerd = false; + forOp->walk([&](Operation *op) -> WalkResult { + if (isa(op)) { + containsSegmentOrHerd = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (!annotateFn && containsSegmentOrHerd) { + if (failed(manuallyUnrollForOpLightweight(forOp))) + return failure(); + preserveAsyncDependenciesAfterUnroll(*parentBlock); + return success(); + } + // Fully unroll the loop if (annotateFn) { auto unroll_factor = air::getStaticScfForTripCountAsInt(forOp); diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir index 5a16bb434..f27018104 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir @@ -8,6 +8,7 @@ // RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1" | FileCheck %s // RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=2,2" | FileCheck %s --check-prefix=NPUTILED // RUN: air-opt %s -air-opt-shim-dma-bds="device=xcvc1902" | FileCheck %s --check-prefix=AIE1 +// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=2,2" | FileCheck %s --check-prefix=LIGHTWEIGHT // Optimize logical air.channel.put/get op into efficient shim dma block descriptor (BD). @@ -954,4 +955,47 @@ module { } return } + + // Lightweight unrolling: verify that after unrolling a shim-level scf.for + // containing air.segment > air.herd, channel ops in herd bodies are + // preserved while compute ops with dead results are NOT cloned into the + // lightweight herd copies. The launch has trip count 2; after the pass + // converts it to a scf.for and unrolls with shim-dma-tile-sizes=2,2, two + // lightweight herd copies should appear, each missing the arith.muli. + + // LIGHTWEIGHT-LABEL: func_lightweight_unroll + // LIGHTWEIGHT: air.segment + // LIGHTWEIGHT: air.channel.get + // LIGHTWEIGHT-NOT: arith.muli + + air.channel @ch_lite [2, 1] + + func.func @func_lightweight_unroll(%arg0: memref<256xi32>) { + %c2 = arith.constant 2 : index + %0 = air.launch async (%arg1) in (%arg2=%c2) args(%arg3=%arg0) : memref<256xi32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %1 = air.wait_all async + %2 = air.segment async [%1] { + %c0_0 = arith.constant 0 : index + %c1_0 = arith.constant 1 : index + %c64_0 = arith.constant 64 : index + %3 = air.herd async tile (%tx, %ty) in (%sx=%c1_0, %sy=%c1_0) { + %c0_1 = arith.constant 0 : index + %c1_1 = arith.constant 1 : index + %c64_1 = arith.constant 64 : index + %l1buf = memref.alloc() : memref<64xi32, 2> + %dead = arith.muli %tx, %ty : index + %4 = air.channel.get async @ch_lite[%tx, %ty] (%l1buf[%c0_1] [%c64_1] [%c1_1]) : (memref<64xi32, 2>) + memref.dealloc %l1buf : memref<64xi32, 2> + air.herd_terminator + } + air.segment_terminator + } + %5 = air.channel.put async [%1] @ch_lite[%arg1, %c0] (%arg3[%arg1, %c0] [%c1, %c64] [%c128, %c1]) : (memref<256xi32>) + } + return + } }