Skip to content

Commit f6a4650

Browse files
erwei-xilinxclaude
andcommitted
[multi-gpu] Phase 2: air.translate op + air-translate-to-llvm lowering
Introduce an AIR primitive for the symmetric-heap pointer rebase, in preparation for the kernel-driven producer/consumer redesign per @mawad-amd's review feedback on PR Xilinx#1577. %peer = air.translate %src, %from, %to, %bases : memref<NxT, A>, !llvm.ptr Signature: - $source: memref on $from_rank's symmetric heap - $from_rank, $to_rank: index-typed rank ids - $heap_bases: !llvm.ptr to the per-rank base table from mgpuGetHeapBases() - result: same memref type, addressing $to_rank's slice of the same collective allocation The op is Pure and folds when from_rank == to_rank (statically equal SSA values or matching constant attrs). Naming follows IRIS's `__translate`. Lowering pass `air-translate-to-llvm` expands each op to the peer-VA arithmetic plus a freshly-built LLVM memref descriptor: byte_diff = ptrtoint(bases[to]) - ptrtoint(bases[from]) peer_aligned_ptr = src_aligned_ptr + byte_diff (i8 GEP) build descriptor { peer_ptr, peer_ptr, 0, sizes, strides } unrealized_conversion_cast back to result memref type The expansion is pure arithmetic (arith + memref + llvm dialect), no runtime calls — therefore valid both at host scope and inside `gpu.func`, provided heap_bases is threaded as a kernel argument. Tests: - mlir/test/Dialect/AIR/air_translate.mlir: parser/printer + folder - mlir/test/Conversion/AIRToROCDL/air_translate_to_llvm.mlir: lowering shape on 1D, 2D-addrspace, gpu.func body, and no-op cases Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a170c5c commit f6a4650

11 files changed

Lines changed: 422 additions & 0 deletions

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- AIRTranslateToLLVMPass.h ----------------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
8+
#ifndef AIR_CONVERSION_AIR_TRANSLATE_TO_LLVM_PASS_H
9+
#define AIR_CONVERSION_AIR_TRANSLATE_TO_LLVM_PASS_H
10+
11+
#include "mlir/Pass/Pass.h"
12+
#include <memory>
13+
14+
namespace xilinx {
15+
namespace air {
16+
17+
std::unique_ptr<mlir::Pass> createAIRTranslateToLLVMPass();
18+
19+
} // namespace air
20+
} // namespace xilinx
21+
22+
#endif // AIR_CONVERSION_AIR_TRANSLATE_TO_LLVM_PASS_H

