Skip to content

Commit 6875ed3

Browse files
erwei-xilinxclaude
andcommitted
[multi-gpu] Phase 3: air-rank-to-mgpu lowering pass
New conversion pass that replaces each `air.rank` op by inlining its body in place, with rank IDs computed at runtime via `mgpuGetRank()` and delinearized into the rank's N-D iteration space. Replaces `air-rank-to-launch` for the GPU pipeline (which serialized ranks via scf.for — a placeholder for single-process execution). After this pass each process executes the entire `air.rank` body once, with its rank id resolved dynamically from the runtime. Heap lifecycle (`mgpuSymmetricHeapInit` / `mgpuSymmetricHeapDestroy`) is bracketed around the parent function once per function (not per rank). - `mlir/include/air/Conversion/AIRRankToMgpuPass.h` — public header - `mlir/include/air/Conversion/GPUPasses.td` — `air-rank-to-mgpu` def with `heap-size` option (default 256 MB) - `mlir/include/air/Conversion/GPUPassDetail.h` — `GEN_PASS_DEF_AIRRANKTOMGPU` - `mlir/lib/Conversion/AIRRankToMgpuPass.cpp` — pass implementation - `mlir/lib/Conversion/CMakeLists.txt`, `Passes.cpp` — registration - `mlir/test/Conversion/AIRRankToMgpu/rank_to_mgpu.mlir` — FileCheck unit tests (10 cases; see Test plan below) - `test/gpu/symmetric_heap_dma/air_sym_with_rank.mlir` — high-level air.rank-based equivalent of the Phase 2 hand-written reference - `test/gpu/symmetric_heap_dma/run.sh` — `INPUT=rank|handwritten` selector to run either form through the same multi-process driver FileCheck unit tests cover: - 1D / 2D rank delinearization (remsi/divsi) - Default + custom heap-size option - Async form (token replacement via wait_all) - Async dependencies (blocking wait_all insertion) - Multiple `air.rank` ops per function (init/destroy emitted once) - Multiple `func.return` paths (destroy before each) - Kernel operand mapping (block args replaced by SSA operands) - Idempotent extern decls across multiple functions - No-op when no `air.rank` is present (audit-found bug fixed: pass was unconditionally inserting decls) End-to-end: rad-mi300a-sh5-1, SHARE_GPU=1, 2 ranks, INPUT=rank — both ranks PASS the cross-rank read. Caveat: same SHARE_GPU=1 single-physical-GPU caveat as Phase 2. True multi-GPU re-validation is needed before declaring multi-GPU production- ready (blocked on ROCm-side work). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 977767d commit 6875ed3

9 files changed

Lines changed: 574 additions & 8 deletions

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- AIRRankToMgpuPass.h ---------------------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
8+
#ifndef AIR_CONVERSION_AIR_RANK_TO_MGPU_PASS_H
9+
#define AIR_CONVERSION_AIR_RANK_TO_MGPU_PASS_H
10+
11+
#include "mlir/Pass/Pass.h"
12+
#include <memory>
13+
14+
namespace xilinx {
15+
namespace air {
16+
17+
std::unique_ptr<mlir::Pass> createAIRRankToMgpuPass();
18+
19+
} // namespace air
20+
} // namespace xilinx
21+
22+
#endif // AIR_CONVERSION_AIR_RANK_TO_MGPU_PASS_H

mlir/include/air/Conversion/GPUPassDetail.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using namespace mlir;
2626
#define GEN_PASS_DEF_AIRTRANSLATETOLLVM
2727
#define GEN_PASS_DEF_CONVERTAIRTOROCDL
2828
#define GEN_PASS_DEF_CONVERTGPUKERNELOUTLINE
29+
#define GEN_PASS_DEF_AIRRANKTOMGPU
2930
#include "air/Conversion/GPUPasses.h.inc"
3031

3132
} // namespace air

mlir/include/air/Conversion/GPUPasses.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,32 @@ def ConvertGPUKernelOutline : Pass<"air-gpu-outlining", "ModuleOp"> {
4949
let options = [];
5050
}
5151

