Skip to content

Commit 087fcb7

Browse files
erwei-xilinxclaude
andcommitted
[multi-gpu] Phase 4: air-symmetric-alloc-to-mgpu lowering pass
New conversion pass that replaces `memref.alloc` carrying the unit attribute `air.symmetric` with a call to `mgpuSymmetricAlloc(size, stream)`. The returned `!llvm.ptr` is wrapped in an LLVM memref descriptor (struct) and projected back to the original memref type via `builtin.unrealized_conversion_cast` so downstream uses keep working through the standard `convert-to-llvm` pipeline. `memref.dealloc` ops whose operand traces back (through the cast) to a symmetric alloc are rewritten to `mgpuSymmetricFree`. The pass is a no-op when no `air.symmetric` allocations are present. ## Files - `mlir/include/air/Conversion/AIRSymmetricAllocToMgpuPass.h` — header - `mlir/include/air/Conversion/GPUPasses.td` — `air-symmetric-alloc-to-mgpu` def - `mlir/include/air/Conversion/GPUPassDetail.h` — `GEN_PASS_DEF_AIRSYMMETRICALLOCTOMGPU` - `mlir/lib/Conversion/AIRSymmetricAllocToMgpuPass.cpp` — implementation - `mlir/lib/Conversion/{CMakeLists.txt,Passes.cpp}` — registration - `mlir/test/Conversion/AIRSymmetricAllocToMgpu/symmetric_alloc.mlir` — FileCheck - `test/gpu/symmetric_heap_dma/air_sym_with_alloc.mlir` — high-level e2e using `memref.alloc {air.symmetric}` (Phase 3 + Phase 4 chained) - `test/gpu/symmetric_heap_dma/run.sh` — `INPUT=alloc` selector ## Test plan FileCheck unit tests: - 1D alloc + dealloc shape (size, descriptor, cast, free) - 2D alloc with row-major strides in descriptor - Element type byte-size: f32 (4B), f64 (8B), i32 (4B) - Multiple symmetric allocs share one decl pair - Pass is a no-op for non-symmetric allocs - Pass is a no-op when there are zero symmetric allocs End-to-end on rad-mi300a-sh5-1 (SHARE_GPU=1, 2 ranks): - INPUT=handwritten — PASS (Phase 2 baseline) - INPUT=rank — PASS (Phase 3) - INPUT=alloc — PASS (Phase 4: chained Phase 4 + Phase 3 lowering) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 47466d3 commit 087fcb7

9 files changed

Lines changed: 484 additions & 1 deletion

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- AIRSymmetricAllocToMgpuPass.h ----------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
8+
#ifndef AIR_CONVERSION_AIR_SYMMETRIC_ALLOC_TO_MGPU_PASS_H
9+
#define AIR_CONVERSION_AIR_SYMMETRIC_ALLOC_TO_MGPU_PASS_H
10+
11+
#include "mlir/Pass/Pass.h"
12+
#include <memory>
13+
14+
namespace xilinx {
15+
namespace air {
16+
17+
std::unique_ptr<mlir::Pass> createAIRSymmetricAllocToMgpuPass();
18+
19+
} // namespace air
20+
} // namespace xilinx
21+
22+
#endif // AIR_CONVERSION_AIR_SYMMETRIC_ALLOC_TO_MGPU_PASS_H

mlir/include/air/Conversion/GPUPassDetail.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ using namespace mlir;
2727
#define GEN_PASS_DEF_CONVERTAIRTOROCDL
2828
#define GEN_PASS_DEF_CONVERTGPUKERNELOUTLINE
2929
#define GEN_PASS_DEF_AIRRANKTOMGPU
30+
#define GEN_PASS_DEF_AIRSYMMETRICALLOCTOMGPU
3031
#include "air/Conversion/GPUPasses.h.inc"
3132

3233
} // namespace air

mlir/include/air/Conversion/GPUPasses.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,29 @@ def ConvertGPUKernelOutline : Pass<"air-gpu-outlining", "ModuleOp"> {
4949
let options = [];
5050
}
5151

