Skip to content

Add infrastructure for auto EP selection #24430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Apr 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3a05559
Add infrastructure for auto EP selection.
skottmckay Apr 15, 2025
9d23198
Fix merge
skottmckay Apr 15, 2025
28cf01d
Apply suggestions from code review
skottmckay Apr 15, 2025
4acf072
Fix CI build failures
skottmckay Apr 16, 2025
bf9ead7
Merge remote-tracking branch 'origin/main' into skottmckay/AutoSelect…
skottmckay Apr 16, 2025
df9cfca
Fix merge
skottmckay Apr 16, 2025
db16855
Merge with main.
skottmckay Apr 17, 2025
613b06f
Address PR comments
skottmckay Apr 17, 2025
ed24b04
Address PR comments
skottmckay Apr 18, 2025
62d6cb2
Merge
skottmckay Apr 18, 2025
fec1d84
Merge branch 'main' into skottmckay/AutoSelectEpInfrastructure_PR
skottmckay Apr 18, 2025
36e44cb
Merge branch 'skottmckay/AutoSelectEpInfrastructure_PR' of https://gi…
skottmckay Apr 18, 2025
602022e
Fix minimal builds.
skottmckay Apr 18, 2025
fab34cd
Update comment for CreateEpApiFactoriesFn
skottmckay Apr 18, 2025
c00a1aa
lint
skottmckay Apr 18, 2025
1c9f8ef
lint
skottmckay Apr 18, 2025
6dcc943
Add TRT dependency to unit tests when it's enabled.
skottmckay Apr 18, 2025
0c21238
Add TRT library before CUDA. Might not matter but seems like a better…
skottmckay Apr 18, 2025
107173a
Apply suggestions from code review
skottmckay Apr 18, 2025
42eeb3b
Merge remote-tracking branch 'origin/main' into skottmckay/AutoSelect…
skottmckay Apr 18, 2025
e90ca4e
Move GetProviderOptionPrefix to OrtSessionOptions. Keeps the dependen…
skottmckay Apr 18, 2025
29b5daa
Split assert helpers to avoid dependency on OrtGetApi unless required.
skottmckay Apr 18, 2025
bb6d8b1
Update onnxruntime/core/session/ep_library.cc
skottmckay Apr 19, 2025
23e6a99
Add onnxruntime_common to autoep_test libs
skottmckay Apr 19, 2025
0a513d6
Merge branch 'skottmckay/AutoSelectEpInfrastructure_PR' of https://gi…
skottmckay Apr 19, 2025
7628efd
Add onnx_proto back in as a dependency for the autoep test project.
skottmckay Apr 19, 2025
eb2fd13
Fix filename
skottmckay Apr 19, 2025
85a770d
Fix autoep unit test dependencies
skottmckay Apr 19, 2025
38a9958
Add part 2
skottmckay Apr 19, 2025
9a25cdd
Skip building autoep tests on non-Windows platforms.
skottmckay Apr 19, 2025
97088b0
Fix minimal build.
skottmckay Apr 19, 2025
e6d11f8
Add all external libraries to link list
skottmckay Apr 19, 2025
a4cdae9
Update cmake/onnxruntime_unittests.cmake
skottmckay Apr 19, 2025
689bce9
Exclude LoadPluginOrProviderBridge in minimal build.
skottmckay Apr 19, 2025
5951350
Merge branch 'skottmckay/AutoSelectEpInfrastructure_PR' of https://gi…
skottmckay Apr 19, 2025
4905d5c
Fix build warning.
skottmckay Apr 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cmake/onnxruntime_common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ set(onnxruntime_common_src_patterns
"${ONNXRUNTIME_ROOT}/core/common/logging/*.cc"
"${ONNXRUNTIME_ROOT}/core/common/logging/sinks/*.h"
"${ONNXRUNTIME_ROOT}/core/common/logging/sinks/*.cc"
"${ONNXRUNTIME_ROOT}/core/platform/device_discovery.h"
"${ONNXRUNTIME_ROOT}/core/platform/device_discovery.cc"
"${ONNXRUNTIME_ROOT}/core/platform/env.h"
"${ONNXRUNTIME_ROOT}/core/platform/env.cc"
"${ONNXRUNTIME_ROOT}/core/platform/env_time.h"
Expand Down Expand Up @@ -157,6 +159,9 @@ if(APPLE)
target_link_libraries(onnxruntime_common PRIVATE "-framework Foundation")
endif()

if(MSVC)
target_link_libraries(onnxruntime_common PRIVATE dxcore.lib)
endif()

if(MSVC)
if(onnxruntime_target_platform STREQUAL "ARM64")
Expand Down Expand Up @@ -205,7 +210,6 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
endif()
endif()


if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64)
# Link cpuinfo if supported
# Using it mainly in ARM with Android.
Expand Down
8 changes: 7 additions & 1 deletion cmake/onnxruntime_session.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ if (onnxruntime_ENABLE_TRAINING_APIS)
list(APPEND onnxruntime_session_srcs ${training_api_srcs})
endif()


