Skip to content

Commit 2857569

Browse files
committed
Integrate static OCL adapter into loader with namespace isolation
1 parent a40f51e commit 2857569

9 files changed

Lines changed: 232 additions & 18 deletions

File tree

unified-runtime/source/adapters/opencl/adapter.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,13 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
103103
}
104104

105105
if (pNumAdapters) {
106-
*pNumAdapters = liveAdapter ? 1 : 0;
106+
#ifdef UR_STATIC_ADAPTER_OPENCL
107+
// Probe libOpenCL for the count-only query pattern (NumEntries == 0);
108+
// loadOCLLibrary() is idempotent.
109+
*pNumAdapters = (liveAdapter || ocl::loadOCLLibrary()) ? 1 : 0;
110+
#else
111+
*pNumAdapters = 1;
112+
#endif
107113
}
108114

109115
return UR_RESULT_SUCCESS;

unified-runtime/source/adapters/opencl/adapter.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,19 @@ struct ur_adapter_handle_t_ : ur::opencl::handle_base {
3636
// versions of the OpenCL-ICD-Loader are tracked here and initialized by
3737
// dynamically loading the symbol by name.
3838
#ifdef UR_STATIC_ADAPTER_OPENCL
39-
// Temporarily undefine the OCL function macros from ocl_dynamic_lib.hpp
40-
// so we can use decltype on the original function names
39+
// Lift redirect macros so decltype resolves the real CL signatures
4140
#undef clSetProgramSpecializationConstant
4241
#undef clSetContextDestructorCallback
4342
#endif
4443
#define CL_CORE_FUNCTION(FUNC) decltype(::FUNC) *FUNC = nullptr;
4544
#include "core_functions.def"
4645
#undef CL_CORE_FUNCTION
46+
#ifdef UR_STATIC_ADAPTER_OPENCL
47+
// Restore redirect macros so direct calls in this TU still go via pointers
48+
#define clSetProgramSpecializationConstant \
49+
ocl::clSetProgramSpecializationConstant_ptr
50+
#define clSetContextDestructorCallback ocl::clSetContextDestructorCallback_ptr
51+
#endif
4752
};
4853

