Skip to content

Commit dae66d7

Browse files
erwei-xilinxclaude
andcommitted
[multi-gpu] Phase 5: air-cross-rank-dma-to-mgpu lowering pass
New conversion pass that lowers `air.dma_memcpy_nd` ops carrying a `src_rank` or `dst_rank` integer attribute (added in Phase 1) to host-side `mgpuMemcpy` calls with peer-VA addressing through `mgpuGetHeapBases()`. The peer pointer is computed at runtime as: peer_ptr = bases[peer_rank] + (local_ptr - bases[my_rank]) where `local_ptr` is extracted from the local-side memref via `memref.extract_aligned_pointer_as_index` and `local_base = bases[my_rank]` gives this rank's symmetric heap base. ## Restrictions (this initial version) - Both `src` and `dst` memrefs must be in `memory_space=0` (L3/global) - The op must be at host scope (not inside a `gpu.launch` or `gpu.func`) - "Entire memref" form only — no explicit `[offsets][sizes][strides]` - Only one of `src_rank` / `dst_rank` may be set per op These restrictions match the hand-written reference's Phase 2 pattern. They can be relaxed in follow-up work. ## Files - `mlir/include/air/Conversion/AIRCrossRankDmaToMgpuPass.h` — header - `mlir/include/air/Conversion/GPUPasses.td` — `air-cross-rank-dma-to-mgpu` def - `mlir/include/air/Conversion/GPUPassDetail.h` — `GEN_PASS_DEF_AIRCROSSRANKDMATOMGPU` - `mlir/lib/Conversion/AIRCrossRankDmaToMgpuPass.cpp` — implementation - `mlir/lib/Conversion/{CMakeLists.txt,Passes.cpp}` — registration - `mlir/test/Conversion/AIRCrossRankDmaToMgpu/cross_rank_dma.mlir` — FileCheck - `test/gpu/symmetric_heap_dma/air_sym_with_dma.mlir` — high-level e2e combining Phase 1 attrs + Phase 3 + Phase 4 + Phase 5 lowering - `test/gpu/symmetric_heap_dma/run.sh` — adds `INPUT=dma` selector ## Test plan FileCheck unit tests cover: - src_rank lowering shape (size, ptr extraction, bases, GEP, ptrtoint, subi, byte-stride GEP, mgpuMemcpy) - dst_rank lowering (peer pointer becomes dst arg) - 2D memref byte size - f64 element type byte size - Multiple cross-rank DMAs share extern decls - Pass is a no-op for non-cross-rank DMAs End-to-end on rad-mi300a-sh5-1 (SHARE_GPU=1, 2 ranks): - INPUT=handwritten — PASS (Phase 2 baseline) - INPUT=rank — PASS (Phase 3) - INPUT=alloc — PASS (Phase 4) - INPUT=dma — PASS (Phase 5: chains Phase 5 -> Phase 4 -> Phase 3) Both ranks read rank 0's symmetric src_buf via cross-rank DMA into their own dst_buf; verification reads back 1.0. Same SHARE_GPU=1 single-physical-GPU caveat as Xilinx#1577 / Xilinx#1578 / Xilinx#1579 — true multi-GPU re-validation is needed before declaring multi-GPU production-ready. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 7dfcbfb commit dae66d7

9 files changed

Lines changed: 551 additions & 1 deletion

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- AIRCrossRankDmaToMgpuPass.h ------------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
8+
#ifndef AIR_CONVERSION_AIR_CROSS_RANK_DMA_TO_MGPU_PASS_H
9+
#define AIR_CONVERSION_AIR_CROSS_RANK_DMA_TO_MGPU_PASS_H
10+
11+
#include "mlir/Pass/Pass.h"
12+
#include <memory>
13+
14+
namespace xilinx {
15+
namespace air {
16+
17+
std::unique_ptr<mlir::Pass> createAIRCrossRankDmaToMgpuPass();
18+
19+
} // namespace air
20+
} // namespace xilinx
21+
22+
#endif // AIR_CONVERSION_AIR_CROSS_RANK_DMA_TO_MGPU_PASS_H

