Skip to content

Commit 8c30cb2

Browse files
authored
[OFFLOAD][L0] Add support for dynamic l0 fallbacks (llvm#200517)
The PR adds support to define fallbacks for DLWRAP routines that are not found when loading the library. It implements a fallback for zeCommandListAppendLaunchKernelWithArguments introduced in llvm#194333 which might not be available in older drivers.
1 parent 4b3249a commit 8c30cb2

1 file changed

Lines changed: 119 additions & 7 deletions

File tree

offload/plugins-nextgen/level_zero/dynamic_l0/L0DynWrapper.cpp

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <level_zero/ze_api.h>
88
#include <level_zero/zes_api.h>
99
#include <memory>
10+
#include <mutex>
1011

1112
#include "DLWrap.h"
1213
#include "Shared/Debug.h"
@@ -106,6 +107,110 @@ DLWRAP_FINALIZE()
106107
#define DEBUG_PREFIX "TARGET " GETNAME(TARGET_NAME) " RTL"
107108
#endif
108109

110+
// Extension function pointer for getting argument sizes.
111+
static ze_result_t (*zexKernelGetArgumentSize_ptr)(ze_kernel_handle_t, uint32_t,
112+
uint32_t *) = nullptr;
113+
114+
static ze_result_t zeCommandListAppendLaunchKernelWithArgumentsFallback(
115+
ze_command_list_handle_t hCommandList, ze_kernel_handle_t hKernel,
116+
const ze_group_count_t groupCounts, const ze_group_size_t groupSizes,
117+
void **pArguments, const void *pNext, ze_event_handle_t hSignalEvent,
118+
uint32_t numWaitEvents, ze_event_handle_t *phWaitEvents) {
119+
120+
static std::once_flag zexKernelGetArgumentSize_once;
121+
ze_result_t Res;
122+
123+
// Load zexKernelGetArgumentSize extension if available.
124+
std::call_once(zexKernelGetArgumentSize_once, []() {
125+
uint32_t DriverCount = 0;
126+
if (zeDriverGet(&DriverCount, nullptr) == ZE_RESULT_SUCCESS &&
127+
DriverCount > 0) {
128+
ze_driver_handle_t Driver;
129+
DriverCount = 1;
130+
if (zeDriverGet(&DriverCount, &Driver) == ZE_RESULT_SUCCESS) {
131+
void *ExtFunc = nullptr;
132+
if (zeDriverGetExtensionFunctionAddress(
133+
Driver, "zexKernelGetArgumentSize", &ExtFunc) ==
134+
ZE_RESULT_SUCCESS &&
135+
ExtFunc) {
136+
zexKernelGetArgumentSize_ptr =
137+
reinterpret_cast<decltype(zexKernelGetArgumentSize_ptr)>(ExtFunc);
138+
ODBG(OLDT_Init) << "Loaded zexKernelGetArgumentSize extension";
139+
}
140+
}
141+
}
142+
});
143+
if (!zexKernelGetArgumentSize_ptr) {
144+
ODBG(OLDT_Kernel) << "zeCommandListAppendLaunchKernelWithArguments is not "
145+
"available, and no fallback is possible without "
146+
"argument size information.";
147+
return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE;
148+
}
149+
150+
Res = zeKernelSetGroupSize(hKernel, groupSizes.groupSizeX,
151+
groupSizes.groupSizeY, groupSizes.groupSizeZ);
152+
if (Res != ZE_RESULT_SUCCESS)
153+
return Res;
154+
155+
ze_kernel_properties_t KernelProps = {};
156+
KernelProps.stype = ZE_STRUCTURE_TYPE_KERNEL_PROPERTIES;
157+
Res = zeKernelGetProperties(hKernel, &KernelProps);
158+
if (Res != ZE_RESULT_SUCCESS)
159+
return Res;
160+
161+
uint32_t NumKernelArgs = KernelProps.numKernelArgs;
162+
163+
for (uint32_t KernelArg = 0; KernelArg < NumKernelArgs; KernelArg++) {
164+
uint32_t ArgSize = 0;
165+
166+
Res = zexKernelGetArgumentSize_ptr(hKernel, KernelArg, &ArgSize);
167+
if (Res != ZE_RESULT_SUCCESS)
168+
return Res;
169+
170+
Res = zeKernelSetArgumentValue(hKernel, KernelArg, ArgSize,
171+
pArguments[KernelArg]);
172+
if (Res != ZE_RESULT_SUCCESS)
173+
return Res;
174+
}
175+
176+
bool IsCooperative = false;
177+
if (pNext) {
178+
const ze_command_list_append_launch_kernel_param_cooperative_desc_t
179+
*CoopDesc = static_cast<
180+
const ze_command_list_append_launch_kernel_param_cooperative_desc_t
181+
*>(pNext);
182+
if (CoopDesc->stype ==
183+
ZE_STRUCTURE_TYPE_COMMAND_LIST_APPEND_PARAM_COOPERATIVE_DESC)
184+
IsCooperative = CoopDesc->isCooperative;
185+
}
186+
187+
if (IsCooperative)
188+
return zeCommandListAppendLaunchCooperativeKernel(
189+
hCommandList, hKernel, &groupCounts, hSignalEvent, numWaitEvents,
190+
phWaitEvents);
191+
return zeCommandListAppendLaunchKernel(hCommandList, hKernel, &groupCounts,
192+
hSignalEvent, numWaitEvents,
193+
phWaitEvents);
194+
}
195+
196+
static struct {
197+
const char *Name;
198+
void *FallbackFunc;
199+
} ZeFallbacksTbl[] = {
200+
{"zeCommandListAppendLaunchKernelWithArguments",
201+
reinterpret_cast<void *>(
202+
&zeCommandListAppendLaunchKernelWithArgumentsFallback)}};
203+
constexpr size_t ZeFallbacksTblSz =
204+
sizeof(ZeFallbacksTbl) / sizeof(ZeFallbacksTbl[0]);
205+
206+
static void *findZeFallback(std::string_view Name) {
207+
for (size_t i = 0; i < ZeFallbacksTblSz; i++) {
208+
if (Name == ZeFallbacksTbl[i].Name)
209+
return ZeFallbacksTbl[i].FallbackFunc;
210+
}
211+
return nullptr;
212+
}
213+
109214
static bool loadLevelZero() {
110215
std::string L0Library{LEVEL_ZERO_LIBRARY};
111216
std::string ErrMsg;
@@ -150,16 +255,23 @@ static bool loadLevelZero() {
150255
const char *Sym = dlwrap::symbol(I);
151256

152257
void *P = DynlibHandle->getAddressOfSymbol(Sym);
258+
void *Fallback = nullptr;
153259
if (P == nullptr) {
154-
ODBG(OLDT_Init) << "Unable to find '" << Sym << "' in '" << L0Library
155-
<< "'!";
156-
emitCheckVersion();
157-
return false;
260+
Fallback = findZeFallback(Sym);
261+
if (!Fallback) {
262+
ODBG(OLDT_Init) << "Symbol '" << Sym << "' not found in '" << L0Library
263+
<< "' and no fallback is available!";
264+
emitCheckVersion();
265+
return false;
266+
}
267+
ODBG(OLDT_Init) << "Symbol '" << Sym << "' not found in '" << L0Library
268+
<< "'. Using fallback implementation -> " << Fallback;
158269
}
159-
ODBG(OLDT_Init) << "Implementing " << Sym << " with dlsym(" << Sym
160-
<< ") -> " << P;
270+
if (P)
271+
ODBG(OLDT_Init) << "Implementing " << Sym << " with dlsym(" << Sym
272+
<< ") -> " << P;
161273

162-
*dlwrap::pointer(I) = P;
274+
*dlwrap::pointer(I) = P ? P : Fallback;
163275
}
164276

165277
return true;

0 commit comments

Comments
 (0)