52+
def AIRRankToMgpu : Pass<"air-rank-to-mgpu", "ModuleOp"> {
53+
let summary = "Lower air.rank to mgpu* runtime calls (multi-GPU process model)";
54+
let constructor = "xilinx::air::createAIRRankToMgpuPass()";
55+
let description = [{
56+
Each `air.rank` op is replaced by inlining its body in place, with rank
57+
IDs computed from `mgpuGetRank()` (delinearized into the rank's N-D
58+
iteration space) and rank sizes substituted from the static size operands.
59+
60+
The pass also inserts `mgpuSymmetricHeapInit(heap_size)` at the entry of
61+
the enclosing `func.func` (default 256 MB; configurable via the
62+
`heap-size` option) and `mgpuSymmetricHeapDestroy()` before each
63+
`func.return` in that function.
64+
65+
This replaces `air-rank-to-launch` for the GPU pipeline. Unlike
66+
`air-rank-to-launch` (which serializes ranks via `scf.for`), this pass
67+
assumes each process executes the whole rank body once and runtime
68+
coordinates across processes via env vars (RANK / WORLD_SIZE / LOCAL_RANK)
69+
and the symmetric-heap fabric.
70+
}];
71+
let options = [
72+
Option<"heapSize", "heap-size", "uint64_t", "/*default=*/268435456",
73+
"Symmetric heap size in bytes (default: 256 MB)">
74+
];
75+
let dependentDialects = [
76+
"func::FuncDialect", "arith::ArithDialect"
77+
];
78+
}
79+
5280
#endif // AIR_CONVERSION_GPU_PASSES
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//===- AIRRankToMgpuPass.cpp -----------------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
//
8+
// Lower air.rank to mgpu* runtime calls (multi-GPU process model).
9+
//
10+
// Each `air.rank` op is replaced by inlining its body in place, with rank
11+
// IDs computed from `mgpuGetRank()` (delinearized into the rank's N-D
12+
// iteration space) and rank sizes substituted from the static size operands.
13+
//
14+
// The pass also inserts `mgpuSymmetricHeapInit(heap_size)` at the entry of
15+
// the enclosing `func.func` and `mgpuSymmetricHeapDestroy()` before each
16+
// `func.return` in that function.
17+
//
18+
//===-----------------------------------------------------------------------===//
19+
20+
#include "air/Conversion/AIRRankToMgpuPass.h"
21+
#include "air/Conversion/GPUPassDetail.h"
22+
#include "air/Dialect/AIR/AIRDialect.h"
23+
24+
#include "mlir/Dialect/Arith/IR/Arith.h"
25+
#include "mlir/Dialect/Func/IR/FuncOps.h"
26+
#include "mlir/IR/Builders.h"
27+
#include "mlir/IR/IRMapping.h"
28+
#include "mlir/Pass/Pass.h"
29+
30+
using namespace mlir;
31+
using namespace xilinx;
32+
33+
namespace {
34+
35+
// Ensure a private extern func declaration exists at the top of the module.
36+
static func::FuncOp ensureExternFunc(ModuleOp module, OpBuilder &builder,
37+
StringRef name, FunctionType type) {
38+
if (auto fn = module.lookupSymbol<func::FuncOp>(name))
39+
return fn;
40+
OpBuilder::InsertionGuard guard(builder);
41+
builder.setInsertionPointToStart(module.getBody());
42+
auto fn = func::FuncOp::create(builder, module.getLoc(), name, type);
43+
fn.setPrivate();
44+
return fn;
45+
}
46+
47+
struct AIRRankToMgpuPass
48+
: public xilinx::air::impl::AIRRankToMgpuBase<AIRRankToMgpuPass> {
49+
50+
AIRRankToMgpuPass() = default;
51+
AIRRankToMgpuPass(const AIRRankToMgpuPass &pass) {}
52+
53+
void runOnOperation() override {
54+
auto module = getOperation();
55+
OpBuilder builder(module.getContext());
56+
auto i32Ty = builder.getI32Type();
57+
auto i64Ty = builder.getI64Type();
58+
auto idxTy = builder.getIndexType();
59+
60+
// Collect all air.rank ops and their parent functions.
61+
SmallVector<air::RankOp> rankOps;
62+
SetVector<func::FuncOp> rankParentFuncs;
63+
module.walk([&](air::RankOp op) {
64+
rankOps.push_back(op);
65+
if (auto fn = op->getParentOfType<func::FuncOp>())
66+
rankParentFuncs.insert(fn);
67+
});
68+
69+
// If no air.rank ops exist, leave the module untouched.
70+
if (rankOps.empty())
71+
return;
72+
73+
// Declare the mgpu* runtime ABI functions (only when needed).
74+
auto initFn = ensureExternFunc(module, builder, "mgpuSymmetricHeapInit",
75+
builder.getFunctionType({i64Ty}, {}));
76+
auto destroyFn =
77+
ensureExternFunc(module, builder, "mgpuSymmetricHeapDestroy",
78+
builder.getFunctionType({}, {}));
79+
auto getRankFn = ensureExternFunc(module, builder, "mgpuGetRank",
80+
builder.getFunctionType({}, {i32Ty}));
81+
82+
// For each parent function, insert mgpuSymmetricHeapInit at entry and
83+
// mgpuSymmetricHeapDestroy before each return.
84+
for (func::FuncOp fn : rankParentFuncs) {
85+
if (fn.empty())
86+
continue;
87+
Block &entry = fn.front();
88+
Location loc = fn.getLoc();
89+
builder.setInsertionPointToStart(&entry);
90+
Value heapSizeVal = arith::ConstantOp::create(
91+
builder, loc, i64Ty,
92+
builder.getI64IntegerAttr(static_cast<int64_t>(heapSize)));
93+
func::CallOp::create(builder, loc, initFn, ValueRange{heapSizeVal});
94+
95+
// Insert destroy before every return op.
96+
SmallVector<func::ReturnOp> returns;
97+
fn.walk([&](func::ReturnOp r) { returns.push_back(r); });
98+
for (func::ReturnOp r : returns) {
99+
builder.setInsertionPoint(r);
100+
func::CallOp::create(builder, r.getLoc(), destroyFn, ValueRange{});
101+
}
102+
}
103+
104+
// Lower each air.rank op.
105+
for (air::RankOp rankOp : rankOps) {
106+
builder.setInsertionPoint(rankOp);
107+
Location loc = rankOp.getLoc();
108+
109+
// If the rank has async dependencies, insert a blocking wait before
110+
// proceeding.
111+
if (!rankOp.getAsyncDependencies().empty()) {
112+
air::WaitAllOp::create(builder, loc, Type{},
113+
rankOp.getAsyncDependencies());
114+
}
115+
116+
// Get the flat rank id from mgpuGetRank() and convert to index.
117+
Value rankI32 =
118+
func::CallOp::create(builder, loc, getRankFn, ValueRange{})
119+
.getResult(0);
120+
Value rankI64 =
121+
arith::ExtSIOp::create(builder, loc, i64Ty, rankI32);
122+
Value flatRank =
123+
arith::IndexCastOp::create(builder, loc, idxTy, rankI64);
124+
125+
// Delinearize flatRank into N rank IDs using the static size operands.
126+
// For sizes [s0, s1, ..., sn-1]:
127+
// id[0] = flat % s0
128+
// id[1] = (flat / s0) % s1
129+
// ...
130+
// id[n-1] = flat / (s0 * s1 * ... * sn-2)
131+
auto sizeOpers = rankOp.getSizeOperands();
132+
unsigned n = rankOp.getNumDims();
133+
SmallVector<Value> ids(n);
134+
Value remaining = flatRank;
135+
for (unsigned d = 0; d < n; ++d) {
136+
if (d == n - 1) {
137+
ids[d] = remaining;
138+
} else {
139+
ids[d] = arith::RemSIOp::create(builder, loc, remaining, sizeOpers[d]);
140+
remaining =
141+
arith::DivSIOp::create(builder, loc, remaining, sizeOpers[d]);
142+
}
143+
}
144+
145+
// Build remap and clone the body.
146+
IRMapping remap;
147+
for (unsigned d = 0; d < n; ++d) {
148+
remap.map(rankOp.getIds()[d], ids[d]);
149+
remap.map(rankOp.getSize()[d], sizeOpers[d]);
150+
}
151+
for (unsigned i = 0; i < rankOp.getNumKernelOperands(); ++i)
152+
remap.map(rankOp.getKernelArgument(i), rankOp.getKernelOperand(i));
153+
154+
auto &ops = rankOp.getBody().front().getOperations();
155+
for (auto oi = ops.begin(), oe = --ops.end(); oi != oe; ++oi)
156+
builder.clone(*oi, remap);
157+
158+
// Replace the async token (if any) with a synchronous wait_all.
159+
if (rankOp.getAsyncToken()) {
160+
auto waitAll = air::WaitAllOp::create(
161+
builder, loc, air::AsyncTokenType::get(builder.getContext()),
162+
ValueRange{});
163+
rankOp.getAsyncToken().replaceAllUsesWith(waitAll.getAsyncToken());
164+
}
165+
166+
rankOp.erase();
167+
}
168+
}
169+
};
170+
171+
} // namespace
172+
173+
namespace xilinx {
174+
namespace air {
175+
176+
std::unique_ptr<mlir::Pass> createAIRRankToMgpuPass() {
177+
return std::make_unique<AIRRankToMgpuPass>();
178+
}
179+
180+
} // namespace air
181+
} // namespace xilinx

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ if(AIR_ENABLE_GPU)
5757
AIRToROCDLPass.cpp
5858
AIRTranslateToLLVMPass.cpp
5959
GPUKernelOutlinePass.cpp
60+
AIRRankToMgpuPass.cpp
6061
)
6162
list(APPEND CONVERSION_LINK_LIBS
6263
MLIRGPUDialect

mlir/lib/Conversion/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "air/Conversion/Passes.h"
1010

1111
#if AIR_ENABLE_GPU
12+
#include "air/Conversion/AIRRankToMgpuPass.h"
1213
#include "air/Conversion/AIRToROCDLPass.h"
1314
#include "air/Conversion/AIRTranslateToLLVMPass.h"
1415
#include "air/Conversion/GPUKernelOutlinePass.h"

0 commit comments

Comments
 (0)