52+
def AIRSymmetricAllocToMgpu : Pass<"air-symmetric-alloc-to-mgpu", "ModuleOp"> {
53+
let summary = "Lower memref.alloc {air.symmetric} to mgpuSymmetricAlloc and "
54+
"memref.dealloc of the result to mgpuSymmetricFree";
55+
let constructor = "xilinx::air::createAIRSymmetricAllocToMgpuPass()";
56+
let description = [{
57+
Replaces each `memref.alloc` carrying the unit attribute `air.symmetric`
58+
with a call to `mgpuSymmetricAlloc(size_in_bytes, stream)` returning
59+
`!llvm.ptr`, then builds an LLVM memref descriptor (struct) wrapping that
60+
pointer and projects it back to the original memref type via
61+
`builtin.unrealized_conversion_cast` so downstream uses keep working.
62+
63+
For every `memref.dealloc` whose operand traces back (through a single
64+
`unrealized_conversion_cast`) to such a symmetric alloc, the pass emits
65+
`mgpuSymmetricFree(ptr, stream)` and erases the dealloc.
66+
67+
Should run before `convert-to-llvm`. Does nothing if no `air.symmetric`
68+
allocations are present.
69+
}];
70+
let dependentDialects = [
71+
"func::FuncDialect", "arith::ArithDialect", "LLVM::LLVMDialect"
72+
];
73+
}
74+
5275
def AIRRankToMgpu : Pass<"air-rank-to-mgpu", "ModuleOp"> {
5376
let summary = "Lower air.rank to mgpu* runtime calls (multi-GPU process model)";
5477
let constructor = "xilinx::air::createAIRRankToMgpuPass()";
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
//===- AIRSymmetricAllocToMgpuPass.cpp -------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
//
8+
// Lower memref.alloc carrying the `air.symmetric` attribute to a call to the
9+
// runtime function `mgpuSymmetricAlloc`. The returned `!llvm.ptr` is wrapped
10+
// in an LLVM memref descriptor (struct) and projected back to the original
11+
// memref type via `builtin.unrealized_conversion_cast` so that downstream
12+
// uses keep working.
13+
//
14+
// `memref.dealloc` ops whose operand traces (through a single
15+
// `unrealized_conversion_cast`) back to a symmetric alloc are rewritten to
16+
// `mgpuSymmetricFree`.
17+
//
18+
//===-----------------------------------------------------------------------===//
19+
20+
#include "air/Conversion/AIRSymmetricAllocToMgpuPass.h"
21+
#include "air/Conversion/GPUPassDetail.h"
22+
23+
#include "mlir/Dialect/Arith/IR/Arith.h"
24+
#include "mlir/Dialect/Func/IR/FuncOps.h"
25+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
27+
#include "mlir/IR/Builders.h"
28+
#include "mlir/Pass/Pass.h"
29+
30+
using namespace mlir;
31+
using namespace xilinx;
32+
33+
namespace {
34+
35+
// Ensure a private extern func declaration exists at module scope.
36+
static func::FuncOp ensureExternFunc(ModuleOp module, OpBuilder &builder,
37+
StringRef name, FunctionType type) {
38+
if (auto fn = module.lookupSymbol<func::FuncOp>(name))
39+
return fn;
40+
OpBuilder::InsertionGuard guard(builder);
41+
builder.setInsertionPointToStart(module.getBody());
42+
auto fn = func::FuncOp::create(builder, module.getLoc(), name, type);
43+
fn.setPrivate();
44+
return fn;
45+
}
46+
47+
// Compute the byte size of a static-shaped memref as an i64 SSA value.
48+
// Returns nullptr if the memref is dynamically shaped.
49+
static Value computeMemrefByteSize(OpBuilder &b, Location loc, MemRefType ty) {
50+
if (!ty.hasStaticShape())
51+
return nullptr;
52+
int64_t numElts = 1;
53+
for (int64_t d : ty.getShape())
54+
numElts *= d;
55+
unsigned eltBits = ty.getElementType().getIntOrFloatBitWidth();
56+
if (eltBits == 0 || (eltBits % 8) != 0)
57+
return nullptr;
58+
int64_t totalBytes = numElts * (eltBits / 8);
59+
return arith::ConstantOp::create(b, loc, b.getI64Type(),
60+
b.getI64IntegerAttr(totalBytes));
61+
}
62+
63+
// Build an LLVM memref descriptor struct populated with the given pointer.
64+
// For now we support only static-shape, contiguous, identity-layout memrefs
65+
// without an offset. For dimensions: sizes from the type, strides as
66+
// row-major (innermost stride = 1).
67+
static Value buildMemrefDescriptor(OpBuilder &b, Location loc,
68+
MemRefType memrefTy, Value ptr) {
69+
ArrayRef<int64_t> shape = memrefTy.getShape();
70+
unsigned rank = shape.size();
71+
auto i64Ty = b.getI64Type();
72+
auto ptrTy = LLVM::LLVMPointerType::get(b.getContext());
73+
74+
// Build the descriptor type: !llvm.struct<(ptr, ptr, i64, array<R x i64>,
75+
// array<R x i64>)>. For rank-0 memrefs, MLIR omits the size/stride arrays.
76+
SmallVector<Type, 5> descFields;
77+
descFields.push_back(ptrTy);
78+
descFields.push_back(ptrTy);
79+
descFields.push_back(i64Ty);
80+
if (rank > 0) {
81+
descFields.push_back(LLVM::LLVMArrayType::get(i64Ty, rank));
82+
descFields.push_back(LLVM::LLVMArrayType::get(i64Ty, rank));
83+
}
84+
auto structTy = LLVM::LLVMStructType::getLiteral(b.getContext(), descFields);
85+
86+
Value desc = LLVM::PoisonOp::create(b, loc, structTy);
87+
desc = LLVM::InsertValueOp::create(b, loc, desc, ptr, ArrayRef<int64_t>{0});
88+
desc = LLVM::InsertValueOp::create(b, loc, desc, ptr, ArrayRef<int64_t>{1});
89+
Value zero = LLVM::ConstantOp::create(b, loc, i64Ty, b.getI64IntegerAttr(0));
90+
desc = LLVM::InsertValueOp::create(b, loc, desc, zero, ArrayRef<int64_t>{2});
91+
92+
if (rank > 0) {
93+
// Compute row-major strides from shape (innermost = 1).
94+
SmallVector<int64_t> strides(rank, 1);
95+
for (int i = static_cast<int>(rank) - 2; i >= 0; --i)
96+
strides[i] = strides[i + 1] * shape[i + 1];
97+
for (unsigned i = 0; i < rank; ++i) {
98+
Value sz = LLVM::ConstantOp::create(b, loc, i64Ty,
99+
b.getI64IntegerAttr(shape[i]));
100+
desc = LLVM::InsertValueOp::create(b, loc, desc, sz,
101+
ArrayRef<int64_t>{3, (int64_t)i});
102+
Value st = LLVM::ConstantOp::create(b, loc, i64Ty,
103+
b.getI64IntegerAttr(strides[i]));
104+
desc = LLVM::InsertValueOp::create(b, loc, desc, st,
105+
ArrayRef<int64_t>{4, (int64_t)i});
106+
}
107+
}
108+
return desc;
109+
}
110+
111+
struct AIRSymmetricAllocToMgpuPass
112+
: public xilinx::air::impl::AIRSymmetricAllocToMgpuBase<
113+
AIRSymmetricAllocToMgpuPass> {
114+
115+
AIRSymmetricAllocToMgpuPass() = default;
116+
AIRSymmetricAllocToMgpuPass(const AIRSymmetricAllocToMgpuPass &) {}
117+
118+
void runOnOperation() override {
119+
auto module = getOperation();
120+
OpBuilder builder(module.getContext());
121+
auto i64Ty = builder.getI64Type();
122+
auto ptrTy = LLVM::LLVMPointerType::get(module.getContext());
123+
124+
// Collect symmetric allocs.
125+
SmallVector<memref::AllocOp> symAllocs;
126+
module.walk([&](memref::AllocOp op) {
127+
if (op->hasAttr("air.symmetric"))
128+
symAllocs.push_back(op);
129+
});
130+
131+
if (symAllocs.empty())
132+
return;
133+
134+
auto allocFn = ensureExternFunc(
135+
module, builder, "mgpuSymmetricAlloc",
136+
builder.getFunctionType({i64Ty, ptrTy}, {ptrTy}));
137+
auto freeFn = ensureExternFunc(
138+
module, builder, "mgpuSymmetricFree",
139+
builder.getFunctionType({ptrTy, ptrTy}, {}));
140+
141+
// Track the !llvm.ptr backing each lowered memref so deallocs can look
142+
// them up.
143+
DenseMap<Value, Value> symmetricMemrefToPtr;
144+
145+
for (memref::AllocOp alloc : symAllocs) {
146+
auto memrefTy = alloc.getType();
147+
Location loc = alloc.getLoc();
148+
builder.setInsertionPoint(alloc);
149+
150+
Value sizeBytes = computeMemrefByteSize(builder, loc, memrefTy);
151+
if (!sizeBytes) {
152+
alloc.emitOpError(
153+
"air.symmetric memref.alloc requires a static-shape memref with "
154+
"byte-aligned element type");
155+
signalPassFailure();
156+
return;
157+
}
158+
Value nullPtr = LLVM::ZeroOp::create(builder, loc, ptrTy);
159+
Value ptr = func::CallOp::create(builder, loc, allocFn,
160+
ValueRange{sizeBytes, nullPtr})
161+
.getResult(0);
162+
163+
Value desc = buildMemrefDescriptor(builder, loc, memrefTy, ptr);
164+
Value newMemref = UnrealizedConversionCastOp::create(
165+
builder, loc, TypeRange{memrefTy}, ValueRange{desc})
166+
.getResult(0);
167+
symmetricMemrefToPtr[newMemref] = ptr;
168+
alloc.getResult().replaceAllUsesWith(newMemref);
169+
alloc.erase();
170+
}
171+
172+
// Lower deallocs whose operand traces back to a symmetric alloc.
173+
SmallVector<memref::DeallocOp> deallocs;
174+
module.walk([&](memref::DeallocOp op) { deallocs.push_back(op); });
175+
for (memref::DeallocOp d : deallocs) {
176+
Value src = d.getMemref();
177+
auto it = symmetricMemrefToPtr.find(src);
178+
if (it == symmetricMemrefToPtr.end())
179+
continue; // not a symmetric memref
180+
builder.setInsertionPoint(d);
181+
Value nullPtr = LLVM::ZeroOp::create(builder, d.getLoc(), ptrTy);
182+
func::CallOp::create(builder, d.getLoc(), freeFn,
183+
ValueRange{it->second, nullPtr});
184+
d.erase();
185+
}
186+
}
187+
};
188+
189+
} // namespace
190+
191+
namespace xilinx {
192+
namespace air {
193+
194+
std::unique_ptr<mlir::Pass> createAIRSymmetricAllocToMgpuPass() {
195+
return std::make_unique<AIRSymmetricAllocToMgpuPass>();
196+
}
197+
198+
} // namespace air
199+
} // namespace xilinx

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ if(AIR_ENABLE_GPU)
5858
AIRTranslateToLLVMPass.cpp
5959
GPUKernelOutlinePass.cpp
6060
AIRRankToMgpuPass.cpp
61+
AIRSymmetricAllocToMgpuPass.cpp
6162
)
6263
list(APPEND CONVERSION_LINK_LIBS
6364
MLIRGPUDialect

mlir/lib/Conversion/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#if AIR_ENABLE_GPU
1212
#include "air/Conversion/AIRRankToMgpuPass.h"
13+
#include "air/Conversion/AIRSymmetricAllocToMgpuPass.h"
1314
#include "air/Conversion/AIRToROCDLPass.h"
1415
#include "air/Conversion/AIRTranslateToLLVMPass.h"
1516
#include "air/Conversion/GPUKernelOutlinePass.h"
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
//===- symmetric_alloc.mlir -------------------------------------*- MLIR -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
8+
// RUN: air-opt %s --split-input-file -air-symmetric-alloc-to-mgpu | FileCheck %s
9+
10+
// Basic 1D alloc + dealloc.
11+
// CHECK-LABEL: func.func @basic_alloc_dealloc
12+
// CHECK: %[[SZ:.*]] = arith.constant 4096 : i64
13+
// CHECK: %[[NULL:.*]] = llvm.mlir.zero : !llvm.ptr
14+
// CHECK: %[[PTR:.*]] = call @mgpuSymmetricAlloc(%[[SZ]], %[[NULL]]) : (i64, !llvm.ptr) -> !llvm.ptr
15+
// Descriptor build (poison + insertvalue) then unrealized cast.
16+
// CHECK: llvm.mlir.poison
17+
// CHECK: llvm.insertvalue %[[PTR]]
18+
// CHECK: llvm.insertvalue %[[PTR]]
19+
// CHECK: builtin.unrealized_conversion_cast {{.*}} : !llvm.struct<{{.*}}> to memref<1024xf32>
20+
// Dealloc -> mgpuSymmetricFree.
21+
// CHECK: call @mgpuSymmetricFree(%[[PTR]],
22+
// CHECK-NOT: memref.alloc
23+
// CHECK-NOT: memref.dealloc
24+
func.func @basic_alloc_dealloc() {
25+
%buf = memref.alloc() {air.symmetric} : memref<1024xf32>
26+
memref.dealloc %buf : memref<1024xf32>
27+
return
28+
}
29+
30+
// -----
31+
32+
// 2D alloc: 64*64*4 = 16384 bytes; descriptor strides should be [64, 1].
33+
// CHECK-LABEL: func.func @alloc_2d
34+
// CHECK: arith.constant 16384 : i64
35+
// CHECK: call @mgpuSymmetricAlloc
36+
// Strides 64 then 1 in the descriptor (innermost-most-contiguous).
37+
// CHECK: llvm.mlir.constant(64 : i64)
38+
// CHECK: llvm.insertvalue
39+
// CHECK: llvm.mlir.constant(1 : i64)
40+
// CHECK: llvm.insertvalue
41+
// CHECK: builtin.unrealized_conversion_cast {{.*}} : !llvm.struct<{{.*}}> to memref<64x64xf32>
42+
func.func @alloc_2d() -> memref<64x64xf32> {
43+
%buf = memref.alloc() {air.symmetric} : memref<64x64xf32>
44+
return %buf : memref<64x64xf32>
45+
}
46+
47+
// -----
48+
49+
// f64 element type (8 bytes): 1024 * 8 = 8192 bytes.
50+
// CHECK-LABEL: func.func @f64_element
51+
// CHECK: arith.constant 8192 : i64
52+
func.func @f64_element() {
53+
%buf = memref.alloc() {air.symmetric} : memref<1024xf64>
54+
memref.dealloc %buf : memref<1024xf64>
55+
return
56+
}
57+
58+
// -----
59+
60+
// i32 element type (4 bytes): 256 * 4 = 1024 bytes.
61+
// CHECK-LABEL: func.func @i32_element
62+
// CHECK: arith.constant 1024 : i64
63+
func.func @i32_element() {
64+
%buf = memref.alloc() {air.symmetric} : memref<256xi32>
65+
memref.dealloc %buf : memref<256xi32>
66+
return
67+
}
68+
69+
// -----
70+
71+
// Multiple symmetric allocs in one function: each lowered independently;
72+
// extern decls are emitted exactly once at module scope.
73+
// Match the actual emission order: Free decl before Alloc decl.
74+
// CHECK-COUNT-1: func.func private @mgpuSymmetricFree
75+
// CHECK-NOT: func.func private @mgpuSymmetricFree
76+
// CHECK-COUNT-1: func.func private @mgpuSymmetricAlloc
77+
// CHECK-NOT: func.func private @mgpuSymmetricAlloc
78+
// CHECK-LABEL: func.func @two_allocs
79+
// CHECK-COUNT-2: call @mgpuSymmetricAlloc
80+
// CHECK-COUNT-2: call @mgpuSymmetricFree
81+
func.func @two_allocs() {
82+
%a = memref.alloc() {air.symmetric} : memref<32xf32>
83+
%b = memref.alloc() {air.symmetric} : memref<64xf32>
84+
memref.dealloc %a : memref<32xf32>
85+
memref.dealloc %b : memref<64xf32>
86+
return
87+
}
88+
89+
// -----
90+
91+
// LAST partition: cases that test the pass leaves things untouched.
92+
// Both `ignores_non_symmetric` and `no_symmetric_alloc` are folded here
93+
// so the trailing CHECK-NOTs only need to match against this one (final)
94+
// partition's text.
95+
// CHECK-LABEL: func.func @no_symmetric_changes
96+
// CHECK: memref.alloc() : memref<1024xf32>
97+
// CHECK: memref.alloc() : memref<32xf32>
98+
// CHECK-NOT: mgpuSymmetricAlloc
99+
// CHECK-NOT: mgpuSymmetricFree
100+
func.func @no_symmetric_changes() {
101+
%a = memref.alloc() : memref<1024xf32>
102+
memref.dealloc %a : memref<1024xf32>
103+
%b = memref.alloc() : memref<32xf32>
104+
memref.dealloc %b : memref<32xf32>
105+
return
106+
}

0 commit comments

Comments
 (0)