Skip to content

Commit 182c680

Browse files
authored
[CIR][CUDA] Miscellanous bugfixes (#1462)
This PR deals with several issues currently present in CUDA CodeGen. Each of them requires only a few lines to fix, so they're combined in a single PR. **Bug 1.** Suppose we write ```cpp __global__ void kernel(int a, int b); ``` Then when we call this kernel with `cudaLaunchKernel`, the 4th argument to that function is something of the form `void *kernel_args[2] = {&a, &b}`. OG allocates the space of it with `alloca ptr, i32 2`, but that doesn't seem to be feasible in CIR, so we allocated `alloca [2 x ptr], i32 1`. This means there must be an extra GEP as compared to OG. In CIR, it means we must add an `array_to_ptrdecay` cast before trying to accessing the array elements. I missed that out in #1332 . **Bug 2.** We missed a load instruction for 6th argument to `cudaLaunchKernel`. It's added back in this PR. **Bug 3.** When we launch a kernel, we first retrieve the return value of `__cudaPopCallConfiguration`. If it's zero, then the call succeeds and we should proceed to call the device stub. In #1348 we did exactly the opposite, calling the device stub only if it's not zero. It's fixed here. **Issue 4.** CallConvLowering is required to make `cudaLaunchKernel` correct. The codepath is unblocked by adding a `getIndirectResult` at the same place as OG does -- the function is already implemented so we can just call it. After this (and other pending PRs), CIR is now able to compile real CUDA programs. There are still missing features, which will be followed up later.
1 parent b776881 commit 182c680

File tree

3 files changed

+25
-15
lines changed

3 files changed

+25
-15
lines changed

clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp

+13-8
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,16 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
6969
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
7070
CharUnits::fromQuantity(16));
7171

72+
mlir::Value kernelArgsDecayed =
73+
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
74+
cir::PointerType::get(cgm.VoidPtrTy));
75+
7276
// Store arguments into kernelArgs
7377
for (auto [i, arg] : llvm::enumerate(args)) {
7478
mlir::Value index =
7579
builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i));
76-
mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index);
80+
mlir::Value storePos =
81+
builder.createPtrStride(loc, kernelArgsDecayed, index);
7782
builder.CIRBaseBuilderTy::createStore(
7883
loc, cgf.GetAddrOfLocalVar(arg).getPointer(), storePos);
7984
}
@@ -166,10 +171,6 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
166171
// mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
167172
CallArgList launchArgs;
168173

169-
mlir::Value kernelArgsDecayed =
170-
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
171-
cir::PointerType::get(cgm.VoidPtrTy));
172-
173174
launchArgs.add(RValue::get(kernel), launchFD->getParamDecl(0)->getType());
174175
launchArgs.add(
175176
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
@@ -182,7 +183,8 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
182183
launchArgs.add(
183184
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
184185
launchFD->getParamDecl(4)->getType());
185-
launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType());
186+
launchArgs.add(RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, stream)),
187+
launchFD->getParamDecl(5)->getType());
186188

187189
mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType());
188190
mlir::Operation *launchFn =
@@ -219,13 +221,16 @@ RValue CIRGenCUDARuntime::emitCUDAKernelCallExpr(CIRGenFunction &cgf,
219221

220222
cgf.emitIfOnBoolExpr(
221223
expr->getConfig(),
224+
[&](mlir::OpBuilder &b, mlir::Location l) {
225+
b.create<cir::YieldOp>(loc);
226+
},
227+
loc,
222228
[&](mlir::OpBuilder &b, mlir::Location l) {
223229
CIRGenCallee callee = cgf.emitCallee(expr->getCallee());
224230
cgf.emitCall(expr->getCallee()->getType(), callee, expr, retValue);
225231
b.create<cir::YieldOp>(loc);
226232
},
227-
loc, [](mlir::OpBuilder &b, mlir::Location l) {},
228-
std::optional<mlir::Location>());
233+
loc);
229234

230235
return RValue::get(nullptr);
231236
}

clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/X86.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ void X86_64ABIInfo::computeInfo(LowerFunctionInfo &FI) const {
751751
if (cir::MissingFeatures::vectorType())
752752
cir_cconv_unreachable("NYI");
753753
} else {
754-
cir_cconv_unreachable("Indirect results are NYI");
754+
it->info = getIndirectResult(it->type, FreeIntRegs);
755755
}
756756
}
757757
}

clang/test/CIR/CodeGen/CUDA/simple.cu

+11-6
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,16 @@ __global__ void global_fn(int a) {}
2828
// Check for device stub emission.
2929

3030
// CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]])
31-
// CIR-HOST: cir.alloca {{.*}}"kernel_args"
31+
// CIR-HOST: %[[#CIRKernelArgs:]] = cir.alloca {{.*}}"kernel_args"
32+
// CIR-HOST: %[[#Decayed:]] = cir.cast(array_to_ptrdecay, %[[#CIRKernelArgs]]
3233
// CIR-HOST: cir.call @__cudaPopCallConfiguration
3334
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
3435
// CIR-HOST: cir.call @cudaLaunchKernel
3536

3637
// LLVM-HOST: void @_Z24__device_stub__global_fni
38+
// LLVM-HOST: %[[#KernelArgs:]] = alloca [1 x ptr], i64 1, align 16
39+
// LLVM-HOST: %[[#GEP1:]] = getelementptr ptr, ptr %[[#KernelArgs]], i32 0
40+
// LLVM-HOST: %[[#GEP2:]] = getelementptr ptr, ptr %[[#GEP1]], i64 0
3741
// LLVM-HOST: call i32 @__cudaPopCallConfiguration
3842
// LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni
3943

@@ -48,6 +52,7 @@ int main() {
4852
// CIR-HOST: [[Push:%[0-9]+]] = cir.call @__cudaPushCallConfiguration
4953
// CIR-HOST: [[ConfigOK:%[0-9]+]] = cir.cast(int_to_bool, [[Push]]
5054
// CIR-HOST: cir.if [[ConfigOK]] {
55+
// CIR-HOST: } else {
5156
// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1>
5257
// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]])
5358
// CIR-HOST: }
@@ -58,9 +63,9 @@ int main() {
5863
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
5964
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
6065
// LLVM-HOST: [[LLVMConfigOK:%[0-9]+]] = call i32 @__cudaPushCallConfiguration
61-
// LLVM-HOST: br [[LLVMConfigOK]], label %[[Good:[0-9]+]], label [[Bad:[0-9]+]]
62-
// LLVM-HOST: [[Good]]:
66+
// LLVM-HOST: br [[LLVMConfigOK]], label %[[#Good:]], label [[#Bad:]]
67+
// LLVM-HOST: [[#Good]]:
68+
// LLVM-HOST: br label [[#End:]]
69+
// LLVM-HOST: [[#Bad]]:
6370
// LLVM-HOST: call void @_Z24__device_stub__global_fni
64-
// LLVM-HOST: br label [[Bad]]
65-
// LLVM-HOST: [[Bad]]:
66-
// LLVM-HOST: ret i32
71+
// LLVM-HOST: br label [[#End]]

0 commit comments

Comments
 (0)