Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

#include "air/Util/Dependency.h"
#include "air/Util/Util.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"

#define DEBUG_TYPE "air-dependency-util"
Expand Down Expand Up @@ -965,6 +968,207 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) {
}
}

// Returns the set of ops in herdBody that must be kept when performing a
// lightweight herd clone: channel ops, allocs, deallocations, wait_alls, the
// terminator, and any ops that transitively define operands consumed by those.
static SmallPtrSet<Operation *, 16> collectHerdBodyOpsToKeep(Block &herdBody) {
SmallPtrSet<Operation *, 16> toKeep;
// Always keep the block terminator.
if (!herdBody.empty())
toKeep.insert(herdBody.getTerminator());
// Seed: channel ops, allocs, deallocations, and wait_alls.
for (Operation &op : herdBody.without_terminator()) {
if (isa<air::ChannelInterface, memref::AllocOp, memref::DeallocOp,
air::WaitAllOp>(op))
toKeep.insert(&op);
}
// Expand: add ops whose results are consumed by any kept op (within block).
// Use a worklist to avoid redundant re-processing of already-kept ops.
SmallVector<Operation *> worklist(toKeep.begin(), toKeep.end());
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
for (Value operand : op->getOperands()) {
if (Operation *defOp = operand.getDefiningOp()) {
if (defOp->getBlock() == &herdBody && toKeep.insert(defOp).second)
worklist.push_back(defOp);
}
}
}
return toKeep;
}

// Lightweight clone of air.HerdOp: creates a new herd shell via OperationState
// with the same operands (remapped through mapper), result types, and
// attributes, but populates the body with ONLY channel ops, allocs,
// deallocations, wait_alls, their transitive operand-defining ops, and the
// terminator. Heavy compute ops whose results do not feed any of the above are
// skipped entirely. Updates mapper with result mappings for the new herd.
static void cloneHerdOpLightweight(OpBuilder &builder, air::HerdOp herdOp,
IRMapping &mapper) {
Block &origBody = herdOp.getBody().front();
SmallPtrSet<Operation *, 16> toKeep = collectHerdBodyOpsToKeep(origBody);

// Map operands through the outer mapper.
SmallVector<Value> mappedOperands;
for (Value v : herdOp->getOperands())
mappedOperands.push_back(mapper.lookupOrDefault(v));

// Create the herd shell via OperationState (no body yet).
OperationState state(herdOp.getLoc(), air::HerdOp::getOperationName());
state.addOperands(mappedOperands);
state.addTypes(herdOp->getResultTypes());
state.addAttributes(herdOp->getAttrs());
state.addRegion(); // Placeholder for the body region.
Operation *newHerd = builder.create(state);

// Map original herd results -> new herd results.
for (auto [origRes, newRes] :
llvm::zip(herdOp->getResults(), newHerd->getResults()))
mapper.map(origRes, newRes);

// Create the body block with the same block-argument types as the original.
Block *newBody = new Block();
newHerd->getRegion(0).push_back(newBody);
IRMapping innerMapper;
for (BlockArgument origArg : origBody.getArguments()) {
BlockArgument newArg =
newBody->addArgument(origArg.getType(), origArg.getLoc());
innerMapper.map(origArg, newArg);
}

// Clone only the kept ops in original block order to preserve use-def.
OpBuilder bodyBuilder(newBody, newBody->end());
for (Operation &op : origBody) {
if (toKeep.contains(&op))
bodyBuilder.clone(op, innerMapper);
}
}

// Forward declaration for mutual recursion.
static void cloneBlockBodyLightweight(OpBuilder &builder, Block &srcBlock,
IRMapping &mapper);

// Lightweight clone of air.SegmentOp: clones the segment shell and fully
// clones all ops directly in the segment body (including L3 channel ops needed
// by BD folding), but substitutes lightweight copies for any contained
// air.HerdOp. Updates mapper with segment result mappings.
static void cloneSegmentOpLightweight(OpBuilder &builder, air::SegmentOp segOp,
IRMapping &mapper) {
// Clone segment shell (without regions); this maps operands and results.
Operation *newSeg = segOp->cloneWithoutRegions(mapper);
builder.insert(newSeg);

// Create and populate the segment body block.
Block &origBody = segOp.getBody().front();
Block *newBody = new Block();
newSeg->getRegion(0).push_back(newBody);

// Map block arguments (segment IDs, sizes, segment_operands).
IRMapping innerMapper;
for (BlockArgument origArg : origBody.getArguments()) {
BlockArgument newArg =
newBody->addArgument(origArg.getType(), origArg.getLoc());
innerMapper.map(origArg, newArg);
}

// Clone segment body, using lightweight cloning for any air.HerdOp.
OpBuilder bodyBuilder(newBody, newBody->end());
cloneBlockBodyLightweight(bodyBuilder, origBody, innerMapper);
}

