Skip to content

Commit f5f26c4

Browse files
erwei-xilinxclaude
andcommitted
Skip runtime loop unrolling in air-opt-shim-dma-bds for all-1 tile sizes
When shim-dma-tile-sizes is empty or all-1 (the default aircc path with --air-runtime-loop-tiling-sizes=1,1), skip the tiling, unrolling, and BD folding for runtime loops inside dummyLaunch ops with non-trivial trip counts. The launch is still converted to scf.for + dummyLaunch (needed for affine symbol validity), and the scf.for loops are preserved through air-to-std and unrolled later in airrt-to-npu after removeDeadDeviceComputeOps strips the heavy segment/herd bodies. BD folding is skipped because AIRUnrollScfForIntoBDChain would otherwise unroll the runtime loops, defeating the optimization. The channel ops already have valid wraps/strides from earlier passes (air-dma-to-channel). The fast path only applies when: - Tile sizes are empty or all-1 (no useful tiling to perform) - The scf.for loops are inside a dummyLaunch (from launch conversion) - The loops have trip count > 1 (trivial loops still use normal BD folding) Loops directly in functions (not from launch conversion) are unaffected. The air.launch_end barrier depends on scf.for result tokens and any top-level channel ops, ensuring proper async dependency tracking. Profiling on flash attention (12 heads, 1024 LQ/LK, NPU1): air-opt-shim-dma-bds: 1,892 ms -> 28 ms (67x faster) Total MLIR passes: 3,432 ms -> 664 ms (5.2x faster) Total aircc: 6,400 ms -> 3,581 ms (1.8x faster) IR size after pass: 2,922 KB -> 226 KB (13x smaller) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 221ae39 commit f5f26c4

4 files changed

Lines changed: 176 additions & 0 deletions

File tree

