diff --git a/include/onnxruntime/ep/README.md b/include/onnxruntime/ep/README.md new file mode 100644 index 0000000000000..2c544d2221f96 --- /dev/null +++ b/include/onnxruntime/ep/README.md @@ -0,0 +1,19 @@ +## EP adapter + +This folder contains a set of C++ header files. They are used specifically for allowing ONNX Runtime internal kernel-based EPs to use the plugin-style EP API while keep minimal changes to existing code. + +### Folder Structure + +There are 2 types of header files: + +- General header files for plugin EP. This may include utilities, macros and shared routines that depending on ONNX Runtime public API only. There are multiple places for header files of this category (which we are going to unify them to one place. There is an ongoing discussion about unifying shared headers for plugin EPs): + - `include/onnxruntime/ep/` (#26919) + - `onnxruntime/test/autoep/library/plugin_ep_utils.h` + - `include/onnxruntime/core/providers/utils/` (#25753) + +- Header files specifically used for supporting WebGPU EP and CUDA EP to use EP APIs. These header files do not only depend on ONNX Runtime public API, but also depend on ONNX Runtime internal headers. + - `include/onnxruntime/ep/adapter/` + +### Usage + +Make sure to include "ep/_pch.h" for all source code in the implementation. Using PCH compiler flag is recommended. diff --git a/include/onnxruntime/ep/_pch.h b/include/onnxruntime/ep/_pch.h new file mode 100644 index 0000000000000..0299b9279b327 --- /dev/null +++ b/include/onnxruntime/ep/_pch.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "api.h" +#include "common.h" + +// This header is only used when building WebGPU/CUDA EP as a shared library. +// +// This header file is used as a precompiled header so it is always included first. + +#pragma push_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED") +#undef ORT_EP_API_ADAPTER_HEADER_INCLUDED +#define ORT_EP_API_ADAPTER_HEADER_INCLUDED + +#include "adapter/allocator.h" +#include "adapter/logging.h" +#include "adapter/ep.h" +#include "adapter/kernel_registry.h" + +#pragma pop_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED") + +// +// EP specific using declarations +// + +#define EP_SPECIFIC_USING_DECLARATIONS \ + using FuncManager = onnxruntime::ep::adapter::FuncManager; \ + using KernelCreatePtrFn = onnxruntime::ep::adapter::KernelCreatePtrFn; \ + using KernelDefBuilder = onnxruntime::ep::adapter::KernelDefBuilder; \ + using KernelRegistry = onnxruntime::ep::adapter::KernelRegistry; \ + using KernelCreateInfo = onnxruntime::ep::adapter::KernelCreateInfo; \ + using BuildKernelCreateInfoFn = onnxruntime::ep::adapter::KernelCreateInfo (*)(); \ + using OpKernelInfo = onnxruntime::ep::adapter::OpKernelInfo; \ + using OpKernelContext = onnxruntime::ep::adapter::OpKernelContext; \ + using OpKernel = onnxruntime::ep::adapter::OpKernel; \ + using DataTransferManager = onnxruntime::ep::adapter::DataTransferManager; \ + namespace logging { \ + using Logger = onnxruntime::ep::adapter::Logger; \ + } + +namespace onnxruntime { +namespace webgpu { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace webgpu +namespace cuda { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace cuda + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { +namespace webgpu { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace webgpu +namespace cuda { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace cuda +} // namespace contrib +#endif + +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h new file mode 100644 index 0000000000000..42c8c7ba8802e --- /dev/null +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "core/framework/allocator.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// A bridge class between the EP API OrtAllocator and an IAllocator implementation. +/// +class Allocator : public OrtAllocator { + public: + explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorPtr impl) + : OrtAllocator{}, memory_info_(memory_info), impl_(impl) { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + } + + private: + static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept { + auto* allocator = static_cast(this_ptr); + return allocator->impl_->Alloc(size); + } + + static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept { + auto* allocator = static_cast(this_ptr); + allocator->impl_->Free(p); + } + + static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept { + auto* allocator = static_cast(this_ptr); + return allocator->memory_info_; + } + + const OrtMemoryInfo* memory_info_; + AllocatorPtr impl_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/data_transfer_manager.h b/include/onnxruntime/ep/adapter/data_transfer_manager.h new file mode 100644 index 0000000000000..7b98a440c7050 --- /dev/null +++ b/include/onnxruntime/ep/adapter/data_transfer_manager.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "core/common/status.h" +#include "core/common/common.h" +#include "core/framework/data_transfer.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// An adapter class partially implementing the facade of `onnxruntime::DataTransferManager`. +/// +struct DataTransferManager { + explicit DataTransferManager(std::unique_ptr impl) : impl_{std::move(impl)} {} + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const { + if (src.Shape().Size() != dst.Shape().Size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, + FAIL, + "Tensor size mismatch: source tensor size is ", + src.Shape().Size(), + ", destination tensor size is ", + dst.Shape().Size()); + } + + if (impl_->CanCopy(src.Location().device, dst.Location().device)) { + return impl_->CopyTensor(src, dst); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, + FAIL, + "There's no data transfer registered for copying tensors from ", + src.Location().device.ToString(), + " to ", + dst.Location().device.ToString()); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager); + std::unique_ptr impl_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/ep.h b/include/onnxruntime/ep/adapter/ep.h new file mode 100644 index 0000000000000..02a6c2f07b0c3 --- /dev/null +++ b/include/onnxruntime/ep/adapter/ep.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "data_transfer_manager.h" + +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// Wrapper around IExecutionProvider to expose via OrtEp. +/// +class Ep : public OrtEp { + protected: + explicit Ep(IExecutionProvider* impl, AllocatorPtr temp_space_cpu_allocator, AllocatorPtr temp_space_allocator) + : OrtEp{}, + impl_(impl), + data_transfer_manager_{impl->GetDataTransfer()}, + profiler_{impl->GetProfiler()}, + temp_space_cpu_allocator_{temp_space_cpu_allocator}, + temp_space_allocator_{temp_space_allocator} { + } + + public: + inline IExecutionProvider* EpImpl() const noexcept { + return impl_.get(); + } + inline const DataTransferManager& GetDataTransferManager() const noexcept { + return data_transfer_manager_; + } + [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { + *output = temp_space_cpu_allocator_; + return Status::OK(); + } + [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { + *output = temp_space_allocator_; + return Status::OK(); + } + + private: + std::unique_ptr impl_; + DataTransferManager data_transfer_manager_; + std::unique_ptr profiler_; + AllocatorPtr temp_space_cpu_allocator_; + AllocatorPtr temp_space_allocator_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/kernel_def.h b/include/onnxruntime/ep/adapter/kernel_def.h new file mode 100644 index 0000000000000..b3d3c83dd0e90 --- /dev/null +++ b/include/onnxruntime/ep/adapter/kernel_def.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// An adapter class partially implementing the facade of `onnxruntime::KernelDef`. +/// +class KernelDef { + public: + explicit KernelDef(const OrtKernelInfo* kernel_info) : kernel_info_{kernel_info} {} + + const std::string OpName() const { + return kernel_info_.GetNodeName(); + } + + const std::string Domain() const { + return kernel_info_.GetOperatorDomain(); + } + + private: + const Ort::ConstKernelInfo kernel_info_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/kernel_def_builder.h b/include/onnxruntime/ep/adapter/kernel_def_builder.h new file mode 100644 index 0000000000000..664c88919cb8a --- /dev/null +++ b/include/onnxruntime/ep/adapter/kernel_def_builder.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +#include "core/framework/data_types.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// Gets an OrtMLDataType for a tensor type. Throws on error. +/// +/// +/// +inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) { + const OrtEpApi& ep_api = Ort::GetEpApi(); + const OrtDataType* result = nullptr; + + Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &result)); + return result; +} + +inline const OrtDataType* MLDataTypeToOrtDataType(MLDataType ml_type) { + auto tensor_type = ml_type->AsTensorType(); + EP_ENFORCE(tensor_type != nullptr, "EP Kernel registration only supports tensor types."); + auto elem_type = tensor_type->GetElementType(); + auto primitive_type = static_cast(elem_type); + auto onnx_type = static_cast(primitive_type->GetDataType()); + return GetTensorType(onnx_type); +} + +/// +/// An adapter class partially implementing the facade of `onnxruntime::KernelDefBuilder`. +/// +struct KernelDefBuilder { + static std::unique_ptr Create() { return std::make_unique(); } + + explicit KernelDefBuilder() {} + + KernelDefBuilder& SetName(const char* op_name) { + builder_.SetOperatorType(op_name); + return *this; + } + + KernelDefBuilder& SetDomain(const char* domain) { + builder_.SetDomain(domain); + return *this; + } + + KernelDefBuilder& SinceVersion(int since_version) { + return SinceVersion(since_version, INT_MAX); + } + + KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) { + builder_.SetSinceVersion(since_version_start, since_version_end); + return *this; + } + + KernelDefBuilder& Provider(const char* provider_type) { + builder_.SetExecutionProvider(provider_type); + return *this; + } + + KernelDefBuilder& TypeConstraint(const char* arg_name, std::vector types) { + std::vector ort_types; + ort_types.reserve(types.size()); + for (const auto& type : types) { + ort_types.push_back(MLDataTypeToOrtDataType(type)); + } + builder_.AddTypeConstraint(arg_name, ort_types); + return *this; + } + + KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type) { + builder_.AddTypeConstraint(arg_name, MLDataTypeToOrtDataType(type)); + return *this; + } + + KernelDefBuilder& MayInplace(const std::vector>& inplaces) { + for (const auto& pair : inplaces) { + builder_.AddInputOutputMutableAlias(pair.first, pair.second); + } + return *this; + } + KernelDefBuilder& MayInplace(int input_index, int output_index) { + builder_.AddInputOutputMutableAlias(input_index, output_index); + return *this; + } + + KernelDefBuilder& Alias(const std::vector>& aliases) { + for (const auto& pair : aliases) { + builder_.AddInputOutputAlias(pair.first, pair.second); + } + return *this; + } + KernelDefBuilder& Alias(int input_index, int output_index) { + builder_.AddInputOutputAlias(input_index, output_index); + return *this; + } + + KernelDefBuilder& InputMemoryType(OrtMemType type, int input_index) { + builder_.SetInputMemType(input_index, type); + return *this; + } + + KernelDefBuilder& InputMemoryType(OrtMemType type, const std::vector& input_indexes) { + for (int input_index : input_indexes) { + builder_.SetInputMemType(input_index, type); + } + return *this; + } + + KernelDefBuilder& OutputMemoryType(OrtMemType type, int output_index) { + builder_.SetOutputMemType(output_index, type); + return *this; + } + + KernelDefBuilder& OutputMemoryType(OrtMemType type, const std::vector& output_indexes) { + for (int output_index : output_indexes) { + builder_.SetOutputMemType(output_index, type); + } + return *this; + } + + KernelDefBuilder& ExecQueueId(int queue_id) { return *this; } + + Ort::KernelDef Build() { return builder_.Build(); } + + private: + Ort::KernelDefBuilder builder_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/kernel_registry.h b/include/onnxruntime/ep/adapter/kernel_registry.h new file mode 100644 index 0000000000000..8b146dffab309 --- /dev/null +++ b/include/onnxruntime/ep/adapter/kernel_registry.h @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +#include "kernel_def_builder.h" +#include "op_kernel_info.h" +#include "op_kernel.h" + +#include "core/graph/basic_types.h" +#include "core/framework/error_code_helper.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct FuncManager {}; +using KernelCreatePtrFn = std::add_pointer& out)>::type; + +/// +/// An adapter class partially implementing the facade of `onnxruntime::KernelCreateInfo`. +/// +struct KernelCreateInfo { + Ort::KernelDef kernel_def; + KernelCreatePtrFn kernel_create_func; + Status status; + + KernelCreateInfo(Ort::KernelDef definition, + KernelCreatePtrFn create_func) + : kernel_def(std::move(definition)), + kernel_create_func(create_func) { + assert(kernel_def != nullptr); + } + + KernelCreateInfo(KernelCreateInfo&& other) noexcept + : kernel_def(std::move(other.kernel_def)), + kernel_create_func(std::move(other.kernel_create_func)) {} + + KernelCreateInfo() = default; +}; + +/// +/// An adapter class partially implementing the facade of `onnxruntime::KernelRegistry`. +/// +struct KernelRegistry { + KernelRegistry() = default; + + static OrtStatus* CreateKernel(void* kernel_create_func_state, const OrtKernelInfo* info, OrtKernelImpl** out) { + FuncManager func_mgr; // not used + std::unique_ptr kernel; + KernelCreatePtrFn create_func = reinterpret_cast(kernel_create_func_state); + Status status = create_func(func_mgr, OpKernelInfo(info), kernel); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + *out = nullptr; + + // Try to create a control flow kernel implementation if applicable. + // For kernel based plugin EPs, the implementation should create the control flow kernel directly using one of the + // following APIs: + // - `OrtEpApi::CreateIfKernel` + // - `OrtEpApi::CreateLoopKernel` + // - `OrtEpApi::CreateScanKernel` + // + // If the kernel being created is one of the control flow kernels, `CreateControlFlowKernelImpl` should be overriden + // to write the value of `out` to the created `OrtKernelImpl`, and the returned status should be OK. + status = kernel->CreateControlFlowKernelImpl(info, out); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + if (*out == nullptr) { + // If the kernel is not a control flow kernel, create a regular kernel implementation. + *out = new KernelImpl(std::move(kernel)); + } + return nullptr; + } + + Status Register(KernelCreateInfo&& create_info) { + registry_.AddKernel(create_info.kernel_def, + KernelRegistry::CreateKernel, + reinterpret_cast(create_info.kernel_create_func)); + return Status::OK(); + } + + // Implicit conversion to OrtKernelRegistry* for compatibility with C API + operator OrtKernelRegistry*() const noexcept { + return registry_.operator OrtKernelRegistry*(); + } + + // Release ownership of the underlying OrtKernelRegistry* + OrtKernelRegistry* release() { + return registry_.release(); + } + + private: + Ort::KernelRegistry registry_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/logging.h b/include/onnxruntime/ep/adapter/logging.h new file mode 100644 index 0000000000000..b93c06bb3f12e --- /dev/null +++ b/include/onnxruntime/ep/adapter/logging.h @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "core/common/logging/logging.h" +#include "core/common/path_string.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct Logger { + Logger(const OrtLogger* logger) : logger_(logger) {} + + bool OutputIsEnabled(logging::Severity severity, logging::DataType /* data_type */) const noexcept { + return ((OrtLoggingLevel)severity >= logger_.GetLoggingSeverityLevel()); + } + + void Log(logging::Severity severity, + const char* file_path, + int line_number, + const char* func_name, + const char* message) const noexcept { + auto path_string = onnxruntime::ToPathString(file_path); + logger_.LogMessage((OrtLoggingLevel)severity, + path_string.c_str(), + line_number, + func_name, + message); + } + + static const Logger& DefaultLogger() { return *instance_; } + static void CreateDefaultLogger(const OrtLogger* logger) { + instance_ = new Logger(logger); + } + static void DestroyDefaultLogger() { + delete instance_; + instance_ = nullptr; + } + + private: + Ort::Logger logger_; + inline static Logger* instance_ = nullptr; +}; + +namespace detail { +struct LoggerCapture { + LoggerCapture(const Logger& logger, + logging::Severity severity, + const char* category, + logging::DataType dataType, + const CodeLocation& location) : logger_{logger}, + severity_{severity}, + category_{category}, + data_type_{dataType}, + location_{location} {} + + ~LoggerCapture() { + logger_.Log(severity_, location_.file_and_path.c_str(), location_.line_num, + location_.function.c_str(), stream_.str().c_str()); + } + + std::ostream& Stream() noexcept { + return stream_; + } + + const Logger& logger_; + logging::Severity severity_; + const char* category_; + logging::DataType data_type_; + const CodeLocation& location_; + std::ostringstream stream_; +}; + +// Helper functions to dispatch to the correct Capture type based on logger type +inline ::onnxruntime::logging::Capture CreateMessageCapture( + const ::onnxruntime::logging::Logger& logger, + ::onnxruntime::logging::Severity severity, + const char* category, + ::onnxruntime::logging::DataType datatype, + const CodeLocation& location) { + return ::onnxruntime::logging::Capture(logger, severity, category, datatype, location); +} + +inline detail::LoggerCapture CreateMessageCapture( + const Logger& logger, + ::onnxruntime::logging::Severity severity, + const char* category, + ::onnxruntime::logging::DataType datatype, + const CodeLocation& location) { + return detail::LoggerCapture(logger, severity, category, datatype, location); +} + +} // namespace detail +} // namespace adapter +} // namespace ep +} // namespace onnxruntime + +// Undefine and redefine LOGS_DEFAULT +#undef LOGS_DEFAULT_CATEGORY +#define LOGS_DEFAULT_CATEGORY(severity, category) \ + LOGS_CATEGORY(::onnxruntime::ep::adapter::Logger::DefaultLogger(), severity, category) + +#undef CREATE_MESSAGE +#define CREATE_MESSAGE(logger, severity, category, datatype) \ + ::onnxruntime::ep::adapter::detail::CreateMessageCapture(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE) diff --git a/include/onnxruntime/ep/adapter/node.h b/include/onnxruntime/ep/adapter/node.h new file mode 100644 index 0000000000000..b46cc1ebe64d4 --- /dev/null +++ b/include/onnxruntime/ep/adapter/node.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// An adapter class partially implementing the facade of `onnxruntime::Node`. +/// +struct Node { + explicit Node(const OrtKernelInfo* kernel_info) : kernel_info_{kernel_info} {} + /** Gets the Node's name. */ + const std::string Name() const noexcept { + return kernel_info_.GetNodeName(); + } + + /** Gets the Node's operator type. */ + const std::string OpType() const noexcept { + return kernel_info_.GetOperatorType(); + } + + /** Gets the since version of the operator. */ + int SinceVersion() const noexcept { + return kernel_info_.GetOperatorSinceVersion(); + } + + private: + const Ort::ConstKernelInfo kernel_info_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h new file mode 100644 index 0000000000000..be5d4501e182d --- /dev/null +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include +#include + +#include "core/framework/allocator.h" +#include "core/framework/tensor.h" + +#include "node.h" +#include "op_kernel_info.h" +#include "tensor_helper.h" + +namespace onnxruntime { +struct PrePackedWeights; +struct TensorShape; +} // namespace onnxruntime + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct OpKernelContext; + +/// +/// An adapter class partially implementing the facade of `onnxruntime::OpKernel`. +/// +struct OpKernel { + explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_{info} {} + virtual ~OpKernel() {} + + Node Node() const { + return op_kernel_info_.node(); + } + const OpKernelInfo& Info() const { + return op_kernel_info_; + } + + virtual Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + return Status::OK(); + } + + virtual Status Compute(OpKernelContext* p_op_kernel_context) const = 0; + virtual Status PrePack(const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + return Status::OK(); + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel); + OpKernelInfo op_kernel_info_; +}; + +/// +/// An adapter class partially implementing the facade of `onnxruntime::OpKernelContext`. +/// +struct OpKernelContext { + explicit OpKernelContext(OrtKernelContext* context, const OpKernel& op_kernel) : context_{context}, + op_kernel_{op_kernel}, + constant_input_tensors_{op_kernel.Info().GetConstantInputTensors()} { + input_tensors_.resize(context_.GetInputCount()); + output_tensors_.resize(context_.GetOutputCount()); + } + + template >> + const T* Input(int index) const { + if (index < 0 || static_cast(index) >= input_tensors_.size()) { + return nullptr; + } + if (input_tensors_[index].DataRaw() != nullptr) { + return &input_tensors_[index]; + } + + if (index < constant_input_tensors_.size() && constant_input_tensors_[index].DataRaw() != nullptr) { + return &constant_input_tensors_[index]; + } + + auto input = context_.GetInput(index); + if (input == nullptr || !input.IsTensor()) { + return nullptr; + } + + input_tensors_[index] = CreateTensorFromApiValue(input); + return &input_tensors_[index]; + } + Tensor* Output(int index, const TensorShape& shape) { + if (index < 0 || static_cast(index) >= output_tensors_.size()) { + return nullptr; + } + if (output_tensors_[index].DataRaw() != nullptr) { + return &output_tensors_[index]; + } + + auto output = context_.GetOutput(index, shape.GetDims().data(), shape.GetDims().size()); + if (output == nullptr) { + return nullptr; + } + + output_tensors_[index] = CreateTensorFromApiValue(output); + return &output_tensors_[index]; + } + Tensor* Output(int index, const std::vector& shape) { + return Output(index, TensorShape{shape}); + } + Tensor* Output(int index, const std::initializer_list& shape) { + return Output(index, TensorShape{shape}); + } + [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { + return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceCPUAllocator(output); + } + [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { + return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceAllocator(output); + } + int InputCount() const { + return static_cast(input_tensors_.size()); + } + int OutputCount() const { + return static_cast(output_tensors_.size()); + } + bool GetUseDeterministicCompute() const { + // TODO(fs-eire): Implement GetUseDeterministicCompute(). + return false; + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernelContext); + Ort::KernelContext context_; + const OpKernel& op_kernel_; + const std::vector& constant_input_tensors_; + mutable std::vector input_tensors_; + std::vector output_tensors_; +}; + +/// +/// A bridge class between `onnxruntime::ep::adapter::OpKernel` and `onnxruntime::OrtKernelImpl`. +/// +struct KernelImpl : OrtKernelImpl { + explicit KernelImpl(std::unique_ptr impl) + : OrtKernelImpl{}, impl_(std::move(impl)) { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; + PrePackWeight = PrePackWeightImpl; + } + + private: + static OrtStatus* ORT_API_CALL ComputeImpl(_In_ OrtKernelImpl* this_ptr, + _In_ OrtKernelContext* context) noexcept { + const auto* kernel_impl = static_cast(this_ptr)->impl_.get(); + OpKernelContext ctx{context, *kernel_impl}; + Status status; + ORT_TRY { + status = kernel_impl->Compute(&ctx); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()); + }); + } + if (status.IsOK()) { + return nullptr; + } else { + return Ort::Status{status.ErrorMessage().c_str(), static_cast(status.Code())}.release(); + } + } + + static void ORT_API_CALL ReleaseImpl(_In_ OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static OrtStatus* ORT_API_CALL PrePackWeightImpl(_In_ OrtKernelImpl* this_ptr, + _In_ const OrtValue* weight, + int input_index, + _In_ OrtAllocator* /* allocator */, + _In_opt_ OrtSharedPrePackedWeightCache* /* prepacked_weight_cache */, + _Out_ bool* is_packed) noexcept { + auto* kernel_impl = static_cast(this_ptr)->impl_.get(); + const auto tensor = CreateTensorFromApiValue(weight); + Status status; + ORT_TRY { + status = kernel_impl->PrePack(tensor, input_index, AllocatorPtr{}, *is_packed, nullptr); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()); + }); + } + if (!status.IsOK()) { + return Ort::Status{status.ErrorMessage().c_str(), static_cast(status.Code())}.release(); + } + return nullptr; + } + + ~KernelImpl() = default; + + private: + std::unique_ptr impl_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/op_kernel_info.h b/include/onnxruntime/ep/adapter/op_kernel_info.h new file mode 100644 index 0000000000000..6d5eb66c26153 --- /dev/null +++ b/include/onnxruntime/ep/adapter/op_kernel_info.h @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +#include "core/common/status.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" + +#include "node.h" +#include "kernel_def.h" +#include "tensor_helper.h" + +namespace onnxruntime { +struct DataTransferManager; +struct IExecutionProvider; +} // namespace onnxruntime + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// An adapter class partially implementing the facade of `onnxruntime::OpKernelInfo`. +/// +struct OpKernelInfo { + // + // A helper struct to cache kernel info data + // + // Because `KernelCreatePtrFn` is defined to use `const OpKernelInfo&` as parameter type of the kernel creation function, `OpKernelInfo` has to be copyable. + // This means we cannot store cached data like `constant_input_tensors_` in `OpKernelInfo` directly to avoid ownership issues. + // + // As a workaround, we define this struct `KernelInfoCache` here to represent the cached data. We use a shared pointer to `KernelInfoCache` in `OpKernelInfo` + // to manage the lifetime of the cached data. + struct KernelInfoCache { + explicit KernelInfoCache(const OrtKernelInfo* kernel_info) : kernel_info_(kernel_info) { + Ort::ConstKernelInfo info{kernel_info}; + const int input_count = info.GetInputCount(); + constant_input_tensors.resize(input_count); + for (int i = 0; i < input_count; ++i) { + int is_constant = 0; + Ort::ConstValue const_input = info.GetTensorConstantInput(i, &is_constant); + if (is_constant && const_input != nullptr && const_input.IsTensor()) { + constant_input_tensors[i] = CreateTensorFromApiValue(const_input); + } + } + } + const OrtKernelInfo* kernel_info_; + std::vector constant_input_tensors; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(KernelInfoCache); + }; + + explicit OpKernelInfo(const OrtKernelInfo* info) : info_(info), cache_{std::make_shared(info)} { + } + + const DataTransferManager& GetDataTransferManager() const noexcept { + return (static_cast(info_.GetEp()))->GetDataTransferManager(); + } + Node node() const noexcept { + return Node{cache_->kernel_info_}; + } + const IExecutionProvider* GetExecutionProvider() const noexcept { + return (static_cast(info_.GetEp()))->EpImpl(); + } + + KernelDef GetKernelDef() const noexcept { + return KernelDef{cache_->kernel_info_}; + } + + const Ort::ConstKernelInfo GetKernelInfo() const noexcept { + return info_; + } + + int GetInputCount() const noexcept { + return info_.GetInputCount(); + } + + const std::vector& GetConstantInputTensors() const noexcept { + return cache_->constant_input_tensors; + } + + bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const { + if (input_index < 0 || static_cast(input_index) >= cache_->constant_input_tensors.size()) { + return false; + } + const Tensor& tensor = cache_->constant_input_tensors[input_index]; + if (tensor.DataRaw() != nullptr) { + *constant_input_value = &tensor; + return true; + } + return false; + } + + template + [[nodiscard]] T GetAttrOrDefault(const std::string& name, const T& default_value) const { + T tmp; + return GetAttr(name, &tmp).IsOK() ? tmp : default_value; + } + template + void GetAttrOrDefault(const std::string& name, T* value, const T& default_value) const { + if (!GetAttr(name, value).IsOK()) + *value = default_value; + } + template + [[nodiscard]] T GetAttr(const std::string& name) const { + T value; + ORT_THROW_IF_ERROR(GetAttr(name, &value)); + return value; + } + template + Status GetAttr(const std::string& name, T* value) const { + try { + *value = info_.GetAttribute(name.c_str()); + return Status::OK(); + } catch (const Ort::Exception& ex) { + return Status(onnxruntime::common::ONNXRUNTIME, ex.GetOrtErrorCode(), ex.what()); + } + } + template + Status GetAttrs(const std::string& name, std::vector& values) const { + try { + values = info_.GetAttributes(name.c_str()); + return Status::OK(); + } catch (const Ort::Exception& ex) { + return Status(onnxruntime::common::ONNXRUNTIME, ex.GetOrtErrorCode(), ex.what()); + } + } + + Status GetAttrs(const std::string& name, TensorShapeVector& out) const { + std::vector shape; + Status status = GetAttrs(name, shape); + if (status.IsOK()) { + out.reserve(shape.size()); + out.assign(shape.begin(), shape.end()); + } + return status; + } + + template + [[nodiscard]] std::vector GetAttrsOrDefault(const std::string& name, + const std::vector& default_value = {}) const { + std::vector tmp; + return GetAttrs(name, tmp).IsOK() ? tmp : default_value; + } + [[nodiscard]] TensorShapeVector GetAttrsOrDefault(const std::string& name, + const TensorShapeVector& default_value = {}) const { + TensorShapeVector tmp; + return GetAttrs(name, tmp).IsOK() ? tmp : default_value; + } + + private: + const Ort::ConstKernelInfo info_; + std::shared_ptr cache_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/tensor_helper.h b/include/onnxruntime/ep/adapter/tensor_helper.h new file mode 100644 index 0000000000000..0b85a224fe430 --- /dev/null +++ b/include/onnxruntime/ep/adapter/tensor_helper.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include +#include + +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// Create an unowned onnxruntime::Tensor from a tensor OrtValue from C API. +/// +inline onnxruntime::Tensor CreateTensorFromApiValue(const OrtValue* ort_value) { + Ort::ConstValue value{ort_value}; + EP_ENFORCE(value.IsTensor(), "Only tensor OrtValue is supported."); + + ONNXTensorElementDataType element_type; + Ort::Value::Shape shape{}; + value.GetTensorElementTypeAndShapeDataReference(element_type, shape); + + auto memory_info = value.GetTensorMemoryInfo(); + MLDataType data_type = DataTypeImpl::TensorTypeFromONNXEnum(element_type)->GetElementType(); + + OrtMemoryInfo tensor_memory_info{memory_info.GetAllocatorName(), + memory_info.GetAllocatorType(), + OrtDevice{ + static_cast(memory_info.GetDeviceType()), + static_cast(memory_info.GetMemoryType()), + static_cast(memory_info.GetVendorId()), + static_cast(memory_info.GetDeviceId()), + + }, + memory_info.GetMemoryType()}; + + return Tensor(data_type, + TensorShape{shape.shape, shape.shape_len}, + const_cast(value.GetTensorRawData()), + tensor_memory_info); +} + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/api.h b/include/onnxruntime/ep/api.h new file mode 100644 index 0000000000000..ff0723489077e --- /dev/null +++ b/include/onnxruntime/ep/api.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#pragma push_macro("ORT_API_MANUAL_INIT") +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#pragma pop_macro("ORT_API_MANUAL_INIT") + +namespace onnxruntime { +namespace ep { + +struct ApiPtrs { + const OrtApi& ort; + const OrtEpApi& ep; + const OrtModelEditorApi& model_editor; +}; + +namespace detail { +inline std::optional g_api_ptrs; +} + +/// +/// Get the global instance of ApiPtrs. +/// +inline const ApiPtrs& Api() { + return *detail::g_api_ptrs; +} + +/// +/// Initialize the EP API pointers and global OrtEnv if not already done. +/// +inline void ApiInit(const OrtApiBase* ort_api_base) { + // Manual init for the C++ API + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + Ort::InitApi(ort_api); + + // Initialize the global API instance + if (!detail::g_api_ptrs) { + detail::g_api_ptrs.emplace(*ort_api, *ep_api, *model_editor_api); + } +} + +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/common.h b/include/onnxruntime/ep/common.h new file mode 100644 index 0000000000000..12118c938820c --- /dev/null +++ b/include/onnxruntime/ep/common.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ + } \ + } while (0) + +// see ORT_ENFORCE for implementations that also capture a stack trace and work in builds with exceptions disabled +// NOTE: In this simplistic implementation you must provide an argument, even it if's an empty string +#define EP_ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + std::ostringstream oss; \ + oss << "EP_ENFORCE failed: " << #condition << " "; \ + oss << __VA_ARGS__; \ + throw std::runtime_error(oss.str()); \ + } \ + } while (false) + +// Ignores an OrtStatus* while taking ownership of it so that it does not get leaked. +#define IGNORE_ORTSTATUS(status_expr) \ + do { \ + OrtStatus* _status = (status_expr); \ + Ort::Status _ignored{_status}; \ + } while (false) diff --git a/include/onnxruntime/ep/get_capability_utils.h b/include/onnxruntime/ep/get_capability_utils.h new file mode 100644 index 0000000000000..2f6b9dfbe1d5b --- /dev/null +++ b/include/onnxruntime/ep/get_capability_utils.h @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace ep { + +using NodeId = size_t; +constexpr int64_t kSmallInitializerThreshold = 100; + +constexpr inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) { + return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput; +} + +// Get all output nodes that consume an output from the given node. +static OrtStatus* GetOutputNodes(gsl::span node_outputs, std::vector& result) { + try { + std::vector output_nodes; + output_nodes.reserve(node_outputs.size()); // May have more + + // Gather the OrtNode consumers of every output. + for (Ort::ConstValueInfo output : node_outputs) { + if (output == nullptr) continue; // Skip missing optional output + + auto consumers_info = output.GetConsumers(); + for (const auto& consumer : consumers_info) { + output_nodes.push_back(consumer.node); + } + } + + result = std::move(output_nodes); + return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } catch (...) { + Ort::Status status("Unknown exception", ORT_EP_FAIL); + return status.release(); + } +} + +// Returns nodes that should be assigned to CPU EP instead of this example EP to avoid costly I/O copies. +// Based on GetCpuPreferredNodes from onnxruntime/core/framework/fallback_cpu_capability.cc +OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo& graph_support_info, + const OrtLogger& logger, gsl::span tentative_nodes, + /*out*/ std::unordered_set& cpu_preferred_nodes) { + try { + const OrtApi& ort_api = Ort::GetApi(); + const OrtEpApi& ep_api = Ort::GetEpApi(); + Ort::ConstGraph graph{&ort_graph}; + std::vector ordered_nodes = graph.GetNodes(); + + if (ordered_nodes.empty()) { + return nullptr; + } + + std::unordered_map node_id_to_node; + std::unordered_map node_id_to_order_map; + for (size_t i = 0, num_nodes = ordered_nodes.size(); i < num_nodes; i++) { + NodeId node_id = ordered_nodes[i].GetId(); + node_id_to_node[node_id] = ordered_nodes[i]; + node_id_to_order_map[node_id] = i; + } + + // If return false, n1 will be output first; If return true, n2 will be output first + auto greater_order_comp = [&](const NodeId node_id1, const NodeId node_id2) { + return node_id_to_order_map[node_id1] > node_id_to_order_map[node_id2]; + }; + std::priority_queue, decltype(greater_order_comp)> candidates(greater_order_comp); + std::unordered_set cpu_output_args; + + std::unordered_set provider_nodes; + provider_nodes.reserve(tentative_nodes.size()); + + std::unordered_map node_to_kernel; + node_to_kernel.reserve(tentative_nodes.size()); + + for (const OrtNode* ort_node : tentative_nodes) { + Ort::ConstNode node(ort_node); + NodeId node_id = node.GetId(); + + provider_nodes.insert(node_id); + + // Expect at least one registry has a target provider's kernel for this node. + const OrtKernelDef* ort_kernel_def = nullptr; + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_LookUpKernel(&graph_support_info, node, &ort_kernel_def)); + RETURN_IF(ort_kernel_def == nullptr, ort_api, "Must have a registered kernel definition on the target EP"); + + Ort::ConstKernelDef kernel_def(ort_kernel_def); + node_to_kernel.insert({node_id, kernel_def}); + + // Find all the direct consumers of CPU tensors. + std::vector outputs = node.GetOutputs(); + for (size_t out_index = 0; out_index < outputs.size(); out_index++) { + Ort::ConstValueInfo output = outputs[out_index]; + if (output == nullptr) continue; // Skip missing optional output + + bool is_output_on_cpu = MemTypeOnCpuExplicitly(kernel_def.GetOutputMemType(out_index)); + if (is_output_on_cpu) { + cpu_output_args.insert(output); + + auto consumer_infos = output.GetConsumers(); + for (const auto& consumer_info : consumer_infos) { + candidates.push(consumer_info.node.GetId()); + ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_INFO, "Candidate for fallback CPU execution: %s\n", + consumer_info.node.GetName().c_str()); + } + } + } + } + + std::unordered_set visited; + visited.reserve(candidates.size()); + + std::unordered_set cpu_nodes; + cpu_nodes.reserve(candidates.size()); + + // The algo below is trying to identity a subgraph that only depends on cpu tensors. + // Usually it is a subgraph that doing shape calculation based on a GPU tensor, then reshape it back. + // The detail: + // for each candidate, if one of its input is a cpu tensor and the Non-CPU kernel doesn't mark it as cpu input, + // force the node to CPU to avoid memory cpu and add its output to the small cpu tensors. + while (!candidates.empty()) { + NodeId cur = candidates.top(); + candidates.pop(); + + auto p = visited.insert(cur); + if (!p.second) { + continue; + } + + auto node_iter = node_id_to_node.find(cur); + RETURN_IF(node_iter == node_id_to_node.end(), ort_api, "Unable to get OrtNode for a given node ID"); + Ort::ConstNode node = node_iter->second; + + if (provider_nodes.find(cur) == provider_nodes.end()) { + // Nodes not in provider_nodes are either have EP assigned or no kernel found on target EP. + // we assume these nodes will fallback to CPU, so add all direct consumers of all outputs to candidates. + std::string ep_name = node.GetEpName(); + if (ep_name.empty() || ep_name == "CPUExecutionProvider") { + std::vector outputs = node.GetOutputs(); + + for (Ort::ConstValueInfo output : outputs) { + if (output == nullptr) continue; // Skip missing optional output + cpu_output_args.insert(output); + } + + std::vector output_nodes; + RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes)); + + for (Ort::ConstNode downstream_node : output_nodes) { + candidates.push(downstream_node.GetId()); + } + } + continue; + } + + std::vector inputs = node.GetInputs(); + bool place_in_cpu = true; + + for (size_t i = 0; i < inputs.size(); i++) { + Ort::ConstValueInfo input = inputs[i]; + if (input == nullptr) continue; // Skip missing optional input + + // skip placing on CPU if the data types is float16 or bfloat16 or + // float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz or float4e2m1 + Ort::ConstTypeInfo type_info = input.TypeInfo(); + auto type_shape_info = type_info.GetTensorTypeAndShapeInfo(); + auto elem_type = type_shape_info.GetElementType(); + if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1) { + place_in_cpu = false; + break; + } + + bool is_small_initializer = input.IsConstantInitializer() && + type_shape_info.GetElementCount() <= kSmallInitializerThreshold; + + // Allow placing on CPU if it's a small initializer or graph input + if (is_small_initializer || input.IsRequiredGraphInput() || input.IsOptionalGraphInput()) { + continue; + } + + // the input is not a CPU tensor + if (cpu_output_args.find(input) == cpu_output_args.end()) { + place_in_cpu = false; + break; + } + + // input is a CPU tensor, but it's intended to be consumed as CPU input by the target EP + bool is_input_on_cpu = MemTypeOnCpuExplicitly(node_to_kernel[cur].GetInputMemType(i)); + if (is_input_on_cpu) { + place_in_cpu = false; + break; + } + } + + if (place_in_cpu) { + cpu_nodes.insert(node); + ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_WARNING, + "EP optimization: Force fallback to CPU execution for node %s because the CPU execution path " + "is deemed faster than overhead involved with execution on other EPs capable of executing " + "this node.\n", + node.GetName().c_str()); + + std::vector outputs = node.GetOutputs(); + for (Ort::ConstValueInfo output : outputs) { + if (output == nullptr) continue; // Skip missing optional output + cpu_output_args.insert(output); + } + + std::vector output_nodes; + RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes)); + + for (Ort::ConstNode downstream_node : output_nodes) { + candidates.push(downstream_node.GetId()); + } + } + } + + cpu_preferred_nodes = std::move(cpu_nodes); + + return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } catch (...) { + Ort::Status status("Unknown exception", ORT_EP_FAIL); + return status.release(); + } +} + +} // namespace ep +} // namespace onnxruntime