Skip to content

Commit 3ad22db

Browse files
committed
[CIR][CUDA] Generate device stubs
1 parent 637f2f3 commit 3ad22db

File tree

5 files changed

+170
-13
lines changed

5 files changed

+170
-13
lines changed

clang/lib/CIR/CodeGen/CIRGenCUDA.cpp

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
//===---- CIRGenCUDA.cpp - CUDA-specific logic for CIR generation ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This contains code dealing with CUDA-specific logic of CIR generation.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "CIRGenFunction.h"
14+
#include "clang/Basic/Cuda.h"
15+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
16+
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
17+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
18+
19+
using namespace clang;
20+
using namespace clang::CIRGen;
21+
22+
void CIRGenFunction::emitCUDADeviceStubBody(cir::FuncOp fn,
23+
FunctionArgList &args) {
24+
// CUDA 9.0 changed the way to launch kernels.
25+
if (!CudaFeatureEnabled(CGM.getTarget().getSDKVersion(),
26+
CudaFeature::CUDA_USES_NEW_LAUNCH))
27+
llvm_unreachable("NYI");
28+
29+
// This requires arguments to be sent to kernels in a different way.
30+
if (getLangOpts().OffloadViaLLVM)
31+
llvm_unreachable("NYI");
32+
33+
if (getLangOpts().HIP)
34+
llvm_unreachable("NYI");
35+
36+
auto &builder = CGM.getBuilder();
37+
38+
// For cudaLaunchKernel, we must add another layer of indirection
39+
// to arguments. For example, for function `add(int a, float b)`,
40+
// we need to pass it as `void *args[2] = { &a, &b }`.
41+
42+
auto loc = fn.getLoc();
43+
auto voidPtrArrayTy =
44+
cir::ArrayType::get(&getMLIRContext(), CGM.VoidPtrTy, args.size());
45+
mlir::Value kernelArgs = builder.createAlloca(
46+
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
47+
CharUnits::fromQuantity(16));
48+
49+
// Store arguments into kernelArgs
50+
for (auto [i, arg] : llvm::enumerate(args)) {
51+
mlir::Value index =
52+
builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i));
53+
mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index);
54+
builder.CIRBaseBuilderTy::createStore(
55+
loc, GetAddrOfLocalVar(arg).getPointer(), storePos);
56+
}
57+
58+
// We retrieve dim3 type by looking into the second argument of
59+
// cudaLaunchKernel, as is done in OG.
60+
TranslationUnitDecl *tuDecl = getContext().getTranslationUnitDecl();
61+
DeclContext *dc = TranslationUnitDecl::castToDeclContext(tuDecl);
62+
63+
// The default stream is usually stream 0 (the legacy default stream).
64+
// For per-thread default stream, we need a different LaunchKernel function.
65+
if (getLangOpts().GPUDefaultStream ==
66+
LangOptions::GPUDefaultStreamKind::PerThread)
67+
llvm_unreachable("NYI");
68+
69+
std::string launchAPI = "cudaLaunchKernel";
70+
const IdentifierInfo &launchII = getContext().Idents.get(launchAPI);
71+
FunctionDecl *launchFD = nullptr;
72+
for (auto *result : dc->lookup(&launchII)) {
73+
if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
74+
launchFD = fd;
75+
}
76+
77+
if (launchFD == nullptr) {
78+
CGM.Error(CurFuncDecl->getLocation(),
79+
"Can't find declaration for " + launchAPI);
80+
return;
81+
}
82+
83+
// Use this function to retrieve arguments for cudaLaunchKernel:
84+
// int __cudaPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
85+
// *sharedMem, cudaStream_t *stream)
86+
//
87+
// Here cudaStream_t, while also being the 6th argument of cudaLaunchKernel,
88+
// is a pointer to some opaque struct.
89+
90+
mlir::Type dim3Ty =
91+
getTypes().convertType(launchFD->getParamDecl(1)->getType());
92+
mlir::Type streamTy =
93+
getTypes().convertType(launchFD->getParamDecl(5)->getType());
94+
95+
mlir::Value gridDim =
96+
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
97+
"grid_dim", CharUnits::fromQuantity(8));
98+
mlir::Value blockDim =
99+
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
100+
"block_dim", CharUnits::fromQuantity(8));
101+
mlir::Value sharedMem =
102+
builder.createAlloca(loc, cir::PointerType::get(CGM.SizeTy), CGM.SizeTy,
103+
"shared_mem", CGM.getSizeAlign());
104+
mlir::Value stream =
105+
builder.createAlloca(loc, cir::PointerType::get(streamTy), streamTy,
106+
"stream", CGM.getPointerAlign());
107+
108+
cir::FuncOp popConfig = CGM.createRuntimeFunction(
109+
cir::FuncType::get({gridDim.getType(), blockDim.getType(),
110+
sharedMem.getType(), stream.getType()},
111+
CGM.SInt32Ty),
112+
"__cudaPopCallConfiguration");
113+
emitRuntimeCall(loc, popConfig, {gridDim, blockDim, sharedMem, stream});
114+
115+
// Now emit the call to cudaLaunchKernel
116+
// cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
117+
// void **args, size_t sharedMem,
118+
// cudaStream_t stream);
119+
auto kernelTy =
120+
cir::PointerType::get(&getMLIRContext(), fn.getFunctionType());
121+
122+
mlir::Value kernel =
123+
builder.create<cir::GetGlobalOp>(loc, kernelTy, fn.getSymName());
124+
mlir::Value func = builder.createBitcast(kernel, CGM.VoidPtrTy);
125+
CallArgList launchArgs;
126+
127+
mlir::Value kernelArgsDecayed =
128+
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
129+
cir::PointerType::get(CGM.VoidPtrTy));
130+
131+
launchArgs.add(RValue::get(func), launchFD->getParamDecl(0)->getType());
132+
launchArgs.add(
133+
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
134+
launchFD->getParamDecl(1)->getType());
135+
launchArgs.add(
136+
RValue::getAggregate(Address(blockDim, CharUnits::fromQuantity(8))),
137+
launchFD->getParamDecl(2)->getType());
138+
launchArgs.add(RValue::get(kernelArgsDecayed),
139+
launchFD->getParamDecl(3)->getType());
140+
launchArgs.add(
141+
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
142+
launchFD->getParamDecl(4)->getType());
143+
launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType());
144+
145+
mlir::Type launchTy = CGM.getTypes().convertType(launchFD->getType());
146+
mlir::Operation *launchFn =
147+
CGM.createRuntimeFunction(cast<cir::FuncType>(launchTy), launchAPI);
148+
const auto &callInfo = CGM.getTypes().arrangeFunctionDeclaration(launchFD);
149+
emitCall(callInfo, CIRGenCallee::forDirect(launchFn), ReturnValueSlot(),
150+
launchArgs);
151+
}

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl GD, cir::FuncOp Fn,
753753
emitConstructorBody(Args);
754754
else if (getLangOpts().CUDA && !getLangOpts().CUDAIsDevice &&
755755
FD->hasAttr<CUDAGlobalAttr>())
756-
llvm_unreachable("NYI");
756+
emitCUDADeviceStubBody(Fn, Args);
757757
else if (isa<CXXMethodDecl>(FD) &&
758758
cast<CXXMethodDecl>(FD)->isLambdaStaticInvoker()) {
759759
// The lambda static invoker function is special, because it forwards or

clang/lib/CIR/CodeGen/CIRGenFunction.h

+3
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,9 @@ class CIRGenFunction : public CIRGenTypeCache {
11561156
mlir::LogicalResult emitOMPTaskyieldDirective(const OMPTaskyieldDirective &S);
11571157
mlir::LogicalResult emitOMPBarrierDirective(const OMPBarrierDirective &S);
11581158

1159+
// CUDA gen functions:
1160+
void emitCUDADeviceStubBody(cir::FuncOp fn, FunctionArgList &args);
1161+
11591162
LValue emitOpaqueValueLValue(const OpaqueValueExpr *e);
11601163

11611164
/// Emit code to compute a designator that specifies the location

clang/lib/CIR/CodeGen/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_clang_library(clangCIR
1919
CIRGenClass.cpp
2020
CIRGenCleanup.cpp
2121
CIRGenCoroutine.cpp
22+
CIRGenCUDA.cpp
2223
CIRGenDecl.cpp
2324
CIRGenDeclCXX.cpp
2425
CIRGenException.cpp

clang/test/CIR/CodeGen/CUDA/simple.cu

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#include "../Inputs/cuda.h"
22

33
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
4-
// RUN: -x cuda -emit-cir %s -o %t.cir
4+
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
5+
// RUN: %s -o %t.cir
56
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
67

78
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
8-
// RUN: -fcuda-is-device -emit-cir %s -o %t.cir
9+
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
10+
// RUN: %s -o %t.cir
911
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
1012

1113
// Attribute for global_fn
12-
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fnv>{{.*}}
14+
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}
1315

1416
__host__ void host_fn(int *a, int *b, int *c) {}
1517
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
@@ -19,13 +21,13 @@ __device__ void device_fn(int* a, double b, float c) {}
1921
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
2022
// CIR-DEVICE: cir.func @_Z9device_fnPidf
2123

22-
#ifdef __CUDA_ARCH__
23-
__global__ void global_fn() {}
24-
#else
25-
__global__ void global_fn();
26-
#endif
27-
// CIR-HOST: @_Z24__device_stub__global_fnv(){{.*}}extra([[Kernel]])
28-
// CIR-DEVICE: @_Z9global_fnv
24+
__global__ void global_fn(int a) {}
25+
// CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]])
26+
// CIR-DEVICE: @_Z9global_fni
2927

30-
// Make sure `global_fn` indeed gets emitted
31-
__host__ void x() { auto v = global_fn; }
28+
// Check for device stub emission.
29+
30+
// CIR-HOST: cir.alloca {{.*}}"kernel_args"
31+
// CIR-HOST: cir.call @__cudaPopCallConfiguration
32+
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
33+
// CIR-HOST: cir.call @cudaLaunchKernel

0 commit comments

Comments
 (0)