From 0fdae096cd2bb72250d881701c65708f44493fcf Mon Sep 17 00:00:00 2001 From: erweiw Date: Sat, 25 Apr 2026 12:43:30 -0700 Subject: [PATCH 1/2] Avoid O(N*body_size) deep-cloning herd bodies during shim loop unroll For shim-level scf.for loops that contain air.SegmentOp / air.HerdOp, replace the standard loopUnrollFull (which deep-clones the entire body N times) with a manual unroller that uses lightweight herd cloning. The lightweight herd clone (cloneHerdOpLightweight) creates a new herd shell via OperationState and populates the body with ONLY channel ops, allocs, deallocations, wait_alls, and their transitive operand-defining ops. Heavy compute ops (matrix multiply, vector ops, etc.) are never cloned. The segment clone (cloneSegmentOpLightweight) preserves the full segment body (L3 channel ops needed by BD folding) but applies lightweight cloning recursively to any contained air.HerdOp. This brings flash attention 12x4 (tiles=2,2) from crashing (T002) or taking O(N*body_size) time back to ~50ms, matching the main-branch baseline. Also adds a lit test (func_lightweight_unroll) that verifies channel ops are preserved and compute ops are excluded from the lightweight herd copies. Co-Authored-By: Claude Sonnet 4.6 --- mlir/lib/Util/Dependency.cpp | 227 ++++++++++++++++++ .../opt_shim_dma_bds.mlir | 44 ++++ 2 files changed, 271 insertions(+) diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index deab87556..091fbd191 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -8,9 +8,12 @@ #include "air/Util/Dependency.h" #include "air/Util/Util.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Iterators.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #define DEBUG_TYPE "air-dependency-util" @@ -965,6 +968,209 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) { } } +// Returns the set of ops in herdBody that must be kept when performing a +// lightweight herd clone: channel ops, allocs, deallocations, wait_alls, the +// terminator, and any ops that transitively define operands consumed by those. +static SmallPtrSet collectHerdBodyOpsToKeep(Block &herdBody) { + SmallPtrSet toKeep; + // Always keep the block terminator. + if (!herdBody.empty()) + toKeep.insert(herdBody.getTerminator()); + // Seed: channel ops, allocs, deallocations, and wait_alls. + for (Operation &op : herdBody.without_terminator()) { + if (isa(op)) + toKeep.insert(&op); + } + // Expand: add ops whose results are consumed by any kept op (within block). + bool changed = true; + while (changed) { + changed = false; + for (Operation *op : + SmallVector(toKeep.begin(), toKeep.end())) { + for (Value operand : op->getOperands()) { + if (Operation *defOp = operand.getDefiningOp()) { + if (defOp->getBlock() == &herdBody && toKeep.insert(defOp).second) + changed = true; + } + } + } + } + return toKeep; +} + +// Lightweight clone of air.HerdOp: creates a new herd shell via OperationState +// with the same operands (remapped through mapper), result types, and +// attributes, but populates the body with ONLY channel ops, allocs, +// deallocations, wait_alls, their transitive operand-defining ops, and the +// terminator. Heavy compute ops whose results do not feed any of the above are +// skipped entirely. Updates mapper with result mappings for the new herd. +static void cloneHerdOpLightweight(OpBuilder &builder, air::HerdOp herdOp, + IRMapping &mapper) { + Block &origBody = herdOp.getBody().front(); + SmallPtrSet toKeep = collectHerdBodyOpsToKeep(origBody); + + // Map operands through the outer mapper. + SmallVector mappedOperands; + for (Value v : herdOp->getOperands()) + mappedOperands.push_back(mapper.lookupOrDefault(v)); + + // Create the herd shell via OperationState (no body yet). + OperationState state(herdOp.getLoc(), air::HerdOp::getOperationName()); + state.addOperands(mappedOperands); + state.addTypes(herdOp->getResultTypes()); + state.addAttributes(herdOp->getAttrs()); + state.addRegion(); // Placeholder for the body region. + Operation *newHerd = builder.create(state); + + // Map original herd results -> new herd results. + for (auto [origRes, newRes] : + llvm::zip(herdOp->getResults(), newHerd->getResults())) + mapper.map(origRes, newRes); + + // Create the body block with the same block-argument types as the original. + Block *newBody = new Block(); + newHerd->getRegion(0).push_back(newBody); + IRMapping innerMapper; + for (BlockArgument origArg : origBody.getArguments()) { + BlockArgument newArg = + newBody->addArgument(origArg.getType(), origArg.getLoc()); + innerMapper.map(origArg, newArg); + } + + // Clone only the kept ops in original block order to preserve use-def. + OpBuilder bodyBuilder(newBody, newBody->end()); + for (Operation &op : origBody) { + if (toKeep.contains(&op)) + bodyBuilder.clone(op, innerMapper); + } +} + +// Forward declaration for mutual recursion. +static void cloneBlockBodyLightweight(OpBuilder &builder, Block &srcBlock, + IRMapping &mapper); + +// Lightweight clone of air.SegmentOp: clones the segment shell and fully +// clones all ops directly in the segment body (including L3 channel ops needed +// by BD folding), but substitutes lightweight copies for any contained +// air.HerdOp. Updates mapper with segment result mappings. +static void cloneSegmentOpLightweight(OpBuilder &builder, air::SegmentOp segOp, + IRMapping &mapper) { + // Clone segment shell (without regions); this maps operands and results. + Operation *newSeg = segOp->cloneWithoutRegions(mapper); + builder.insert(newSeg); + + // Create and populate the segment body block. + Block &origBody = segOp.getBody().front(); + Block *newBody = new Block(); + newSeg->getRegion(0).push_back(newBody); + + // Map block arguments (segment IDs, sizes, segment_operands). + IRMapping innerMapper; + for (BlockArgument origArg : origBody.getArguments()) { + BlockArgument newArg = + newBody->addArgument(origArg.getType(), origArg.getLoc()); + innerMapper.map(origArg, newArg); + } + + // Clone segment body, using lightweight cloning for any air.HerdOp. + OpBuilder bodyBuilder(newBody, newBody->end()); + cloneBlockBodyLightweight(bodyBuilder, origBody, innerMapper); +} + +// Clone all ops in srcBlock into the current builder insertion point. +// HerdOps are cloned lightweight (body stripped to channel ops + deps); +// SegmentOps are cloned with lightweight recursion into their bodies; +// all other ops are cloned fully. mapper is updated with result mappings. +static void cloneBlockBodyLightweight(OpBuilder &builder, Block &srcBlock, + IRMapping &mapper) { + for (Operation &op : srcBlock) { + if (auto herdOp = dyn_cast(&op)) { + cloneHerdOpLightweight(builder, herdOp, mapper); + } else if (auto segOp = dyn_cast(&op)) { + cloneSegmentOpLightweight(builder, segOp, mapper); + } else { + builder.clone(op, mapper); + } + } +} + +// Manually unroll a shim-level scf.for that contains air.SegmentOp or +// air.HerdOp, using lightweight cloning for herd bodies to avoid +// O(N * body_size) IR explosion. Inserts N copies of the loop body before the +// forOp, replaces forOp results with the final iteration's yields, and erases +// the forOp. Returns failure if bounds are not statically known. +static LogicalResult manuallyUnrollForOpLightweight(scf::ForOp forOp) { + auto maybeCount = air::getStaticScfForTripCountAsInt(forOp); + if (!maybeCount) { + forOp->emitOpError("lightweight unroll failed: dynamic trip count"); + return failure(); + } + unsigned tripCount = *maybeCount; + + if (tripCount == 0) { + // Loop body never executes: replace results with init args and erase. + for (auto [result, initArg] : + llvm::zip(forOp.getResults(), forOp.getInitArgs())) + result.replaceAllUsesWith(initArg); + forOp.erase(); + return success(); + } + + auto lbConst = getConstantIntValue(forOp.getLowerBound()); + auto stepConst = getConstantIntValue(forOp.getStep()); + if (!lbConst || !stepConst) { + forOp->emitOpError("lightweight unroll failed: non-constant bounds"); + return failure(); + } + + // Insert new ops immediately before the forOp. + OpBuilder builder(forOp); + auto loc = forOp.getLoc(); + Block &loopBody = forOp.getRegion().front(); + auto yieldOp = cast(loopBody.getTerminator()); + + // prevYields starts as the loop's init args. + SmallVector prevYields(forOp.getInitArgs().begin(), + forOp.getInitArgs().end()); + + for (unsigned i = 0; i < tripCount; ++i) { + IRMapping mapper; + // Map the induction variable to a constant for this iteration. + int64_t ivVal = *lbConst + static_cast(i) * *stepConst; + Value ivConst = arith::ConstantIndexOp::create(builder, loc, ivVal); + mapper.map(forOp.getInductionVar(), ivConst); + + // Map iter_args to the previous iteration's yielded values. + for (auto [iterArg, prevYield] : + llvm::zip(forOp.getRegionIterArgs(), prevYields)) + mapper.map(iterArg, prevYield); + + // Clone the loop body ops (excluding the yield terminator) with + // lightweight treatment for air.SegmentOp and air.HerdOp. + for (Operation &op : loopBody.without_terminator()) { + if (auto segOp = dyn_cast(&op)) { + cloneSegmentOpLightweight(builder, segOp, mapper); + } else if (auto herdOp = dyn_cast(&op)) { + cloneHerdOpLightweight(builder, herdOp, mapper); + } else { + builder.clone(op, mapper); + } + } + + // Collect this iteration's yielded values for the next iteration. + prevYields.clear(); + for (Value yieldVal : yieldOp.getOperands()) + prevYields.push_back(mapper.lookupOrDefault(yieldVal)); + } + + // Replace forOp results with the final iteration's yields and erase. + for (auto [result, finalYield] : llvm::zip(forOp.getResults(), prevYields)) + result.replaceAllUsesWith(finalYield); + forOp.erase(); + return success(); +} + // Fully unrolls an `scf.for` loop while preserving async token dependencies. // // This function labels the operations that define the values yielded by @@ -976,6 +1182,13 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) { // // If `annotateFn` is provided, it is passed to `loopUnrollByFactor` for result // tagging. +// +// For shim-level loops (loops whose bodies contain air.SegmentOp or directly +// contain air.HerdOp), lightweight unrolling is used: herd bodies are NOT +// deep-cloned N times. Instead, only channel ops, allocs, and their transitive +// operand-defining ops are cloned into each iteration's herd copy. This avoids +// O(N * body_size) IR explosion while preserving all channel ops needed by +// downstream passes (BD folding, air-to-std channel matching). LogicalResult loopUnrollFullWithAsyncTokenPreserved( scf::ForOp forOp, function_ref annotateFn) { @@ -984,6 +1197,20 @@ LogicalResult loopUnrollFullWithAsyncTokenPreserved( Block *parentBlock = forOp->getBlock(); + // For shim-level loops containing segments or herds, use lightweight + // unrolling to avoid O(N * herd_body_size) deep-clone cost. + bool containsSegmentOrHerd = false; + forOp.walk([&](air::SegmentOp) { containsSegmentOrHerd = true; }); + if (!containsSegmentOrHerd) + forOp.walk([&](air::HerdOp) { containsSegmentOrHerd = true; }); + + if (!annotateFn && containsSegmentOrHerd) { + if (failed(manuallyUnrollForOpLightweight(forOp))) + return failure(); + preserveAsyncDependenciesAfterUnroll(*parentBlock); + return success(); + } + // Fully unroll the loop if (annotateFn) { auto unroll_factor = air::getStaticScfForTripCountAsInt(forOp); diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir index 5a16bb434..f27018104 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/opt_shim_dma_bds.mlir @@ -8,6 +8,7 @@ // RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1" | FileCheck %s // RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=2,2" | FileCheck %s --check-prefix=NPUTILED // RUN: air-opt %s -air-opt-shim-dma-bds="device=xcvc1902" | FileCheck %s --check-prefix=AIE1 +// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=2,2" | FileCheck %s --check-prefix=LIGHTWEIGHT // Optimize logical air.channel.put/get op into efficient shim dma block descriptor (BD). @@ -954,4 +955,47 @@ module { } return } + + // Lightweight unrolling: verify that after unrolling a shim-level scf.for + // containing air.segment > air.herd, channel ops in herd bodies are + // preserved while compute ops with dead results are NOT cloned into the + // lightweight herd copies. The launch has trip count 2; after the pass + // converts it to a scf.for and unrolls with shim-dma-tile-sizes=2,2, two + // lightweight herd copies should appear, each missing the arith.muli. + + // LIGHTWEIGHT-LABEL: func_lightweight_unroll + // LIGHTWEIGHT: air.segment + // LIGHTWEIGHT: air.channel.get + // LIGHTWEIGHT-NOT: arith.muli + + air.channel @ch_lite [2, 1] + + func.func @func_lightweight_unroll(%arg0: memref<256xi32>) { + %c2 = arith.constant 2 : index + %0 = air.launch async (%arg1) in (%arg2=%c2) args(%arg3=%arg0) : memref<256xi32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %1 = air.wait_all async + %2 = air.segment async [%1] { + %c0_0 = arith.constant 0 : index + %c1_0 = arith.constant 1 : index + %c64_0 = arith.constant 64 : index + %3 = air.herd async tile (%tx, %ty) in (%sx=%c1_0, %sy=%c1_0) { + %c0_1 = arith.constant 0 : index + %c1_1 = arith.constant 1 : index + %c64_1 = arith.constant 64 : index + %l1buf = memref.alloc() : memref<64xi32, 2> + %dead = arith.muli %tx, %ty : index + %4 = air.channel.get async @ch_lite[%tx, %ty] (%l1buf[%c0_1] [%c64_1] [%c1_1]) : (memref<64xi32, 2>) + memref.dealloc %l1buf : memref<64xi32, 2> + air.herd_terminator + } + air.segment_terminator + } + %5 = air.channel.put async [%1] @ch_lite[%arg1, %c0] (%arg3[%arg1, %c0] [%c1, %c64] [%c128, %c1]) : (memref<256xi32>) + } + return + } } From 400cd2a52d418a2804165a3c11bc32dc784fe840 Mon Sep 17 00:00:00 2001 From: erweiw Date: Sat, 25 Apr 2026 20:08:53 -0700 Subject: [PATCH 2/2] Address review comments: worklist transitive closure, single early-terminating walk Two efficiency improvements identified in self-review: 1. collectHerdBodyOpsToKeep: replace fixed-point iteration (copies toKeep set each round) with a worklist algorithm. Each newly-added op is pushed once and processed once, avoiding redundant re-scanning of already-kept ops in subsequent rounds. 2. loopUnrollFullWithAsyncTokenPreserved: replace two sequential walks (one for SegmentOp, one for HerdOp) with a single interruptible walk that stops as soon as the first SegmentOp or HerdOp is found. Avoids walking the entire IR a second time in the common case where a segment is present. No functional change; build and all tests pass (same 2 pre-existing ROCDL failures unrelated to T003). Co-Authored-By: Claude Sonnet 4.6 --- mlir/lib/Util/Dependency.cpp | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index 091fbd191..a26fa449c 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -983,16 +983,14 @@ static SmallPtrSet collectHerdBodyOpsToKeep(Block &herdBody) { toKeep.insert(&op); } // Expand: add ops whose results are consumed by any kept op (within block). - bool changed = true; - while (changed) { - changed = false; - for (Operation *op : - SmallVector(toKeep.begin(), toKeep.end())) { - for (Value operand : op->getOperands()) { - if (Operation *defOp = operand.getDefiningOp()) { - if (defOp->getBlock() == &herdBody && toKeep.insert(defOp).second) - changed = true; - } + // Use a worklist to avoid redundant re-processing of already-kept ops. + SmallVector worklist(toKeep.begin(), toKeep.end()); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + for (Value operand : op->getOperands()) { + if (Operation *defOp = operand.getDefiningOp()) { + if (defOp->getBlock() == &herdBody && toKeep.insert(defOp).second) + worklist.push_back(defOp); } } } @@ -1200,9 +1198,13 @@ LogicalResult loopUnrollFullWithAsyncTokenPreserved( // For shim-level loops containing segments or herds, use lightweight // unrolling to avoid O(N * herd_body_size) deep-clone cost. bool containsSegmentOrHerd = false; - forOp.walk([&](air::SegmentOp) { containsSegmentOrHerd = true; }); - if (!containsSegmentOrHerd) - forOp.walk([&](air::HerdOp) { containsSegmentOrHerd = true; }); + forOp->walk([&](Operation *op) -> WalkResult { + if (isa(op)) { + containsSegmentOrHerd = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); if (!annotateFn && containsSegmentOrHerd) { if (failed(manuallyUnrollForOpLightweight(forOp)))