mlir/include/air/Conversion/GPUPassDetail.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ using namespace mlir;
2828
#define GEN_PASS_DEF_CONVERTGPUKERNELOUTLINE
2929
#define GEN_PASS_DEF_AIRRANKTOMGPU
3030
#define GEN_PASS_DEF_AIRSYMMETRICALLOCTOMGPU
31+
#define GEN_PASS_DEF_AIRCROSSRANKDMATOMGPU
3132
#include "air/Conversion/GPUPasses.h.inc"
3233

3334
} // namespace air

mlir/include/air/Conversion/GPUPasses.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,31 @@ def ConvertGPUKernelOutline : Pass<"air-gpu-outlining", "ModuleOp"> {
4949
let options = [];
5050
}
5151

52+
def AIRCrossRankDmaToMgpu : Pass<"air-cross-rank-dma-to-mgpu", "ModuleOp"> {
53+
let summary = "Lower air.dma_memcpy_nd with src_rank/dst_rank to mgpuMemcpy "
54+
"with peer-VA addressing through mgpuGetHeapBases()";
55+
let constructor = "xilinx::air::createAIRCrossRankDmaToMgpuPass()";
56+
let description = [{
57+
For each `air.dma_memcpy_nd` op carrying a `src_rank` or `dst_rank`
58+
integer attribute, emit a host-side `mgpuMemcpy` whose peer-side pointer
59+
is computed as `mgpuGetHeapBases()[peer] + (local_ptr - local_base)`.
60+
61+
Restrictions in this initial version:
62+
- Both `src` and `dst` memrefs must be in `memory_space=0`.
63+
- The op must be at host scope (not inside any `gpu.launch`/`gpu.func`).
64+
- "Entire memref" form only: `[]` `[]` `[]` for both sides — no
65+
custom offsets / sizes / strides.
66+
67+
Lower this pass *before* `air-symmetric-alloc-to-mgpu` so that pointer
68+
extraction (`memref.extract_aligned_pointer_as_index`) sees plain
69+
memrefs rather than already-cast LLVM struct values.
70+
}];
71+
let dependentDialects = [
72+
"func::FuncDialect", "arith::ArithDialect", "memref::MemRefDialect",
73+
"LLVM::LLVMDialect"
74+
];
75+
}
76+
5277
def AIRSymmetricAllocToMgpu : Pass<"air-symmetric-alloc-to-mgpu", "ModuleOp"> {
5378
let summary = "Lower memref.alloc {air.symmetric} to mgpuSymmetricAlloc and "
5479
"memref.dealloc of the result to mgpuSymmetricFree";
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
//===- AIRCrossRankDmaToMgpuPass.cpp ---------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
//
8+
// Lower air.dma_memcpy_nd ops carrying a `src_rank` or `dst_rank` integer
9+
// attribute to host-side mgpuMemcpy calls with peer-VA addressing through
10+
// mgpuGetHeapBases().
11+
//
12+
// Pattern emitted (for src_rank = R):
13+
// %size = arith.constant <bytes> : i64
14+
// %nullptr = llvm.mlir.zero : !llvm.ptr
15+
// %dst_ptr = (extract aligned ptr from %dst memref)
16+
// %src_ptr = (extract aligned ptr from %src memref)
17+
// %my_rank = call @mgpuGetRank() : () -> i32
18+
// %bases = call @mgpuGetHeapBases() : () -> !llvm.ptr
19+
// %my_base_at = llvm.getelementptr %bases[%my_rank] : ... -> !llvm.ptr, !llvm.ptr
20+
// %my_base = llvm.load %my_base_at : !llvm.ptr -> !llvm.ptr
21+
// %src_int = llvm.ptrtoint %src_ptr : !llvm.ptr to i64
22+
// %my_base_int = llvm.ptrtoint %my_base : !llvm.ptr to i64
23+
// %offset = arith.subi %src_int, %my_base_int : i64
24+
// %peer_base_at = llvm.getelementptr %bases[<R>] : ... -> !llvm.ptr, !llvm.ptr
25+
// %peer_base = llvm.load %peer_base_at : !llvm.ptr -> !llvm.ptr
26+
// %peer_src = llvm.getelementptr %peer_base[%offset] : ... -> !llvm.ptr, i8
27+
// call @mgpuMemcpy(%dst_ptr, %peer_src, %size, %nullptr)
28+
//
29+
// Initial restrictions:
30+
// - Both memrefs must have memory_space=0 (L3/global).
31+
// - Op must be at host scope (not inside a gpu.launch / gpu.func).
32+
// - "Entire memref" form only: empty offsets/sizes/strides on both sides.
33+
//
34+
//===-----------------------------------------------------------------------===//
35+
36+
#include "air/Conversion/AIRCrossRankDmaToMgpuPass.h"
37+
#include "air/Conversion/GPUPassDetail.h"
38+
#include "air/Dialect/AIR/AIRDialect.h"
39+
40+
#include "mlir/Dialect/Arith/IR/Arith.h"
41+
#include "mlir/Dialect/Func/IR/FuncOps.h"
42+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
43+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
44+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
45+
#include "mlir/IR/Builders.h"
46+
#include "mlir/Pass/Pass.h"
47+
48+
using namespace mlir;
49+
using namespace xilinx;
50+
51+
namespace {
52+
53+
// Ensure a private extern func declaration exists at module scope.
54+
static func::FuncOp ensureExternFunc(ModuleOp module, OpBuilder &builder,
55+
StringRef name, FunctionType type) {
56+
if (auto fn = module.lookupSymbol<func::FuncOp>(name))
57+
return fn;
58+
OpBuilder::InsertionGuard guard(builder);
59+
builder.setInsertionPointToStart(module.getBody());
60+
auto fn = func::FuncOp::create(builder, module.getLoc(), name, type);
61+
fn.setPrivate();
62+
return fn;
63+
}
64+
65+
// Compute byte size of a static-shape memref as an i64 SSA value.
66+
static Value computeMemrefByteSize(OpBuilder &b, Location loc, MemRefType ty) {
67+
if (!ty.hasStaticShape())
68+
return nullptr;
69+
int64_t numElts = 1;
70+
for (int64_t d : ty.getShape())
71+
numElts *= d;
72+
unsigned eltBits = ty.getElementType().getIntOrFloatBitWidth();
73+
if (eltBits == 0 || (eltBits % 8) != 0)
74+
return nullptr;
75+
int64_t totalBytes = numElts * (eltBits / 8);
76+
return arith::ConstantOp::create(b, loc, b.getI64Type(),
77+
b.getI64IntegerAttr(totalBytes));
78+
}
79+
80+
// Extract an aligned !llvm.ptr from a memref via the standard idiom.
81+
static Value extractAlignedPtr(OpBuilder &b, Location loc, Value memref) {
82+
Value idx = memref::ExtractAlignedPointerAsIndexOp::create(b, loc, memref);
83+
Value i64 = arith::IndexCastOp::create(b, loc, b.getI64Type(), idx);
84+
auto ptrTy = LLVM::LLVMPointerType::get(b.getContext());
85+
return LLVM::IntToPtrOp::create(b, loc, ptrTy, i64);
86+
}
87+
88+
struct AIRCrossRankDmaToMgpuPass
89+
: public xilinx::air::impl::AIRCrossRankDmaToMgpuBase<
90+
AIRCrossRankDmaToMgpuPass> {
91+
92+
AIRCrossRankDmaToMgpuPass() = default;
93+
AIRCrossRankDmaToMgpuPass(const AIRCrossRankDmaToMgpuPass &) {}
94+
95+
void runOnOperation() override {
96+
auto module = getOperation();
97+
OpBuilder builder(module.getContext());
98+
auto i32Ty = builder.getI32Type();
99+
auto i64Ty = builder.getI64Type();
100+
auto ptrTy = LLVM::LLVMPointerType::get(module.getContext());
101+
102+
// Collect cross-rank DMA ops.
103+
SmallVector<air::DmaMemcpyNdOp> crossRankDmas;
104+
module.walk([&](air::DmaMemcpyNdOp op) {
105+
if (op.hasCrossRank())
106+
crossRankDmas.push_back(op);
107+
});
108+
if (crossRankDmas.empty())
109+
return;
110+
111+
// Declare the runtime ABI functions we may need.
112+
auto getRankFn = ensureExternFunc(module, builder, "mgpuGetRank",
113+
builder.getFunctionType({}, {i32Ty}));
114+
auto getBasesFn =
115+
ensureExternFunc(module, builder, "mgpuGetHeapBases",
116+
builder.getFunctionType({}, {ptrTy}));
117+
auto memcpyFn = ensureExternFunc(
118+
module, builder, "mgpuMemcpy",
119+
builder.getFunctionType({ptrTy, ptrTy, i64Ty, ptrTy}, {}));
120+
121+
for (air::DmaMemcpyNdOp dma : crossRankDmas) {
122+
Location loc = dma.getLoc();
123+
124+
// Restrictions
125+
if (dma->getParentOfType<gpu::LaunchOp>() ||
126+
dma->getParentOfType<gpu::GPUFuncOp>()) {
127+
dma.emitOpError(
128+
"cross-rank DMA inside a GPU kernel is not yet supported");
129+
signalPassFailure();
130+
return;
131+
}
132+
if (!dma.getSrcOffsets().empty() || !dma.getSrcSizes().empty() ||
133+
!dma.getSrcStrides().empty() || !dma.getDstOffsets().empty() ||
134+
!dma.getDstSizes().empty() || !dma.getDstStrides().empty()) {
135+
dma.emitOpError("cross-rank DMA with explicit offsets/sizes/strides "
136+
"is not yet supported");
137+
signalPassFailure();
138+
return;
139+
}
140+
141+
auto srcType = cast<MemRefType>(dma.getSrcMemref().getType());
142+
auto dstType = cast<MemRefType>(dma.getDstMemref().getType());
143+
if (srcType.getMemorySpaceAsInt() != 0 ||
144+
dstType.getMemorySpaceAsInt() != 0) {
145+
dma.emitOpError(
146+
"cross-rank DMA requires both memrefs in memory_space=0");
147+
signalPassFailure();
148+
return;
149+
}
150+
151+
// Determine which side has the rank attribute. (Only one is supported
152+
// per op for now.)
153+
bool srcIsPeer = dma.getSrcRank().has_value();
154+
bool dstIsPeer = dma.getDstRank().has_value();
155+
if (srcIsPeer && dstIsPeer) {
156+
dma.emitOpError(
157+
"cross-rank DMA with both src_rank and dst_rank set is not yet "
158+
"supported");
159+
signalPassFailure();
160+
return;
161+
}
162+
int64_t peerRank =
163+
srcIsPeer ? *dma.getSrcRank() : *dma.getDstRank();
164+
auto peerSideType = srcIsPeer ? srcType : dstType;
165+
Value peerMemref = srcIsPeer ? dma.getSrcMemref() : dma.getDstMemref();
166+
Value localMemref =
167+
srcIsPeer ? dma.getDstMemref() : dma.getSrcMemref();
168+
169+
builder.setInsertionPoint(dma);
170+
Value sizeBytes = computeMemrefByteSize(builder, loc, peerSideType);
171+
if (!sizeBytes) {
172+
dma.emitOpError("cross-rank DMA requires static memref shape with "
173+
"byte-aligned element type");
174+
signalPassFailure();
175+
return;
176+
}
177+
Value nullPtr = LLVM::ZeroOp::create(builder, loc, ptrTy);
178+
179+
Value peerLocalPtr = extractAlignedPtr(builder, loc, peerMemref);
180+
Value localPtr = extractAlignedPtr(builder, loc, localMemref);
181+
182+
// bases = mgpuGetHeapBases()
183+
Value bases = func::CallOp::create(builder, loc, getBasesFn, ValueRange{})
184+
.getResult(0);
185+
186+
// my_rank = mgpuGetRank() (i32 -> i64)
187+
Value myRankI32 =
188+
func::CallOp::create(builder, loc, getRankFn, ValueRange{})
189+
.getResult(0);
190+
Value myRankI64 = arith::ExtSIOp::create(builder, loc, i64Ty, myRankI32);
191+
192+
// my_base = bases[my_rank]
193+
Value myBaseAddr = LLVM::GEPOp::create(builder, loc, ptrTy, ptrTy, bases,
194+
ArrayRef<Value>{myRankI64});
195+
Value myBase = LLVM::LoadOp::create(builder, loc, ptrTy, myBaseAddr);
196+
197+
// peer_base = bases[<peerRank>]
198+
Value peerRankIdx = LLVM::ConstantOp::create(
199+
builder, loc, i64Ty, builder.getI64IntegerAttr(peerRank));
200+
Value peerBaseAddr = LLVM::GEPOp::create(
201+
builder, loc, ptrTy, ptrTy, bases, ArrayRef<Value>{peerRankIdx});
202+
Value peerBase = LLVM::LoadOp::create(builder, loc, ptrTy, peerBaseAddr);
203+
204+
// offset = peerLocalPtr (as i64) - my_base (as i64)
205+
Value peerLocalInt =
206+
LLVM::PtrToIntOp::create(builder, loc, i64Ty, peerLocalPtr);
207+
Value myBaseInt = LLVM::PtrToIntOp::create(builder, loc, i64Ty, myBase);
208+
Value offset =
209+
arith::SubIOp::create(builder, loc, peerLocalInt, myBaseInt);
210+
211+
// peer_ptr = peer_base + offset (byte-stride GEP)
212+
auto i8Ty = builder.getI8Type();
213+
Value peerPtr = LLVM::GEPOp::create(builder, loc, ptrTy, i8Ty, peerBase,
214+
ArrayRef<Value>{offset});
215+
216+
// mgpuMemcpy(dst, src, size, nullptr) — substitute peerPtr on the
217+
// peer side.
218+
Value srcArg = srcIsPeer ? peerPtr : localPtr;
219+
Value dstArg = dstIsPeer ? peerPtr : localPtr;
220+
func::CallOp::create(builder, loc, memcpyFn,
221+
ValueRange{dstArg, srcArg, sizeBytes, nullPtr});
222+
223+
// If this DMA returned an async token, replace it with a wait_all.
224+
if (dma.getAsyncToken()) {
225+
Value tok = air::WaitAllOp::create(
226+
builder, loc,
227+
air::AsyncTokenType::get(builder.getContext()),
228+
ValueRange{})
229+
.getAsyncToken();
230+
dma.getAsyncToken().replaceAllUsesWith(tok);
231+
}
232+
dma.erase();
233+
}
234+
}
235+
};
236+
237+
} // namespace
238+
239+
namespace xilinx {
240+
namespace air {
241+
242+
std::unique_ptr<mlir::Pass> createAIRCrossRankDmaToMgpuPass() {
243+
return std::make_unique<AIRCrossRankDmaToMgpuPass>();
244+
}
245+
246+
} // namespace air
247+
} // namespace xilinx

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ if(AIR_ENABLE_GPU)
5959
GPUKernelOutlinePass.cpp
6060
AIRRankToMgpuPass.cpp
6161
AIRSymmetricAllocToMgpuPass.cpp
62+
AIRCrossRankDmaToMgpuPass.cpp
6263
)
6364
list(APPEND CONVERSION_LINK_LIBS
6465
MLIRGPUDialect

mlir/lib/Conversion/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "air/Conversion/Passes.h"
1010

1111
#if AIR_ENABLE_GPU
12+
#include "air/Conversion/AIRCrossRankDmaToMgpuPass.h"
1213
#include "air/Conversion/AIRRankToMgpuPass.h"
1314
#include "air/Conversion/AIRSymmetricAllocToMgpuPass.h"
1415
#include "air/Conversion/AIRToROCDLPass.h"

0 commit comments

Comments
 (0)