Skip to content

Commit f9e8a6e

Browse files
erwei-xilinxclaude
andauthored
Three follow-up fixes for channel_type="mmio" multi-tile / non-i32 use (#1573)
* Skip mmio channels in air-dma-to-channel shim-pressure auto-upgrade The shim-pressure heuristic in AIRDmaToChannel auto-upgrades L3-bound channels to dma_packet when their per-column count exceeds the shim S2MM/MM2S limit. mmio channels are runtime-sequence MMIO writes, not shim DMA, so they neither contribute to pressure nor are eligible for the upgrade. Counting them in the pressure check force-flipped their type tag to dma_packet, which destroyed the mmio lowering before it ran. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Skip mmio channels in specializeChannelBundle specializeChannelBundle splits a `[N]`-sized channel into N single-position channels and rewrites all matching put/get ops inside the device. For mmio channels with a multi-tile herd, this left the host-side puts orphaned on the original bundle symbol — they sit outside the device, where this pattern's rewrites don't reach — while the per-tile gets had been moved to new specialized channel names. The mmio lowering then saw N hostPuts on the bundled channel with 0 matching device-side gets and emitted "channel_type=\"mmio\" put has no matching device-side air.channel.get". Skip mmio channels here and let lowerAIRMMIOChannelOps match host-side puts to per-core gets directly across the full bundle by constant index — the same path it already takes for single-tile mmio. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Repack non-i32 mmio source globals as memref<Nxi32> for blockwrite aiex.npu.blockwrite's translator only handles 32-bit element types ("Only 32-bit data type is supported for now"); a bf16 source warned then segfaulted in AIETranslateNpuToBinary. The destination buffer type is irrelevant on the wire — `buffer = @sym` is just a symbol ref — so the fix is local to the data side. When the original memref.global isn't already i32-typed, mirror it into the device as a 1-D memref<Nxi32> with the same raw bytes (suffixed `_mmio_i32`) and reference that from the blockwrite. The original global is kept undisturbed for any other uses. Splat attributes are expanded to a full byte buffer before reinterpretation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Factor mmio repack lowering, drop duplication, expand tests The non-i32 mmio repack path in lowerAIRMMIOChannelOps had grown a ~120-line inline block with three issues review surfaced: - The repack-to-i32 byte transform was rebuilt as a stack lambda inside the per-put loop and called twice with identical inputs (once for the module-scope mirror, once for the in-device clone). - The two memref.global creation paths used inconsistent builders; the in-device path went through the rewriter then called cloned->remove() to undo the rewriter's insertion. - A `(void)modI32;` swallowed the return of the module-scope create because the op was found later by symbol lookup, not by the local binding. Lift the byte transform to a file-scope `repackAsI32Bytes` static helper, compute the repacked DenseElementsAttr once, and use a fresh OpBuilder for both mirror creations so the rewriter-detach hack is gone. Reject collisions where the suffixed `_mmio_i32` symbol already exists at module scope and isn't itself an `air.mmio_global`. Test coverage: - Tighten the splat bf16 test to assert the exact repacked value (`dense<1069563840>` = two bf16 1.5 packed into one i32) so the splat-expansion bytes are checked, not just the type. - Add a non-splat bf16 case (`@bf16_nonsplat`) that exercises the raw-buffer copy branch and asserts both packed i32 values. - Add an invalid case for a byte-aligned element type whose total payload size isn't a multiple of 4 bytes (memref<3xbf16> = 6 bytes), asserting the new diagnostic. No functional change for callers — same diagnostics, same emitted IR. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Preserve source alignment on the mmio i32 mirror globals Both `memref::GlobalOp::create` calls in the repack path passed `IntegerAttr{}` for the alignment argument, dropping any explicit `alignment = N : i64` attribute carried by the source `memref.global`. The non-repack path preserves alignment for free via `clone()`. Forward `moduleGlobal.getAlignmentAttr()` to both the module-scope and in-device mirrors so the two paths agree, and extend the splat bf16 lit case to assert the round-trip. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Use SymbolTable lookup instead of device.walk for in-device global AIE::DeviceOp carries the SymbolTable trait, so the linear walk to find the in-device mirror by name is an O(N) operation where O(log N) suffices. Switch to `SymbolTable::lookupSymbolIn(device, cloneName)`, matching the lookup style used for the module-side global a few lines above (and elsewhere in this file, e.g. line 4554). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Mark mmio repack path as removable once upstream blockwrite gains non-i32 Add a TODO at the top of the repack block pointing at the two mlir-aie sites that enforce the 32-bit-only payload restriction (`NpuBlockWriteOp::getDataWords` in AIEXDialect.cpp and the analogous warning in AIETxnToControlPacket.cpp). When those learn to handle non-i32 element types, the entire repack path and its `_mmio_i32` mirror globals can be deleted as dead code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Add i8 and i16 mmio repack lit cases The repack helper handles any positive-multiple-of-8 element bitwidth, but coverage was bf16-only. Add two more positive cases that exercise different strides through repackAsI32Bytes: - i8 splat (4xi8 = dense<66>): bytesPerElt=1, splat-expansion loop copies a single byte per iteration; result packs to one i32 = 0x42424242 = 1111638594. - i16 non-splat (2xi16 = {0x1234, 0xABCD}): pure-int storage path through the wholesale-copy branch; LE byte stream {34, 12, CD, AB} packs to one signed i32 0xABCD1234 = -1412623820. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Reject mmio repack on uninitialized memref.global with a clean diagnostic The repack path called `*moduleGlobal.getInitialValue()` and cast the result to DenseElementsAttr without checking either step. A pure declaration (no `= dense<...>` initializer) returns `std::nullopt` and crashed via the optional dereference assertion; an uninitialized definition (UnitAttr initializer) would have crashed the cast. Guard both up front: bind `getInitialValue()` to an optional, then `dyn_cast<DenseElementsAttr>` and reject with a clean op error if either fails. Add a lit case for the uninitialized-global path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Name the symbol-dce site that cleans up mmio module-scope globals The two AIRToAIEPass.cpp comments and the matching test-file note referenced "symbol-dce" generically, leaving a future debugger to guess where the cleanup actually fires. The concrete site is the `symbol-dce` pass invoked twice in the NPU pipeline at tools/aircc/aircc.cpp:1117 and 1123 — once before and once after `airrt-to-npu` — which already carries an inline note about dropping mmio-orphaned globals. Update the three comments to name that file/pass so the load-bearing dependency is greppable from either side. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Trim verbose mmio comments Comments across the mmio lowering and its lit cases had grown into multi-paragraph doc blocks. Compress each to one or two lines — the load-bearing "why" stays, but the prose explanations move to the PR description. - AIRToAIEPass.cpp: shrink the launch-hoist note, the IsolatedFromAbove mirror rationale, the V1 limitation reject, the i32-only repack preamble + TODO, and the post-blockwrite cleanup note. Helper docstring on `repackAsI32Bytes` also tightened. - air_channel_mmio.mlir: drop the bullet-list file header and the multi-line preambles on each split (simple/mixed/bcast/indexed/ bf16/bf16ns/i8/i16). Each case now reads as a single sentence above its CHECK lines. - air_channel_mmio_invalid.mlir: same treatment for the negative cases. No functional change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Stamp mmio source onto destination buffer's initial_value, drop blockwrite The runtime-sequence blockwrite path for channel_type="mmio" had a host↔core race that the V1 design didn't address: aiex.configure enables the cores during CDO load, before the runtime sequence's blockwrites run. A core that reads its mmio destination buffer before any lock-gated DMA acquire races the host writes — the existing mmio_simple lit test only happens to dodge this because the core's first action is a lock acquire on a shim DMA that the runtime sequence fires after the blockwrite, by which time the blockwrite has long completed. Without that lucky ordering (or in any non-trivial example that reads the mmio buffer before the first DMA), the data read is undefined. Stamp the source memref.global's initializer onto the destination L1 aie.buffer's initial_value attribute instead. AIERTControl::initBuffers already loads buffer initial_values into the tile via XAie_DataMemBlockWrite at device-init time — before any core starts — which makes the data delivery race-free relative to core execution. Side benefits: * Obsoletes the i32 repack hack (and its sub-byte / non-multiple-of-4 guards): XAie_DataMemBlockWrite handles APInt and APFloat element types natively, so the destination buffer can carry its native bf16 / i8 / i16 / etc. initializer. * Drops the module-level mirror, get_global hoisting, and the symbol-dce dance in aircc that protected against orphan-global collisions in LLVM lowering — none of those exist anymore. * Simpler V1 invariants: source/dest must agree on element type and count (the buffer's natural shape constrains the initializer); a given destination buffer can have at most one mmio source. Net -184 lines. All 379 existing check-air-mlir tests still pass. On NPU2 hardware, the existing mmio_simple programming example still PASSes, and an [NKV=8, NS=1] decode-attention prototype with Q delivered via mmio (mocking cascade-Q from #1565) PASSes with correlation 0.999655 against the NumPy reference — without any defer-the-read workaround. The "variable-data MMIO via bo_instr patching" V2 plan (host-loaded data, separate from compile-time constants) is unaffected: when it arrives it can re-introduce blockwrite with a proper sync mechanism (tile lock that the host releases after blockwrite completion). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent afd381f commit f9e8a6e

5 files changed

Lines changed: 293 additions & 179 deletions

File tree

mlir/lib/Conversion/AIRToAIEPass.cpp

Lines changed: 69 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -2536,6 +2536,15 @@ struct SpecializeChannelBundlePattern
25362536
if (channel.getBundleSize() <= 1)
25372537
return failure();
25382538

2539+
// mmio channels are handled directly by lowerAIRMMIOChannelOps, which
2540+
// matches host-side puts to per-core gets by constant index across the
2541+
// full bundle. Splitting the bundle here would orphan the original
2542+
// host-side puts (they sit outside the device, where this pattern's
2543+
// rewrites don't reach), leaving them to fail later as
2544+
// "no matching device-side air.channel.get".
2545+
if (channel.getChannelType() == "mmio")
2546+
return failure();
2547+
25392548
std::vector<air::ChannelPutOp> channelPuts =
25402549
getChannelPutOpThroughSymbol(channel, device);
25412550
std::vector<air::ChannelGetOp> channelGets =
@@ -2950,14 +2959,9 @@ static void removeDeadGlobalOps(AIE::DeviceOp device) {
29502959
device.walk(
29512960
[&](memref::GetGlobalOp op) { referencedGlobals.insert(op.getName()); });
29522961

2953-
// Erase unreferenced memref.global declarations, but preserve any
2954-
// tagged with `air.mmio_global` — those are MMIO-channel mirrors of
2955-
// module-level globals whose runtime_sequence get_global users are
2956-
// synthesized later in the pipeline.
2962+
// Erase unreferenced memref.global declarations.
29572963
SmallVector<memref::GlobalOp> deadGlobals;
29582964
for (auto globalOp : device.getOps<memref::GlobalOp>()) {
2959-
if (globalOp->hasAttr("air.mmio_global"))
2960-
continue;
29612965
if (!referencedGlobals.contains(globalOp.getSymName()))
29622966
deadGlobals.push_back(globalOp);
29632967
}
@@ -5809,11 +5813,20 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
58095813
return failure();
58105814
}
58115815

5812-
// For each L3-side put, find every matching get and emit one
5813-
// blockwrite per destination buffer. Match rule:
5816+
// For each L3-side put, find every matching get and stamp the
5817+
// source data onto the destination L1 buffer's `initial_value`
5818+
// attribute. Match rule:
58145819
// * non-broadcast: indices must be constant-equal between put/get;
58155820
// * broadcast: every device-side get on this channel matches every
58165821
// put (one put fans out to all destinations).
5822+
//
5823+
// The aie.buffer initial_value is loaded into the tile by
5824+
// AIERTControl::initBuffers (XAie_DataMemBlockWrite) at device-init
5825+
// time — before any core starts. This eliminates the host↔core
5826+
// race that would arise from placing the data delivery in the
5827+
// runtime sequence (where blockwrites would race CDO-started cores).
5828+
// It also handles bf16/other float types natively, so no i32
5829+
// repack is required.
58175830
for (auto put : hostPuts) {
58185831
Value src = put.getMemref();
58195832
memref::GetGlobalOp getGlobalOp = getSourceGlobal(src);
@@ -5822,6 +5835,26 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
58225835
"channel_type=\"mmio\" put requires source memref defined by "
58235836
"memref.get_global of a constant memref.global");
58245837

5838+
StringAttr origName = getGlobalOp.getNameAttr().getAttr();
5839+
Operation *moduleOp = device->getParentOp();
5840+
while (moduleOp && !isa<ModuleOp>(moduleOp))
5841+
moduleOp = moduleOp->getParentOp();
5842+
auto moduleGlobal = dyn_cast_if_present<memref::GlobalOp>(
5843+
moduleOp ? SymbolTable::lookupSymbolIn(moduleOp, origName)
5844+
: nullptr);
5845+
if (!moduleGlobal)
5846+
return getGlobalOp.emitOpError(
5847+
"channel_type=\"mmio\" lowering: cannot find memref.global "
5848+
"for the put source at module scope");
5849+
5850+
auto initOpt = moduleGlobal.getInitialValue();
5851+
auto initDense =
5852+
initOpt ? dyn_cast<DenseElementsAttr>(*initOpt) : nullptr;
5853+
if (!initDense)
5854+
return put.emitOpError(
5855+
"channel_type=\"mmio\" source memref.global must have a "
5856+
"DenseElementsAttr initializer");
5857+
58255858
unsigned matchCount = 0;
58265859
for (auto get : deviceGets) {
58275860
if (!isBcast && !sameConstIndices(put.getIndices(), get.getIndices()))
@@ -5832,105 +5865,37 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
58325865
"channel_type=\"mmio\" get destination does not resolve to "
58335866
"an aie.buffer (must be an L1 allocation)");
58345867

5835-
auto bufSymOpt = bufferOp.getSymName();
5836-
if (!bufSymOpt)
5868+
// Element type and total element count must match between source
5869+
// and destination so the DenseElementsAttr is valid for the
5870+
// buffer's memref type.
5871+
auto bufMemTy = bufferOp.getType();
5872+
auto srcMemTy = cast<MemRefType>(getGlobalOp.getType());
5873+
if (bufMemTy.getElementType() != srcMemTy.getElementType())
58375874
return get.emitOpError(
5838-
"channel_type=\"mmio\" get destination aie.buffer has no "
5839-
"sym_name; cannot reference from blockwrite");
5840-
// The blockwrite must end up OUTSIDE any enclosing air.launch:
5841-
// later passes (AIROptimizeShimDMABDs / AIRLaunchToScfForPattern)
5842-
// recreate a fresh dummy launch and only carry forward channel
5843-
// ops, dropping any other ops that happened to live in the
5844-
// launch body. Hoisting the blockwrite to the func.func level
5845-
// also forces us to hoist a clone of the source
5846-
// `memref.get_global` so SSA dominance is preserved when the
5847-
// original get_global lives inside the launch.
5848-
air::LaunchOp outerLaunch = put->getParentOfType<air::LaunchOp>();
5849-
Operation *insertionAnchor =
5850-
outerLaunch ? outerLaunch.getOperation() : put.getOperation();
5851-
rewriter.setInsertionPoint(insertionAnchor);
5852-
5853-
// The blockwrite carries a constant payload that must remain
5854-
// resolvable from inside `aie.runtime_sequence` (which lives
5855-
// inside `aie.device`) after AIRRtToNpuPass wraps the func.
5856-
// `aie.device` is `IsolatedFromAbove`, so a get_global inside
5857-
// it cannot reference a `memref.global` at module scope.
5858-
// Mirror the global into the device under the SAME name
5859-
// (different SymbolTables admit identical names — vanilla
5860-
// MLIR verifiers accept this), then erase the module-level
5861-
// original. The aircc LLVM-lowering pipeline promotes
5862-
// memref.global to llvm.mlir.global at module scope, where
5863-
// duplicates collide; deleting the module-level original
5864-
// keeps the in-device copy as the unique source of truth.
5865-
StringAttr origName = getGlobalOp.getNameAttr().getAttr();
5866-
Operation *moduleOp = device->getParentOp();
5867-
while (moduleOp && !isa<ModuleOp>(moduleOp))
5868-
moduleOp = moduleOp->getParentOp();
5869-
auto moduleGlobal = dyn_cast_if_present<memref::GlobalOp>(
5870-
moduleOp ? SymbolTable::lookupSymbolIn(moduleOp, origName)
5871-
: nullptr);
5872-
if (!moduleGlobal)
5873-
return getGlobalOp.emitOpError(
5874-
"channel_type=\"mmio\" lowering: cannot find memref.global "
5875-
"for the put source at module scope");
5876-
5877-
// V1 limitation: the in-device mirror is cloned under the same
5878-
// sym_name, so a later symbol-dce of the module-level original
5879-
// is required to avoid a duplicate-symbol collision in LLVM
5880-
// lowering. That requires no users to survive outside the
5881-
// func that becomes the runtime_sequence and moves into the
5882-
// device. Reject other module-scope users loudly.
5883-
auto putFunc = put->getParentOfType<func::FuncOp>();
5884-
auto uses = SymbolTable::getSymbolUses(origName, moduleOp);
5885-
if (uses && llvm::any_of(*uses, [&](SymbolTable::SymbolUse u) {
5886-
return u.getUser()->getParentOfType<func::FuncOp>() != putFunc;
5887-
}))
5888-
return getGlobalOp.emitOpError(
5889-
"channel_type=\"mmio\" V1 requires the source memref.global "
5890-
"to be used only inside the func containing the put");
5891-
5892-
memref::GlobalOp inDevGlobal;
5893-
device.walk([&](memref::GlobalOp g) {
5894-
if (g.getSymName() == origName.getValue())
5895-
inDevGlobal = g;
5896-
});
5897-
if (!inDevGlobal) {
5898-
auto cloned = cast<memref::GlobalOp>(moduleGlobal->clone());
5899-
cloned->removeAttr(SymbolTable::getVisibilityAttrName());
5900-
cloned->setAttr("air.mmio_global",
5901-
UnitAttr::get(rewriter.getContext()));
5902-
// Place at the very top of the device body so it dominates
5903-
// the runtime_sequence created later. push_front bypasses
5904-
// SymbolTable::insert renaming logic.
5905-
device.getBody()->getOperations().push_front(cloned);
5906-
inDevGlobal = cloned;
5907-
}
5908-
5909-
auto hoistedGG = memref::GetGlobalOp::create(
5910-
rewriter, getGlobalOp.getLoc(),
5911-
cast<MemRefType>(getGlobalOp.getResult().getType()),
5912-
FlatSymbolRefAttr::get(rewriter.getContext(),
5913-
inDevGlobal.getSymNameAttr()));
5914-
Value dataOperand = hoistedGG.getResult();
5915-
5916-
// After this point the hoisted get_global at module scope
5917-
// resolves to the module-level original; once airrt-to-npu
5918-
// moves the surrounding func into the device, lookup will
5919-
// pick up the in-device copy first. The module-level original
5920-
// is not removed here because there may be a brief window
5921-
// where the func still references it.
5922-
5923-
AIEX::NpuBlockWriteOp::create(
5924-
rewriter, put.getLoc(),
5925-
/*address=*/rewriter.getUI32IntegerAttr(0),
5926-
/*data=*/dataOperand,
5927-
/*buffer=*/
5928-
FlatSymbolRefAttr::get(rewriter.getContext(), *bufSymOpt),
5929-
/*column=*/IntegerAttr{},
5930-
/*row=*/IntegerAttr{});
5875+
"channel_type=\"mmio\" source/destination element type "
5876+
"mismatch (source: ")
5877+
<< srcMemTy.getElementType()
5878+
<< ", destination: " << bufMemTy.getElementType() << ")";
5879+
if (bufMemTy.getNumElements() != srcMemTy.getNumElements())
5880+
return get.emitOpError(
5881+
"channel_type=\"mmio\" source/destination element count "
5882+
"mismatch (source: ")
5883+
<< srcMemTy.getNumElements()
5884+
<< ", destination: " << bufMemTy.getNumElements() << ")";
5885+
5886+
// Reshape the source DenseElementsAttr to match the destination
5887+
// buffer's tensor shape (same bytes, possibly different rank).
5888+
auto bufTensorTy = RankedTensorType::get(bufMemTy.getShape(),
5889+
bufMemTy.getElementType());
5890+
auto reshapedInit = initDense.reshape(bufTensorTy);
5891+
5892+
if (auto existing = bufferOp.getInitialValue())
5893+
return bufferOp.emitOpError(
5894+
"channel_type=\"mmio\" destination aie.buffer already has an "
5895+
"initial_value; cannot stamp two sources into one buffer");
5896+
bufferOp.setInitialValueAttr(reshapedInit);
59315897
++matchCount;
59325898
}
5933-
// The put would otherwise be erased below with no replacement.
59345899
if (matchCount == 0)
59355900
return put.emitOpError("channel_type=\"mmio\" put has no matching "
59365901
"device-side air.channel.get");

mlir/lib/Transform/AIRDmaToChannel.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,12 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase<DmaToChannelPass> {
16031603
if (!chanOp)
16041604
continue;
16051605

1606+
// mmio channels are runtime-sequence MMIO writes, not shim DMA, so
1607+
// they neither contribute to per-column shim pressure nor are
1608+
// eligible for dma_packet upgrade.
1609+
if (chanOp.getChannelType() == "mmio")
1610+
continue;
1611+
16061612
bool isAlreadyPacket = (chanOp.getChannelType() == "dma_packet");
16071613
auto channelName = chanOp.getSymName();
16081614

0 commit comments

Comments
 (0)