Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 49 additions & 48 deletions src/runtime_src/core/common/api/xrt_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ class kernel_command : public xrt_core::command
, m_hwqueue(std::move(hwqueue))
, m_hwctx(std::move(hwctx))
, m_execbuf(m_device->create_exec_buf<ert_start_kernel_cmd>())
, m_done(true)
, m_run_state(run_state::ready)
{
static unsigned int count = 0;
m_uid = count++;
Expand Down Expand Up @@ -758,8 +758,7 @@ class kernel_command : public xrt_core::command
bool
is_done() const
{
std::lock_guard<std::mutex> lk(m_mutex);
return m_done;
return (m_run_state != run_state::running) || (!m_managed && get_state() >= ERT_CMD_STATE_COMPLETED);
}

// Return state of command object. The underlying packet
Expand Down Expand Up @@ -812,50 +811,48 @@ class kernel_command : public xrt_core::command
ert_cmd_state state = ERT_CMD_STATE_MAX;
{
std::lock_guard<std::mutex> lk(m_mutex);
if (!m_managed && !m_done)
if (!m_managed && m_run_state != run_state::ready)
throw xrt_core::error(ENOTSUP, "Cannot add callback to running unmanaged command");
if (!m_callbacks)
m_callbacks = std::make_unique<callback_list>();
m_callbacks->emplace_back(std::move(fcn));
m_callbacks.emplace_back(std::move(fcn));
auto pkt = get_ert_packet();
state = static_cast<ert_cmd_state>(pkt->state);
complete = m_done && state >= ERT_CMD_STATE_COMPLETED;
complete = (m_run_state == run_state::ready) && (state >= ERT_CMD_STATE_COMPLETED);
}

// lock must not be helt while calling callback function
// lock must not be held while calling callback function
if (complete)
m_callbacks.get()->back()(state);
m_callbacks.back()(state);
}

// Remove last added callback
void
pop_callback()
{
if (m_callbacks && m_callbacks->size())
m_callbacks->pop_back();
if (!m_callbacks.empty())
m_callbacks.pop_back();
}

// Run registered callbacks.
void
run_callbacks(ert_cmd_state state) const
{
std::vector<const callback_function_type*> copy;

{
std::lock_guard<std::mutex> lk(m_mutex);
if (!m_callbacks)

if (m_callbacks.empty())
return;
}


// cannot lock mutex while calling the callbacks
// so copy address of callbacks while holding the lock
// then execute callbacks without lock
std::vector<callback_function_type*> copy;
copy.reserve(m_callbacks->size());
// cannot lock mutex while calling the callbacks
// so copy address of callbacks while holding the lock
// then execute callbacks without lock
copy.reserve(m_callbacks.size());

{
std::lock_guard<std::mutex> lk(m_mutex);
std::transform(m_callbacks->begin(),m_callbacks->end()
std::transform(m_callbacks.begin(),m_callbacks.end()
,std::back_inserter(copy)
,[](callback_function_type& cb) { return &cb; });
,[](const callback_function_type& cb) { return &cb; });
}

for (auto cb : copy)
Expand All @@ -868,10 +865,13 @@ class kernel_command : public xrt_core::command
{
{
std::lock_guard<std::mutex> lk(m_mutex);
if (!m_done)

run_state state = run_state::ready;
m_run_state.compare_exchange_strong(state, run_state::running);
if (state != run_state::ready)
throw std::runtime_error("bad command state, can't launch");
m_managed = (m_callbacks && !m_callbacks->empty());
m_done = false;

m_managed = (!m_callbacks.empty());
}

try {
Expand All @@ -881,10 +881,9 @@ class kernel_command : public xrt_core::command
m_hwqueue.unmanaged_start(this);
}
catch (...) {
// Start failed, m_done remains true
// Start failed, run_state remains ready
// command can be retried if needed
std::lock_guard<std::mutex> lk(m_mutex);
m_done = true;
m_run_state = run_state::ready;
throw;
}
}
Expand All @@ -895,7 +894,7 @@ class kernel_command : public xrt_core::command
{
if (m_managed) {
std::unique_lock<std::mutex> lk(m_mutex);
while (!m_done)
while (m_run_state == run_state::running)
m_exec_done.wait(lk);
}
else {
Expand All @@ -910,7 +909,7 @@ class kernel_command : public xrt_core::command
{
if (m_managed) {
std::unique_lock<std::mutex> lk(m_mutex);
while (!m_done)
while (m_run_state == run_state::running)
if (m_exec_done.wait_for(lk, timeout_ms) == std::cv_status::timeout)
return {get_state_raw(), std::cv_status::timeout};
}
Expand Down Expand Up @@ -954,27 +953,25 @@ class kernel_command : public xrt_core::command
void
notify(ert_cmd_state s) const override
{
run_state state = run_state::running;
bool complete = false;
bool callbacks = false;
if (s >= ERT_CMD_STATE_COMPLETED) {
std::lock_guard<std::mutex> lk(m_mutex);

// Handle potential race if multiple threads end up here. This
// condition is by design because there are multiple paths into
// this function and first conditional check should not be locked
if (m_done)
return;

complete = m_run_state.compare_exchange_strong(state, run_state::finishing);
XRT_DEBUGF("kernel_command::notify() m_uid(%d) m_state(%d)\n", m_uid, s);
complete = m_done = true;
callbacks = (m_callbacks && !m_callbacks->empty());
}

if (complete) {
// m_run_state is flipped to finishing to prevent m_managed becoming true at this point
// Otherwise there's a potential deadlock because run_callbacks() takes the lock,
// and we can get here while already holding the lock, but only when m_managed is false

if (m_managed && complete) {
m_exec_done.notify_all();
if (callbacks)
run_callbacks(s);
run_callbacks(s);
}

// if we were finishing, we're done done - ready to run more commands
state = run_state::finishing;
m_run_state.compare_exchange_strong(state, run_state::ready);
}

void
Expand All @@ -992,13 +989,17 @@ class kernel_command : public xrt_core::command
xrt::hw_context m_hwctx; // hw_context for command
execbuf_type m_execbuf; // underlying execution buffer
unsigned int m_uid = 0;
bool m_managed = false;
mutable bool m_done = false;
enum class run_state : int {
ready,
running,
finishing,
};
mutable std::atomic<run_state> m_run_state = run_state::ready;

mutable std::mutex m_mutex;
mutable std::condition_variable m_exec_done;

std::unique_ptr<callback_list> m_callbacks;
callback_list m_callbacks;
};

// class argument - get argument value from va_arg
Expand Down
Loading