Skip to content

Commit 014a718

Browse files
committed
[CIR][CUDA] Register __global__ functions
1 parent d9eeb83 commit 014a718

File tree

2 files changed

+207
-59
lines changed

2 files changed

+207
-59
lines changed

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

+153-31
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,16 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
123123
/// CUDA related
124124
/// ------------
125125

126-
// Maps CUDA device stub name to kernel name.
127-
llvm::DenseMap<llvm::StringRef, std::string> cudaKernelMap;
126+
// Maps CUDA kernel name to device stub function.
127+
std::unordered_map<std::string, FuncOp> cudaKernelMap;
128+
llvm::StringRef cudaPrefix;
128129

129130
void buildCUDAModuleCtor();
130131
void buildCUDAModuleDtor();
131132
std::optional<FuncOp> buildCUDARegisterGlobals();
132133

134+
std::string addUnderscoredPrefix(llvm::StringRef cudaFunctionName);
135+
133136
///
134137
/// AST related
135138
/// -----------
@@ -184,6 +187,8 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
184187
llvm::SmallVector<mlir::Attribute, 4> globalDtorList;
185188
/// List of annotations in the module
186189
llvm::SmallVector<mlir::Attribute, 4> globalAnnotations;
190+
191+
TypeSizeInfoAttr typeSizeInfo;
187192
};
188193
} // namespace
189194

@@ -983,6 +988,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
983988
if (astCtx->getLangOpts().GPURelocatableDeviceCode)
984989
llvm_unreachable("NYI");
985990

991+
// For CUDA without -fgpu-rdc, it's safe to stop generating ctor
992+
// if there's nothing to register.
993+
if (cudaKernelMap.empty())
994+
return;
995+
986996
// There's no device-side binary, so no need to proceed for CUDA.
987997
// HIP has to create an external symbol in this case, which is NYI.
988998
auto cudaBinaryHandleAttr =
@@ -995,18 +1005,14 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
9951005
std::string cudaGPUBinaryName =
9961006
cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr).getName();
9971007

998-
llvm::StringRef prefix = "cuda";
1008+
cudaPrefix = "cuda";
9991009

10001010
constexpr unsigned cudaFatMagic = 0x466243b1;
10011011
constexpr unsigned hipFatMagic = 0x48495046; // "HIPF"
10021012

10031013
const unsigned fatMagic =
10041014
astCtx->getLangOpts().HIP ? hipFatMagic : cudaFatMagic;
10051015

1006-
auto addUnderscoredPrefix = [&](llvm::StringRef name) -> std::string {
1007-
return ("__" + prefix + name).str();
1008-
};
1009-
10101016
// MAC OS X needs special care, but we haven't supported that in CIR yet.
10111017
assert(!cir::MissingFeatures::checkMacOSXTriple());
10121018

@@ -1015,15 +1021,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
10151021

10161022
mlir::Location loc = theModule.getLoc();
10171023

1018-
// Extract types from the module.
1019-
auto typeSizesAttr = cast<TypeSizeInfoAttr>(
1020-
theModule->getAttr(CIRDialect::getTypeSizeInfoAttrName()));
1021-
10221024
auto voidTy = VoidType::get(&getContext());
10231025
auto voidPtrTy = PointerType::get(voidTy);
10241026
auto voidPtrPtrTy = PointerType::get(voidPtrTy);
1025-
auto intTy = typeSizesAttr.getIntType(&getContext());
1026-
auto charTy = typeSizesAttr.getCharType(&getContext());
1027+
auto intTy = typeSizeInfo.getIntType(&getContext());
1028+
auto charTy = typeSizeInfo.getCharType(&getContext());
10271029

10281030
// Read the GPU binary and create a constant array for it.
10291031
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cudaGPUBinaryOrErr =
@@ -1066,22 +1068,30 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
10661068

