Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions unified-runtime/source/adapters/cuda/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,15 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
}

ur_result_t ur_program_handle_t_::setBinary(const char *Source, size_t Length) {
// Do not re-set program binary data which has already been set as that will
// delete the old binary data.
UR_ASSERT(Binary == nullptr && BinarySizeInBytes == 0,
// Do not re-set program binary data once initialized.
UR_ASSERT(Binary.empty() && BinarySizeInBytes == 0,
UR_RESULT_ERROR_INVALID_OPERATION);
Binary = Source;
UR_ASSERT(Source || Length == 0, UR_RESULT_ERROR_INVALID_NULL_POINTER);

Binary.resize(Length + 1, '\0');
if (Length) {
std::memcpy(Binary.data(), Source, Length);
}
BinarySizeInBytes = Length;
return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -149,9 +153,11 @@ ur_result_t ur_program_handle_t_::buildProgram(const char *BuildOptions) {
}
}

UR_CHECK_ERROR(cuModuleLoadDataEx(&Module, static_cast<const void *>(Binary),
Options.size(), Options.data(),
OptionVals.data()));
UR_ASSERT(Binary.size() > 1, UR_RESULT_ERROR_INVALID_PROGRAM);

UR_CHECK_ERROR(
cuModuleLoadDataEx(&Module, static_cast<const void *>(Binary.data()),
Options.size(), Options.data(), OptionVals.data()));

BuildStatus = UR_PROGRAM_BUILD_STATUS_SUCCESS;

Expand Down Expand Up @@ -274,7 +280,7 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count,
for (size_t i = 0; i < count; ++i) {
ur_program_handle_t Program = phPrograms[i];
UR_CHECK_ERROR(cuLinkAddData(
State, CU_JIT_INPUT_PTX, const_cast<char *>(Program->Binary),
State, CU_JIT_INPUT_PTX, Program->Binary.data(),
Program->BinarySizeInBytes, nullptr, 0, nullptr, nullptr));
}
void *CuBin = nullptr;
Expand Down Expand Up @@ -366,8 +372,24 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName,
return ReturnValue(&hProgram->Device, 1);
case UR_PROGRAM_INFO_BINARY_SIZES:
return ReturnValue(&hProgram->BinarySizeInBytes, 1);
case UR_PROGRAM_INFO_BINARIES:
return ReturnValue(&hProgram->Binary, 1);
case UR_PROGRAM_INFO_BINARIES: {
if (pPropSizeRet) {
*pPropSizeRet = sizeof(uint8_t *);
}

if (!pProgramInfo) {
return UR_RESULT_SUCCESS;
}

UR_ASSERT(propSize >= sizeof(uint8_t *), UR_RESULT_ERROR_INVALID_SIZE);

auto **ppBinaries = static_cast<uint8_t **>(pProgramInfo);
if (ppBinaries[0]) {
std::memcpy(ppBinaries[0], hProgram->Binary.data(),
hProgram->BinarySizeInBytes);
}
return UR_RESULT_SUCCESS;
}
case UR_PROGRAM_INFO_KERNEL_NAMES:
// CUDA has no way to query a list of kernels from a binary.
// In SYCL this is only used in kernel bundle when building from source
Expand Down
3 changes: 2 additions & 1 deletion unified-runtime/source/adapters/cuda/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

#include <atomic>
#include <unordered_map>
#include <vector>

#include "common/ur_ref_count.hpp"
#include "context.hpp"

struct ur_program_handle_t_ : ur::cuda::handle_base {
using native_type = CUmodule;
native_type Module;
const char *Binary;
std::vector<char> Binary;
size_t BinarySizeInBytes;
ur::RefCount RefCount;
ur_context_handle_t Context;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ TEST_P(urProgramGetInfoTest, SuccessBinarySizes) {
TEST_P(urProgramGetInfoTest, SuccessBinaries) {
// Not implemented correctly on these targets - they copy their own pointer into the output rather than copying the
// binary
UUR_KNOWN_FAILURE_ON(uur::HIP{}, uur::CUDA{});
UUR_KNOWN_FAILURE_ON(uur::HIP{});

size_t binary_sizes_len = 0;
std::vector<char> property_value(0);
Expand Down
Loading