4954
namespace ur {

unified-runtime/source/adapters/opencl/ocl_dynamic_lib.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,29 @@ static bool getSymbolAddr(void *handle, const char *name, T *funcPtr) {
4646

4747
static void loadOCLLibraryImpl() {
4848
#ifdef _WIN32
49-
OCLLibHandle = LoadLibraryExA("OpenCL.dll", NULL, LOAD_LIBRARY_SEARCH_SYSTEM32);
49+
OCLLibHandle =
50+
LoadLibraryExA("OpenCL.dll", NULL, LOAD_LIBRARY_SEARCH_SYSTEM32);
5051
if (!OCLLibHandle) {
5152
DWORD error = GetLastError();
52-
UR_LOG(ERR, "Failed to load OpenCL.dll from system directory (error code: {})", error);
53+
UR_LOG(ERR,
54+
"Failed to load OpenCL.dll from system directory (error code: {})",
55+
error);
5356
return;
5457
}
5558
UR_LOG(DEBUG, "Successfully loaded OpenCL.dll");
5659
#else
5760
OCLLibHandle = dlopen("libOpenCL.so.1", RTLD_NOW | RTLD_LOCAL);
5861
if (!OCLLibHandle) {
5962
const char *error1 = dlerror();
60-
UR_LOG(DEBUG, "Failed to load libOpenCL.so.1: {}", error1 ? error1 : "unknown error");
63+
UR_LOG(DEBUG, "Failed to load libOpenCL.so.1: {}",
64+
error1 ? error1 : "unknown error");
6165

6266
OCLLibHandle = dlopen("libOpenCL.so", RTLD_NOW | RTLD_LOCAL);
6367
if (!OCLLibHandle) {
6468
const char *error2 = dlerror();
65-
UR_LOG(ERR, "Failed to load OpenCL library. Tried libOpenCL.so.1 and libOpenCL.so: {}",
69+
UR_LOG(ERR,
70+
"Failed to load OpenCL library. Tried libOpenCL.so.1 and "
71+
"libOpenCL.so: {}",
6672
error2 ? error2 : "unknown error");
6773
return;
6874
}
@@ -94,15 +100,32 @@ static void loadOCLLibraryImpl() {
94100
#undef OCL_FUNC
95101

96102
if (required_missing > 0) {
97-
UR_LOG(ERR, "Failed to load {} required OpenCL function(s)", required_missing);
103+
UR_LOG(ERR, "Failed to load {} required OpenCL function(s)",
104+
required_missing);
98105
}
99106

100107
if (optional_missing > 0) {
101-
UR_LOG(DEBUG, "{} optional OpenCL function(s) not available (normal for older OpenCL versions)",
108+
UR_LOG(DEBUG,
109+
"{} optional OpenCL function(s) not available (normal for older "
110+
"OpenCL versions)",
102111
optional_missing);
103112
}
104113

105-
OCLLoadSuccess = success;
114+
if (!success) {
115+
// Required symbols missing — close the handle we opened to avoid a leak.
116+
#ifdef _WIN32
117+
FreeLibrary((HMODULE)OCLLibHandle);
118+
#else
119+
dlclose(OCLLibHandle);
120+
#endif
121+
OCLLibHandle = nullptr;
122+
#define OCL_FUNC(name, required) name##_ptr = nullptr;
123+
#include "ocl_functions.def"
124+
#undef OCL_FUNC
125+
return;
126+
}
127+
128+
OCLLoadSuccess = true;
106129
}
107130

108131
bool loadOCLLibrary() {

unified-runtime/source/adapters/opencl/ocl_dynamic_lib.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ void unloadOCLLibrary();
9292
#define clEnqueueBarrier ocl::clEnqueueBarrier_ptr
9393
#define clGetExtensionFunctionAddress ocl::clGetExtensionFunctionAddress_ptr
9494
#define clCreateSubBuffer ocl::clCreateSubBuffer_ptr
95-
#define clSetMemObjectDestructorCallback ocl::clSetMemObjectDestructorCallback_ptr
95+
#define clSetMemObjectDestructorCallback \
96+
ocl::clSetMemObjectDestructorCallback_ptr
9697
#define clCreateUserEvent ocl::clCreateUserEvent_ptr
9798
#define clSetUserEventStatus ocl::clSetUserEventStatus_ptr
9899
#define clSetEventCallback ocl::clSetEventCallback_ptr
@@ -109,8 +110,10 @@ void unloadOCLLibrary();
109110
#define clEnqueueMigrateMemObjects ocl::clEnqueueMigrateMemObjects_ptr
110111
#define clEnqueueMarkerWithWaitList ocl::clEnqueueMarkerWithWaitList_ptr
111112
#define clEnqueueBarrierWithWaitList ocl::clEnqueueBarrierWithWaitList_ptr
112-
#define clGetExtensionFunctionAddressForPlatform ocl::clGetExtensionFunctionAddressForPlatform_ptr
113-
#define clCreateCommandQueueWithProperties ocl::clCreateCommandQueueWithProperties_ptr
113+
#define clGetExtensionFunctionAddressForPlatform \
114+
ocl::clGetExtensionFunctionAddressForPlatform_ptr
115+
#define clCreateCommandQueueWithProperties \
116+
ocl::clCreateCommandQueueWithProperties_ptr
114117
#define clCreatePipe ocl::clCreatePipe_ptr
115118
#define clGetPipeInfo ocl::clGetPipeInfo_ptr
116119
#define clSVMAlloc ocl::clSVMAlloc_ptr
@@ -123,7 +126,8 @@ void unloadOCLLibrary();
123126
#define clEnqueueSVMMemFill ocl::clEnqueueSVMMemFill_ptr
124127
#define clEnqueueSVMMap ocl::clEnqueueSVMMap_ptr
125128
#define clEnqueueSVMUnmap ocl::clEnqueueSVMUnmap_ptr
126-
#define clSetProgramSpecializationConstant ocl::clSetProgramSpecializationConstant_ptr
129+
#define clSetProgramSpecializationConstant \
130+
ocl::clSetProgramSpecializationConstant_ptr
127131
#define clSetProgramReleaseCallback ocl::clSetProgramReleaseCallback_ptr
128132
#define clCreateBufferWithProperties ocl::clCreateBufferWithProperties_ptr
129133
#define clCreateImageWithProperties ocl::clCreateImageWithProperties_ptr

unified-runtime/source/adapters/opencl/ocl_functions.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ OCL_FUNC(clEnqueueMapBuffer, 1)
5858
OCL_FUNC(clEnqueueMapImage, 1)
5959
OCL_FUNC(clEnqueueUnmapMemObject, 1)
6060
OCL_FUNC(clEnqueueNDRangeKernel, 1)
61-
OCL_FUNC(clEnqueueNativeKernel, 1)
61+
OCL_FUNC(clEnqueueNativeKernel, 0)
6262
OCL_FUNC(clEnqueueMarker, 1)
6363
OCL_FUNC(clEnqueueWaitForEvents, 1)
6464
OCL_FUNC(clEnqueueBarrier, 1)

unified-runtime/source/adapters/opencl/ur_interface_loader.cpp

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ ur_result_t validateProcInputs(ur_api_version_t Version, void *pDdiTable) {
2828
}
2929
} // namespace
3030

31+
#ifdef UR_STATIC_ADAPTER_OPENCL
32+
namespace ur::opencl {
33+
#else
3134
extern "C" {
35+
#endif
3236

3337
UR_DLLEXPORT ur_result_t UR_APICALL urGetAdapterProcAddrTable(
3438
ur_api_version_t version, ur_adapter_dditable_t *pDdiTable) {
@@ -509,6 +513,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urGetGraphExpProcAddrTable(
509513
return UR_RESULT_SUCCESS;
510514
}
511515

516+
#ifndef UR_STATIC_ADAPTER_OPENCL
512517
UR_DLLEXPORT ur_result_t UR_APICALL urAllAddrTable(ur_api_version_t version,
513518
ur_dditable_t *pDdiTable) {
514519
urGetAdapterProcAddrTable(version, &pDdiTable->Adapter);
@@ -539,14 +544,147 @@ UR_DLLEXPORT ur_result_t UR_APICALL urAllAddrTable(ur_api_version_t version,
539544

540545
return UR_RESULT_SUCCESS;
541546
}
547+
#endif // UR_STATIC_ADAPTER_OPENCL
542548

549+
#ifdef UR_STATIC_ADAPTER_OPENCL
550+
} // namespace ur::opencl
551+
#else
543552
} // extern "C"
553+
#endif
544554

545-
const ur_dditable_t *ur::opencl::ddi_getter::value() {
555+
namespace {
556+
ur_result_t populateDdiTable(ur_dditable_t *ddi) {
557+
if (ddi == nullptr) {
558+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
559+
}
560+
561+
ur_result_t result;
562+
563+
#ifdef UR_STATIC_ADAPTER_OPENCL
564+
#define ADAPTER_CALL ::ur::opencl
565+
#else
566+
#define ADAPTER_CALL
567+
#endif
568+
569+
result = ADAPTER_CALL::urGetAdapterProcAddrTable(UR_API_VERSION_CURRENT,
570+
&ddi->Adapter);
571+
if (result != UR_RESULT_SUCCESS)
572+
return result;
573+
result = ADAPTER_CALL::urGetBindlessImagesExpProcAddrTable(
574+
UR_API_VERSION_CURRENT, &ddi->BindlessImagesExp);
575+
if (result != UR_RESULT_SUCCESS)
576+
return result;
577+
result = ADAPTER_CALL::urGetCommandBufferExpProcAddrTable(
578+
UR_API_VERSION_CURRENT, &ddi->CommandBufferExp);
579+
if (result != UR_RESULT_SUCCESS)
580+
return result;
581+
result = ADAPTER_CALL::urGetContextProcAddrTable(UR_API_VERSION_CURRENT,
582+
&ddi->Context);
583+
if (result != UR_RESULT_SUCCESS)
584+
return result;
585+
result = ADAPTER_CALL::urGetEnqueueProcAddrTable(UR_API_VERSION_CURRENT,
586+
&ddi->Enqueue);
587+
if (result != UR_RESULT_SUCCESS)
588+
return result;
589+
result = ADAPTER_CALL::urGetEnqueueExpProcAddrTable(UR_API_VERSION_CURRENT,
590+
&ddi->EnqueueExp);
591+
if (result != UR_RESULT_SUCCESS)
592+
return result;
593+
result = ADAPTER_CALL::urGetEventProcAddrTable(UR_API_VERSION_CURRENT,
594+
&ddi->Event);
595+
if (result != UR_RESULT_SUCCESS)
596+
return result;
597+
result = ADAPTER_CALL::urGetGraphExpProcAddrTable(UR_API_VERSION_CURRENT,
598+
&ddi->GraphExp);
599+
if (result != UR_RESULT_SUCCESS)
600+
return result;
601+
result = ADAPTER_CALL::urGetIPCExpProcAddrTable(UR_API_VERSION_CURRENT,
602+
&ddi->IPCExp);
603+
if (result != UR_RESULT_SUCCESS)
604+
return result;
605+
result = ADAPTER_CALL::urGetKernelProcAddrTable(UR_API_VERSION_CURRENT,
606+
&ddi->Kernel);
607+
if (result != UR_RESULT_SUCCESS)
608+
return result;
609+
result =
610+
ADAPTER_CALL::urGetMemProcAddrTable(UR_API_VERSION_CURRENT, &ddi->Mem);
611+
if (result != UR_RESULT_SUCCESS)
612+
return result;
613+
result = ADAPTER_CALL::urGetMemoryExportExpProcAddrTable(
614+
UR_API_VERSION_CURRENT, &ddi->MemoryExportExp);
615+
if (result != UR_RESULT_SUCCESS)
616+
return result;
617+
result = ADAPTER_CALL::urGetPhysicalMemProcAddrTable(UR_API_VERSION_CURRENT,
618+
&ddi->PhysicalMem);
619+
if (result != UR_RESULT_SUCCESS)
620+
return result;
621+
result = ADAPTER_CALL::urGetPlatformProcAddrTable(UR_API_VERSION_CURRENT,
622+
&ddi->Platform);
623+
if (result != UR_RESULT_SUCCESS)
624+
return result;
625+
result = ADAPTER_CALL::urGetProgramProcAddrTable(UR_API_VERSION_CURRENT,
626+
&ddi->Program);
627+
if (result != UR_RESULT_SUCCESS)
628+
return result;
629+
result = ADAPTER_CALL::urGetProgramExpProcAddrTable(UR_API_VERSION_CURRENT,
630+
&ddi->ProgramExp);
631+
if (result != UR_RESULT_SUCCESS)
632+
return result;
633+
result = ADAPTER_CALL::urGetQueueProcAddrTable(UR_API_VERSION_CURRENT,
634+
&ddi->Queue);
635+
if (result != UR_RESULT_SUCCESS)
636+
return result;
637+
result = ADAPTER_CALL::urGetQueueExpProcAddrTable(UR_API_VERSION_CURRENT,
638+
&ddi->QueueExp);
639+
if (result != UR_RESULT_SUCCESS)
640+
return result;
641+
result = ADAPTER_CALL::urGetSamplerProcAddrTable(UR_API_VERSION_CURRENT,
642+
&ddi->Sampler);
643+
if (result != UR_RESULT_SUCCESS)
644+
return result;
645+
result =
646+
ADAPTER_CALL::urGetUSMProcAddrTable(UR_API_VERSION_CURRENT, &ddi->USM);
647+
if (result != UR_RESULT_SUCCESS)
648+
return result;
649+
result = ADAPTER_CALL::urGetUSMExpProcAddrTable(UR_API_VERSION_CURRENT,
650+
&ddi->USMExp);
651+
if (result != UR_RESULT_SUCCESS)
652+
return result;
653+
result = ADAPTER_CALL::urGetUsmP2PExpProcAddrTable(UR_API_VERSION_CURRENT,
654+
&ddi->UsmP2PExp);
655+
if (result != UR_RESULT_SUCCESS)
656+
return result;
657+
result = ADAPTER_CALL::urGetVirtualMemProcAddrTable(UR_API_VERSION_CURRENT,
658+
&ddi->VirtualMem);
659+
if (result != UR_RESULT_SUCCESS)
660+
return result;
661+
result = ADAPTER_CALL::urGetDeviceProcAddrTable(UR_API_VERSION_CURRENT,
662+
&ddi->Device);
663+
if (result != UR_RESULT_SUCCESS)
664+
return result;
665+
result = ADAPTER_CALL::urGetDeviceExpProcAddrTable(UR_API_VERSION_CURRENT,
666+
&ddi->DeviceExp);
667+
if (result != UR_RESULT_SUCCESS)
668+
return result;
669+
670+
#undef ADAPTER_CALL
671+
672+
return result;
673+
}
674+
} // namespace
675+
676+
namespace ur::opencl {
677+
const ur_dditable_t *ddi_getter::value() {
546678
static std::once_flag flag;
547679
static ur_dditable_t table;
548680

549-
std::call_once(flag,
550-
[]() { urAllAddrTable(UR_API_VERSION_CURRENT, &table); });
681+
std::call_once(flag, []() { populateDdiTable(&table); });
551682
return &table;
552683
}
684+
685+
#ifdef UR_STATIC_ADAPTER_OPENCL
686+
ur_result_t urAdapterGetDdiTables(ur_dditable_t *ddi) {
687+
return populateDdiTable(ddi);
688+
}
689+
#endif
690+
} // namespace ur::opencl
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===--------- ur_interface_loader.hpp - OpenCL Adapter ---------------===//
2+
//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
5+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
#pragma once
10+
11+
#include <unified-runtime/ur_api.h>
12+
#include <unified-runtime/ur_ddi.h>
13+
14+
namespace ur::opencl {
15+
16+
struct ddi_getter {
17+
static const ur_dditable_t *value();
18+
};
19+
20+
#ifdef UR_STATIC_ADAPTER_OPENCL
21+
ur_result_t urAdapterGetDdiTables(ur_dditable_t *ddi);
22+
#endif
23+
24+
} // namespace ur::opencl

unified-runtime/source/loader/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ if(UR_STATIC_ADAPTER_L0)
7575
target_compile_definitions(ur_loader PRIVATE UR_STATIC_ADAPTER_LEVEL_ZERO)
7676
endif()
7777

78+
if(UR_STATIC_ADAPTER_OPENCL)
79+
target_link_libraries(ur_loader PRIVATE ur_adapter_opencl)
80+
target_compile_definitions(ur_loader PRIVATE UR_STATIC_ADAPTER_OPENCL)
81+
endif()
82+
7883
if(UR_ENABLE_TRACING)
7984
target_link_libraries(ur_loader PRIVATE ${TARGET_XPTI})
8085
target_include_directories(ur_loader PRIVATE ${xpti_SOURCE_DIR}/include)

unified-runtime/source/loader/ur_loader.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#ifdef UR_STATIC_ADAPTER_LEVEL_ZERO
1212
#include "adapters/level_zero/ur_interface_loader.hpp"
1313
#endif
14+
#ifdef UR_STATIC_ADAPTER_OPENCL
15+
#include "adapters/opencl/ur_interface_loader.hpp"
16+
#endif
1417

1518
namespace ur_loader {
1619
///////////////////////////////////////////////////////////////////////////////
@@ -39,6 +42,12 @@ ur_result_t context_t::init() {
3942
ur::level_zero::urAdapterGetDdiTables(&level_zero.dditable);
4043
}
4144
#endif
45+
#ifdef UR_STATIC_ADAPTER_OPENCL
46+
if (!adapter_registry.adaptersForceLoaded()) {
47+
auto &opencl = platforms.emplace_back(nullptr);
48+
ur::opencl::urAdapterGetDdiTables(&opencl.dditable);
49+
}
50+
#endif
4251

4352
for (const auto &adapterPaths : adapter_registry) {
4453
for (const auto &path : adapterPaths) {

0 commit comments

Comments
 (0)