|
7 | 7 | #include <level_zero/ze_api.h> |
8 | 8 | #include <level_zero/zes_api.h> |
9 | 9 | #include <memory> |
| 10 | +#include <mutex> |
10 | 11 |
|
11 | 12 | #include "DLWrap.h" |
12 | 13 | #include "Shared/Debug.h" |
@@ -106,6 +107,110 @@ DLWRAP_FINALIZE() |
106 | 107 | #define DEBUG_PREFIX "TARGET " GETNAME(TARGET_NAME) " RTL" |
107 | 108 | #endif |
108 | 109 |
|
| 110 | +// Extension function pointer for getting argument sizes. |
| 111 | +static ze_result_t (*zexKernelGetArgumentSize_ptr)(ze_kernel_handle_t, uint32_t, |
| 112 | + uint32_t *) = nullptr; |
| 113 | + |
| 114 | +static ze_result_t zeCommandListAppendLaunchKernelWithArgumentsFallback( |
| 115 | + ze_command_list_handle_t hCommandList, ze_kernel_handle_t hKernel, |
| 116 | + const ze_group_count_t groupCounts, const ze_group_size_t groupSizes, |
| 117 | + void **pArguments, const void *pNext, ze_event_handle_t hSignalEvent, |
| 118 | + uint32_t numWaitEvents, ze_event_handle_t *phWaitEvents) { |
| 119 | + |
| 120 | + static std::once_flag zexKernelGetArgumentSize_once; |
| 121 | + ze_result_t Res; |
| 122 | + |
| 123 | + // Load zexKernelGetArgumentSize extension if available. |
| 124 | + std::call_once(zexKernelGetArgumentSize_once, []() { |
| 125 | + uint32_t DriverCount = 0; |
| 126 | + if (zeDriverGet(&DriverCount, nullptr) == ZE_RESULT_SUCCESS && |
| 127 | + DriverCount > 0) { |
| 128 | + ze_driver_handle_t Driver; |
| 129 | + DriverCount = 1; |
| 130 | + if (zeDriverGet(&DriverCount, &Driver) == ZE_RESULT_SUCCESS) { |
| 131 | + void *ExtFunc = nullptr; |
| 132 | + if (zeDriverGetExtensionFunctionAddress( |
| 133 | + Driver, "zexKernelGetArgumentSize", &ExtFunc) == |
| 134 | + ZE_RESULT_SUCCESS && |
| 135 | + ExtFunc) { |
| 136 | + zexKernelGetArgumentSize_ptr = |
| 137 | + reinterpret_cast<decltype(zexKernelGetArgumentSize_ptr)>(ExtFunc); |
| 138 | + ODBG(OLDT_Init) << "Loaded zexKernelGetArgumentSize extension"; |
| 139 | + } |
| 140 | + } |
| 141 | + } |
| 142 | + }); |
| 143 | + if (!zexKernelGetArgumentSize_ptr) { |
| 144 | + ODBG(OLDT_Kernel) << "zeCommandListAppendLaunchKernelWithArguments is not " |
| 145 | + "available, and no fallback is possible without " |
| 146 | + "argument size information."; |
| 147 | + return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE; |
| 148 | + } |
| 149 | + |
| 150 | + Res = zeKernelSetGroupSize(hKernel, groupSizes.groupSizeX, |
| 151 | + groupSizes.groupSizeY, groupSizes.groupSizeZ); |
| 152 | + if (Res != ZE_RESULT_SUCCESS) |
| 153 | + return Res; |
| 154 | + |
| 155 | + ze_kernel_properties_t KernelProps = {}; |
| 156 | + KernelProps.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES; |
| 157 | + Res = zeKernelGetProperties(hKernel, &KernelProps); |
| 158 | + if (Res != ZE_RESULT_SUCCESS) |
| 159 | + return Res; |
| 160 | + |
| 161 | + uint32_t NumKernelArgs = KernelProps.numKernelArgs; |
| 162 | + |
| 163 | + for (uint32_t KernelArg = 0; KernelArg < NumKernelArgs; KernelArg++) { |
| 164 | + uint32_t ArgSize = 0; |
| 165 | + |
| 166 | + Res = zexKernelGetArgumentSize_ptr(hKernel, KernelArg, &ArgSize); |
| 167 | + if (Res != ZE_RESULT_SUCCESS) |
| 168 | + return Res; |
| 169 | + |
| 170 | + Res = zeKernelSetArgumentValue(hKernel, KernelArg, ArgSize, |
| 171 | + pArguments[KernelArg]); |
| 172 | + if (Res != ZE_RESULT_SUCCESS) |
| 173 | + return Res; |
| 174 | + } |
| 175 | + |
| 176 | + bool IsCooperative = false; |
| 177 | + if (pNext) { |
| 178 | + const ze_command_list_append_launch_kernel_param_cooperative_desc_t |
| 179 | + *CoopDesc = static_cast< |
| 180 | + const ze_command_list_append_launch_kernel_param_cooperative_desc_t |
| 181 | + *>(pNext); |
| 182 | + if (CoopDesc->stype == |
| 183 | + ZE_STRUCTURE_TYPE_COMMAND_LIST_APPEND_PARAM_COOPERATIVE_DESC) |
| 184 | + IsCooperative = CoopDesc->isCooperative; |
| 185 | + } |
| 186 | + |
| 187 | + if (IsCooperative) |
| 188 | + return zeCommandListAppendLaunchCooperativeKernel( |
| 189 | + hCommandList, hKernel, &groupCounts, hSignalEvent, numWaitEvents, |
| 190 | + phWaitEvents); |
| 191 | + return zeCommandListAppendLaunchKernel(hCommandList, hKernel, &groupCounts, |
| 192 | + hSignalEvent, numWaitEvents, |
| 193 | + phWaitEvents); |
| 194 | +} |
| 195 | + |
| 196 | +static struct { |
| 197 | + const char *Name; |
| 198 | + void *FallbackFunc; |
| 199 | +} ZeFallbacksTbl[] = { |
| 200 | + {"zeCommandListAppendLaunchKernelWithArguments", |
| 201 | + reinterpret_cast<void *>( |
| 202 | + &zeCommandListAppendLaunchKernelWithArgumentsFallback)}}; |
| 203 | +constexpr size_t ZeFallbacksTblSz = |
| 204 | + sizeof(ZeFallbacksTbl) / sizeof(ZeFallbacksTbl[0]); |
| 205 | + |
| 206 | +static void *findZeFallback(std::string_view Name) { |
| 207 | + for (size_t i = 0; i < ZeFallbacksTblSz; i++) { |
| 208 | + if (Name == ZeFallbacksTbl[i].Name) |
| 209 | + return ZeFallbacksTbl[i].FallbackFunc; |
| 210 | + } |
| 211 | + return nullptr; |
| 212 | +} |
| 213 | + |
109 | 214 | static bool loadLevelZero() { |
110 | 215 | std::string L0Library{LEVEL_ZERO_LIBRARY}; |
111 | 216 | std::string ErrMsg; |
@@ -150,16 +255,23 @@ static bool loadLevelZero() { |
150 | 255 | const char *Sym = dlwrap::symbol(I); |
151 | 256 |
|
152 | 257 | void *P = DynlibHandle->getAddressOfSymbol(Sym); |
| 258 | + void *Fallback = nullptr; |
153 | 259 | if (P == nullptr) { |
154 | | - ODBG(OLDT_Init) << "Unable to find '" << Sym << "' in '" << L0Library |
155 | | - << "'!"; |
156 | | - emitCheckVersion(); |
157 | | - return false; |
| 260 | + Fallback = findZeFallback(Sym); |
| 261 | + if (!Fallback) { |
| 262 | + ODBG(OLDT_Init) << "Symbol '" << Sym << "' not found in '" << L0Library |
| 263 | + << "' and no fallback is available!"; |
| 264 | + emitCheckVersion(); |
| 265 | + return false; |
| 266 | + } |
| 267 | + ODBG(OLDT_Init) << "Symbol '" << Sym << "' not found in '" << L0Library |
| 268 | + << "'. Using fallback implementation -> " << Fallback; |
158 | 269 | } |
159 | | - ODBG(OLDT_Init) << "Implementing " << Sym << " with dlsym(" << Sym |
160 | | - << ") -> " << P; |
| 270 | + if (P) |
| 271 | + ODBG(OLDT_Init) << "Implementing " << Sym << " with dlsym(" << Sym |
| 272 | + << ") -> " << P; |
161 | 273 |
|
162 | | - *dlwrap::pointer(I) = P; |
| 274 | + *dlwrap::pointer(I) = P ? P : Fallback; |
163 | 275 | } |
164 | 276 |
|
165 | 277 | return true; |
|
0 commit comments