Skip to content

Commit 324e4c2

Browse files
committed
[CIR][CUDA] Generate device stubs
1 parent 637f2f3 commit 324e4c2

File tree

7 files changed

+248
-15
lines changed

7 files changed

+248
-15
lines changed
+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
//===--- CIRGenCUDARuntime.cpp - Interface to CUDA Runtimes ----*- C++ -*--===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This provides an abstract class for CUDA CIR generation. Concrete
10+
// subclasses of this implement code generation for specific OpenCL
11+
// runtime libraries.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "CIRGenCUDARuntime.h"
16+
#include "CIRGenFunction.h"
17+
#include "clang/Basic/Cuda.h"
18+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
19+
20+
using namespace clang;
21+
using namespace clang::CIRGen;
22+
23+
CIRGenCUDARuntime::~CIRGenCUDARuntime() {}
24+
25+
void CIRGenCUDARuntime::emitDeviceStubBodyLegacy(CIRGenFunction &cgf,
26+
cir::FuncOp fn,
27+
FunctionArgList &args) {
28+
llvm_unreachable("NYI");
29+
}
30+
31+
void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
32+
cir::FuncOp fn,
33+
FunctionArgList &args) {
34+
if (cgm.getLangOpts().HIP)
35+
llvm_unreachable("NYI");
36+
37+
// This requires arguments to be sent to kernels in a different way.
38+
if (cgm.getLangOpts().OffloadViaLLVM)
39+
llvm_unreachable("NYI");
40+
41+
auto &builder = cgm.getBuilder();
42+
43+
// For cudaLaunchKernel, we must add another layer of indirection
44+
// to arguments. For example, for function `add(int a, float b)`,
45+
// we need to pass it as `void *args[2] = { &a, &b }`.
46+
47+
auto loc = fn.getLoc();
48+
auto voidPtrArrayTy =
49+
cir::ArrayType::get(&cgm.getMLIRContext(), cgm.VoidPtrTy, args.size());
50+
mlir::Value kernelArgs = builder.createAlloca(
51+
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
52+
CharUnits::fromQuantity(16));
53+
54+
// Store arguments into kernelArgs
55+
for (auto [i, arg] : llvm::enumerate(args)) {
56+
mlir::Value index =
57+
builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i));
58+
mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index);
59+
builder.CIRBaseBuilderTy::createStore(
60+
loc, cgf.GetAddrOfLocalVar(arg).getPointer(), storePos);
61+
}
62+
63+
// We retrieve dim3 type by looking into the second argument of
64+
// cudaLaunchKernel, as is done in OG.
65+
TranslationUnitDecl *tuDecl = cgm.getASTContext().getTranslationUnitDecl();
66+
DeclContext *dc = TranslationUnitDecl::castToDeclContext(tuDecl);
67+
68+
// The default stream is usually stream 0 (the legacy default stream).
69+
// For per-thread default stream, we need a different LaunchKernel function.
70+
if (cgm.getLangOpts().GPUDefaultStream ==
71+
LangOptions::GPUDefaultStreamKind::PerThread)
72+
llvm_unreachable("NYI");
73+
74+
std::string launchAPI = "cudaLaunchKernel";
75+
const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI);
76+
FunctionDecl *launchFD = nullptr;
77+
for (auto *result : dc->lookup(&launchII)) {
78+
if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
79+
launchFD = fd;
80+
}
81+
82+
if (launchFD == nullptr) {
83+
cgm.Error(cgf.CurFuncDecl->getLocation(),
84+
"Can't find declaration for " + launchAPI);
85+
return;
86+
}
87+
88+
// Use this function to retrieve arguments for cudaLaunchKernel:
89+
// int __cudaPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
90+
// *sharedMem, cudaStream_t *stream)
91+
//
92+
// Here cudaStream_t, while also being the 6th argument of cudaLaunchKernel,
93+
// is a pointer to some opaque struct.
94+
95+
mlir::Type dim3Ty =
96+
cgf.getTypes().convertType(launchFD->getParamDecl(1)->getType());
97+
mlir::Type streamTy =
98+
cgf.getTypes().convertType(launchFD->getParamDecl(5)->getType());
99+
100+
mlir::Value gridDim =
101+
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
102+
"grid_dim", CharUnits::fromQuantity(8));
103+
mlir::Value blockDim =
104+
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
105+
"block_dim", CharUnits::fromQuantity(8));
106+
mlir::Value sharedMem =
107+
builder.createAlloca(loc, cir::PointerType::get(cgm.SizeTy), cgm.SizeTy,
108+
"shared_mem", cgm.getSizeAlign());
109+
mlir::Value stream =
110+
builder.createAlloca(loc, cir::PointerType::get(streamTy), streamTy,
111+
"stream", cgm.getPointerAlign());
112+
113+
cir::FuncOp popConfig = cgm.createRuntimeFunction(
114+
cir::FuncType::get({gridDim.getType(), blockDim.getType(),
115+
sharedMem.getType(), stream.getType()},
116+
cgm.SInt32Ty),
117+
"__cudaPopCallConfiguration");
118+
cgf.emitRuntimeCall(loc, popConfig, {gridDim, blockDim, sharedMem, stream});
119+
120+
// Now emit the call to cudaLaunchKernel
121+
// cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
122+
// void **args, size_t sharedMem,
123+
// cudaStream_t stream);
124+
auto kernelTy =
125+
cir::PointerType::get(&cgm.getMLIRContext(), fn.getFunctionType());
126+
127+
mlir::Value kernel =
128+
builder.create<cir::GetGlobalOp>(loc, kernelTy, fn.getSymName());
129+
mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
130+
CallArgList launchArgs;
131+
132+
mlir::Value kernelArgsDecayed =
133+
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
134+
cir::PointerType::get(cgm.VoidPtrTy));
135+
136+
launchArgs.add(RValue::get(func), launchFD->getParamDecl(0)->getType());
137+
launchArgs.add(
138+
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
139+
launchFD->getParamDecl(1)->getType());
140+
launchArgs.add(
141+
RValue::getAggregate(Address(blockDim, CharUnits::fromQuantity(8))),
142+
launchFD->getParamDecl(2)->getType());
143+
launchArgs.add(RValue::get(kernelArgsDecayed),
144+
launchFD->getParamDecl(3)->getType());
145+
launchArgs.add(
146+
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
147+
launchFD->getParamDecl(4)->getType());
148+
launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType());
149+
150+
mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType());
151+
mlir::Operation *launchFn =
152+
cgm.createRuntimeFunction(cast<cir::FuncType>(launchTy), launchAPI);
153+
const auto &callInfo = cgm.getTypes().arrangeFunctionDeclaration(launchFD);
154+
cgf.emitCall(callInfo, CIRGenCallee::forDirect(launchFn), ReturnValueSlot(),
155+
launchArgs);
156+
}
157+
158+
void CIRGenCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
159+
FunctionArgList &args) {
160+
// Device stub and its handle might be different.
161+
if (cgm.getLangOpts().HIP)
162+
llvm_unreachable("NYI");
163+
164+
// CUDA 9.0 changed the way to launch kernels.
165+
if (CudaFeatureEnabled(cgm.getTarget().getSDKVersion(),
166+
CudaFeature::CUDA_USES_NEW_LAUNCH) ||
167+
cgm.getLangOpts().OffloadViaLLVM)
168+
emitDeviceStubBodyNew(cgf, fn, args);
169+
else
170+
emitDeviceStubBodyLegacy(cgf, fn, args);
171+
}
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===------ CIRGenCUDARuntime.h - Interface to CUDA Runtimes -----*- C++ -*-==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This provides an abstract class for CUDA CIR generation. Concrete
10+
// subclasses of this implement code generation for specific OpenCL
11+
// runtime libraries.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H
16+
#define LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H
17+
18+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
19+
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
20+
21+
namespace clang::CIRGen {
22+
23+
class CIRGenFunction;
24+
class CIRGenModule;
25+
class FunctionArgList;
26+
27+
class CIRGenCUDARuntime {
28+
protected:
29+
CIRGenModule &cgm;
30+
31+
private:
32+
void emitDeviceStubBodyLegacy(CIRGenFunction &cgf, cir::FuncOp fn,
33+
FunctionArgList &args);
34+
void emitDeviceStubBodyNew(CIRGenFunction &cgf, cir::FuncOp fn,
35+
FunctionArgList &args);
36+
37+
public:
38+
CIRGenCUDARuntime(CIRGenModule &cgm) : cgm(cgm) {}
39+
virtual ~CIRGenCUDARuntime();
40+
41+
virtual void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
42+
FunctionArgList &args);
43+
};
44+
45+
} // namespace clang::CIRGen
46+
47+
#endif // LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl GD, cir::FuncOp Fn,
753753
emitConstructorBody(Args);
754754
else if (getLangOpts().CUDA && !getLangOpts().CUDAIsDevice &&
755755
FD->hasAttr<CUDAGlobalAttr>())
756-
llvm_unreachable("NYI");
756+
CGM.getCUDARuntime().emitDeviceStub(*this, Fn, Args);
757757
else if (isa<CXXMethodDecl>(FD) &&
758758
cast<CXXMethodDecl>(FD)->isLambdaStaticInvoker()) {
759759
// The lambda static invoker function is special, because it forwards or

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// This is the internal per-translation-unit state used for CIR translation.
1010
//
1111
//===----------------------------------------------------------------------===//
12+
#include "CIRGenCUDARuntime.h"
1213
#include "CIRGenCXXABI.h"
1314
#include "CIRGenCstEmitter.h"
1415
#include "CIRGenFunction.h"
@@ -108,7 +109,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &mlirContext,
108109
theModule{mlir::ModuleOp::create(builder.getUnknownLoc())}, Diags(Diags),
109110
target(astContext.getTargetInfo()), ABI(createCXXABI(*this)),
110111
genTypes{*this}, VTables{*this},
111-
openMPRuntime(new CIRGenOpenMPRuntime(*this)) {
112+
openMPRuntime(new CIRGenOpenMPRuntime(*this)),
113+
cudaRuntime(new CIRGenCUDARuntime(*this)) {
112114

113115
// Initialize CIR signed integer types cache.
114116
SInt8Ty = cir::IntType::get(&getMLIRContext(), 8, /*isSigned=*/true);

clang/lib/CIR/CodeGen/CIRGenModule.h

+11-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "Address.h"
1717
#include "CIRGenBuilder.h"
18+
#include "CIRGenCUDARuntime.h"
1819
#include "CIRGenCall.h"
1920
#include "CIRGenOpenCLRuntime.h"
2021
#include "CIRGenTBAA.h"
@@ -113,6 +114,9 @@ class CIRGenModule : public CIRGenTypeCache {
113114
/// Holds the OpenMP runtime
114115
std::unique_ptr<CIRGenOpenMPRuntime> openMPRuntime;
115116

117+
/// Holds the CUDA runtime
118+
std::unique_ptr<CIRGenCUDARuntime> cudaRuntime;
119+
116120
/// Per-function codegen information. Updated everytime emitCIR is called
117121
/// for FunctionDecls's.
118122
CIRGenFunction *CurCGF = nullptr;
@@ -862,12 +866,18 @@ class CIRGenModule : public CIRGenTypeCache {
862866
/// Print out an error that codegen doesn't support the specified decl yet.
863867
void ErrorUnsupported(const Decl *D, const char *Type);
864868

865-
/// Return a reference to the configured OpenMP runtime.
869+
/// Return a reference to the configured OpenCL runtime.
866870
CIRGenOpenCLRuntime &getOpenCLRuntime() {
867871
assert(openCLRuntime != nullptr);
868872
return *openCLRuntime;
869873
}
870874

875+
/// Return a reference to the configured CUDA runtime.
876+
CIRGenCUDARuntime &getCUDARuntime() {
877+
assert(cudaRuntime != nullptr);
878+
return *cudaRuntime;
879+
}
880+
871881
void createOpenCLRuntime() {
872882
openCLRuntime.reset(new CIRGenOpenCLRuntime(*this));
873883
}

clang/lib/CIR/CodeGen/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_clang_library(clangCIR
1919
CIRGenClass.cpp
2020
CIRGenCleanup.cpp
2121
CIRGenCoroutine.cpp
22+
CIRGenCUDARuntime.cpp
2223
CIRGenDecl.cpp
2324
CIRGenDeclCXX.cpp
2425
CIRGenException.cpp

clang/test/CIR/CodeGen/CUDA/simple.cu

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#include "../Inputs/cuda.h"
22

33
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
4-
// RUN: -x cuda -emit-cir %s -o %t.cir
4+
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
5+
// RUN: %s -o %t.cir
56
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
67

78
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
8-
// RUN: -fcuda-is-device -emit-cir %s -o %t.cir
9+
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
10+
// RUN: %s -o %t.cir
911
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
1012

1113
// Attribute for global_fn
12-
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fnv>{{.*}}
14+
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}
1315

1416
__host__ void host_fn(int *a, int *b, int *c) {}
1517
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
@@ -19,13 +21,13 @@ __device__ void device_fn(int* a, double b, float c) {}
1921
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
2022
// CIR-DEVICE: cir.func @_Z9device_fnPidf
2123

22-
#ifdef __CUDA_ARCH__
23-
__global__ void global_fn() {}
24-
#else
25-
__global__ void global_fn();
26-
#endif
27-
// CIR-HOST: @_Z24__device_stub__global_fnv(){{.*}}extra([[Kernel]])
28-
// CIR-DEVICE: @_Z9global_fnv
24+
__global__ void global_fn(int a) {}
25+
// CIR-DEVICE: @_Z9global_fni
2926

30-
// Make sure `global_fn` indeed gets emitted
31-
__host__ void x() { auto v = global_fn; }
27+
// Check for device stub emission.
28+
29+
// CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]])
30+
// CIR-HOST: cir.alloca {{.*}}"kernel_args"
31+
// CIR-HOST: cir.call @__cudaPopCallConfiguration
32+
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
33+
// CIR-HOST: cir.call @cudaLaunchKernel

0 commit comments

Comments
 (0)