mlir/include/air/Conversion/GPUPassDetail.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace air {
2323
using namespace mlir;
2424

2525
#define GEN_PASS_DECL
26+
#define GEN_PASS_DEF_AIRTRANSLATETOLLVM
2627
#define GEN_PASS_DEF_CONVERTAIRTOROCDL
2728
#define GEN_PASS_DEF_CONVERTGPUKERNELOUTLINE
2829
#include "air/Conversion/GPUPasses.h.inc"

mlir/include/air/Conversion/GPUPasses.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,23 @@ def ConvertAIRToROCDL : Pass<"air-to-rocdl", "ModuleOp"> {
2121
let options = [];
2222
}
2323

24+
def AIRTranslateToLLVM : Pass<"air-translate-to-llvm", "ModuleOp"> {
25+
let summary = "Lower air.translate to memref.reinterpret_cast + LLVM-dialect address arithmetic";
26+
let description = [{
27+
Expands each `air.translate` op into the pointer-rebase computation:
28+
`bases[to_rank] - bases[from_rank]`, converted from bytes to elements
29+
of the source memref's element type, then applied as a new offset
30+
via `memref.reinterpret_cast`. The expansion is pure arithmetic; it
31+
works identically on host functions and inside `gpu.func`.
32+
}];
33+
let constructor = "xilinx::air::createAIRTranslateToLLVMPass()";
34+
let dependentDialects = [
35+
"mlir::arith::ArithDialect",
36+
"mlir::memref::MemRefDialect",
37+
"mlir::LLVM::LLVMDialect"
38+
];
39+
}
40+
2441
def ConvertGPUKernelOutline : Pass<"air-gpu-outlining", "ModuleOp"> {
2542
let summary = "Outline GPU Kernel Func from GPU Launch";
2643
let constructor = "xilinx::air::createGPUKernelOutlinePass()";

mlir/include/air/Dialect/AIR/AIR.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
1616
include "mlir/Interfaces/SideEffectInterfaces.td"
1717
include "mlir/Interfaces/TilingInterface.td"
1818

19+
// Type predicate for !llvm.ptr. Inlined here (instead of including
20+
// "mlir/Dialect/LLVMIR/LLVMOpBase.td") to avoid pulling the LLVM dialect
21+
// into our TableGen scope — that would confuse `mlir-tblgen
22+
// -gen-dialect-doc` which expects exactly one dialect per .td file.
23+
def air_LLVMPtr : Type<CPred<"::llvm::isa<::mlir::LLVM::LLVMPointerType>($_self)">,
24+
"LLVM pointer",
25+
"::mlir::LLVM::LLVMPointerType">;
26+
1927
class air_Op<string mnemonic, list<Trait> traits = []> :
2028
Op<air_Dialect, mnemonic, traits>;
2129

@@ -926,6 +934,42 @@ def air_ExecuteTerminatorOp : air_Op<"execute_terminator", [HasParent<"ExecuteOp
926934
[{ attr-dict ($results^ `:` type($results))? }];
927935
}
928936

937+
def air_TranslateOp : air_Op<"translate",
938+
[Pure, AllTypesMatch<["source", "result"]>]>,
939+
Arguments<(ins AnyMemRef:$source,
940+
Index:$from_rank,
941+
Index:$to_rank,
942+
air_LLVMPtr:$heap_bases)>,
943+
Results<(outs AnyMemRef:$result)> {
944+
let summary = "Re-express a symmetric-heap memref in another rank's address space";
945+
let description = [{
946+
Produces a memref of the same type as `$source` whose underlying
947+
pointer references the corresponding allocation on `$to_rank`. The
948+
`$source` memref is assumed to live on `$from_rank`'s symmetric heap.
949+
The translation is the pointer rebase
950+
951+
peer_va = bases[to_rank] + (source_ptr - bases[from_rank])
952+
953+
where `$heap_bases` is the per-rank base table obtained from the
954+
`mgpuGetHeapBases()` runtime hook (typically called once at host
955+
scope and threaded through `gpu.launch_func` as a kernel argument).
956+
No data is moved; this op produces a value-level "view" of peer
957+
memory.
958+
959+
Folds to `$source` when `$from_rank` and `$to_rank` are statically
960+
equal.
961+
962+
Both ranks must address the same collective allocation on the
963+
symmetric heap (i.e. `$source` must trace back to a
964+
`memref.alloc {air.symmetric}`). Using this op outside that contract
965+
is undefined.
966+
}];
967+
let assemblyFormat =
968+
[{ $source `,` $from_rank `,` $to_rank `,` $heap_bases
969+
attr-dict `:` type($source) `,` type($heap_bases) }];
970+
let hasFolder = 1;
971+
}
972+
929973
// AIR custom op, as a handle for a user-provided AIE kernel
930974

931975
def air_CustomOp : air_Op<"custom", [air_AsyncOpInterface,

mlir/include/air/Dialect/AIR/AIRDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_AIR_DIALECT_H
1010
#define MLIR_AIR_DIALECT_H
1111

12+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1213
#include "mlir/IR/Builders.h"
1314
#include "mlir/IR/BuiltinOps.h"
1415
#include "mlir/IR/BuiltinTypes.h"
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
//===- AIRTranslateToLLVMPass.cpp -------------------------------*- C++ -*-===//
2+
//
3+
// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
4+
// SPDX-License-Identifier: MIT
5+
//
6+
//===-----------------------------------------------------------------------===//
7+
//
8+
// Lower air.translate to memref-descriptor construction over a peer-rebased
9+
// pointer.
10+
//
11+
// For each `air.translate %src, %from, %to, %bases`:
12+
// 1. Extract the source memref's aligned pointer as !llvm.ptr.
13+
// 2. Compute the byte diff between the per-rank base pointers from the
14+
// `$heap_bases` table:
15+
// byte_diff = ptrtoint(bases[to]) - ptrtoint(bases[from])
16+
// 3. Apply the byte diff to the source aligned pointer (i8 GEP) to obtain
17+
// the peer aligned pointer.
18+
// 4. Build a fresh LLVM memref descriptor (poison + insertvalue chain)
19+
// whose allocated/aligned pointers both point at the peer address; the
20+
// offset is 0, and sizes/strides are taken from the source memref's
21+
// static type.
22+
// 5. unrealized_conversion_cast the descriptor back to the result memref
23+
// type so downstream uses keep working through the standard
24+
// memref-to-llvm pipeline.
25+
//
26+
// The lowering only uses arith + memref + llvm dialect ops — no runtime
27+
// calls. It is therefore valid both at host scope and inside `gpu.func`
28+
// (the kernel must already have been given the heap_bases pointer as a
29+
// kernel argument).
30+
//
31+
//===-----------------------------------------------------------------------===//
32+
33+
#include "air/Conversion/AIRTranslateToLLVMPass.h"
34+
#include "air/Conversion/GPUPassDetail.h"
35+
#include "air/Dialect/AIR/AIRDialect.h"
36+
37+
#include "mlir/Dialect/Arith/IR/Arith.h"
38+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
39+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
40+
#include "mlir/IR/Builders.h"
41+
#include "mlir/IR/BuiltinOps.h"
42+
#include "mlir/Pass/Pass.h"
43+
44+
using namespace mlir;
45+
using namespace xilinx;
46+
47+
namespace {
48+
49+
// Build a fresh LLVM memref descriptor for `memrefTy` whose
50+
// allocated_ptr and aligned_ptr both reference `ptr`, offset is 0, and
51+
// sizes/strides come from the static type (row-major).
52+
//
53+
// Mirrors buildMemrefDescriptor in AIRSymmetricAllocToMgpuPass.
54+
static Value buildPeerDescriptor(OpBuilder &b, Location loc,
55+
MemRefType memrefTy, Value ptr) {
56+
ArrayRef<int64_t> shape = memrefTy.getShape();
57+
unsigned rank = shape.size();
58+
auto i64Ty = b.getI64Type();
59+
auto ptrTy = LLVM::LLVMPointerType::get(b.getContext());
60+
61+
SmallVector<Type, 5> descFields;
62+
descFields.push_back(ptrTy);
63+
descFields.push_back(ptrTy);
64+
descFields.push_back(i64Ty);
65+
if (rank > 0) {
66+
descFields.push_back(LLVM::LLVMArrayType::get(i64Ty, rank));
67+
descFields.push_back(LLVM::LLVMArrayType::get(i64Ty, rank));
68+
}
69+
auto structTy = LLVM::LLVMStructType::getLiteral(b.getContext(), descFields);
70+
71+
Value desc = LLVM::PoisonOp::create(b, loc, structTy);
72+
desc = LLVM::InsertValueOp::create(b, loc, desc, ptr, ArrayRef<int64_t>{0});
73+
desc = LLVM::InsertValueOp::create(b, loc, desc, ptr, ArrayRef<int64_t>{1});
74+
Value zero = LLVM::ConstantOp::create(b, loc, i64Ty, b.getI64IntegerAttr(0));
75+
desc = LLVM::InsertValueOp::create(b, loc, desc, zero, ArrayRef<int64_t>{2});
76+
77+
if (rank > 0) {
78+
SmallVector<int64_t> strides(rank, 1);
79+
for (int i = static_cast<int>(rank) - 2; i >= 0; --i)
80+
strides[i] = strides[i + 1] * shape[i + 1];
81+
for (unsigned i = 0; i < rank; ++i) {
82+
Value sz = LLVM::ConstantOp::create(b, loc, i64Ty,
83+
b.getI64IntegerAttr(shape[i]));
84+
desc = LLVM::InsertValueOp::create(b, loc, desc, sz,
85+
ArrayRef<int64_t>{3, (int64_t)i});
86+
Value st = LLVM::ConstantOp::create(b, loc, i64Ty,
87+
b.getI64IntegerAttr(strides[i]));
88+
desc = LLVM::InsertValueOp::create(b, loc, desc, st,
89+
ArrayRef<int64_t>{4, (int64_t)i});
90+
}
91+
}
92+
return desc;
93+
}
94+
95+
struct AIRTranslateToLLVMPass
96+
: public xilinx::air::impl::AIRTranslateToLLVMBase<AIRTranslateToLLVMPass> {
97+
98+
AIRTranslateToLLVMPass() = default;
99+
AIRTranslateToLLVMPass(const AIRTranslateToLLVMPass &) {}
100+
101+
void runOnOperation() override {
102+
auto module = getOperation();
103+
auto *ctx = module.getContext();
104+
OpBuilder builder(ctx);
105+
auto i64Ty = builder.getI64Type();
106+
auto ptrTy = LLVM::LLVMPointerType::get(ctx);
107+
108+
SmallVector<air::TranslateOp> translates;
109+
module.walk([&](air::TranslateOp op) { translates.push_back(op); });
110+
if (translates.empty())
111+
return;
112+
113+
for (air::TranslateOp op : translates) {
114+
builder.setInsertionPoint(op);
115+
Location loc = op.getLoc();
116+
117+
auto memrefTy = cast<MemRefType>(op.getSource().getType());
118+
if (!memrefTy.hasStaticShape()) {
119+
op.emitOpError("air.translate requires a static-shape source memref");
120+
signalPassFailure();
121+
return;
122+
}
123+
124+
// Extract source aligned pointer as !llvm.ptr.
125+
Value srcAlignedIdx = memref::ExtractAlignedPointerAsIndexOp::create(
126+
builder, loc, op.getSource());
127+
Value srcAlignedI64 = arith::IndexCastOp::create(builder, loc, i64Ty,
128+
srcAlignedIdx);
129+
Value srcAlignedPtr =
130+
LLVM::IntToPtrOp::create(builder, loc, ptrTy, srcAlignedI64);
131+
132+
// Load bases[from] and bases[to].
133+
Value fromI64 = arith::IndexCastOp::create(builder, loc, i64Ty,
134+
op.getFromRank());
135+
Value toI64 = arith::IndexCastOp::create(builder, loc, i64Ty,
136+
op.getToRank());
137+
Value fromBaseAddr = LLVM::GEPOp::create(
138+
builder, loc, ptrTy, ptrTy, op.getHeapBases(), ValueRange{fromI64});
139+
Value fromBase = LLVM::LoadOp::create(builder, loc, ptrTy, fromBaseAddr);
140+
Value toBaseAddr = LLVM::GEPOp::create(builder, loc, ptrTy, ptrTy,
141+
op.getHeapBases(),
142+
ValueRange{toI64});
143+
Value toBase = LLVM::LoadOp::create(builder, loc, ptrTy, toBaseAddr);
144+
145+
// byte_diff = ptrtoint(toBase) - ptrtoint(fromBase)
146+
Value fromInt = LLVM::PtrToIntOp::create(builder, loc, i64Ty, fromBase);
147+
Value toInt = LLVM::PtrToIntOp::create(builder, loc, i64Ty, toBase);
148+
Value byteDiff = arith::SubIOp::create(builder, loc, toInt, fromInt);
149+
150+
// peer_aligned_ptr = srcAlignedPtr + byteDiff (as i8 GEP)
151+
auto i8Ty = builder.getI8Type();
152+
Value peerAlignedPtr = LLVM::GEPOp::create(
153+
builder, loc, ptrTy, i8Ty, srcAlignedPtr, ValueRange{byteDiff});
154+
155+
// Build a fresh memref descriptor with the peer aligned pointer.
156+
Value desc = buildPeerDescriptor(builder, loc, memrefTy, peerAlignedPtr);
157+
Value newMemref =
158+
UnrealizedConversionCastOp::create(builder, loc,
159+
TypeRange{memrefTy},
160+
ValueRange{desc})
161+
.getResult(0);
162+
163+
op.getResult().replaceAllUsesWith(newMemref);
164+
op.erase();
165+
}
166+
}
167+
};
168+
169+
} // namespace
170+
171+
namespace xilinx {
172+
namespace air {
173+
174+
std::unique_ptr<mlir::Pass> createAIRTranslateToLLVMPass() {
175+
return std::make_unique<AIRTranslateToLLVMPass>();
176+
}
177+
178+
} // namespace air
179+
} // namespace xilinx

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ if(AIR_ENABLE_GPU)
5555
set(GPU_PASS_DEPENDS AIRGPUConversionPassIncGen)
5656
list(APPEND CONVERSION_SOURCES
5757
AIRToROCDLPass.cpp
58+
AIRTranslateToLLVMPass.cpp
5859
GPUKernelOutlinePass.cpp
5960
)
6061
list(APPEND CONVERSION_LINK_LIBS

mlir/lib/Conversion/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#if AIR_ENABLE_GPU
1212
#include "air/Conversion/AIRToROCDLPass.h"
13+
#include "air/Conversion/AIRTranslateToLLVMPass.h"
1314
#include "air/Conversion/GPUKernelOutlinePass.h"
1415
#endif
1516

mlir/lib/Dialect/AIR/IR/AIRDialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3598,6 +3598,20 @@ ParseResult air::CustomOp::parse(OpAsmParser &parser, OperationState &result) {
35983598
return success();
35993599
}
36003600

3601+
//
3602+
// TranslateOp
3603+
//
3604+
3605+
OpFoldResult air::TranslateOp::fold(FoldAdaptor adaptor) {
3606+
if (getFromRank() == getToRank())
3607+
return getSource();
3608+
auto fromAttr = dyn_cast_if_present<IntegerAttr>(adaptor.getFromRank());
3609+
auto toAttr = dyn_cast_if_present<IntegerAttr>(adaptor.getToRank());
3610+
if (fromAttr && toAttr && fromAttr.getValue() == toAttr.getValue())
3611+
return getSource();
3612+
return {};
3613+
}
3614+
36013615
} // namespace xilinx
36023616

36033617
#include "air/Dialect/AIR/AIROpInterfaces.cpp.inc"

0 commit comments

Comments
 (0)