# disable for all minimal builds. enabling this pulls in all the provider bridge stuff,
# which is not enabled for any minimal builds.
if (onnxruntime_MINIMAL_BUILD)
file(GLOB autoep_srcs
"${ONNXRUNTIME_ROOT}/core/session/ep_*.*"
)

set(onnxruntime_session_src_exclude
"${ONNXRUNTIME_ROOT}/core/session/provider_bridge_ort.cc"
"${ONNXRUNTIME_ROOT}/core/session/model_builder_c_api.cc"
${autoep_srcs}
)

list(REMOVE_ITEM onnxruntime_session_srcs ${onnxruntime_session_src_exclude})
Expand Down
65 changes: 65 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${TEST_SRC_DIR}/shared_lib")
set (ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR "${TEST_SRC_DIR}/global_thread_pools")
set (ONNXRUNTIME_CUSTOM_OP_REGISTRATION_TEST_SRC_DIR "${TEST_SRC_DIR}/custom_op_registration")
set (ONNXRUNTIME_LOGGING_APIS_TEST_SRC_DIR "${TEST_SRC_DIR}/logging_apis")
set (ONNXRUNTIME_AUTOEP_TEST_SRC_DIR "${TEST_SRC_DIR}/autoep")

set (onnxruntime_shared_lib_test_SRC
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h
Expand Down Expand Up @@ -1798,6 +1799,70 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI
${ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG})
endif()

# Build library that can be used with RegisterExecutionProviderLibrary and automatic EP selection
# We need a shared lib build to use that as a dependency for the test library
# Currently we only have device discovery on Windows so no point building the test app on other platforms.
if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND
NOT onnxruntime_MINIMAL_BUILD)
onnxruntime_add_shared_library_module(example_plugin_ep
${TEST_SRC_DIR}/autoep/library/example_plugin_ep.cc)
target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
target_link_libraries(example_plugin_ep PRIVATE onnxruntime)

if(UNIX)
if (APPLE)
set(ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG "-Xlinker -dead_strip")
elseif (NOT CMAKE_SYSTEM_NAME MATCHES "AIX")
string(CONCAT ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG
"-Xlinker --version-script=${TEST_SRC_DIR}/autoep/library/example_plugin_ep_library.lds "
"-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
endif()
else()
set(ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG
"-DEF:${TEST_SRC_DIR}/autoep/library/example_plugin_ep_library.def")
endif()

set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS
${ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG})

# test library
file(GLOB_RECURSE onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h"
"${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc")

set(onnxruntime_autoep_test_LIBS onnxruntime_mocked_allocator ${ONNXRUNTIME_TEST_LIBS} onnxruntime_test_utils
onnx_proto onnx ${onnxruntime_EXTERNAL_LIBRARIES})

if (onnxruntime_USE_TENSORRT)
list(APPEND onnxruntime_autoep_test_LIBS ${TENSORRT_LIBRARY_INFER})
endif()

if (onnxruntime_USE_CUDA)
list(APPEND onnxruntime_autoep_test_LIBS CUDA::cudart)
endif()

if (onnxruntime_USE_DML)
list(APPEND onnxruntime_autoep_test_LIBS d3d12.lib)
endif()

if (CPUINFO_SUPPORTED)
list(APPEND onnxruntime_autoep_test_LIBS cpuinfo)
endif()

if (CMAKE_SYSTEM_NAME MATCHES "AIX")
list(APPEND onnxruntime_autoep_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers
onnxruntime_optimizer onnxruntime_mlas onnxruntime_framework onnxruntime_util onnxruntime_flatbuffers
iconv re2 onnx)
endif()

AddTest(DYN
TARGET onnxruntime_autoep_test
SOURCES ${onnxruntime_autoep_test_SRC} ${onnxruntime_unittest_main_src}
LIBS ${onnxruntime_autoep_test_LIBS}
DEPENDS ${all_dependencies} example_plugin_ep
)
endif()

if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD)
set (onnxruntime_logging_apis_test_SRC
${ONNXRUNTIME_LOGGING_APIS_TEST_SRC_DIR}/test_logging_apis.cc)
Expand Down
7 changes: 6 additions & 1 deletion include/onnxruntime/core/common/logging/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
*/