// Clone all ops in srcBlock into the current builder insertion point.
// HerdOps are cloned lightweight (body stripped to channel ops + deps);
// SegmentOps are cloned with lightweight recursion into their bodies;
// all other ops are cloned fully. mapper is updated with result mappings.
static void cloneBlockBodyLightweight(OpBuilder &builder, Block &srcBlock,
IRMapping &mapper) {
for (Operation &op : srcBlock) {
if (auto herdOp = dyn_cast<air::HerdOp>(&op)) {
cloneHerdOpLightweight(builder, herdOp, mapper);
} else if (auto segOp = dyn_cast<air::SegmentOp>(&op)) {
cloneSegmentOpLightweight(builder, segOp, mapper);
} else {
builder.clone(op, mapper);
}
}
}

// Manually unroll a shim-level scf.for that contains air.SegmentOp or
// air.HerdOp, using lightweight cloning for herd bodies to avoid
// O(N * body_size) IR explosion. Inserts N copies of the loop body before the
// forOp, replaces forOp results with the final iteration's yields, and erases
// the forOp. Returns failure if bounds are not statically known.
static LogicalResult manuallyUnrollForOpLightweight(scf::ForOp forOp) {
auto maybeCount = air::getStaticScfForTripCountAsInt(forOp);
if (!maybeCount) {
forOp->emitOpError("lightweight unroll failed: dynamic trip count");
return failure();
}
unsigned tripCount = *maybeCount;

if (tripCount == 0) {
// Loop body never executes: replace results with init args and erase.
for (auto [result, initArg] :
llvm::zip(forOp.getResults(), forOp.getInitArgs()))
result.replaceAllUsesWith(initArg);
forOp.erase();
return success();
}

auto lbConst = getConstantIntValue(forOp.getLowerBound());
auto stepConst = getConstantIntValue(forOp.getStep());
if (!lbConst || !stepConst) {
forOp->emitOpError("lightweight unroll failed: non-constant bounds");
return failure();
}

// Insert new ops immediately before the forOp.
OpBuilder builder(forOp);
auto loc = forOp.getLoc();
Block &loopBody = forOp.getRegion().front();
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());

// prevYields starts as the loop's init args.
SmallVector<Value> prevYields(forOp.getInitArgs().begin(),
forOp.getInitArgs().end());

for (unsigned i = 0; i < tripCount; ++i) {
IRMapping mapper;
// Map the induction variable to a constant for this iteration.
int64_t ivVal = *lbConst + static_cast<int64_t>(i) * *stepConst;
Value ivConst = arith::ConstantIndexOp::create(builder, loc, ivVal);
mapper.map(forOp.getInductionVar(), ivConst);

// Map iter_args to the previous iteration's yielded values.
for (auto [iterArg, prevYield] :
llvm::zip(forOp.getRegionIterArgs(), prevYields))
mapper.map(iterArg, prevYield);

// Clone the loop body ops (excluding the yield terminator) with
// lightweight treatment for air.SegmentOp and air.HerdOp.
for (Operation &op : loopBody.without_terminator()) {
if (auto segOp = dyn_cast<air::SegmentOp>(&op)) {
cloneSegmentOpLightweight(builder, segOp, mapper);
} else if (auto herdOp = dyn_cast<air::HerdOp>(&op)) {
cloneHerdOpLightweight(builder, herdOp, mapper);
} else {
builder.clone(op, mapper);
}
}

// Collect this iteration's yielded values for the next iteration.
prevYields.clear();
for (Value yieldVal : yieldOp.getOperands())
prevYields.push_back(mapper.lookupOrDefault(yieldVal));
}

