Skip to content

[core] Make IAsyncInferRequest more customizable #29809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 23 additions & 84 deletions src/inference/dev_api/openvino/runtime/iasync_infer_request.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

#pragma once

#include <future>
#include <memory>

#include "openvino/runtime/common.hpp"
#include "openvino/runtime/exception.hpp"
#include "openvino/runtime/iinfer_request.hpp"
#include "openvino/runtime/infer_request_fsm.hpp"
#include "openvino/runtime/ipipeline_process.hpp"
#include "openvino/runtime/profiling_info.hpp"
#include "openvino/runtime/tensor.hpp"
#include "openvino/runtime/threading/itask_executor.hpp"
Expand Down Expand Up @@ -155,11 +156,24 @@ class OPENVINO_RUNTIME_API IAsyncInferRequest : public IInferRequest {
const std::vector<ov::Output<const ov::Node>>& get_outputs() const override;

protected:
using Stage = std::pair<std::shared_ptr<ov::threading::ITaskExecutor>, ov::threading::Task>;
using Stage = IPipelineProcess::Stage;
/**
* @brief Pipeline is vector of stages
*/
using Pipeline = std::vector<Stage>;
using Pipeline = IPipelineProcess::Pipeline;

/**
* @brief Constructor for IAsyncInferRequest
* @param request Synchronous infer request
* @param task_executor Task executor for pipeline stages
* @param callback_executor Task executor for callback
* @param fsm State machine for asynchronous request.
*/
IAsyncInferRequest(const std::shared_ptr<IInferRequest>& request,
const std::shared_ptr<ov::threading::ITaskExecutor>& task_executor,
const std::shared_ptr<ov::threading::ITaskExecutor>& callback_executor,
std::unique_ptr<InferRequestFsm> fsm,
std::unique_ptr<IPipelineProcess> pipeline_process);

/**
* @brief Forbids pipeline start and wait for all started pipelines.
Expand All @@ -177,7 +191,7 @@ class OPENVINO_RUNTIME_API IAsyncInferRequest : public IInferRequest {
*/
void check_cancelled_state() const;
/**
* @brief Performs inference of pipeline in syncronous mode
* @brief Performs inference of pipeline in synchronous mode
* @note Used by Infer which ensures thread-safety and calls this method after.
*/
virtual void infer_thread_unsafe();
Expand All @@ -193,90 +207,15 @@ class OPENVINO_RUNTIME_API IAsyncInferRequest : public IInferRequest {

Pipeline m_pipeline; //!< Pipeline variable that should be filled by inherited class.
Pipeline m_sync_pipeline; //!< Synchronous pipeline variable that should be filled by inherited class.
std::unique_ptr<InferRequestFsm> m_fsm; //!< State machine for asynchronous request.
std::unique_ptr<IPipelineProcess> m_pipeline_process; //!< Pipeline process for request.

private:
enum InferState { IDLE, BUSY, CANCELLED, STOP };
using Futures = std::vector<std::shared_future<void>>;
enum Stage_e : std::uint8_t { EXECUTOR, TASK };
InferState m_state = InferState::IDLE;
Futures m_futures;
std::promise<void> m_promise;

friend struct DisableCallbackGuard;
struct DisableCallbackGuard {
explicit DisableCallbackGuard(IAsyncInferRequest* this_) : _this{this_} {
std::lock_guard<std::mutex> lock{_this->m_mutex};
std::swap(m_callback, _this->m_callback);
}
~DisableCallbackGuard() {
std::lock_guard<std::mutex> lock{_this->m_mutex};
_this->m_callback = m_callback;
}
IAsyncInferRequest* _this = nullptr;
std::function<void(std::exception_ptr)> m_callback;
};

void run_first_stage(const Pipeline::iterator itBeginStage,
const Pipeline::iterator itEndStage,
const std::shared_ptr<ov::threading::ITaskExecutor> callbackExecutor = {});

ov::threading::Task make_next_stage_task(const Pipeline::iterator itStage,
const Pipeline::iterator itEndStage,
const std::shared_ptr<ov::threading::ITaskExecutor> callbackExecutor);

template <typename F>
void infer_impl(const F& f) {
check_tensors();
InferState state = InferState::IDLE;
{
std::lock_guard<std::mutex> lock{m_mutex};
state = m_state;
switch (m_state) {
case InferState::BUSY:
ov::Busy::create("Infer Request is busy");
case InferState::CANCELLED:
ov::Cancelled::create("Infer Request was canceled");
case InferState::IDLE: {
m_futures.erase(std::remove_if(std::begin(m_futures),
std::end(m_futures),
[](const std::shared_future<void>& future) {
if (future.valid()) {
return (std::future_status::ready ==
future.wait_for(std::chrono::milliseconds{0}));
} else {
return true;
}
}),
m_futures.end());
m_promise = {};
m_futures.emplace_back(m_promise.get_future().share());
} break;
case InferState::STOP:
break;
}
m_state = InferState::BUSY;
}
if (state != InferState::STOP) {
try {
f();
} catch (...) {
m_promise.set_exception(std::current_exception());
std::lock_guard<std::mutex> lock{m_mutex};
m_state = InferState::IDLE;
throw;
}
}
}

std::shared_ptr<IInferRequest> m_sync_request;

std::shared_ptr<ov::threading::ITaskExecutor> m_request_executor; //!< Used to run inference CPU tasks.
std::shared_ptr<ov::threading::ITaskExecutor>
m_callback_executor; //!< Used to run post inference callback in asynchronous pipline
std::shared_ptr<ov::threading::ITaskExecutor>
threading::ITaskExecutor::Ptr m_request_executor; //!< Used to run inference CPU tasks.
threading::ITaskExecutor::Ptr m_callback_executor; //!< Used to run post inference callback in asynchronous pipline
threading::ITaskExecutor::Ptr
m_sync_callback_executor; //!< Used to run post inference callback in synchronous pipline
mutable std::mutex m_mutex;
std::function<void(std::exception_ptr)> m_callback;
};

} // namespace ov
136 changes: 136 additions & 0 deletions src/inference/dev_api/openvino/runtime/infer_request_fsm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

