Skip to content

Commit 1e3b942

Browse files
committed
add initial set of lowerings for MPI dialect
1 parent ac40463 commit 1e3b942

File tree

7 files changed

+321
-1
lines changed

7 files changed

+321
-1
lines changed

Diff for: mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//
2+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// See https://llvm.org/LICENSE.txt for license information.
4+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#ifndef MLIR_CONVERSION_MPITOLLVM_H
9+
#define MLIR_CONVERSION_MPITOLLVM_H
10+
11+
#include "mlir/IR/DialectRegistry.h"
12+
13+
namespace mlir {
14+
15+
class LLVMTypeConverter;
16+
class RewritePatternSet;
17+
18+
#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
namespace mpi {
22+
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
23+
RewritePatternSet &patterns);
24+
25+
void registerConvertMPIToLLVMInterface(DialectRegistry &registry);
26+
27+
} // namespace mpi
28+
} // namespace mlir
29+
30+
#endif // MLIR_CONVERSION_MPITOLLVM_H

Diff for: mlir/include/mlir/Dialect/MPI/IR/MPITypes.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
3030
//===----------------------------------------------------------------------===//
3131

3232
def MPI_Retval : MPI_Type<"Retval", "retval"> {
33-
let summary = "MPI function call return value";
33+
let summary = "MPI function call return value (!mpi.retval)";
3434
let description = [{
3535
This type represents a return value from an MPI function call.
3636
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.

Diff for: mlir/include/mlir/InitAllExtensions.h

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_INITALLEXTENSIONS_H_
1515
#define MLIR_INITALLEXTENSIONS_H_
1616

17+
#include "Conversion/MPIToLLVM/MPIToLLVM.h"
1718
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1819
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
1920
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
@@ -62,6 +63,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
6263
registerConvertFuncToLLVMInterface(registry);
6364
index::registerConvertIndexToLLVMInterface(registry);
6465
registerConvertMathToLLVMInterface(registry);
66+
mpi::registerConvertMPIToLLVMInterface(registry);
6567
registerConvertMemRefToLLVMInterface(registry);
6668
registerConvertNVVMToLLVMInterface(registry);
6769
ub::registerConvertUBToLLVMInterface(registry);

Diff for: mlir/lib/Conversion/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ add_subdirectory(MathToSPIRV)
3939
add_subdirectory(MemRefToEmitC)
4040
add_subdirectory(MemRefToLLVM)
4141
add_subdirectory(MemRefToSPIRV)
42+
add_subdirectory(MPIToLLVM)
4243
add_subdirectory(NVGPUToNVVM)
4344
add_subdirectory(NVVMToLLVM)
4445
add_subdirectory(OpenACCToSCF)

Diff for: mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(MLIRMPIToLLVM
2+
MPIToLLVM.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRLLVMCommonConversion
15+
MLIRLLVMDialect
16+
MLIRMPIDialect
17+
)

Diff for: mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

+230
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
10+
11+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
12+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13+
#include "mlir/Dialect/MPI/IR/MPI.h"
14+
#include "mlir/Pass/Pass.h"
15+
16+
#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
17+
18+
using namespace mlir;
19+
20+
namespace {
21+
22+
struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
23+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
24+
25+
LogicalResult
26+
matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
27+
ConversionPatternRewriter &rewriter) const override;
28+
};
29+
30+
struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
31+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
32+
33+
LogicalResult
34+
matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
35+
ConversionPatternRewriter &rewriter) const override;
36+
};
37+
38+
struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
39+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
40+
41+
LogicalResult
42+
matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
43+
ConversionPatternRewriter &rewriter) const override;
44+
};
45+
46+
// TODO: this was copied from GPUOpsLowering.cpp:288
47+
// is this okay, or should this be moved to some common file?
48+
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
49+
ConversionPatternRewriter &rewriter,
50+
StringRef name,
51+
LLVM::LLVMFunctionType type) {
52+
LLVM::LLVMFuncOp ret;
53+
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
54+
ConversionPatternRewriter::InsertionGuard guard(rewriter);
55+
rewriter.setInsertionPointToStart(moduleOp.getBody());
56+
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
57+
LLVM::Linkage::External);
58+
}
59+
return ret;
60+
}
61+
62+
// TODO: this is pretty close to getOrDefineFunction, can probably be factored
63+
LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
64+
ConversionPatternRewriter &rewriter,
65+
StringRef name,
66+
LLVM::LLVMStructType type) {
67+
LLVM::GlobalOp ret;
68+
if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
69+
ConversionPatternRewriter::InsertionGuard guard(rewriter);
70+
rewriter.setInsertionPointToStart(moduleOp.getBody());
71+
ret = rewriter.create<LLVM::GlobalOp>(
72+
loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
73+
/*value=*/Attribute(), /*alignment=*/0, 0);
74+
}
75+
return ret;
76+
}
77+
78+
} // namespace
79+
80+
//===----------------------------------------------------------------------===//
81+
// InitOpLowering
82+
//===----------------------------------------------------------------------===//
83+
84+
LogicalResult
85+
InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
86+
ConversionPatternRewriter &rewriter) const {
87+
// get loc
88+
auto loc = op.getLoc();
89+
90+
// ptrType `!llvm.ptr`
91+
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
92+
93+
// instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
94+
auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
95+
Value llvmnull = nullPtrOp.getRes();
96+
97+
// grab a reference to the global module op:
98+
auto moduleOp = op->getParentOfType<ModuleOp>();
99+
100+
// LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
101+
auto initFuncType =
102+
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
103+
// get or create function declaration:
104+
LLVM::LLVMFuncOp initDecl =
105+
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
106+
107+
// replace init with function call
108+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
109+
ValueRange{llvmnull, llvmnull});
110+
111+
return success();
112+
}
113+
114+
//===----------------------------------------------------------------------===//
115+
// FinalizeOpLowering
116+
//===----------------------------------------------------------------------===//
117+
118+
LogicalResult
119+
FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
120+
ConversionPatternRewriter &rewriter) const {
121+
// get loc
122+
auto loc = op.getLoc();
123+
124+
// grab a reference to the global module op:
125+
auto moduleOp = op->getParentOfType<ModuleOp>();
126+
127+
// LLVM Function type representing `i32 MPI_Finalize()`
128+
auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
129+
// get or create function declaration:
130+
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
131+
"MPI_Finalize", initFuncType);
132+
133+
// replace init with function call
134+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
135+
136+
return success();
137+
}
138+
139+
//===----------------------------------------------------------------------===//
140+
// CommRankLowering
141+
//===----------------------------------------------------------------------===//
142+
143+
LogicalResult
144+
CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
145+
ConversionPatternRewriter &rewriter) const {
146+
// get some helper vars
147+
auto loc = op.getLoc();
148+
auto context = rewriter.getContext();
149+
auto i32 = rewriter.getI32Type();
150+
151+
// ptrType `!llvm.ptr`
152+
Type ptrType = LLVM::LLVMPointerType::get(context);
153+
154+
// get external opaque struct pointer type
155+
auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
156+
157+
// grab a reference to the global module op:
158+
auto moduleOp = op->getParentOfType<ModuleOp>();
159+
160+
// make sure global op definition exists
161+
getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
162+
commStructT);
163+
164+
// get address of @MPI_COMM_WORLD
165+
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
166+
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
167+
auto commWorld = rewriter.create<LLVM::AddressOfOp>(
168+
loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
169+
170+
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
171+
auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
172+
// get or create function declaration:
173+
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
174+
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
175+
176+
// replace init with function call
177+
auto callOp = rewriter.create<LLVM::CallOp>(
178+
loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});
179+
180+
// load the rank into a register
181+
auto loadedRank =
182+
rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
183+
184+
// if retval is checked, replace uses of retval with the results from the call
185+
// op
186+
SmallVector<Value> replacements;
187+
if (op.getRetval()) {
188+
replacements.push_back(callOp.getResult());
189+
}
190+
// replace all uses, then erase op
191+
replacements.push_back(loadedRank.getRes());
192+
rewriter.replaceOp(op, replacements);
193+
194+
return success();
195+
}
196+
197+
//===----------------------------------------------------------------------===//
198+
// Pattern Population
199+
//===----------------------------------------------------------------------===//
200+
201+
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
202+
RewritePatternSet &patterns) {
203+
patterns.add<InitOpLowering>(converter);
204+
patterns.add<CommRankOpLowering>(converter);
205+
patterns.add<FinalizeOpLowering>(converter);
206+
}
207+
208+
//===----------------------------------------------------------------------===//
209+
// ConvertToLLVMPatternInterface implementation
210+
//===----------------------------------------------------------------------===//
211+
212+
namespace {
213+
/// Implement the interface to convert Func to LLVM.
214+
struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
215+
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
216+
/// Hook for derived dialect interface to provide conversion patterns
217+
/// and mark dialect legal for the conversion target.
218+
void populateConvertToLLVMConversionPatterns(
219+
ConversionTarget &target, LLVMTypeConverter &typeConverter,
220+
RewritePatternSet &patterns) const final {
221+
mpi::populateMPIToLLVMConversionPatterns(typeConverter, patterns);
222+
}
223+
};
224+
} // namespace
225+
226+
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
227+
registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
228+
dialect->addInterfaces<FuncToLLVMDialectInterface>();
229+
});
230+
}

Diff for: mlir/test/Conversion/MPIToLLVM/ops.mlir

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s
2+
3+
module {
4+
// CHECK: llvm.func @MPI_Finalize() -> i32
5+
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
6+
// CHECK: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
7+
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
8+
9+
func.func @mpi_test(%arg0: memref<100xf32>) {
10+
%0 = mpi.init : !mpi.retval
11+
// CHECK: %7 = llvm.mlir.zero : !llvm.ptr
12+
// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32
13+
// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval
14+
15+
16+
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
17+
// CHECK: %10 = llvm.mlir.constant(1 : i32) : i32
18+
// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr
19+
// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
20+
// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32
21+
// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32
22+
// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval
23+
24+
mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
25+
26+
%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
27+
28+
mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
29+
30+
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
31+
32+
%3 = mpi.finalize : !mpi.retval
33+
// CHECK: %18 = llvm.call @MPI_Finalize() : () -> i32
34+
35+
%4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
36+
37+
%5 = mpi.error_class %0 : !mpi.retval
38+
return
39+
}
40+
}

0 commit comments

Comments
 (0)