namespace onnxruntime {
struct OrtLogger; // opaque API type. is always an instance of Logger

namespace onnxruntime {
namespace logging {

using Timestamp = std::chrono::time_point<std::chrono::system_clock>;
Expand Down Expand Up @@ -352,6 +353,10 @@ class Logger {
logging_manager_->SendProfileEvent(eventRecord);
}

// convert to API type for custom ops and plugin EPs
OrtLogger* ToExternal() { return reinterpret_cast<OrtLogger*>(this); }
const OrtLogger* ToExternal() const { return reinterpret_cast<const OrtLogger*>(this); }

private:
const LoggingManager* logging_manager_;
const std::string id_;
Expand Down
83 changes: 79 additions & 4 deletions include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,30 @@
#pragma once

#include <atomic>
#include <filesystem>

Check warning on line 7 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 <filesystem> is an unapproved C++17 header. [build/c++17] [5] Raw Output: include/onnxruntime/core/session/environment.h:7: <filesystem> is an unapproved C++17 header. [build/c++17] [5]
#include <memory>

#include "core/common/common.h"
#include "core/common/status.h"
#include "core/platform/threadpool.h"
#include "core/common/basic_types.h"
#include "core/common/logging/logging.h"
#include "core/common/status.h"
#include "core/framework/allocator.h"
#include "core/framework/execution_provider.h"
#include "core/platform/device_discovery.h"
#include "core/platform/threadpool.h"

#include "core/session/abi_devices.h"
#include "core/session/ep_library.h"
#include "core/session/onnxruntime_c_api.h"

struct OrtThreadingOptions;
namespace onnxruntime {
/** TODO: remove this class
class EpFactoryInternal;
class InferenceSession;
struct IExecutionProviderFactory;
struct SessionOptions;

/**
Provides the runtime environment for onnxruntime.
Create one instance for the duration of execution.
*/
Expand Down Expand Up @@ -86,10 +100,30 @@
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
* For provider_type please refer core/graph/constants.h
*/
Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg = nullptr);
Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info,
const std::unordered_map<std::string, std::string>& options,
const OrtArenaCfg* arena_cfg = nullptr);

#if !defined(ORT_MINIMAL_BUILD)
Status RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path);
Status UnregisterExecutionProviderLibrary(const std::string& registration_name);

// convert an OrtEpFactory* to EpFactoryInternal* if possible.
EpFactoryInternal* GetEpFactoryInternal(OrtEpFactory* factory) const {
// we're comparing pointers so the reinterpret_cast should be safe
auto it = internal_ep_factories_.find(reinterpret_cast<EpFactoryInternal*>(factory));
return it != internal_ep_factories_.end() ? *it : nullptr;
}

const std::vector<const OrtEpDevice*>& GetOrtEpDevices() const {
return execution_devices_;
}
#endif // !defined(ORT_MINIMAL_BUILD)
~Environment();

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);

Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
const OrtThreadingOptions* tp_options = nullptr,
bool create_global_thread_pools = false);
Expand All @@ -99,5 +133,46 @@
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
bool create_global_thread_pools_{false};
std::vector<AllocatorPtr> shared_allocators_;

#if !defined(ORT_MINIMAL_BUILD)
// register EPs that are built into the ORT binary so they can take part in AutoEP selection
// added to ep_libraries
Status CreateAndRegisterInternalEps();

Status RegisterExecutionProviderLibrary(const std::string& registration_name,
std::unique_ptr<EpLibrary> ep_library,
const std::vector<EpFactoryInternal*>& internal_factories = {});

struct EpInfo {
// calls EpLibrary::Load
// for each factory gets the OrtEpDevice instances and adds to execution_devices
// internal_factory is set if this is an internal EP
static Status Create(std::unique_ptr<EpLibrary> library_in, std::unique_ptr<EpInfo>& out,
const std::vector<EpFactoryInternal*>& internal_factories = {});

// removes entries for this library from execution_devices
// calls EpLibrary::Unload
~EpInfo();

std::unique_ptr<EpLibrary> library;
std::vector<std::unique_ptr<OrtEpDevice>> execution_devices;
std::vector<EpFactoryInternal*> internal_factories; // factories that can create IExecutionProvider instances

private:
EpInfo() = default;
};

// registration name to EpInfo for library
std::unordered_map<std::string, std::unique_ptr<EpInfo>> ep_libraries_;

Check warning on line 166 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/environment.h:166: Add #include <string> for string [build/include_what_you_use] [4]

Check warning on line 166 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/environment.h:166: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]

// combined set of OrtEpDevices for all registered OrtEpFactory instances
// std::vector so we can use directly in GetEpDevices.
// inefficient when EPs are unregistered but that is not expected to be a common operation.
std::vector<const OrtEpDevice*> execution_devices_;

Check warning on line 171 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/environment.h:171: Add #include <vector> for vector<> [build/include_what_you_use] [4]

// lookup set for internal EPs so we can create an IExecutionProvider directly
std::unordered_set<EpFactoryInternal*> internal_ep_factories_;

Check warning on line 174 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/environment.h:174: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]
#endif // !defined(ORT_MINIMAL_BUILD)
};

} // namespace onnxruntime
Loading
Loading