@@ -123,13 +123,16 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
123
123
// / CUDA related
124
124
// / ------------
125
125
126
- // Maps CUDA device stub name to kernel name.
127
- llvm::DenseMap<llvm::StringRef, std::string> cudaKernelMap;
126
+ // Maps CUDA kernel name to device stub function.
127
+ std::unordered_map<std::string, FuncOp> cudaKernelMap;
128
+ llvm::StringRef cudaPrefix;
128
129
129
130
void buildCUDAModuleCtor ();
130
131
void buildCUDAModuleDtor ();
131
132
std::optional<FuncOp> buildCUDARegisterGlobals ();
132
133
134
+ std::string addUnderscoredPrefix (llvm::StringRef cudaFunctionName);
135
+
133
136
// /
134
137
// / AST related
135
138
// / -----------
@@ -184,6 +187,8 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
184
187
llvm::SmallVector<mlir::Attribute, 4 > globalDtorList;
185
188
// / List of annotations in the module
186
189
llvm::SmallVector<mlir::Attribute, 4 > globalAnnotations;
190
+
191
+ TypeSizeInfoAttr typeSizeInfo;
187
192
};
188
193
} // namespace
189
194
@@ -983,6 +988,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
983
988
if (astCtx->getLangOpts ().GPURelocatableDeviceCode )
984
989
llvm_unreachable (" NYI" );
985
990
991
+ // For CUDA without -fgpu-rdc, it's safe to stop generating ctor
992
+ // if there's nothing to register.
993
+ if (cudaKernelMap.empty ())
994
+ return ;
995
+
986
996
// There's no device-side binary, so no need to proceed for CUDA.
987
997
// HIP has to create an external symbol in this case, which is NYI.
988
998
auto cudaBinaryHandleAttr =
@@ -995,18 +1005,14 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
995
1005
std::string cudaGPUBinaryName =
996
1006
cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr).getName ();
997
1007
998
- llvm::StringRef prefix = " cuda" ;
1008
+ cudaPrefix = " cuda" ;
999
1009
1000
1010
constexpr unsigned cudaFatMagic = 0x466243b1 ;
1001
1011
constexpr unsigned hipFatMagic = 0x48495046 ; // "HIPF"
1002
1012
1003
1013
const unsigned fatMagic =
1004
1014
astCtx->getLangOpts ().HIP ? hipFatMagic : cudaFatMagic;
1005
1015
1006
- auto addUnderscoredPrefix = [&](llvm::StringRef name) -> std::string {
1007
- return (" __" + prefix + name).str ();
1008
- };
1009
-
1010
1016
// MAC OS X needs special care, but we haven't supported that in CIR yet.
1011
1017
assert (!cir::MissingFeatures::checkMacOSXTriple ());
1012
1018
@@ -1015,15 +1021,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
1015
1021
1016
1022
mlir::Location loc = theModule.getLoc ();
1017
1023
1018
- // Extract types from the module.
1019
- auto typeSizesAttr = cast<TypeSizeInfoAttr>(
1020
- theModule->getAttr (CIRDialect::getTypeSizeInfoAttrName ()));
1021
-
1022
1024
auto voidTy = VoidType::get (&getContext ());
1023
1025
auto voidPtrTy = PointerType::get (voidTy);
1024
1026
auto voidPtrPtrTy = PointerType::get (voidPtrTy);
1025
- auto intTy = typeSizesAttr .getIntType (&getContext ());
1026
- auto charTy = typeSizesAttr .getCharType (&getContext ());
1027
+ auto intTy = typeSizeInfo .getIntType (&getContext ());
1028
+ auto charTy = typeSizeInfo .getCharType (&getContext ());
1027
1029
1028
1030
// Read the GPU binary and create a constant array for it.
1029
1031
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cudaGPUBinaryOrErr =
@@ -1066,22 +1068,30 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
1066
1068
1067
1069
std::string fatbinWrapperName = addUnderscoredPrefix (" _fatbin_wrapper" );
1068
1070
GlobalOp fatbinWrapper = builder.create <GlobalOp>(
1069
- loc, fatbinWrapperName, fatbinWrapperType, /* isConstant=*/ false ,
1071
+ loc, fatbinWrapperName, fatbinWrapperType, /* isConstant=*/ true ,
1070
1072
/* linkage=*/ cir::GlobalLinkageKind::InternalLinkage);
1071
1073
fatbinWrapper.setPrivate ();
1072
1074
fatbinWrapper.setSection (fatbinSectionName);
1073
1075
1074
1076
auto magicInit = IntAttr::get (intTy, fatMagic);
1075
1077
auto versionInit = IntAttr::get (intTy, 1 );
1076
- // `fatbinInit` is only a placeholder. The value will be initialized at the
1077
- // beginning of module ctor.
1078
- auto fatbinInit = builder. getConstNullPtrAttr (voidPtrTy);
1078
+ auto fatbinStrSymbol =
1079
+ mlir::FlatSymbolRefAttr::get (fatbinStr. getSymNameAttr ());
1080
+ auto fatbinInit = GlobalViewAttr::get (voidPtrTy, fatbinStrSymbol );
1079
1081
auto unusedInit = builder.getConstNullPtrAttr (voidPtrTy);
1080
1082
fatbinWrapper.setInitialValueAttr (cir::ConstStructAttr::get (
1081
1083
fatbinWrapperType,
1082
1084
ArrayAttr::get (&getContext (),
1083
1085
{magicInit, versionInit, fatbinInit, unusedInit})));
1084
1086
1087
+ // GPU fat binary handle is also a global variable in OG.
1088
+ std::string gpubinHandleName = addUnderscoredPrefix (" _gpubin_handle" );
1089
+ auto gpubinHandle = builder.create <GlobalOp>(
1090
+ loc, gpubinHandleName, voidPtrPtrTy,
1091
+ /* isConstant=*/ false , /* linkage=*/ GlobalLinkageKind::InternalLinkage);
1092
+ gpubinHandle.setInitialValueAttr (builder.getConstNullPtrAttr (voidPtrPtrTy));
1093
+ gpubinHandle.setPrivate ();
1094
+
1085
1095
// Declare this function:
1086
1096
// void **__{cuda|hip}RegisterFatBinary(void *);
1087
1097
@@ -1098,25 +1108,131 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
1098
1108
globalCtorList.push_back (GlobalCtorAttr::get (&getContext (), moduleCtorName));
1099
1109
builder.setInsertionPointToStart (moduleCtor.addEntryBlock ());
1100
1110
1101
- auto wrapper = builder.createGetGlobal (fatbinWrapper);
1102
- // Put fatbinStr inside fatbinWrapper.
1103
- mlir::Value fatbinStrValue = builder.createGetGlobal (fatbinStr);
1104
- mlir::Value fatbinField = builder.createGetMemberOp (loc, wrapper, " " , 2 );
1105
- builder.createStore (loc, fatbinStrValue, fatbinField);
1106
-
1107
1111
// Register binary with CUDA runtime. This is substantially different in
1108
1112
// default mode vs. separate compilation.
1109
1113
// Corresponding code:
1110
1114
// gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
1115
+ auto wrapper = builder.createGetGlobal (fatbinWrapper);
1111
1116
auto fatbinVoidPtr = builder.createBitcast (wrapper, voidPtrTy);
1112
- auto gpuBinaryHandle = builder.createCallOp (loc, regFunc, fatbinVoidPtr);
1117
+ auto gpuBinaryHandleCall = builder.createCallOp (loc, regFunc, fatbinVoidPtr);
1118
+ auto gpuBinaryHandle = gpuBinaryHandleCall.getResult ();
1119
+ // Store the value back to the global `__cuda_gpubin_handle`.
1120
+ auto gpuBinaryHandleGlobal = builder.createGetGlobal (gpubinHandle);
1121
+ builder.createStore (loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
1122
+
1123
+ // Generate __cuda_register_globals and call it.
1124
+ std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals ();
1125
+ if (regGlobal) {
1126
+ builder.createCallOp (loc, *regGlobal, gpuBinaryHandle);
1127
+ }
1113
1128
1114
- // This is currently incomplete.
1115
- // TODO(cir): create __cuda_register_globals(), and call it here.
1129
+ // From CUDA 10.1 onwards, we must call this function to end registration:
1130
+ // void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
1131
+ // This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
1132
+ if (clang::CudaFeatureEnabled (
1133
+ astCtx->getTargetInfo ().getSDKVersion (),
1134
+ clang::CudaFeature::CUDA_USES_FATBIN_REGISTER_END)) {
1135
+ cir::CIRBaseBuilderTy globalBuilder (getContext ());
1136
+ globalBuilder.setInsertionPointToStart (theModule.getBody ());
1137
+ FuncOp endFunc =
1138
+ buildRuntimeFunction (globalBuilder, " __cudaRegisterFatBinaryEnd" , loc,
1139
+ FuncType::get ({voidPtrPtrTy}, voidTy));
1140
+ builder.createCallOp (loc, endFunc, gpuBinaryHandle);
1141
+ }
1116
1142
1117
1143
builder.create <cir::ReturnOp>(loc);
1118
1144
}
1119
1145
1146
+ std::string
1147
+ LoweringPreparePass::addUnderscoredPrefix (llvm::StringRef cudaFunctionName) {
1148
+ return (" __" + cudaPrefix + cudaFunctionName).str ();
1149
+ }
1150
+
1151
+ std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals () {
1152
+ // There is nothing to register.
1153
+ if (cudaKernelMap.empty ())
1154
+ return {};
1155
+
1156
+ cir::CIRBaseBuilderTy builder (getContext ());
1157
+ builder.setInsertionPointToStart (theModule.getBody ());
1158
+
1159
+ auto loc = theModule.getLoc ();
1160
+
1161
+ auto voidTy = cir::VoidType::get (&getContext ());
1162
+ auto voidPtrTy = cir::PointerType::get (voidTy);
1163
+ auto voidPtrPtrTy = cir::PointerType::get (voidPtrTy);
1164
+ auto intTy = typeSizeInfo.getIntType (&getContext ());
1165
+ auto charTy = typeSizeInfo.getCharType (&getContext ());
1166
+
1167
+ // Create the function:
1168
+ // void __cuda_register_globals(void **fatbinHandle)
1169
+ std::string regGlobalFuncName = addUnderscoredPrefix (" _register_globals" );
1170
+ auto regGlobalFuncTy = FuncType::get ({voidPtrPtrTy}, voidTy);
1171
+ FuncOp regGlobalFunc =
1172
+ buildRuntimeFunction (builder, regGlobalFuncName, loc, regGlobalFuncTy,
1173
+ /* linkage=*/ GlobalLinkageKind::InternalLinkage);
1174
+ builder.setInsertionPointToStart (regGlobalFunc.addEntryBlock ());
1175
+
1176
+ // Extract the GPU binary handle argument.
1177
+ mlir::Value fatbinHandle = *regGlobalFunc.args_begin ();
1178
+
1179
+ // Declare CUDA internal functions:
1180
+ // int __cudaRegisterFunction(
1181
+ // void **fatbinHandle,
1182
+ // const char *hostFunc,
1183
+ // char *deviceFunc,
1184
+ // const char *deviceName,
1185
+ // int threadLimit,
1186
+ // uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
1187
+ // int *wsize
1188
+ // )
1189
+ // OG doesn't care about the types at all. They're treated as void*.
1190
+ cir::CIRBaseBuilderTy globalBuilder (getContext ());
1191
+ globalBuilder.setInsertionPointToStart (theModule.getBody ());
1192
+
1193
+ FuncOp cudaRegisterFunction = buildRuntimeFunction (
1194
+ globalBuilder, addUnderscoredPrefix (" RegisterFunction" ), loc,
1195
+ FuncType::get ({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
1196
+ voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
1197
+ intTy));
1198
+
1199
+ auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
1200
+ auto strType = ArrayType::get (&getContext (), charTy, 1 + str.size ());
1201
+
1202
+ auto tmpString = globalBuilder.create <GlobalOp>(
1203
+ loc, (" .str" + str).str (), strType, /* isConstant=*/ true ,
1204
+ /* linkage=*/ cir::GlobalLinkageKind::PrivateLinkage);
1205
+
1206
+ // We must make the string zero-terminated.
1207
+ tmpString.setInitialValueAttr (ConstArrayAttr::get (
1208
+ strType, StringAttr::get (&getContext (), str + " \0 " )));
1209
+ tmpString.setPrivate ();
1210
+ return tmpString;
1211
+ };
1212
+
1213
+ auto cirNullPtr = builder.getNullPtr (voidPtrTy, loc);
1214
+ for (auto [kernelName, deviceStub] : cudaKernelMap) {
1215
+ GlobalOp deviceFuncStr = makeConstantString (kernelName);
1216
+ mlir::Value deviceFunc = builder.createBitcast (
1217
+ builder.createGetGlobal (deviceFuncStr), voidPtrTy);
1218
+ mlir::Value hostFunc = builder.createBitcast (
1219
+ builder.create <GetGlobalOp>(
1220
+ loc, PointerType::get (deviceStub.getFunctionType ()),
1221
+ mlir::FlatSymbolRefAttr::get (deviceStub.getSymNameAttr ())),
1222
+ voidPtrTy);
1223
+ builder.createCallOp (
1224
+ loc, cudaRegisterFunction,
1225
+ {fatbinHandle, hostFunc, deviceFunc, deviceFunc,
1226
+ builder.create <ConstantOp>(loc, IntAttr::get (intTy, -1 )), cirNullPtr,
1227
+ cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
1228
+ }
1229
+
1230
+ // TODO(cir): registration for global variables.
1231
+
1232
+ builder.create <ReturnOp>(loc);
1233
+ return regGlobalFunc;
1234
+ }
1235
+
1120
1236
void LoweringPreparePass::lowerDynamicCastOp (DynamicCastOp op) {
1121
1237
CIRBaseBuilderTy builder (getContext ());
1122
1238
builder.setInsertionPointAfter (op);
@@ -1378,11 +1494,10 @@ void LoweringPreparePass::runOnOp(Operation *op) {
1378
1494
globalDtorList.push_back (globalDtor);
1379
1495
}
1380
1496
if (auto attr = fnOp.getExtraAttrs ().getElements ().get (
1381
- CIRDialect::getCUDABinaryHandleAttrName ())) {
1382
- auto cudaBinaryAttr = dyn_cast<CUDABinaryHandleAttr>(attr);
1383
- std::string kernelName = cudaBinaryAttr.getName ();
1384
- llvm::StringRef stubName = fnOp.getSymName ();
1385
- cudaKernelMap[stubName] = kernelName;
1497
+ CUDAKernelNameAttr::getMnemonic ())) {
1498
+ auto cudaBinaryAttr = dyn_cast<CUDAKernelNameAttr>(attr);
1499
+ std::string kernelName = cudaBinaryAttr.getKernelName ();
1500
+ cudaKernelMap[kernelName] = fnOp;
1386
1501
}
1387
1502
if (std::optional<mlir::ArrayAttr> annotations = fnOp.getAnnotations ())
1388
1503
addGlobalAnnotations (fnOp, annotations.value ());
@@ -1399,6 +1514,13 @@ void LoweringPreparePass::runOnOperation() {
1399
1514
datalayout.emplace (theModule);
1400
1515
}
1401
1516
1517
+ if (astCtx->getLangOpts ().CUDA ) {
1518
+ cudaPrefix = " cuda" ;
1519
+ }
1520
+
1521
+ typeSizeInfo = cast<TypeSizeInfoAttr>(
1522
+ theModule->getAttr (CIRDialect::getTypeSizeInfoAttrName ()));
1523
+
1402
1524
llvm::SmallVector<Operation *> opsToTransform;
1403
1525
1404
1526
op->walk ([&](Operation *op) {
0 commit comments