Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ ur_result_t EnqueueMemCopyRectHelper(

UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(Queue, Events.size(),
Events.data(), Event));

for (const auto &E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
getAsanInterceptor()
->getContextInfo(GetContext(Queue))
->DeferredEvents.add(Events);

return UR_RESULT_SUCCESS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,8 @@ ProgramInfo::getKernelMetadata(ur_kernel_handle_t Kernel) const {
}

ContextInfo::~ContextInfo() {
DeferredEvents.releaseAll();

Stats.Print(Handle);

InternalQueueMap.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ struct ContextInfo {

std::optional<Quarantine> m_Quarantine;

DeferredEventList DeferredEvents;

AsanStatsWrapper Stats;

explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,11 @@ ur_result_t EnqueueMemCopyRectHelper(
UR_CALL(getContext()->urDdiTable.Event.pfnWait(Events.size(), &Events[0]));
}

if (Event) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(Queue, Events.size(),
&Events[0], Event));
}

for (const auto &E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(Queue, Events.size(),
&Events[0], Event));
getMsanInterceptor()
->getContextInfo(GetContext(Queue))
->DeferredEvents.add(Events);

return UR_RESULT_SUCCESS;
}
Expand Down
108 changes: 44 additions & 64 deletions unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,12 @@ ur_result_t urEnqueueUSMFill2DFallback(ur_queue_handle_t hQueue, void *pMem,
WaitEvents.push_back(Event);
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, WaitEvents.size(), WaitEvents.data(), phEvent));
}
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, WaitEvents.size(), WaitEvents.data(), phEvent));