mlir/lib/Conversion/AIRRtToNpuPass.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,25 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
21622162
return;
21632163
}
21642164
});
2165+
if (!containsOnlyWaitAll) {
2166+
// Convert non-trivial scf.parallel to scf.for so that
2167+
// unrollSCFFors can handle them. This is needed when the shim BD
2168+
// pass preserves runtime loops (skip-unroll optimization).
2169+
if (par_op.getNumLoops() == 1) {
2170+
IRRewriter rewriter(par_op->getContext());
2171+
rewriter.setInsertionPoint(par_op);
2172+
auto forOp = scf::ForOp::create(
2173+
rewriter, par_op.getLoc(), par_op.getLowerBound()[0],
2174+
par_op.getUpperBound()[0], par_op.getStep()[0]);
2175+
IRMapping mapper;
2176+
mapper.map(par_op.getInductionVars()[0], forOp.getInductionVar());
2177+
rewriter.setInsertionPointToStart(forOp.getBody());
2178+
for (auto &op : par_op.getBody()->without_terminator())
2179+
rewriter.clone(op, mapper);
2180+
rewriter.eraseOp(par_op);
2181+
}
2182+
continue;
2183+
}
21652184
builder.setInsertionPoint(par_op);
21662185
auto newWaitAll = airrt::WaitAllOp::create(
21672186
builder, par_op->getLoc(),

mlir/lib/Transform/AIRDependencyScheduleOpt.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2219,6 +2219,10 @@ struct AIRUnrollScfForIntoBDChain : public OpRewritePattern<scf::ForOp> {
22192219

22202220
LogicalResult matchAndRewrite(scf::ForOp for_op,
22212221
PatternRewriter &rewriter) const override {
2222+
// Skip runtime loops marked for deferred unrolling in airrt-to-npu.
2223+
if (for_op->hasAttr("air.runtime_loop"))
2224+
return failure();
2225+
22222226
// Check if the loop contains only air.channel.put/get ops, or pure ops.
22232227
auto containsOnlyAIRChannels = [](Block *block) {
22242228
if (block->getOperations().empty())
@@ -6142,6 +6146,9 @@ struct AIRLaunchToScfForPattern : public OpRewritePattern<air::LaunchOp> {
61426146
for (unsigned i = 0; i < lbs.size(); i++) {
61436147
auto scfFor =
61446148
scf::ForOp::create(rewriter, loc, lbs[i], ubs[i], steps[i], iterArgs);
6149+
// Mark as runtime loop so BD folding's AIRUnrollScfForIntoBDChain
6150+
// skips it (preserving it for later unrolling in airrt-to-npu).
6151+
scfFor->setAttr("air.runtime_loop", BoolAttr::get(context, true));
61456152
if (i != 0 && scfFor->getNumResults())
61466153
scf::YieldOp::create(rewriter, loc, scfFor->getResults());
61476154
iterArgs.clear();
@@ -6235,6 +6242,23 @@ class AIROptimizeShimDMABDs
62356242
return;
62366243
}
62376244

6245+
// When tile sizes are empty or all-1, skip tiling and unrolling. Tiling by
6246+
// 1 followed by full unrolling is equivalent to full unrolling of the
6247+
// runtime loop, which creates N copies of the entire launch body (including
6248+
// segment/herd/channel ops). This is wasteful because the BD folding only
6249+
// operates on the ~16 L3 channel ops per iteration, while the ~700 lines
6250+
// of segment/herd bodies are dead weight stripped later by airrt-to-npu.
6251+
//
6252+
// We still convert air.launch to scf.for + dummyLaunch (needed so that
6253+
// launch IVs are valid as affine symbols in downstream passes), but skip
6254+
// the tiling/unrolling. The scf.for loops survive through air-to-std and
6255+
// are unrolled later in airrt-to-npu AFTER removeDeadDeviceComputeOps
6256+
// strips the heavy segment/herd bodies, yielding O(16) ops per iteration
6257+
// instead of O(700).
6258+
bool allTileSizesAreOne =
6259+
!clTileSizes.empty() &&
6260+
llvm::all_of(clTileSizes, [](unsigned s) { return s == 1; });
6261+
62386262
// Convert air.launch to scf.for.
62396263
RewritePatternSet patterns(ctx);
62406264
patterns.insert<AIRLaunchToScfForPattern>(ctx);
@@ -6254,6 +6278,57 @@ class AIROptimizeShimDMABDs
62546278
applyAIRL3DmaFoldingPatterns(func, *device);
62556279
return;
62566280
}
6281+
// Check if there are runtime loops from launch conversion (inside a
6282+
// dummyLaunch) with non-trivial trip count. Only skip tiling/unrolling
6283+
// for these — they cause O(N) IR explosion when unrolled. Loops directly
6284+
// in functions (not from launch conversion) still need BD folding.
6285+
bool hasNonTrivialLaunchLoop = llvm::any_of(shimFors, [](scf::ForOp f) {
6286+
auto tc = air::getStaticScfForTripCountAsInt(f);
6287+
if (!tc || *tc <= 1)
6288+
return false;
6289+
auto parentLaunch = f->getParentOfType<air::LaunchOp>();
6290+
return parentLaunch && parentLaunch->hasAttr("dummyLaunch");
6291+
});
6292+
if ((clTileSizes.empty() || allTileSizesAreOne) &&
6293+
hasNonTrivialLaunchLoop) {
6294+
// Skip tiling and unrolling. The runtime scf.for loops survive through
6295+
// air-to-std and are unrolled in airrt-to-npu after dead device compute
6296+
// ops (segment/herd bodies) are stripped, making unrolling much cheaper.
6297+
// BD folding still runs (isolation + specialize), but the runtime loops
6298+
// are protected by the "air.runtime_loop" attribute which causes
6299+
// AIRUnrollScfForIntoBDChain to skip them.
6300+
applyAIRL3DmaFoldingPatterns(func, *device);
6301+
// Generate air.launch_end barriers. Collect async tokens from
6302+
// top-level channel ops (isolated by BD folding) and scf.for results
6303+
// (which carry the async dependency from runtime loop iterations).
6304+
IRRewriter rw(ctx);
6305+
SmallVector<Block *> funcAndLaunchBlocks(1, &func.getBody().front());
6306+
func.walk([&funcAndLaunchBlocks](air::LaunchOp launch) {
6307+
if (air::isAsyncOp(launch))
6308+
funcAndLaunchBlocks.push_back(&launch.getRegion().front());
6309+
});
6310+
for (auto blk : funcAndLaunchBlocks) {
6311+
OpBuilder::InsertionGuard guard(rw);
6312+
SmallVector<Value> asyncTokens;
6313+
for (auto chan : blk->getOps<air::ChannelInterface>())
6314+
if (air::isAsyncOp(chan))
6315+
asyncTokens.push_back(air::getAsyncTokenFromOp(chan));
6316+
for (auto forOp : blk->getOps<scf::ForOp>())
6317+
for (auto result : forOp->getResults())
6318+
if (isa<air::AsyncTokenType>(result.getType()))
6319+
asyncTokens.push_back(result);
6320+
6321+
if (blk->mightHaveTerminator())
6322+
rw.setInsertionPoint(blk->getTerminator());
6323+
else
6324+
rw.setInsertionPointToEnd(blk);
6325+
auto launchEndWaitAll =
6326+
air::WaitAllOp::create(rw, rw.getUnknownLoc(),
6327+
/*result_type*/ Type(), asyncTokens);
6328+
launchEndWaitAll->setAttr("air.launch_end", rw.getUnitAttr());
6329+
}
6330+
return;
6331+
}
62576332
// Helper function converting a vector of unsigned int to a vector of Value.
62586333
auto convertVecOfIntToVecOfValue = [](OpBuilder &b,
62596334
SmallVector<unsigned> clTileSizes) {

mlir/lib/Util/Dependency.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,9 @@ scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter,
801801
if (auto attr =
802802
for_op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
803803
new_for_op->setAttr(SymbolTable::getSymbolAttrName(), attr);
804+
// Propagate air.runtime_loop attribute so BD folding patterns skip it.
805+
if (for_op->hasAttr("air.runtime_loop"))
806+
new_for_op->setAttr("air.runtime_loop", rewriter.getUnitAttr());
804807
remap.map(for_op.getInductionVar(), new_for_op.getInductionVar());
805808
remap.map(getLoopCarriedTokenFromScfOp(for_op, "argument"),
806809
getLoopCarriedTokenFromScfOp(new_for_op, "argument"));
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//===- opt_shim_dma_bds_skip_unroll.mlir ------------------------*- MLIR -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=1,1" | FileCheck %s --check-prefix=TILE1
9+
// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1 shim-dma-tile-sizes=2" | FileCheck %s --check-prefix=TILE2
10+
// RUN: air-opt %s -air-opt-shim-dma-bds="device=npu1" | FileCheck %s --check-prefix=NOTILE
11+
12+
// Test: all-1 tile sizes and empty tile sizes skip tiling and unrolling of the
13+
// runtime loop while still running BD folding (which isolates channel ops and
14+
// folds inner loops into BD dimensions). The runtime scf.for (marked with
15+
// air.runtime_loop) is preserved. Non-trivial tile sizes still tile and unroll.
16+
17+
// TILE1-LABEL: func @multi_iter_with_segment
18+
// With tile-sizes=1,1: BD folding runs (isolates channel.put out of the loop),
19+
// but runtime scf.for is preserved (not unrolled by AIRUnrollScfForIntoBDChain).
20+
// TILE1: air.launch async () in ()
21+
// TILE1-SAME: dummyLaunch
22+
// TILE1: %[[PUT:.*]] = air.channel.put async
23+
// TILE1: %[[FOR_RESULT:.*]] = scf.for
24+
// TILE1: air.segment @seg
25+
// TILE1: air.channel.get
26+
// TILE1: scf.yield
27+
// TILE1: air.wait_all [%[[PUT]], %[[FOR_RESULT]]] {air.launch_end}
28+
29+
// TILE2-LABEL: func @multi_iter_with_segment
30+
// With tile-sizes=2: outer loop (trip=2) is unrolled into 2 copies.
31+
// Each copy has an inner scf.for (trip=2) with segment + channel.get.
32+
// TILE2: air.launch async () in ()
33+
// TILE2-SAME: dummyLaunch
34+
// TILE2: air.channel.put async
35+
// TILE2: scf.for
36+
// TILE2: air.segment @seg
37+
// TILE2: air.channel.get
38+
// TILE2: scf.yield
39+
// TILE2: air.wait_all {{.*}} {air.launch_end}
40+
// The second unrolled copy:
41+
// TILE2: air.channel.put async
42+
// TILE2: scf.for
43+
// TILE2: air.segment @seg
44+
// TILE2: air.channel.get
45+
// TILE2: scf.yield
46+
// TILE2: air.wait_all {{.*}} {air.launch_end}
47+
48+
// NOTILE-LABEL: func @multi_iter_with_segment
49+
// No tile sizes: same behavior as all-1 (fast path).
50+
// NOTILE: air.launch async () in ()
51+
// NOTILE-SAME: dummyLaunch
52+
// NOTILE: %[[NT_PUT:.*]] = air.channel.put async
53+
// NOTILE: %[[NT_FOR:.*]] = scf.for
54+
// NOTILE: air.segment @seg
55+
// NOTILE: air.channel.get
56+
// NOTILE: scf.yield
57+
// NOTILE: air.wait_all [%[[NT_PUT]], %[[NT_FOR]]] {air.launch_end}
58+
59+
module {
60+
air.channel @input_ch [1, 1]
61+
air.channel @output_ch [1, 1]
62+
func.func @multi_iter_with_segment(%arg0: memref<256x64xbf16>, %arg1: memref<256x64xbf16>) {
63+
%c4 = arith.constant 4 : index
64+
%0 = air.launch async (%arg2) in (%arg3=%c4) args(%arg4=%arg0, %arg5=%arg1) : memref<256x64xbf16>, memref<256x64xbf16> {
65+
%c0 = arith.constant 0 : index
66+
%c1 = arith.constant 1 : index
67+
%c64 = arith.constant 64 : index
68+
%1 = affine.apply affine_map<()[s0] -> (s0 * 4096)>()[%arg2]
69+
%2 = air.channel.put async @input_ch[%c0, %c0] (%arg4[%c0, %1] [%c64, %c64] [%c64, %c1]) {metadata = @airMemcpyId1} : (memref<256x64xbf16>)
70+
%3 = air.segment @seg async {
71+
%alloc = memref.alloc() : memref<64x64xbf16, 1>
72+
memref.dealloc %alloc : memref<64x64xbf16, 1>
73+
}
74+
%4 = air.channel.get async [%3] @output_ch[%c0, %c0] (%arg5[%c0, %1] [%c64, %c64] [%c64, %c1]) {metadata = @airMemcpyId2} : (memref<256x64xbf16>)
75+
%5 = air.wait_all async [%2, %4]
76+
}
77+
return
78+
}
79+
}

0 commit comments

Comments
 (0)