10671069
std::string fatbinWrapperName = addUnderscoredPrefix("_fatbin_wrapper");
10681070
GlobalOp fatbinWrapper = builder.create<GlobalOp>(
1069-
loc, fatbinWrapperName, fatbinWrapperType, /*isConstant=*/false,
1071+
loc, fatbinWrapperName, fatbinWrapperType, /*isConstant=*/true,
10701072
/*linkage=*/cir::GlobalLinkageKind::InternalLinkage);
10711073
fatbinWrapper.setPrivate();
10721074
fatbinWrapper.setSection(fatbinSectionName);
10731075

10741076
auto magicInit = IntAttr::get(intTy, fatMagic);
10751077
auto versionInit = IntAttr::get(intTy, 1);
1076-
// `fatbinInit` is only a placeholder. The value will be initialized at the
1077-
// beginning of module ctor.
1078-
auto fatbinInit = builder.getConstNullPtrAttr(voidPtrTy);
1078+
auto fatbinStrSymbol =
1079+
mlir::FlatSymbolRefAttr::get(fatbinStr.getSymNameAttr());
1080+
auto fatbinInit = GlobalViewAttr::get(voidPtrTy, fatbinStrSymbol);
10791081
auto unusedInit = builder.getConstNullPtrAttr(voidPtrTy);
10801082
fatbinWrapper.setInitialValueAttr(cir::ConstStructAttr::get(
10811083
fatbinWrapperType,
10821084
ArrayAttr::get(&getContext(),
10831085
{magicInit, versionInit, fatbinInit, unusedInit})));
10841086

1087+
// GPU fat binary handle is also a global variable in OG.
1088+
std::string gpubinHandleName = addUnderscoredPrefix("_gpubin_handle");
1089+
auto gpubinHandle = builder.create<GlobalOp>(
1090+
loc, gpubinHandleName, voidPtrPtrTy,
1091+
/*isConstant=*/false, /*linkage=*/GlobalLinkageKind::InternalLinkage);
1092+
gpubinHandle.setInitialValueAttr(builder.getConstNullPtrAttr(voidPtrPtrTy));
1093+
gpubinHandle.setPrivate();
1094+
10851095
// Declare this function:
10861096
// void **__{cuda|hip}RegisterFatBinary(void *);
10871097

@@ -1098,25 +1108,131 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
10981108
globalCtorList.push_back(GlobalCtorAttr::get(&getContext(), moduleCtorName));
10991109
builder.setInsertionPointToStart(moduleCtor.addEntryBlock());
11001110

1101-
auto wrapper = builder.createGetGlobal(fatbinWrapper);
1102-
// Put fatbinStr inside fatbinWrapper.
1103-
mlir::Value fatbinStrValue = builder.createGetGlobal(fatbinStr);
1104-
mlir::Value fatbinField = builder.createGetMemberOp(loc, wrapper, "", 2);
1105-
builder.createStore(loc, fatbinStrValue, fatbinField);
1106-
11071111
// Register binary with CUDA runtime. This is substantially different in
11081112
// default mode vs. separate compilation.
11091113
// Corresponding code:
11101114
// gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
1115+
auto wrapper = builder.createGetGlobal(fatbinWrapper);
11111116
auto fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
1112-
auto gpuBinaryHandle = builder.createCallOp(loc, regFunc, fatbinVoidPtr);
1117+
auto gpuBinaryHandleCall = builder.createCallOp(loc, regFunc, fatbinVoidPtr);
1118+
auto gpuBinaryHandle = gpuBinaryHandleCall.getResult();
1119+
// Store the value back to the global `__cuda_gpubin_handle`.
1120+
auto gpuBinaryHandleGlobal = builder.createGetGlobal(gpubinHandle);
1121+
builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
1122+
1123+
// Generate __cuda_register_globals and call it.
1124+
std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals();
1125+
if (regGlobal) {
1126+
builder.createCallOp(loc, *regGlobal, gpuBinaryHandle);
1127+
}
11131128