/**
* @brief OpenVINO Runtime InfeR interface
* @file openvino/runtime/iinfer_request_fsm.hpp
*/

#pragma once

#include <chrono>
#include <condition_variable>
#include <mutex>
#include <variant>
#include <vector>

#include "openvino/runtime/threading/itask_executor.hpp"
#include "openvino/runtime/ipipeline_process.hpp"

namespace ov {

class OPENVINO_RUNTIME_API InferRequestFsm {
public:
using PipelineIter = IPipelineProcess::Pipeline::iterator;

InferRequestFsm();

// define base events
struct StartEvent {
const PipelineIter first_stage; // !< Iterator to the first stage of the pipeline.
const PipelineIter last_stage; // !< Iterator to the last stage of the pipeline.
const threading::ITaskExecutor::Ptr callback_executor; // !< Executor for the callback.
const IPipelineProcess::pipeline_process_func process_pipeline; // !< Function to process the pipeline.
};

struct StopEvent {};
struct CancelEvent {};
struct DoneEvent {};

template <class Event>
void on_event(const Event& event) {
std::visit(
[&event](auto& state) {
state.on_event(event);
},
m_state);
}

/**
* @brief Checks if the FSM is in Idle state.
* @return True if in Idle state, false otherwise.
*/
bool is_ready() const;
/**
* @brief Checks if the FSM is in busy state.
* @return True if in busy state, false otherwise.
*/
bool is_busy() const;

/**
* @brief Checks if the FSM is in cancelled state.
* @return True if in cancelled state, false otherwise.
*/
bool is_cancelled() const;

/**
* @brief Locks the FSM mutex.
* @return A unique lock for the mutex.
*/
std::unique_lock<std::mutex> lock();

private:
template <class... Actions>
struct EventHandlers : Actions... {
using Actions::on_event...;
};

struct NoAction {
template <class Event>
void on_event(const Event&) {}
};

using Event = std::variant<StartEvent, StopEvent, CancelEvent>;

struct StateBase : EventHandlers<NoAction> {
using EventHandlers::on_event;

StateBase() : m_fsm{} {};
StateBase(InferRequestFsm* fsm) : m_fsm{fsm} {}

InferRequestFsm* m_fsm;
};

// defined base states
struct Idle : public StateBase {
using StateBase::on_event;

Idle(InferRequestFsm* fsm) : StateBase{fsm} {}
Idle() = default;

void on_event(const StartEvent& event);
void on_event(const StopEvent& event);
};

struct Busy : public StateBase {
using StateBase::on_event;

Busy(InferRequestFsm* fsm) : StateBase{fsm} {}

void on_event(const StartEvent& event);
void on_event(const StopEvent& event);
void on_event(const CancelEvent& event);
void on_event(const DoneEvent& event);
};

struct Cancelled : public StateBase {
using StateBase::on_event;

Cancelled(InferRequestFsm* fsm) : StateBase{fsm} {}

void on_event(const StartEvent& event);
void on_event(const StopEvent& event);
void on_event(const DoneEvent& event);
};

struct Stop : public EventHandlers<NoAction> {
using EventHandlers::on_event;
};

using State = std::variant<Idle, Busy, Cancelled, Stop>;

State m_state{}; //!< State of the request.
mutable std::mutex m_mutex{}; //!< Mutex to protect state and callback.
};
} // namespace ov
92 changes: 92 additions & 0 deletions src/inference/dev_api/openvino/runtime/ipipeline_process.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