for (const auto Event : WaitEvents) {
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(Event));
}
getMsanInterceptor()
->getContextInfo(GetContext(hQueue))
->DeferredEvents.add(WaitEvents);

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -830,21 +828,17 @@ ur_result_t urEnqueueMemBufferWrite(
// Update shadow memory
std::shared_ptr<DeviceInfo> DeviceInfo =
getMsanInterceptor()->getDeviceInfo(Device);
const char Val = 0;
uptr ShadowAddr = DeviceInfo->Shadow->MemToShadow((uptr)pDst + offset);
Event = nullptr;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList,
phEventWaitList, &Event));
UR_CALL(EnqueueUSMSetZero(hQueue, (void *)ShadowAddr, size,
numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}

for (const auto &E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
getMsanInterceptor()
->getContextInfo(GetContext(hQueue))
->DeferredEvents.add(Events);
} else {
UR_CALL(pfnMemBufferWrite(hQueue, hBuffer, blockingWrite, offset, size,
pSrc, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -1039,13 +1033,11 @@ ur_result_t urEnqueueMemBufferCopy(
numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}

for (const auto &E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
getMsanInterceptor()
->getContextInfo(GetContext(hQueue))
->DeferredEvents.add(Events);
} else {
UR_CALL(pfnMemBufferCopy(hQueue, hBufferSrc, hBufferDst, srcOffset,
dstOffset, size, numEventsInWaitList,
Expand Down Expand Up @@ -1164,21 +1156,17 @@ ur_result_t urEnqueueMemBufferFill(
// Update shadow memory
std::shared_ptr<DeviceInfo> DeviceInfo =
getMsanInterceptor()->getDeviceInfo(Device);
const char Val = 0;
uptr ShadowAddr = DeviceInfo->Shadow->MemToShadow((uptr)Handle + offset);
Event = nullptr;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList,
phEventWaitList, &Event));
UR_CALL(EnqueueUSMSetZero(hQueue, (void *)ShadowAddr, size,
numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}

for (const auto &E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
getMsanInterceptor()
->getContextInfo(GetContext(hQueue))
->DeferredEvents.add(Events);
} else {
UR_CALL(pfnMemBufferFill(hQueue, hBuffer, pPattern, patternSize, offset,
size, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -1417,13 +1405,11 @@ ur_result_t urEnqueueUSMFill(

// NOTE: No need to set origin, since its shadow is clean

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}

for (const auto &E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
getMsanInterceptor()
->getContextInfo(GetContext(hQueue))
->DeferredEvents.add(Events);

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -1509,13 +1495,11 @@ ur_result_t urEnqueueUSMMemcpy(
}
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}

for (const auto &E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
getMsanInterceptor()
->getContextInfo(GetContext(hQueue))
->DeferredEvents.add(Events);

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -1574,13 +1558,11 @@ ur_result_t urEnqueueUSMFill2D(

// NOTE: No need to set origin, since its shadow is clean

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}

for (const auto &E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
getMsanInterceptor()
->getContextInfo(GetContext(hQueue))
->DeferredEvents.add(Events);

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -1682,13 +1664,11 @@ ur_result_t urEnqueueUSMMemcpy2D(
Events.push_back(Event);
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}

for (const auto E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
getMsanInterceptor()
->getContextInfo(GetContext(hQueue))
->DeferredEvents.add(Events);

return UR_RESULT_SUCCESS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ ProgramInfo::getKernelMetadata(ur_kernel_handle_t Kernel) const {
}

ContextInfo::~ContextInfo() {
DeferredEvents.releaseAll();
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRelease(Handle);
assert(Result == UR_RESULT_SUCCESS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "msan_shadow.hpp"
#include "sanitizer_common/sanitizer_common.hpp"
#include "sanitizer_common/sanitizer_options.hpp"
#include "sanitizer_common/sanitizer_utils.hpp"
#include "ur_sanitizer_layer.hpp"

#include <memory>
Expand Down Expand Up @@ -141,6 +142,7 @@ struct ContextInfo {
std::atomic<int32_t> RefCount = 1;

std::vector<ur_device_handle_t> DeviceList;
DeferredEventList DeferredEvents;

explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) {
[[maybe_unused]] auto Result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,38 @@

#include "sanitizer_libdevice.hpp"
#include "unified-runtime/ur_api.h"
#include "ur/ur.hpp"
#include "ur_sanitizer_layer.hpp"

#include <string>
#include <vector>

namespace ur_sanitizer_layer {

// Accumulates events whose release must be deferred until a safe point
// (e.g., context release). L0 may not retain input events passed to
// pfnEventsWait long enough for the caller to release them immediately.
struct DeferredEventList {
void add(const std::vector<ur_event_handle_t> &Events) {
std::scoped_lock<ur_shared_mutex> Lock(Mutex);
List.insert(List.end(), Events.begin(), Events.end());
}

void releaseAll() {
std::scoped_lock<ur_shared_mutex> Lock(Mutex);
for (auto &E : List) {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Event.pfnRelease(E);
assert(Result == UR_RESULT_SUCCESS);
}
List.clear();
}

private:
ur_shared_mutex Mutex;
std::vector<ur_event_handle_t> List;
};

struct ManagedQueue {
ManagedQueue(ur_context_handle_t Context, ur_device_handle_t Device,
bool IsOutOfOrder = false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,11 @@ ur_result_t EnqueueMemCopyRectHelper(
UR_CALL(getContext()->urDdiTable.Event.pfnWait(Events.size(), &Events[0]));
}

if (Event) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(Queue, Events.size(),
&Events[0], Event));
}

for (const auto &E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(Queue, Events.size(),
&Events[0], Event));
getTsanInterceptor()
->getContextInfo(GetContext(Queue))
->DeferredEvents.add(Events);

return UR_RESULT_SUCCESS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,16 @@ struct ContextInfo {
std::unordered_map<ur_device_handle_t, std::optional<ManagedQueue>>
InternalQueueMap;

DeferredEventList DeferredEvents;

explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRetain(Context);
assert(Result == UR_RESULT_SUCCESS);
}

~ContextInfo() {
DeferredEvents.releaseAll();
InternalQueueMap.clear();
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRelease(Handle);
Expand Down
Loading