1114-
// This is currently incomplete.
1115-
// TODO(cir): create __cuda_register_globals(), and call it here.
1129+
// From CUDA 10.1 onwards, we must call this function to end registration:
1130+
// void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
1131+
// This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
1132+
if (clang::CudaFeatureEnabled(
1133+
astCtx->getTargetInfo().getSDKVersion(),
1134+
clang::CudaFeature::CUDA_USES_FATBIN_REGISTER_END)) {
1135+
cir::CIRBaseBuilderTy globalBuilder(getContext());
1136+
globalBuilder.setInsertionPointToStart(theModule.getBody());
1137+
FuncOp endFunc =
1138+
buildRuntimeFunction(globalBuilder, "__cudaRegisterFatBinaryEnd", loc,
1139+
FuncType::get({voidPtrPtrTy}, voidTy));
1140+
builder.createCallOp(loc, endFunc, gpuBinaryHandle);
1141+
}
11161142

11171143
builder.create<cir::ReturnOp>(loc);
11181144
}
11191145

1146+
std::string
1147+
LoweringPreparePass::addUnderscoredPrefix(llvm::StringRef cudaFunctionName) {
1148+
return ("__" + cudaPrefix + cudaFunctionName).str();
1149+
}
1150+
1151+
std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
1152+
// There is nothing to register.
1153+
if (cudaKernelMap.empty())
1154+
return {};
1155+
1156+
cir::CIRBaseBuilderTy builder(getContext());
1157+
builder.setInsertionPointToStart(theModule.getBody());
1158+
1159+
auto loc = theModule.getLoc();
1160+
1161+
auto voidTy = cir::VoidType::get(&getContext());
1162+
auto voidPtrTy = cir::PointerType::get(voidTy);
1163+
auto voidPtrPtrTy = cir::PointerType::get(voidPtrTy);
1164+
auto intTy = typeSizeInfo.getIntType(&getContext());
1165+
auto charTy = typeSizeInfo.getCharType(&getContext());
1166+
1167+
// Create the function:
1168+
// void __cuda_register_globals(void **fatbinHandle)
1169+
std::string regGlobalFuncName = addUnderscoredPrefix("_register_globals");
1170+
auto regGlobalFuncTy = FuncType::get({voidPtrPtrTy}, voidTy);
1171+
FuncOp regGlobalFunc =
1172+
buildRuntimeFunction(builder, regGlobalFuncName, loc, regGlobalFuncTy,
1173+
/*linkage=*/GlobalLinkageKind::InternalLinkage);
1174+
builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
1175+
1176+
// Extract the GPU binary handle argument.
1177+
mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
1178+
1179+
// Declare CUDA internal functions:
1180+
// int __cudaRegisterFunction(
1181+
// void **fatbinHandle,
1182+
// const char *hostFunc,
1183+
// char *deviceFunc,
1184+
// const char *deviceName,
1185+
// int threadLimit,
1186+
// uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
1187+
// int *wsize
1188+
// )
1189+
// OG doesn't care about the types at all. They're treated as void*.
1190+
cir::CIRBaseBuilderTy globalBuilder(getContext());
1191+
globalBuilder.setInsertionPointToStart(theModule.getBody());
1192+
1193+
FuncOp cudaRegisterFunction = buildRuntimeFunction(
1194+
globalBuilder, addUnderscoredPrefix("RegisterFunction"), loc,
1195+
FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
1196+
voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
1197+
intTy));
1198+
1199+
auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
1200+
auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
1201+
1202+
auto tmpString = globalBuilder.create<GlobalOp>(
1203+
loc, (".str" + str).str(), strType, /*isConstant=*/true,
1204+
/*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
1205+
1206+
// We must make the string zero-terminated.
1207+
tmpString.setInitialValueAttr(ConstArrayAttr::get(
1208+
strType, StringAttr::get(&getContext(), str + "\0")));
1209+
tmpString.setPrivate();
1210+
return tmpString;
1211+
};
1212+
1213+
auto cirNullPtr = builder.getNullPtr(voidPtrTy, loc);
1214+
for (auto [kernelName, deviceStub] : cudaKernelMap) {
1215+
GlobalOp deviceFuncStr = makeConstantString(kernelName);
1216+
mlir::Value deviceFunc = builder.createBitcast(
1217+
builder.createGetGlobal(deviceFuncStr), voidPtrTy);
1218+
mlir::Value hostFunc = builder.createBitcast(
1219+
builder.create<GetGlobalOp>(
1220+
loc, PointerType::get(deviceStub.getFunctionType()),
1221+
mlir::FlatSymbolRefAttr::get(deviceStub.getSymNameAttr())),
1222+
voidPtrTy);
1223+
builder.createCallOp(
1224+
loc, cudaRegisterFunction,
1225+
{fatbinHandle, hostFunc, deviceFunc, deviceFunc,
1226+
builder.create<ConstantOp>(loc, IntAttr::get(intTy, -1)), cirNullPtr,
1227+
cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
1228+
}
1229+
1230+
// TODO(cir): registration for global variables.
1231+
1232+
builder.create<ReturnOp>(loc);
1233+
return regGlobalFunc;
1234+
}
1235+
11201236
void LoweringPreparePass::lowerDynamicCastOp(DynamicCastOp op) {
11211237
CIRBaseBuilderTy builder(getContext());
11221238
builder.setInsertionPointAfter(op);
@@ -1378,11 +1494,10 @@ void LoweringPreparePass::runOnOp(Operation *op) {
13781494
globalDtorList.push_back(globalDtor);
13791495
}
13801496
if (auto attr = fnOp.getExtraAttrs().getElements().get(
1381-
CIRDialect::getCUDABinaryHandleAttrName())) {
1382-
auto cudaBinaryAttr = dyn_cast<CUDABinaryHandleAttr>(attr);
1383-
std::string kernelName = cudaBinaryAttr.getName();
1384-
llvm::StringRef stubName = fnOp.getSymName();
1385-
cudaKernelMap[stubName] = kernelName;
1497+
CUDAKernelNameAttr::getMnemonic())) {
1498+
auto cudaBinaryAttr = dyn_cast<CUDAKernelNameAttr>(attr);
1499+
std::string kernelName = cudaBinaryAttr.getKernelName();
1500+
cudaKernelMap[kernelName] = fnOp;
13861501
}
13871502
if (std::optional<mlir::ArrayAttr> annotations = fnOp.getAnnotations())
13881503
addGlobalAnnotations(fnOp, annotations.value());
@@ -1399,6 +1514,13 @@ void LoweringPreparePass::runOnOperation() {
13991514
datalayout.emplace(theModule);
14001515
}
14011516

1517+
if (astCtx->getLangOpts().CUDA) {
1518+
cudaPrefix = "cuda";
1519+
}
1520+
1521+
typeSizeInfo = cast<TypeSizeInfoAttr>(
1522+
theModule->getAttr(CIRDialect::getTypeSizeInfoAttrName()));
1523+
14021524
llvm::SmallVector<Operation *> opsToTransform;
14031525

14041526
op->walk([&](Operation *op) {

clang/test/CIR/CodeGen/CUDA/registration.cu

+54-28
Original file line numberDiff line numberDiff line change
@@ -13,56 +13,82 @@
1313
// RUN: %s -o %t.ll
1414
// RUN: FileCheck --check-prefix=LLVM-HOST --input-file=%t.ll %s
1515

16-
// COM: OG doesn't emit anything if there is nothing to register.
17-
// COM: Here we still emit the template for test purposes,
18-
// COM: and the behaviour will be fixed later.
19-
2016
// CIR-HOST: module @"{{.*}}" attributes {
2117
// CIR-HOST: cir.cu.binary_handle = #cir.cu.binary_handle<{{.*}}.fatbin>,
2218
// CIR-HOST: cir.global_ctors = [#cir.global_ctor<"__cuda_module_ctor", {{[0-9]+}}>]
2319
// CIR-HOST: }
2420

21+
// CIR-HOST: cir.global "private" constant cir_private @".str_Z2fnv" =
22+
// CIR-HOST-SAME: #cir.const_array<"_Z2fnv", trailing_zeros>
23+
24+
// COM: In OG this variable has an `unnamed_addr` attribute.
25+
// LLVM-HOST: @.str_Z2fnv = private constant [7 x i8] c"_Z2fnv\00"
26+
27+
// The corresponding CIR test for these three variables are down below.
28+
// They are here because LLVM IR puts global variables at the front of file.
29+
30+
// LLVM-HOST: @__cuda_fatbin_str = private constant [14 x i8] c"sample fatbin\0A", section ".nv_fatbin"
31+
// LLVM-HOST: @__cuda_fatbin_wrapper = internal constant {
32+
// LLVM-HOST: i32 1180844977, i32 1, ptr @__cuda_fatbin_str, ptr null
33+
// LLVM-HOST: }
34+
// LLVM-HOST: @llvm.global_ctors = {{.*}}ptr @__cuda_module_ctor
35+
36+
__global__ void fn() {}
37+
38+
// CIR-HOST: cir.func internal private @__cuda_register_globals(%[[FatbinHandle:[a-zA-Z0-9]+]]{{.*}}) {
39+
// CIR-HOST: %[[#NULL:]] = cir.const #cir.ptr<null>
40+
// CIR-HOST: %[[#T1:]] = cir.get_global @".str_Z2fnv"
41+
// CIR-HOST: %[[#DeviceFn:]] = cir.cast(bitcast, %[[#T1]]
42+
// CIR-HOST: %[[#T2:]] = cir.get_global @_Z17__device_stub__fnv
43+
// CIR-HOST: %[[#HostFn:]] = cir.cast(bitcast, %[[#T2]]
44+
// CIR-HOST: %[[#MinusOne:]] = cir.const #cir.int<-1>
45+
// CIR-HOST: cir.call @__cudaRegisterFunction(
46+
// CIR-HOST-SAME: %[[FatbinHandle]],
47+
// CIR-HOST-SAME: %[[#HostFn]],
48+
// CIR-HOST-SAME: %[[#DeviceFn]],
49+
// CIR-HOST-SAME: %[[#DeviceFn]],
50+
// CIR-HOST-SAME: %[[#MinusOne]],
51+
// CIR-HOST-SAME: %[[#NULL]], %[[#NULL]], %[[#NULL]], %[[#NULL]], %[[#NULL]])
52+
// CIR-HOST: }
53+
54+
// LLVM-HOST: define internal void @__cuda_register_globals(ptr %[[#LLVMFatbin:]]) {
55+
// LLVM-HOST: call i32 @__cudaRegisterFunction(
56+
// LLVM-HOST-SAME: ptr %[[#LLVMFatbin]],
57+
// LLVM-HOST-SAME: ptr @_Z17__device_stub__fnv,
58+
// LLVM-HOST-SAME: ptr @.str_Z2fnv,
59+
// LLVM-HOST-SAME: ptr @.str_Z2fnv,
60+
// LLVM-HOST-SAME: i32 -1,
61+
// LLVM-HOST-SAME: ptr null, ptr null, ptr null, ptr null, ptr null)
62+
// LLVM-HOST: }
63+
2564
// The content in const array should be the same as echoed above,
2665
// with a trailing line break ('\n', 0x0A).
2766
// CIR-HOST: cir.global "private" constant cir_private @__cuda_fatbin_str =
2867
// CIR-HOST-SAME: #cir.const_array<"sample fatbin\0A">
2968
// CIR-HOST-SAME: {{.*}}section = ".nv_fatbin"
3069

31-
// LLVM-HOST: @__cuda_fatbin_str = private constant [14 x i8] c"sample fatbin\0A", section ".nv_fatbin"
32-
3370
// The first value is CUDA file head magic number.
34-
// CIR-HOST: cir.global "private" internal @__cuda_fatbin_wrapper
71+
// CIR-HOST: cir.global "private" constant internal @__cuda_fatbin_wrapper
3572
// CIR-HOST: = #cir.const_struct<{
3673
// CIR-HOST: #cir.int<1180844977> : !s32i,
3774
// CIR-HOST: #cir.int<1> : !s32i,
38-
// CIR-HOST: #cir.ptr<null> : !cir.ptr<!void>,
75+
// CIR-HOST: #cir.global_view<@__cuda_fatbin_str> : !cir.ptr<!void>,
3976
// CIR-HOST: #cir.ptr<null> : !cir.ptr<!void>
4077
// CIR-HOST: }>
4178
// CIR-HOST-SAME: {{.*}}section = ".nvFatBinSegment"
4279

43-
// COM: @__cuda_fatbin_wrapper is constant for OG.
44-
// COM: However, as we don't have a way to put @__cuda_fatbin_str directly
45-
// COM: to its third field in Clang IR, we can't mark this variable as
46-
// COM: constant: we need to initialize it later, at the beginning
47-
// COM: of @__cuda_module_ctor.
48-
49-
// LLVM-HOST: @__cuda_fatbin_wrapper = internal global {
50-
// LLVM-HOST: i32 1180844977, i32 1, ptr null, ptr null
51-
// LLVM-HOST: }
52-
53-
// LLVM-HOST: @llvm.global_ctors = {{.*}}ptr @__cuda_module_ctor
54-
5580
// CIR-HOST: cir.func private @__cudaRegisterFatBinary
5681
// CIR-HOST: cir.func {{.*}} @__cuda_module_ctor() {
57-
// CIR-HOST: %[[#F0:]] = cir.get_global @__cuda_fatbin_wrapper
58-
// CIR-HOST: %[[#F1:]] = cir.get_global @__cuda_fatbin_str
59-
// CIR-HOST: %[[#F2:]] = cir.get_member %[[#F0]][2]
60-
// CIR-HOST: %[[#F3:]] = cir.cast(bitcast, %[[#F2]]
61-
// CIR-HOST: cir.store %[[#F1]], %[[#F3]]
62-
// CIR-HOST: cir.call @__cudaRegisterFatBinary
82+
// CIR-HOST: %[[#Fatbin:]] = cir.call @__cudaRegisterFatBinary
83+
// CIR-HOST: %[[#FatbinGlobal:]] = cir.get_global @__cuda_gpubin_handle
84+
// CIR-HOST: cir.store %[[#Fatbin]], %[[#FatbinGlobal]]
85+
// CIR-HOST: cir.call @__cuda_register_globals
86+
// CIR-HOTS: cir.call @__cudaRegisterFatBinaryEnd
6387
// CIR-HOST: }
6488

6589
// LLVM-HOST: define internal void @__cuda_module_ctor() {
66-
// LLVM-HOST: store ptr @__cuda_fatbin_str, ptr getelementptr {{.*}}, ptr @__cuda_fatbin_wrapper
67-
// LLVM-HOST: call ptr @__cudaRegisterFatBinary(ptr @__cuda_fatbin_wrapper)
90+
// LLVM-HOST: %[[#LLVMFatbin:]] = call ptr @__cudaRegisterFatBinary(ptr @__cuda_fatbin_wrapper)
91+
// LLVM-HOST: store ptr %[[#LLVMFatbin]], ptr @__cuda_gpubin_handle
92+
// LLVM-HOST: call void @__cuda_register_globals
93+
// LLVM-HOST: call void @__cudaRegisterFatBinaryEnd
6894
// LLVM-HOST: }

0 commit comments

Comments
 (0)