Skip to content

Commit 92a401f

Browse files
erwei-xilinxclaude
andcommitted
[multi-gpu] Phase 6: air-gpu-channel-to-mgpu lowering pass
New conversion pass that lowers `air.channel` ops of `channel_type = "gpu_symmetric_heap"` plus their put/get pair to host-side `mgpuMemcpy` calls with peer-VA addressing through `mgpuGetHeapBases()`, with `mgpuBarrier`-based synchronization. Per channel: - put becomes `mgpuBarrier()` (publish — the data is already in the symmetric heap via the put's `air.symmetric` source memref) - get becomes `mgpuBarrier()` followed by `mgpuMemcpy(dst, peer_va(src), sz)` where the peer rank is the get's first index operand - The channel symbol itself is erased This makes `air.channel` of type `gpu_symmetric_heap` syntactic sugar over cross-rank DMA, with the additional benefit of decoupling the producer site (where put appears) from the consumer site (where get appears) via the channel symbol. ## Restrictions (initial version) - One put and one get per channel symbol - Both at host scope (no `gpu.launch`/`gpu.func`) - put's source memref must be `air.symmetric`-tagged - "Entire memref" form on both sides (no offsets/sizes/strides) - get must take exactly one index operand (the peer rank) ## Files - `mlir/include/air/Conversion/AIRGpuChannelToMgpuPass.h` — header - `mlir/include/air/Conversion/GPUPasses.td` — pass def - `mlir/include/air/Conversion/GPUPassDetail.h` — `GEN_PASS_DEF_AIRGPUCHANNELTOMGPU` - `mlir/lib/Conversion/AIRGpuChannelToMgpuPass.cpp` — implementation - `mlir/lib/Conversion/{CMakeLists.txt,Passes.cpp}` — registration - `mlir/test/Conversion/AIRGpuChannelToMgpu/gpu_channel.mlir` — FileCheck - `test/gpu/symmetric_heap_dma/air_sym_with_channel.mlir` — high-level e2e - `test/gpu/symmetric_heap_dma/run.sh` — adds `INPUT=channel` selector ## Test plan FileCheck unit tests cover: - Basic put/get pair lowering shape (barrier + mgpuMemcpy with peer-VA) - Channel symbol is erased after lowering - Pass is a no-op for non-`gpu_symmetric_heap` channels (e.g., `npu_*`) End-to-end on rad-mi300a-sh5-1 (SHARE_GPU=1, 2 ranks): - INPUT=handwritten — PASS - INPUT=rank — PASS - INPUT=alloc — PASS - INPUT=dma — PASS - INPUT=channel — PASS (chains Phase 6 -> Phase 4 -> Phase 3 -> standard LLVM) Both ranks publish their src_buf via channel.put, then read rank 0's slot via channel.get. Verification reads back 1.0. Same SHARE_GPU=1 single-physical-GPU caveat as previous PRs in the stack — true multi-GPU re-validation is needed before declaring multi-GPU production-ready. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 71efe72 commit 92a401f

9 files changed

Lines changed: 538 additions & 1 deletion

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- AIRGpuChannelToMgpuPass.h --------------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
8+
#ifndef AIR_CONVERSION_AIR_GPU_CHANNEL_TO_MGPU_PASS_H
9+
#define AIR_CONVERSION_AIR_GPU_CHANNEL_TO_MGPU_PASS_H
10+
11+
#include "mlir/Pass/Pass.h"
12+
#include <memory>
13+
14+
namespace xilinx {
15+
namespace air {
16+
17+
std::unique_ptr<mlir::Pass> createAIRGpuChannelToMgpuPass();
18+
19+
} // namespace air
20+
} // namespace xilinx
21+
22+
#endif // AIR_CONVERSION_AIR_GPU_CHANNEL_TO_MGPU_PASS_H

mlir/include/air/Conversion/GPUPassDetail.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ using namespace mlir;
2929
#define GEN_PASS_DEF_AIRRANKTOMGPU
3030
#define GEN_PASS_DEF_AIRSYMMETRICALLOCTOMGPU
3131
#define GEN_PASS_DEF_AIRCROSSRANKDMATOMGPU
32+
#define GEN_PASS_DEF_AIRGPUCHANNELTOMGPU
3233
#include "air/Conversion/GPUPasses.h.inc"
3334

3435
} // namespace air