// Replace forOp results with the final iteration's yields and erase.
for (auto [result, finalYield] : llvm::zip(forOp.getResults(), prevYields))
result.replaceAllUsesWith(finalYield);
forOp.erase();
return success();
}

// Fully unrolls an `scf.for` loop while preserving async token dependencies.
//
// This function labels the operations that define the values yielded by
Expand All @@ -976,6 +1180,13 @@ void preserveAsyncDependenciesAfterUnroll(Block &parentBlock) {
//
// If `annotateFn` is provided, it is passed to `loopUnrollByFactor` for result
// tagging.
//
// For shim-level loops (loops whose bodies contain air.SegmentOp or directly
// contain air.HerdOp), lightweight unrolling is used: herd bodies are NOT
// deep-cloned N times. Instead, only channel ops, allocs, and their transitive
// operand-defining ops are cloned into each iteration's herd copy. This avoids
// O(N * body_size) IR explosion while preserving all channel ops needed by
// downstream passes (BD folding, air-to-std channel matching).
LogicalResult loopUnrollFullWithAsyncTokenPreserved(
scf::ForOp forOp,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
Expand All @@ -984,6 +1195,24 @@ LogicalResult loopUnrollFullWithAsyncTokenPreserved(

Block *parentBlock = forOp->getBlock();

// For shim-level loops containing segments or herds, use lightweight
// unrolling to avoid O(N * herd_body_size) deep-clone cost.
bool containsSegmentOrHerd = false;
forOp->walk([&](Operation *op) -> WalkResult {
if (isa<air::SegmentOp, air::HerdOp>(op)) {
containsSegmentOrHerd = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (!annotateFn && containsSegmentOrHerd) {
if (failed(manuallyUnrollForOpLightweight(forOp)))
return failure();
preserveAsyncDependenciesAfterUnroll(*parentBlock);
return success();
}

// Fully unroll the loop
if (annotateFn) {
auto unroll_factor = air::getStaticScfForTripCountAsInt(forOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1" | FileCheck %s
// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=2,2" | FileCheck %s --check-prefix=NPUTILED
// RUN: air-opt %s -air-opt-shim-dma-bds="device=xcvc1902" | FileCheck %s --check-prefix=AIE1
// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=2,2" | FileCheck %s --check-prefix=LIGHTWEIGHT

// Optimize logical air.channel.put/get op into efficient shim dma block descriptor (BD).

Expand Down Expand Up @@ -954,4 +955,47 @@ module {
}
return
}

// Lightweight unrolling: verify that after unrolling a shim-level scf.for
// containing air.segment > air.herd, channel ops in herd bodies are
// preserved while compute ops with dead results are NOT cloned into the
// lightweight herd copies. The launch has trip count 2; after the pass
// converts it to a scf.for and unrolls with shim-dma-tile-sizes=2,2, two
// lightweight herd copies should appear, each missing the arith.muli.

// LIGHTWEIGHT-LABEL: func_lightweight_unroll
// LIGHTWEIGHT: air.segment
// LIGHTWEIGHT: air.channel.get
// LIGHTWEIGHT-NOT: arith.muli

air.channel @ch_lite [2, 1]

func.func @func_lightweight_unroll(%arg0: memref<256xi32>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg1) in (%arg2=%c2) args(%arg3=%arg0) : memref<256xi32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%1 = air.wait_all async
%2 = air.segment async [%1] {
%c0_0 = arith.constant 0 : index
%c1_0 = arith.constant 1 : index
%c64_0 = arith.constant 64 : index
%3 = air.herd async tile (%tx, %ty) in (%sx=%c1_0, %sy=%c1_0) {
%c0_1 = arith.constant 0 : index
%c1_1 = arith.constant 1 : index
%c64_1 = arith.constant 64 : index
%l1buf = memref.alloc() : memref<64xi32, 2>
%dead = arith.muli %tx, %ty : index
%4 = air.channel.get async @ch_lite[%tx, %ty] (%l1buf[%c0_1] [%c64_1] [%c1_1]) : (memref<64xi32, 2>)
memref.dealloc %l1buf : memref<64xi32, 2>
air.herd_terminator
}
air.segment_terminator
}
%5 = air.channel.put async [%1] @ch_lite[%arg1, %c0] (%arg3[%arg1, %c0] [%c1, %c64] [%c128, %c1]) : (memref<256xi32>)
}
return
}
}
Loading