Skip to content

Commit 9e9230d

Browse files
authored
Merge pull request #1588 from CEED/jed/nvrtc-cubin
backends/cuda: NVRTC compile to CUBIN when supported (resolve #1587)
2 parents 38f3b71 + 29ec485 commit 9e9230d

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

backends/cuda/ceed-cuda-compile.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,19 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
8080
opts[0] = "-default-device";
8181
CeedCallBackend(CeedGetData(ceed, &ceed_data));
8282
CeedCallCuda(ceed, cudaGetDeviceProperties(&prop, ceed_data->device_id));
83-
std::string arch_arg = "-arch=compute_" + std::to_string(prop.major) + std::to_string(prop.minor);
84-
opts[1] = arch_arg.c_str();
85-
opts[2] = "-Dint32_t=int";
83+
std::string arch_arg =
84+
#if CUDA_VERSION >= 11010
85+
// NVRTC used to support only virtual architectures through the option
86+
// -arch, since it was only emitting PTX. It will now support actual
87+
// architectures as well to emit SASS.
88+
// https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#dynamic-code-generation
89+
"-arch=sm_"
90+
#else
91+
"-arch=compute_"
92+
#endif
93+
+ std::to_string(prop.major) + std::to_string(prop.minor);
94+
opts[1] = arch_arg.c_str();
95+
opts[2] = "-Dint32_t=int";
8696

8797
// Add string source argument provided in call
8898
code << source;
@@ -106,9 +116,15 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
106116
return CeedError(ceed, CEED_ERROR_BACKEND, "%s\n%s", nvrtcGetErrorString(result), log);
107117
}
108118

119+
#if CUDA_VERSION >= 11010
120+
CeedCallNvrtc(ceed, nvrtcGetCUBINSize(prog, &ptx_size));
121+
CeedCallBackend(CeedMalloc(ptx_size, &ptx));
122+
CeedCallNvrtc(ceed, nvrtcGetCUBIN(prog, ptx));
123+
#else
109124
CeedCallNvrtc(ceed, nvrtcGetPTXSize(prog, &ptx_size));
110125
CeedCallBackend(CeedMalloc(ptx_size, &ptx));
111126
CeedCallNvrtc(ceed, nvrtcGetPTX(prog, ptx));
127+
#endif
112128
CeedCallNvrtc(ceed, nvrtcDestroyProgram(&prog));
113129

114130
CeedCallCuda(ceed, cuModuleLoadData(module, ptx));

0 commit comments

Comments
 (0)