mlir/include/air/Conversion/GPUPasses.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,33 @@ def ConvertGPUKernelOutline : Pass<"air-gpu-outlining", "ModuleOp"> {
4949
let options = [];
5050
}
5151

52+
def AIRGpuChannelToMgpu : Pass<"air-gpu-channel-to-mgpu", "ModuleOp"> {
53+
let summary = "Lower air.channel.put/get of channel_type=\"gpu_symmetric_heap\" "
54+
"to host-side mgpuMemcpy (peer-VA) + mgpuBarrier";
55+
let constructor = "xilinx::air::createAIRGpuChannelToMgpuPass()";
56+
let description = [{
57+
For each `air.channel @C [...] {channel_type = "gpu_symmetric_heap"}`,
58+
pair its single `air.channel.put` and single `air.channel.get`. The put
59+
becomes `mgpuBarrier()` (publish: data is already in the symmetric heap
60+
via the put's `air.symmetric` source memref). The get becomes
61+
`mgpuBarrier()` followed by `mgpuMemcpy(dst, peer_va(put_src), size)`
62+
where the peer rank is the get's first index operand and the peer VA is
63+
computed via `mgpuGetHeapBases()`.
64+
65+
Restrictions in this initial version:
66+
- One put and one get per channel symbol.
67+
- Both put and get at host scope (no `gpu.launch`/`gpu.func`).
68+
- put's source memref must be `air.symmetric`-tagged.
69+
- get's destination memref must be in `memory_space=0`.
70+
- "Entire memref" form only on both sides.
71+
- get must take exactly one index operand (the peer rank).
72+
}];
73+
let dependentDialects = [
74+
"func::FuncDialect", "arith::ArithDialect", "memref::MemRefDialect",
75+
"LLVM::LLVMDialect"
76+
];
77+
}
78+
5279
def AIRCrossRankDmaToMgpu : Pass<"air-cross-rank-dma-to-mgpu", "ModuleOp"> {
5380
let summary = "Lower air.dma_memcpy_nd with src_rank/dst_rank to mgpuMemcpy "
5481
"with peer-VA addressing through mgpuGetHeapBases()";
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
//===- AIRGpuChannelToMgpuPass.cpp ------------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
//
8+
// Lower air.channel of channel_type="gpu_symmetric_heap" plus its put/get
9+
// pair to host-side mgpuMemcpy with peer-VA addressing through
10+
// mgpuGetHeapBases(), with mgpuBarrier-based synchronization.
11+
//
12+
// Per channel:
13+
// - put becomes mgpuBarrier() (publish — the data is already in the
14+
// symmetric heap via the put's air.symmetric source memref)
15+
// - get becomes mgpuBarrier() followed by mgpuMemcpy(dst, peer_va(src), sz)
16+
// where the peer rank is the get's first index operand
17+
//
18+
//===-----------------------------------------------------------------------===//
19+
20+
#include "air/Conversion/AIRGpuChannelToMgpuPass.h"
21+
#include "air/Conversion/GPUPassDetail.h"
22+
#include "air/Dialect/AIR/AIRDialect.h"
23+
24+
#include "mlir/Dialect/Arith/IR/Arith.h"
25+
#include "mlir/Dialect/Func/IR/FuncOps.h"
26+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
27+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
28+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
29+
#include "mlir/IR/Builders.h"
30+
#include "mlir/IR/SymbolTable.h"
31+
#include "mlir/Pass/Pass.h"
32+
33+
using namespace mlir;
34+
using namespace xilinx;
35+
36+
namespace {
37+
38+
static func::FuncOp ensureExternFunc(ModuleOp module, OpBuilder &builder,
39+
StringRef name, FunctionType type) {
40+
if (auto fn = module.lookupSymbol<func::FuncOp>(name))
41+
return fn;
42+
OpBuilder::InsertionGuard guard(builder);
43+
builder.setInsertionPointToStart(module.getBody());
44+
auto fn = func::FuncOp::create(builder, module.getLoc(), name, type);
45+
fn.setPrivate();
46+
return fn;
47+
}
48+
49+
static Value computeMemrefByteSize(OpBuilder &b, Location loc, MemRefType ty) {
50+
if (!ty.hasStaticShape())
51+
return nullptr;
52+
int64_t numElts = 1;
53+
for (int64_t d : ty.getShape())
54+
numElts *= d;
55+
unsigned eltBits = ty.getElementType().getIntOrFloatBitWidth();
56+
if (eltBits == 0 || (eltBits % 8) != 0)
57+
return nullptr;
58+
int64_t totalBytes = numElts * (eltBits / 8);
59+
return arith::ConstantOp::create(b, loc, b.getI64Type(),
60+
b.getI64IntegerAttr(totalBytes));
61+
}
62+
63+
static Value extractAlignedPtr(OpBuilder &b, Location loc, Value memref) {
64+
Value idx = memref::ExtractAlignedPointerAsIndexOp::create(b, loc, memref);
65+
Value i64 = arith::IndexCastOp::create(b, loc, b.getI64Type(), idx);
66+
auto ptrTy = LLVM::LLVMPointerType::get(b.getContext());
67+
return LLVM::IntToPtrOp::create(b, loc, ptrTy, i64);
68+
}
69+
70+
struct AIRGpuChannelToMgpuPass
71+
: public xilinx::air::impl::AIRGpuChannelToMgpuBase<
72+
AIRGpuChannelToMgpuPass> {
73+
74+
AIRGpuChannelToMgpuPass() = default;
75+
AIRGpuChannelToMgpuPass(const AIRGpuChannelToMgpuPass &) {}
76+
77+
void runOnOperation() override {
78+
auto module = getOperation();
79+
OpBuilder builder(module.getContext());
80+
auto i32Ty = builder.getI32Type();
81+
auto i64Ty = builder.getI64Type();
82+
auto ptrTy = LLVM::LLVMPointerType::get(module.getContext());
83+
84+
// Collect gpu_symmetric_heap channel decls and their put/get sites.
85+
SmallVector<air::ChannelOp> chans;
86+
module.walk([&](air::ChannelOp ch) {
87+
if (ch.getChannelType() == "gpu_symmetric_heap")
88+
chans.push_back(ch);
89+
});
90+
if (chans.empty())
91+
return;
92+
93+
auto getRankFn = ensureExternFunc(module, builder, "mgpuGetRank",
94+
builder.getFunctionType({}, {i32Ty}));
95+
auto getBasesFn =
96+
ensureExternFunc(module, builder, "mgpuGetHeapBases",
97+
builder.getFunctionType({}, {ptrTy}));
98+
auto memcpyFn = ensureExternFunc(
99+
module, builder, "mgpuMemcpy",
100+
builder.getFunctionType({ptrTy, ptrTy, i64Ty, ptrTy}, {}));
101+
auto barrierFn = ensureExternFunc(
102+
module, builder, "mgpuBarrier", builder.getFunctionType({}, {}));
103+
104+
for (air::ChannelOp ch : chans) {
105+
StringAttr sym = ch.getSymNameAttr();
106+
107+
// Find puts and gets that reference this channel symbol.
108+
SmallVector<air::ChannelPutOp> puts;
109+
SmallVector<air::ChannelGetOp> gets;
110+
module.walk([&](air::ChannelPutOp p) {
111+
if (p.getChanName() == sym.getValue())
112+
puts.push_back(p);
113+
});
114+
module.walk([&](air::ChannelGetOp g) {
115+
if (g.getChanName() == sym.getValue())
116+
gets.push_back(g);
117+
});
118+
119+
if (puts.size() != 1 || gets.size() != 1) {
120+
ch.emitOpError()
121+
<< "channel_type=\"gpu_symmetric_heap\" requires exactly one "
122+
"put and one get per channel; found "
123+
<< puts.size() << " put(s), " << gets.size() << " get(s)";
124+
signalPassFailure();
125+
return;
126+
}
127+
air::ChannelPutOp put = puts.front();
128+
air::ChannelGetOp get = gets.front();
129+
130+
// Restrictions
131+
if (put->getParentOfType<gpu::LaunchOp>() ||
132+
put->getParentOfType<gpu::GPUFuncOp>() ||
133+
get->getParentOfType<gpu::LaunchOp>() ||
134+
get->getParentOfType<gpu::GPUFuncOp>()) {
135+
ch.emitOpError("gpu_symmetric_heap put/get inside a GPU kernel is "
136+
"not yet supported");
137+
signalPassFailure();
138+
return;
139+
}
140+
if (!put.getSrcOffsets().empty() || !put.getSrcSizes().empty() ||
141+
!put.getSrcStrides().empty() || !get.getDstOffsets().empty() ||
142+
!get.getDstSizes().empty() || !get.getDstStrides().empty()) {
143+
ch.emitOpError("gpu_symmetric_heap put/get with explicit "
144+
"offsets/sizes/strides is not yet supported");
145+
signalPassFailure();
146+
return;
147+
}
148+
149+
auto srcType = cast<MemRefType>(put.getSrc().getType());
150+
auto dstType = cast<MemRefType>(get.getDst().getType());
151+
if (srcType.getMemorySpaceAsInt() != 0 ||
152+
dstType.getMemorySpaceAsInt() != 0) {
153+
ch.emitOpError(
154+
"gpu_symmetric_heap put/get requires both memrefs in memory_space=0");
155+
signalPassFailure();
156+
return;
157+
}
158+
159+
// The put's source must be air.symmetric so peers can read it.
160+
if (auto allocOp = put.getSrc().getDefiningOp<memref::AllocOp>())
161+
if (!allocOp->hasAttr("air.symmetric")) {
162+
ch.emitOpError("gpu_symmetric_heap put requires a memref.alloc "
163+
"carrying the \"air.symmetric\" attribute");
164+
signalPassFailure();
165+
return;
166+
}
167+
168+
if (get.getIndices().size() != 1) {
169+
ch.emitOpError("gpu_symmetric_heap get requires exactly one index "
170+
"operand (the peer rank)");
171+
signalPassFailure();
172+
return;
173+
}
174+
Value peerRankIdx = get.getIndices().front();
175+
176+
// ---- Lower put: emit barrier (publish) and erase ----
177+
Location putLoc = put.getLoc();
178+
builder.setInsertionPointAfter(put);
179+
func::CallOp::create(builder, putLoc, barrierFn, ValueRange{});
180+
if (put.getAsyncToken()) {
181+
Value tok = air::WaitAllOp::create(
182+
builder, putLoc,
183+
air::AsyncTokenType::get(builder.getContext()),
184+
ValueRange{})
185+
.getAsyncToken();
186+
put.getAsyncToken().replaceAllUsesWith(tok);
187+
}
188+
put.erase();
189+
190+
// ---- Lower get: barrier + cross-rank mgpuMemcpy(dst, peer_va(src), sz) ----
191+
Location getLoc = get.getLoc();
192+
builder.setInsertionPoint(get);
193+
194+
// Barrier (consume)
195+
func::CallOp::create(builder, getLoc, barrierFn, ValueRange{});
196+
197+
Value sizeBytes = computeMemrefByteSize(builder, getLoc, srcType);
198+
if (!sizeBytes) {
199+
ch.emitOpError("gpu_symmetric_heap requires static memref shape");
200+
signalPassFailure();
201+
return;
202+
}
203+
Value nullPtr = LLVM::ZeroOp::create(builder, getLoc, ptrTy);
204+
205+
Value srcLocalPtr = extractAlignedPtr(builder, getLoc, put.getSrc());
206+
Value dstLocalPtr = extractAlignedPtr(builder, getLoc, get.getDst());
207+
208+
Value bases =
209+
func::CallOp::create(builder, getLoc, getBasesFn, ValueRange{})
210+
.getResult(0);
211+
Value myRankI32 =
212+
func::CallOp::create(builder, getLoc, getRankFn, ValueRange{})
213+
.getResult(0);
214+
Value myRankI64 =
215+
arith::ExtSIOp::create(builder, getLoc, i64Ty, myRankI32);
216+
Value myBaseAddr = LLVM::GEPOp::create(builder, getLoc, ptrTy, ptrTy,
217+
bases, ArrayRef<Value>{myRankI64});
218+
Value myBase = LLVM::LoadOp::create(builder, getLoc, ptrTy, myBaseAddr);
219+
220+
// Peer rank: convert dynamic index operand to i64.
221+
Value peerRankI64;
222+
Type peerTy = peerRankIdx.getType();
223+
if (isa<IndexType>(peerTy))
224+
peerRankI64 = arith::IndexCastOp::create(builder, getLoc, i64Ty,
225+
peerRankIdx);
226+
else if (auto intTy = dyn_cast<IntegerType>(peerTy)) {
227+
if (intTy.getWidth() == 64)
228+
peerRankI64 = peerRankIdx;
229+
else
230+
peerRankI64 =
231+
arith::ExtSIOp::create(builder, getLoc, i64Ty, peerRankIdx);
232+
} else {
233+
ch.emitOpError("gpu_symmetric_heap get peer-rank index must be index "
234+
"or integer type");
235+
signalPassFailure();
236+
return;
237+
}
238+
239+
Value peerBaseAddr = LLVM::GEPOp::create(
240+
builder, getLoc, ptrTy, ptrTy, bases, ArrayRef<Value>{peerRankI64});
241+
Value peerBase =
242+
LLVM::LoadOp::create(builder, getLoc, ptrTy, peerBaseAddr);
243+
244+
Value srcLocalInt =
245+
LLVM::PtrToIntOp::create(builder, getLoc, i64Ty, srcLocalPtr);
246+
Value myBaseInt =
247+
LLVM::PtrToIntOp::create(builder, getLoc, i64Ty, myBase);
248+
Value offset =
249+
arith::SubIOp::create(builder, getLoc, srcLocalInt, myBaseInt);
250+
251+
auto i8Ty = builder.getI8Type();
252+
Value peerSrc = LLVM::GEPOp::create(builder, getLoc, ptrTy, i8Ty,
253+
peerBase, ArrayRef<Value>{offset});
254+
255+
func::CallOp::create(
256+
builder, getLoc, memcpyFn,
257+
ValueRange{dstLocalPtr, peerSrc, sizeBytes, nullPtr});
258+
259+
if (get.getAsyncToken()) {
260+
Value tok = air::WaitAllOp::create(
261+
builder, getLoc,
262+
air::AsyncTokenType::get(builder.getContext()),
263+
ValueRange{})
264+
.getAsyncToken();
265+
get.getAsyncToken().replaceAllUsesWith(tok);
266+
}
267+
get.erase();
268+
269+
// The channel symbol can now be erased.
270+
ch.erase();
271+
}
272+
}
273+
};
274+
275+
} // namespace
276+
277+
namespace xilinx {
278+
namespace air {
279+
280+
std::unique_ptr<mlir::Pass> createAIRGpuChannelToMgpuPass() {
281+
return std::make_unique<AIRGpuChannelToMgpuPass>();
282+
}
283+
284+
} // namespace air
285+
} // namespace xilinx

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ if(AIR_ENABLE_GPU)
6060
AIRRankToMgpuPass.cpp
6161
AIRSymmetricAllocToMgpuPass.cpp
6262
AIRCrossRankDmaToMgpuPass.cpp
63+
AIRGpuChannelToMgpuPass.cpp
6364
)
6465
list(APPEND CONVERSION_LINK_LIBS
6566
MLIRGPUDialect

mlir/lib/Conversion/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#if AIR_ENABLE_GPU
1212
#include "air/Conversion/AIRCrossRankDmaToMgpuPass.h"
13+
#include "air/Conversion/AIRGpuChannelToMgpuPass.h"
1314
#include "air/Conversion/AIRRankToMgpuPass.h"
1415
#include "air/Conversion/AIRSymmetricAllocToMgpuPass.h"
1516
#include "air/Conversion/AIRToROCDLPass.h"

0 commit comments

Comments
 (0)