Skip to content

Commit fe4f3c5

Browse files
committed
[CIR][CUDA] Generate device stubs
1 parent 637f2f3 commit fe4f3c5

File tree

7 files changed

+226
-15
lines changed

7 files changed

+226
-15
lines changed
+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//===--- CIRGenCUDARuntime.cpp - Interface to CUDA Runtimes ----*- C++ -*--===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This provides an abstract class for CUDA CIR generation. Concrete
10+
// subclasses of this implement code generation for specific OpenCL
11+
// runtime libraries.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "CIRGenCUDARuntime.h"
16+
#include "CIRGenFunction.h"
17+
#include "clang/Basic/Cuda.h"
18+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
19+
20+
using namespace clang;
21+
using namespace clang::CIRGen;
22+
23+
CIRGenCUDARuntime::~CIRGenCUDARuntime() {}
24+
25+
void CIRGenCUDARuntime::emitDeviceStubBody(CIRGenFunction &cgf, cir::FuncOp fn,
26+
FunctionArgList &args) {
27+
// CUDA 9.0 changed the way to launch kernels.
28+
if (!CudaFeatureEnabled(cgm.getTarget().getSDKVersion(),
29+
CudaFeature::CUDA_USES_NEW_LAUNCH))
30+
llvm_unreachable("NYI");
31+
32+
// This requires arguments to be sent to kernels in a different way.
33+
if (cgm.getLangOpts().OffloadViaLLVM)
34+
llvm_unreachable("NYI");
35+
36+
if (cgm.getLangOpts().HIP)
37+
llvm_unreachable("NYI");
38+
39+
auto &builder = cgm.getBuilder();
40+
41+
// For cudaLaunchKernel, we must add another layer of indirection
42+
// to arguments. For example, for function `add(int a, float b)`,
43+
// we need to pass it as `void *args[2] = { &a, &b }`.
44+
45+
auto loc = fn.getLoc();
46+
auto voidPtrArrayTy =
47+
cir::ArrayType::get(&cgm.getMLIRContext(), cgm.VoidPtrTy, args.size());
48+
mlir::Value kernelArgs = builder.createAlloca(
49+
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
50+
CharUnits::fromQuantity(16));
51+
52+
// Store arguments into kernelArgs
53+
for (auto [i, arg] : llvm::enumerate(args)) {
54+
mlir::Value index =
55+
builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i));
56+
mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index);
57+
builder.CIRBaseBuilderTy::createStore(
58+
loc, cgf.GetAddrOfLocalVar(arg).getPointer(), storePos);
59+
}
60+
61+
// We retrieve dim3 type by looking into the second argument of
62+
// cudaLaunchKernel, as is done in OG.
63+
TranslationUnitDecl *tuDecl = cgm.getASTContext().getTranslationUnitDecl();
64+
DeclContext *dc = TranslationUnitDecl::castToDeclContext(tuDecl);
65+
66+
// The default stream is usually stream 0 (the legacy default stream).
67+
// For per-thread default stream, we need a different LaunchKernel function.
68+
if (cgm.getLangOpts().GPUDefaultStream ==
69+
LangOptions::GPUDefaultStreamKind::PerThread)
70+
llvm_unreachable("NYI");
71+
72+
std::string launchAPI = "cudaLaunchKernel";
73+
const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI);
74+
FunctionDecl *launchFD = nullptr;
75+
for (auto *result : dc->lookup(&launchII)) {
76+
if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
77+
launchFD = fd;
78+
}
79+
80+
if (launchFD == nullptr) {
81+
cgm.Error(cgf.CurFuncDecl->getLocation(),
82+
"Can't find declaration for " + launchAPI);
83+
return;
84+
}
85+
86+
// Use this function to retrieve arguments for cudaLaunchKernel:
87+
// int __cudaPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
88+
// *sharedMem, cudaStream_t *stream)
89+
//
90+
// Here cudaStream_t, while also being the 6th argument of cudaLaunchKernel,
91+
// is a pointer to some opaque struct.
92+
93+
mlir::Type dim3Ty =
94+
cgf.getTypes().convertType(launchFD->getParamDecl(1)->getType());
95+
mlir::Type streamTy =
96+
cgf.getTypes().convertType(launchFD->getParamDecl(5)->getType());
97+
98+
mlir::Value gridDim =
99+
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
100+
"grid_dim", CharUnits::fromQuantity(8));
101+
mlir::Value blockDim =
102+
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
103+
"block_dim", CharUnits::fromQuantity(8));
104+
mlir::Value sharedMem =
105+
builder.createAlloca(loc, cir::PointerType::get(cgm.SizeTy), cgm.SizeTy,
106+
"shared_mem", cgm.getSizeAlign());
107+
mlir::Value stream =
108+
builder.createAlloca(loc, cir::PointerType::get(streamTy), streamTy,
109+
"stream", cgm.getPointerAlign());
110+
111+
cir::FuncOp popConfig = cgm.createRuntimeFunction(
112+
cir::FuncType::get({gridDim.getType(), blockDim.getType(),
113+
sharedMem.getType(), stream.getType()},
114+
cgm.SInt32Ty),
115+
"__cudaPopCallConfiguration");
116+
cgf.emitRuntimeCall(loc, popConfig, {gridDim, blockDim, sharedMem, stream});
117+
118+
// Now emit the call to cudaLaunchKernel
119+
// cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
120+
// void **args, size_t sharedMem,
121+
// cudaStream_t stream);
122+
auto kernelTy =
123+
cir::PointerType::get(&cgm.getMLIRContext(), fn.getFunctionType());
124+
125+
mlir::Value kernel =
126+
builder.create<cir::GetGlobalOp>(loc, kernelTy, fn.getSymName());
127+
mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
128+
CallArgList launchArgs;
129+
130+
mlir::Value kernelArgsDecayed =
131+
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
132+
cir::PointerType::get(cgm.VoidPtrTy));
133+
134+
launchArgs.add(RValue::get(func), launchFD->getParamDecl(0)->getType());
135+
launchArgs.add(
136+
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
137+
launchFD->getParamDecl(1)->getType());
138+
launchArgs.add(
139+
RValue::getAggregate(Address(blockDim, CharUnits::fromQuantity(8))),
140+
launchFD->getParamDecl(2)->getType());
141+
launchArgs.add(RValue::get(kernelArgsDecayed),
142+
launchFD->getParamDecl(3)->getType());
143+
launchArgs.add(
144+
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
145+
launchFD->getParamDecl(4)->getType());
146+
launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType());
147+
148+
mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType());
149+
mlir::Operation *launchFn =
150+
cgm.createRuntimeFunction(cast<cir::FuncType>(launchTy), launchAPI);
151+
const auto &callInfo = cgm.getTypes().arrangeFunctionDeclaration(launchFD);
152+
cgf.emitCall(callInfo, CIRGenCallee::forDirect(launchFn), ReturnValueSlot(),
153+
launchArgs);
154+
}
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===------ CIRGenCUDARuntime.h - Interface to CUDA Runtimes ------*- C++
2+
//-*-==//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This provides an abstract class for CUDA CIR generation. Concrete
11+
// subclasses of this implement code generation for specific OpenCL
12+
// runtime libraries.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#ifndef LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H
17+
#define LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H
18+
19+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
20+
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
21+
22+
namespace clang::CIRGen {
23+
24+
class CIRGenFunction;
25+
class CIRGenModule;
26+
class FunctionArgList;
27+
28+
class CIRGenCUDARuntime {
29+
protected:
30+
CIRGenModule &cgm;
31+
32+
public:
33+
CIRGenCUDARuntime(CIRGenModule &cgm) : cgm(cgm) {}
34+
virtual ~CIRGenCUDARuntime();
35+
36+
virtual void emitDeviceStubBody(CIRGenFunction &cgf, cir::FuncOp fn,
37+
FunctionArgList &args);
38+
};
39+
40+
} // namespace clang::CIRGen
41+
42+
#endif // LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl GD, cir::FuncOp Fn,
753753
emitConstructorBody(Args);
754754
else if (getLangOpts().CUDA && !getLangOpts().CUDAIsDevice &&
755755
FD->hasAttr<CUDAGlobalAttr>())
756-
llvm_unreachable("NYI");
756+
CGM.getCUDARuntime().emitDeviceStubBody(*this, Fn, Args);
757757
else if (isa<CXXMethodDecl>(FD) &&
758758
cast<CXXMethodDecl>(FD)->isLambdaStaticInvoker()) {
759759
// The lambda static invoker function is special, because it forwards or

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// This is the internal per-translation-unit state used for CIR translation.
1010
//
1111
//===----------------------------------------------------------------------===//
12+
#include "CIRGenCUDARuntime.h"
1213
#include "CIRGenCXXABI.h"
1314
#include "CIRGenCstEmitter.h"
1415
#include "CIRGenFunction.h"
@@ -108,7 +109,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &mlirContext,
108109
theModule{mlir::ModuleOp::create(builder.getUnknownLoc())}, Diags(Diags),
109110
target(astContext.getTargetInfo()), ABI(createCXXABI(*this)),
110111
genTypes{*this}, VTables{*this},
111-
openMPRuntime(new CIRGenOpenMPRuntime(*this)) {
112+
openMPRuntime(new CIRGenOpenMPRuntime(*this)),
113+
cudaRuntime(new CIRGenCUDARuntime(*this)) {
112114

113115
// Initialize CIR signed integer types cache.
114116
SInt8Ty = cir::IntType::get(&getMLIRContext(), 8, /*isSigned=*/true);

clang/lib/CIR/CodeGen/CIRGenModule.h

+11-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "Address.h"
1717
#include "CIRGenBuilder.h"
18+
#include "CIRGenCUDARuntime.h"
1819
#include "CIRGenCall.h"
1920
#include "CIRGenOpenCLRuntime.h"
2021
#include "CIRGenTBAA.h"
@@ -113,6 +114,9 @@ class CIRGenModule : public CIRGenTypeCache {
113114
/// Holds the OpenMP runtime
114115
std::unique_ptr<CIRGenOpenMPRuntime> openMPRuntime;
115116

117+
/// Holds the CUDA runtime
118+
std::unique_ptr<CIRGenCUDARuntime> cudaRuntime;
119+
116120
/// Per-function codegen information. Updated everytime emitCIR is called
117121
/// for FunctionDecls's.
118122
CIRGenFunction *CurCGF = nullptr;
@@ -862,12 +866,18 @@ class CIRGenModule : public CIRGenTypeCache {
862866
/// Print out an error that codegen doesn't support the specified decl yet.
863867
void ErrorUnsupported(const Decl *D, const char *Type);
864868

865-
/// Return a reference to the configured OpenMP runtime.
869+
/// Return a reference to the configured OpenCL runtime.
866870
CIRGenOpenCLRuntime &getOpenCLRuntime() {
867871
assert(openCLRuntime != nullptr);
868872
return *openCLRuntime;
869873
}
870874

875+
/// Return a reference to the configured CUDA runtime.
876+
CIRGenCUDARuntime &getCUDARuntime() {
877+
assert(cudaRuntime != nullptr);
878+
return *cudaRuntime;
879+
}
880+
871881
void createOpenCLRuntime() {
872882
openCLRuntime.reset(new CIRGenOpenCLRuntime(*this));
873883
}

clang/lib/CIR/CodeGen/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_clang_library(clangCIR
1919
CIRGenClass.cpp
2020
CIRGenCleanup.cpp
2121
CIRGenCoroutine.cpp
22+
CIRGenCUDARuntime.cpp
2223
CIRGenDecl.cpp
2324
CIRGenDeclCXX.cpp
2425
CIRGenException.cpp

clang/test/CIR/CodeGen/CUDA/simple.cu

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#include "../Inputs/cuda.h"
22

33
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
4-
// RUN: -x cuda -emit-cir %s -o %t.cir
4+
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
5+
// RUN: %s -o %t.cir
56
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
67

78
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
8-
// RUN: -fcuda-is-device -emit-cir %s -o %t.cir
9+
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
10+
// RUN: %s -o %t.cir
911
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
1012

1113
// Attribute for global_fn
12-
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fnv>{{.*}}
14+
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}
1315

1416
__host__ void host_fn(int *a, int *b, int *c) {}
1517
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
@@ -19,13 +21,13 @@ __device__ void device_fn(int* a, double b, float c) {}
1921
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
2022
// CIR-DEVICE: cir.func @_Z9device_fnPidf
2123

22-
#ifdef __CUDA_ARCH__
23-
__global__ void global_fn() {}
24-
#else
25-
__global__ void global_fn();
26-
#endif
27-
// CIR-HOST: @_Z24__device_stub__global_fnv(){{.*}}extra([[Kernel]])
28-
// CIR-DEVICE: @_Z9global_fnv
24+
__global__ void global_fn(int a) {}
25+
// CIR-DEVICE: @_Z9global_fni
2926

30-
// Make sure `global_fn` indeed gets emitted
31-
__host__ void x() { auto v = global_fn; }
27+
// Check for device stub emission.
28+
29+
// CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]])
30+
// CIR-HOST: cir.alloca {{.*}}"kernel_args"
31+
// CIR-HOST: cir.call @__cudaPopCallConfiguration
32+
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
33+
// CIR-HOST: cir.call @cudaLaunchKernel

0 commit comments

Comments
 (0)