/**
* @brief OpenVINO Runtime interface pipeline processing
* @file openvino/runtime/ipipeline_process_fsm.hpp
*/

#pragma once

#include <chrono>
#include <functional>
#include <future>
#include <mutex>
#include <vector>

#include "openvino/runtime/threading/itask_executor.hpp"

namespace ov {

class OPENVINO_RUNTIME_API IPipelineProcess {
protected:
struct DisableCallbackGuard {
DisableCallbackGuard() = delete;
DisableCallbackGuard(const DisableCallbackGuard&) = delete;
DisableCallbackGuard& operator=(const DisableCallbackGuard&) = delete;

explicit DisableCallbackGuard(IPipelineProcess& pipeline);
~DisableCallbackGuard();

IPipelineProcess* _this;
std::function<void(std::exception_ptr)> m_callback;
};

public:
using Stage = std::pair<threading::ITaskExecutor::Ptr, threading::Task>;
using Pipeline = std::vector<Stage>;
using pipeline_process_func = std::function<void(const Pipeline::iterator first_stage,
const Pipeline::iterator lasts_stage,
const threading::ITaskExecutor::Ptr callback_executor)>;
using callback_func = std::function<void(std::exception_ptr)>;

virtual ~IPipelineProcess();

virtual void wait() = 0;
virtual bool wait_for(const std::chrono::milliseconds& timeout) = 0;
virtual void stop() = 0;
virtual void prepare_sync() = 0;
virtual void prepare_async() = 0;
virtual pipeline_process_func sync_pipeline_func() = 0;
virtual pipeline_process_func async_pipeline_func() = 0;
virtual void set_exception(std::exception_ptr) = 0;
virtual void set_callback(std::function<void(std::exception_ptr)>) = 0;
virtual DisableCallbackGuard disable_callback() = 0;

protected:
enum Stage_e : std::uint8_t { EXECUTOR, TASK };
virtual void swap_callbacks(callback_func& other) = 0;
};

class OPENVINO_RUNTIME_API PipelineProcess : public IPipelineProcess {
public:
PipelineProcess();
PipelineProcess(std::function<void(void)> fsm_notify);

void wait() override;
bool wait_for(const std::chrono::milliseconds& timeout) override;
void stop() override;
void prepare_sync() override;
void prepare_async() override;
pipeline_process_func sync_pipeline_func() override;
pipeline_process_func async_pipeline_func() override;
void set_exception(std::exception_ptr) override;
void set_callback(std::function<void(std::exception_ptr)>) override;
DisableCallbackGuard disable_callback() override;

private:
void swap_callbacks(callback_func& other) override;
std::shared_future<void> get_last_future() const;
ov::threading::Task make_next_stage_task(const Pipeline::iterator itStage,
const Pipeline::iterator itEndStage,
const std::shared_ptr<ov::threading::ITaskExecutor> callbackExecutor);

using Futures = std::vector<std::shared_future<void>>;
mutable std::mutex m_mutex;
Futures m_futures;
std::promise<void> m_promise;
std::function<void(std::exception_ptr)> m_callback; //!< Called on on success or failure of asynchronous request.
std::function<void()> m_fsm_done_event; //!< Called when pipeline done to notify FSM.
};
} // namespace ov
Loading
Loading