diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9c6551ad5e792..751e17702574e 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -2104,11 +2104,12 @@ if (onnxruntime_BUILD_SHARED_LIB AND "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h" - "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h" - "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h" diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 5f391432ce503..8ec94c67cc0a4 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -4,6 +4,7 @@ #pragma once #include "boost/mp11.hpp" +#include // It is safe to include the below header even if SHARED_PROVIDER macro is enabled // as it doesn't include any pb headers. @@ -26,7 +27,6 @@ #include "core/graph/constants.h" #include "core/graph/graph_viewer.h" #include "core/graph/onnx_protobuf.h" -#include namespace onnxruntime { class OpKernelContext; } @@ -105,6 +105,7 @@ class OpKernel { return Status::OK(); } + // Note: New implementations should override OpKernel::UseSharedPrePackedBuffers_V2 instead. // Override this function to use provided pre-packed weight. // Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, // int input_idx, @@ -130,6 +131,27 @@ class OpKernel { return Status::OK(); } + /// + /// Version 2 of OpKernel::UseSharedPrePackedBuffers() that additionally accepts the buffer sizes as a parameter. + /// The default implementation of this function just calls directly to OpKernel::UseSharedPrePackedBuffers() + /// to avoid the need to update all existing kernel-based provider-bridge EPs. + /// + /// TODO: Consolidate UseSharedPrePackedBuffers and UseSharedPrePackedBuffers_V2 into a single function, + /// which will require updating kernel-based provider-bridge EPs (cpu, cuda, webgpu). + /// + /// + /// + /// + /// + /// + /// + virtual Status UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, + /*out*/ bool& used_shared_buffers) { + return UseSharedPrePackedBuffers(prepacked_buffers, input_idx, used_shared_buffers); + } + const OrtDevice GetDevice(OrtMemType mem_type) const; const OpKernelInfo& Info() const { return *op_kernel_info_; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index bc75aabc7e229..d72d08e5bb249 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3379,5 +3379,34 @@ struct KernelRegistry : detail::Base { Status AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func, void* kernel_create_func_state); }; + +namespace detail { +template +struct SharedPrePackedWeightCacheImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + //< Wraps SharedPrePackedWeightCache_StoreWeightData + Status StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes, size_t num_buffers); +}; +} // namespace detail + +/** \brief Convenience C++ wrapper class around a ::OrtSharedPrePackedWeightCache instance owned by ORT. + * + * An `OrtSharedPrePackedWeightCache*` instance is passed as an argument to OrtKernelImpl::PrePackWeight. + * Example use: + * OrtStatus* MyKernel::PrePackWeightImpl(OrtKernelImpl*, ..., OrtSharedPrePackedWeightCache* c_cache, ...) { + * ... + * if (c_cache != nullptr) { + * Ort::UnownedSharedPrePackedWeightCache cpp_cache(c_cache); + * Ort::Status status = cpp_cache.StoreWeightData(...); + * } + * ... + * } + * + * \remarks OrtSharedPrePackedWeightCache is always unowned, but mutable, for EpApi users. + */ +using UnownedSharedPrePackedWeightCache = + detail::SharedPrePackedWeightCacheImpl>; } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index aff1061a67fea..1267cdb405f3a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -3713,4 +3713,13 @@ inline Status KernelRegistry::AddKernel(const OrtKernelDef* kernel_def, OrtKerne void* kernel_create_func_state) { return Status{GetEpApi().KernelRegistry_AddKernel(p_, kernel_def, kernel_create_func, kernel_create_func_state)}; } + +namespace detail { +template +inline Status SharedPrePackedWeightCacheImpl::StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes, + size_t num_buffers) { + return Status{GetEpApi().SharedPrePackedWeightCache_StoreWeightData(this->p_, buffer_data_ptrs, buffer_sizes, + num_buffers)}; +} +} // namespace detail } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 6fa5c8dea04e6..c67e73d1cd4a0 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -29,6 +29,7 @@ ORT_RUNTIME_CLASS(KernelRegistry); ORT_RUNTIME_CLASS(KernelDefBuilder); ORT_RUNTIME_CLASS(KernelDef); ORT_RUNTIME_CLASS(DataType); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType +ORT_RUNTIME_CLASS(SharedPrePackedWeightCache); /** \brief Struct that an EP implements for IDataTransfer to copy between devices it uses and CPU. * @@ -308,6 +309,101 @@ struct OrtKernelImpl { * \since Version 1.24. */ ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr); + + /** \brief Optional function to pre-pack a constant tensor (i.e., a weight) to the kernel's preferred data layout. + * + * For example, a Conv kernel can define this function to pack input W to the channel-last data layout + * before inference. + * + * Pre-packing can operate in three different modes: no pre-packing mode, sharing mode, and non-sharing mode. + * 1) No pre-packing mode: The kernel can forgo any weight pre-packing for the given `input_index` by setting + * `is_packed` to false and returning a successful OrtStatus. In this mode, the kernel's + * OrtKernelImpl::SetSharedPrePackedWeight() function is not called for that specific + * `input_index`. + * 2) Sharing mode: Sharing is allowed if the `prepacked_weight_cache` argument is not NULL and the EP stores + * weight data in CPU-accessible memory. In this case, the kernel can optionally choose + * to share the packed weight with other kernels that use the same weight + * (compared by content hash). To do so, the kernel must allocate the packed weight with the + * provided `allocator`, then it stores the packed weight data into `prepacked_weight_cache` + * via SharedPrePackedWeightCache_StoreWeightData(), sets `is_packed` to true, and returns a + * successful OrtStatus. ORT will subsequently call OrtKernelImpl::SetSharedPrePackedWeight() + * to provide this kernel with the actual shared weight data, whose memory location could + * differ (i.e., if shared data was allocated by a previously processed kernel). + * 3) Non-sharing mode: In non-sharing mode, the `prepacked_weight_cache` argument is ignored. In this mode, + * the implementation allocates the packed data with the provided `allocator`, sets + * `is_packed` to true, and returns a successful OrtStatus. The kernel is ultimately + * responsible for releasing the packed data for the weight with `allocator`. + * ORT may release the original (unpacked) weight, which must not be accessed in + * OrtKernelImpl::Compute(). Note that in this mode, the kernel's + * OrtKernelImpl::SetSharedPrePackedWeight() function is not called by ORT for that specific + * `input_index`. + * + * \note This function is based on the internal OpKernel::PrePack() virtual function used within ORT. + * + * \param[in] this_ptr The OrtKernelImpl instance. + * \param[in] tensor The OrtValue instance representing the constant tensor (weight). Do not cache in the kernel. + * \param[in] input_index The input index of the tensor in this kernel. + * \param[in] allocator Allocator for allocating the pre-packed data. Its use is required in sharing mode and + * recommended, but not required, in the non-sharing mode. This will be an allocator set by + * the application for the session/environment (e.g., via CreateAndRegisterAllocator[V2] + * or RegisterAllocator), or an allocator on the OrtEpDevice (read-only or default) otherwise. + * The allocator remains valid throughout the lifetime of the OrtKernelImpl instance. + * \param[in] prepacked_weights_cache May be NULL. If not NULL, the kernel may choose to share a packed weight by + * first storing it in the OrtSharedPrePackedWeightCache instance and then + * receiving the actual shared weight data in the call to + * OrtKernelImpl::SetSharedPrePackedWeight(). See the above description for + * "sharing mode". + * \param[out] is_packed Output parameter that the implementation sets to true if the kernel packed the tensor data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. If not implemented (set to NULL), ORT assumes the kernel + * does not pre-pack weight data (i.e., `is_packed` defaults to false). + * + * \since Version 1.24. + */ + ORT_API2_STATUS(PrePackWeight, _In_ OrtKernelImpl* this_ptr, _In_ const OrtValue* tensor, + _In_ int input_index, _Inout_ OrtAllocator* allocator, + _In_opt_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, _Out_ bool* is_packed); + + /** \brief Optional function that receives data for a shared pre-packed weight from ORT. + * + * ORT calls this function after calling OrtKernelImpl::PrePackWeight for a specific `input_index` if: + * - OrtKernelImpl::PrePackWeight set the output parameter `is_packed` to true. + * - OrtKernelImpl::PrePackWeight stored weight data to share into the provided OrtSharedPrePackedWeightCache + * parameter (`prepacked_weight_cache`) via the API SharedPrePackedWeightCache_StoreWeightData. + * + * Refer to the description of the "sharing-mode" in the documentation for OrtKernelImpl::PrePackWeight(). + * + * \note ORT will not call this function for an `input_index` that a previous call to + * OrtKernelImpl::PrePackWeight() did not elect to pre-pack and share. + * + * \note This function is based on the internal OpKernel::UseSharedPrePackedBuffers() virtual function used + * within ORT. + * + * \param[in] this_ptr The OrtKernelImpl instance. + * \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a + * single shared weight. The buffers are provided in the same order and with the same + * contents (in a potentially different memory location) as the buffers + * passed into SharedPrePackedWeightCache_StoreWeightData() within the + * OrtKernelImpl::PrePackWeight() call for the same `input_index`. + * \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`. + * \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight. + * Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays. + * \param[in] input_index The input index of the tensor in this kernel. This index identifies the identity of + * the weight. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is generally optional. It is only required if OrtKernelImpl::PrePack() + * elects to share pre-packed weights. + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SetSharedPrePackedWeight, _In_ OrtKernelImpl* this_ptr, + _In_reads_(num_buffers) const void* const* buffer_data_ptrs, + _In_reads_(num_buffers) const size_t* buffer_data_sizes, + _In_ size_t num_buffers, _In_ int input_index); }; /** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel. @@ -846,6 +942,35 @@ struct OrtEpApi { */ ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def); + + /** \brief Sets one or more data buffers that collectively hold the pre-packed data for a single shared weight. + * + * \note Used within the implementation of OrtKernelImpl::PrePackWeight() when the kernel wants to share pre-packed + * weight data with other kernels. The buffer data MUST be allocated with the OrtAllocator provided to + * OrtKernelImpl::PrePack. + * + * \note Ownership of weight data transfers to the OrtSharedPrePackedWeightCache instance on success. + * If this function returns an error status, the caller retains ownership of the weight data. + * + * \note Subsequent calls with the same OrtSharedPrePackedWeightCache instance release and replace the old data. + * + * \param[in] this_ptr The OrtKernelImpl instance. + * \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a + * single shared weight. Note that sometimes a single weight may have multiple pre-packed + * buffers and it is up to the kernel implementation to determine how to split the data + * into multiple buffers (if desired). + * \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`. + * \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight. + * Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SharedPrePackedWeightCache_StoreWeightData, + _In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, + _In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes, + _In_ size_t num_buffers); }; /** diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template index 1b74862515c69..6a66d2eb402e5 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template @@ -84,6 +84,21 @@ $MAIN { #if n_bits == 4 var sum = output_element_t(0); var a_offset = idx * (8 / component_a) * component_b; +#if component_b == 1 + let b_value_lower = vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(zero); + let b_value_upper = vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(zero); + let b0 = vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b; + let b1 = vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b; +#if component_a == 1 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1); +#elif component_a == 2 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b0) + + dot(vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1); +#elif component_a == 4 + sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1); +#endif +#else for (var i = 0u; i < component_b; i++) { let b_value_lower = vec4(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4(zero); let b_value_upper = vec4(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); @@ -102,25 +117,63 @@ $MAIN { a_offset += 2; #endif } +#endif #elif n_bits == 8 var sum = output_element_t(0); var a_offset = idx * (4 / component_a) * component_b; +#if component_b == 1 + let b_value_unpacked = (vec4(unpack4xU8(b_value)) - vec4(zero)) * scale_b; +#if component_a == 1 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked); +#elif component_a == 2 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked); +#elif component_a == 4 + sum += dot(tile_A[a_offset], b_value_unpacked); +#endif +#else for (var i = 0u; i < component_b; i++) { - let b_value = (vec4(unpack4xU8(b_value[i])) - vec4(zero)) * scale_b; + let b_value_unpacked = (vec4(unpack4xU8(b_value[i])) - vec4(zero)) * scale_b; #if component_a == 1 - sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value); + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked); a_offset += 4; #elif component_a == 2 - sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b_value); + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked); a_offset += 2; #elif component_a == 4 - sum += dot(tile_A[a_offset], b_value); + sum += dot(tile_A[a_offset], b_value_unpacked); a_offset += 1; #endif } +#endif #elif n_bits == 2 var sum = output_element_t(0); var a_offset = idx * (16 / component_a) * component_b; +#if component_b == 1 + let b_data_0 = vec4(unpack4xU8(b_value & 0x03030303u)) - vec4(zero); + let b_data_1 = vec4(unpack4xU8((b_value >> 2) & 0x03030303u)) - vec4(zero); + let b_data_2 = vec4(unpack4xU8((b_value >> 4) & 0x03030303u)) - vec4(zero); + let b_data_3 = vec4(unpack4xU8((b_value >> 6) & 0x03030303u)) - vec4(zero); + + let b0 = vec4(b_data_0[0], b_data_1[0], b_data_2[0], b_data_3[0]) * scale_b; + let b1 = vec4(b_data_0[1], b_data_1[1], b_data_2[1], b_data_3[1]) * scale_b; + let b2 = vec4(b_data_0[2], b_data_1[2], b_data_2[2], b_data_3[2]) * scale_b; + let b3 = vec4(b_data_0[3], b_data_1[3], b_data_2[3], b_data_3[3]) * scale_b; + +#if component_a == 1 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1) + + dot(vec4(tile_A[a_offset + 8], tile_A[a_offset + 9], tile_A[a_offset + 10], tile_A[a_offset + 11]), b2) + + dot(vec4(tile_A[a_offset + 12], tile_A[a_offset + 13], tile_A[a_offset + 14], tile_A[a_offset + 15]), b3); +#elif component_a == 2 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b0) + + dot(vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5]), b2) + + dot(vec4(tile_A[a_offset + 6], tile_A[a_offset + 7]), b3); +#elif component_a == 4 + sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1) + + dot(tile_A[a_offset + 2], b2) + dot(tile_A[a_offset + 3], b3); +#endif +#else for (var i = 0u; i < component_b; i++) { let b_data_0 = vec4(unpack4xU8(b_value[i] & 0x03030303u)) - vec4(zero); let b_data_1 = vec4(unpack4xU8((b_value[i] >> 2) & 0x03030303u)) - vec4(zero); @@ -150,6 +203,7 @@ $MAIN { a_offset += 4; #endif } +#endif #endif inter_results[local_row_offset + idy][idx] += sum; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index a14e219d9c039..be51e19023037 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -421,16 +421,23 @@ void SessionState::CleanInitializedTensorsFromGraph() { static Status KernelUseSharedPrePackedBuffers(OpKernel& kernel, int input_idx, const PrePackedWeights& prepacked_weights, const std::string& node_name) { + const size_t num_buffers = prepacked_weights.buffers_.size(); + assert(prepacked_weights.buffer_sizes_.size() == num_buffers); + std::vector shared_prepacked_buffers; - shared_prepacked_buffers.reserve(4); // Unlikely to see more than 4 prepacked buffers per initializer + std::vector shared_prepacked_buffer_sizes; + shared_prepacked_buffers.reserve(num_buffers); + shared_prepacked_buffer_sizes.reserve(num_buffers); - for (const auto& prepacked_buffer : prepacked_weights.buffers_) { + for (size_t i = 0; i < num_buffers; i++) { // BufferDeleter is nullptr because the kernel should not delete the shared buffer - it can only use it - shared_prepacked_buffers.emplace_back(prepacked_buffer.get(), BufferDeleter(nullptr)); + shared_prepacked_buffers.emplace_back(prepacked_weights.buffers_[i].get(), BufferDeleter(nullptr)); + shared_prepacked_buffer_sizes.push_back(prepacked_weights.buffer_sizes_[i]); } bool used_shared_buffers = false; - ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers(shared_prepacked_buffers, input_idx, used_shared_buffers)); + ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers_V2(shared_prepacked_buffers, shared_prepacked_buffer_sizes, + input_idx, used_shared_buffers)); // BUG CHECK: Ensure that the kernel used the provided shared buffers // Mostly a debug check to ensure that the kernel has an overridden implementation of the diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.cc b/onnxruntime/core/optimizer/relu_clip_fusion.cc index 494c646778d10..6f150cd29e90f 100644 --- a/onnxruntime/core/optimizer/relu_clip_fusion.cc +++ b/onnxruntime/core/optimizer/relu_clip_fusion.cc @@ -57,6 +57,12 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff data_type = initializer->data_type(); // construct an initializer to gracefully handle typed or raw data in the TensorProto Initializer i(graph, *initializer, graph.ModelPath()); + + // Empty tensor is invalid for 'min' input - skip optimization to avoid null pointer dereference + if (i.size() == 0) { + return Status::OK(); + } + switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: if (*i.data() < 0.f) { diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index 5c30631882fe2..b90ce34be1db4 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -20,22 +20,6 @@ const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(cons return context.ep_.BufferManager(); } -Status ComputeContextBase::CreateUnmappedGPUTensor(AllocatorPtr alloc, MLDataType data_type, const TensorShape& shape, std::unique_ptr& tensor) const { - ORT_RETURN_IF_NOT(alloc != nullptr, "Allocator must not be null when creating GPU tensor."); - - tensor = std::make_unique(data_type, shape, alloc); - ORT_RETURN_IF_NOT(tensor != nullptr, "Failed to allocate GPU tensor."); - - void* data = tensor->MutableDataRaw(); - ORT_RETURN_IF_NOT(data != nullptr, "Failed to get GPU tensor buffer."); - - auto buffer = reinterpret_cast(data); - if (wgpuBufferGetMapState(buffer) != WGPUBufferMapState_Unmapped) { - wgpuBufferUnmap(buffer); - } - return Status::OK(); -} - ComputeContext::ComputeContext(WebGpuContext& webgpu_context, const WebGpuExecutionProvider& ep, const OpKernel& op_kernel, diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 1a61bd5b32cd9..5b694a7a2e3f1 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -56,9 +56,6 @@ class ComputeContextBase { return op_kernel_.Node().Name(); } - Status CreateUnmappedGPUTensor(AllocatorPtr alloc, MLDataType data_type, const TensorShape& shape, - std::unique_ptr& tensor) const; - // // Get the operator type. // diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 48342d2b84fec..b435986f8cc7a 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -354,11 +354,9 @@ Status Conv::PrePackInternal(ComputeContextBase& con } TensorShape transposed_kernel_shape(transposed_kernel_shape_vector); - ORT_ENFORCE(alloc != nullptr, "Allocator must be provided for WebGPU pre-pack."); - - // Create the transposed kernel tensor using the WebGPU allocator. - // Both input tensor and output tensor are GPU tensors, ready for GPU operations. - ORT_RETURN_IF_ERROR(context.CreateUnmappedGPUTensor(alloc, tensor.DataType(), transposed_kernel_shape, transposed_kernel_)); + // Create the transposed kernel tensor using the prepack allocator. + // This allocator creates GPU buffers without mapping, suitable for GPU-based operations. + transposed_kernel_ = std::make_unique(tensor.DataType(), transposed_kernel_shape, alloc); // Perform GPU-based transpose directly from the input GPU tensor ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, tensor, *transposed_kernel_)); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 457867061d6a7..2f50fd8051b9c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -37,13 +37,13 @@ namespace onnxruntime { namespace webgpu { -void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) { - std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() { +void WebGpuContext::Initialize(const WebGpuContextConfig& config) { + std::call_once(init_flag_, [this, &config]() { if (device_ == nullptr) { // Create wgpu::Adapter wgpu::RequestAdapterOptions req_adapter_options = {}; - req_adapter_options.backendType = static_cast(backend_type); - req_adapter_options.powerPreference = static_cast(power_preference_); + req_adapter_options.backendType = static_cast(config.backend_type); + req_adapter_options.powerPreference = static_cast(config.power_preference); #if !defined(__wasm__) auto enabled_adapter_toggles = GetEnabledAdapterToggles(); @@ -134,9 +134,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // create buffer manager buffer_mgr_ = BufferManagerFactory::Create(*this, - buffer_cache_config.storage.mode, - buffer_cache_config.uniform.mode, - buffer_cache_config.query_resolve.mode); + config.buffer_cache_config.storage.mode, + config.buffer_cache_config.uniform.mode, + config.buffer_cache_config.query_resolve.mode); // create initializer buffer manager. cache is always disabled for initializer buffer manager initializer_buffer_mgr_ = BufferManagerFactory::Create(*this, @@ -161,15 +161,6 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi } else { query_type_ = TimestampQueryType::None; } - if (enable_pix_capture) { -#if defined(ENABLE_PIX_FOR_WEBGPU_EP) - // set pix frame generator - pix_frame_generator_ = std::make_unique(instance_, - Device()); -#else - ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)"); -#endif // ENABLE_PIX_FOR_WEBGPU_EP - } }); } @@ -757,14 +748,6 @@ void WebGpuContext::Flush(const webgpu::BufferManager& buffer_mgr) { num_pending_dispatches_ = 0; } -void WebGpuContext::OnRunEnd() { -#if defined(ENABLE_PIX_FOR_WEBGPU_EP) - if (pix_frame_generator_) { - pix_frame_generator_->GeneratePIXFrame(); - } -#endif // ENABLE_PIX_FOR_WEBGPU_EP -} - void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder, const std::vector& bind_buffers, const std::vector& bind_buffers_segments, @@ -979,8 +962,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co device, config.validation_mode, config.preserve_device, - config.max_storage_buffer_binding_size, - config.power_preference)); + config.max_storage_buffer_binding_size)); it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first; } else if (context_id != 0) { ORT_ENFORCE(it->second.context->instance_.Get() == instance && @@ -988,6 +970,10 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co "WebGPU EP context ID ", context_id, " is already created with different WebGPU instance or device."); } it->second.ref_count++; + + // perform initialization + it->second.context->Initialize(config); + return *it->second.context; } @@ -1017,6 +1003,11 @@ void WebGpuContextFactory::Cleanup() { default_instance_ = nullptr; } +WebGpuContext& WebGpuContextFactory::DefaultContext() { + WebGpuContextConfig config{}; + return WebGpuContextFactory::CreateContext(config); +} + void CleanupWebGpuContexts() { WebGpuContextFactory::Cleanup(); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 84dfb47ef4687..8cc513680142d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -34,26 +34,50 @@ struct CapturedCommandInfo { WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch }; -struct WebGpuContextConfig { - int context_id; - WGPUInstance instance; - WGPUDevice device; - const void* dawn_proc_table; - ValidationMode validation_mode; - bool preserve_device; - uint64_t max_storage_buffer_binding_size; - int power_preference; -}; - struct WebGpuBufferCacheConfig { struct ConfigEntry { BufferCacheMode mode; - std::string config_string; + std::string config_string; // preserved for customized configuration, eg. bucket sizes + }; + ConfigEntry storage{BufferCacheMode::Bucket, {}}; + ConfigEntry uniform{BufferCacheMode::Simple, {}}; + ConfigEntry query_resolve{BufferCacheMode::Disabled, {}}; + ConfigEntry default_entry{BufferCacheMode::Disabled, {}}; +}; + +/// +/// Represents the configuration options for creating a WebGpuContext. +/// +struct WebGpuContextConfig { + int context_id{0}; + WGPUInstance instance{nullptr}; + WGPUDevice device{nullptr}; + const void* dawn_proc_table{nullptr}; + ValidationMode validation_mode{ +#ifndef NDEBUG + webgpu::ValidationMode::Full // for debug build, enable full validation by default +#else + webgpu::ValidationMode::Basic // for release build, enable basic validation by default +#endif // !NDEBUG + }; + bool preserve_device{false}; + uint64_t max_storage_buffer_binding_size{0}; + WebGpuBufferCacheConfig buffer_cache_config{}; + int power_preference{static_cast(WGPUPowerPreference_HighPerformance)}; + int backend_type{ +#ifdef _WIN32 + // Setup Windows default backend type based on the build configuration +#if defined(DAWN_ENABLE_D3D12) + static_cast(WGPUBackendType_D3D12) +#elif defined(DAWN_ENABLE_VULKAN) + static_cast(WGPUBackendType_Vulkan) +#else + 0 +#endif +#else + 0 +#endif }; - ConfigEntry storage; - ConfigEntry uniform; - ConfigEntry query_resolve; - ConfigEntry default_entry; }; class WebGpuContextFactory { @@ -63,13 +87,28 @@ class WebGpuContextFactory { int ref_count; }; + /// + /// Create a new WebGPU context for the specified context ID if not present, or return the existing one. (ref-count based) + /// static WebGpuContext& CreateContext(const WebGpuContextConfig& config); + + /// + /// Get the WebGPU context for the specified context ID. Throw if not present. + /// static WebGpuContext& GetContext(int context_id); + /// + /// Release the WebGPU context. (ref-count based) + /// static void ReleaseContext(int context_id); static void Cleanup(); + /// + /// Return the default context. Create if not present. + /// + static WebGpuContext& DefaultContext(); + private: WebGpuContextFactory() {} @@ -82,8 +121,6 @@ class WebGpuContextFactory { // Class WebGpuContext includes all necessary resources for the context. class WebGpuContext final { public: - void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture); - Status Wait(wgpu::Future f); const wgpu::Device& Device() const { return device_; } @@ -177,7 +214,13 @@ class WebGpuContext final { Status PopErrorScope(); Status Run(ComputeContextBase& context, const ProgramBase& program); - void OnRunEnd(); + +#if defined(ENABLE_PIX_FOR_WEBGPU_EP) + std::unique_ptr CreatePIXFrameGenerator() { + return std::make_unique(instance_, + Device()); + } +#endif // ENABLE_PIX_FOR_WEBGPU_EP private: enum class TimestampQueryType { @@ -190,20 +233,20 @@ class WebGpuContext final { WGPUDevice device, webgpu::ValidationMode validation_mode, bool preserve_device, - uint64_t max_storage_buffer_binding_size, - int power_preference = static_cast(wgpu::PowerPreference::HighPerformance)) + uint64_t max_storage_buffer_binding_size) : instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None}, preserve_device_{preserve_device}, - max_storage_buffer_binding_size_{max_storage_buffer_binding_size}, - power_preference_{power_preference} { + max_storage_buffer_binding_size_{max_storage_buffer_binding_size} { ORT_ENFORCE(max_storage_buffer_binding_size_ == 0 || max_storage_buffer_binding_size_ >= 134217728, "max_storage_buffer_binding_size must be 0 or at least 128MB"); } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); + void Initialize(const WebGpuContextConfig& config); + void LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder, const std::vector& bind_buffers, const std::vector& bind_buffers_segments, @@ -292,15 +335,10 @@ class WebGpuContext final { bool is_profiling_ = false; bool preserve_device_; uint64_t max_storage_buffer_binding_size_; - int power_preference_; GraphCaptureState graph_capture_state_{GraphCaptureState::Default}; // External vector to store captured commands, owned by EP std::vector* external_captured_commands_ = nullptr; - -#if defined(ENABLE_PIX_FOR_WEBGPU_EP) - std::unique_ptr pix_frame_generator_ = nullptr; -#endif // ENABLE_PIX_FOR_WEBGPU_EP }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6b764d51bcf75..15263a87a17b6 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -799,7 +799,8 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, context_{context}, preferred_data_layout_{config.data_layout}, force_cpu_node_names_{std::move(config.force_cpu_node_names)}, - enable_graph_capture_{config.enable_graph_capture} { + enable_graph_capture_{config.enable_graph_capture}, + prepack_allocator_{std::make_shared(context_.InitializerBufferManager(), false)} { // If graph capture is enabled, create a dedicated buffer manager for graph mode if (enable_graph_capture_) { // Create buffer manager for graph capture mode with appropriate cache modes @@ -809,6 +810,15 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, webgpu::BufferCacheMode::GraphSimple, webgpu::BufferCacheMode::Disabled); } + + if (config.enable_pix_capture) { +#if defined(ENABLE_PIX_FOR_WEBGPU_EP) + // set pix frame generator + pix_frame_generator_ = context_.CreatePIXFrameGenerator(); +#else + ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)"); +#endif // ENABLE_PIX_FOR_WEBGPU_EP + } } std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { @@ -1007,7 +1017,11 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti context_.CollectProfilingData(profiler_->Events()); } - context_.OnRunEnd(); +#if defined(ENABLE_PIX_FOR_WEBGPU_EP) + if (pix_frame_generator_) { + pix_frame_generator_->GeneratePIXFrame(); + } +#endif // ENABLE_PIX_FOR_WEBGPU_EP if (context_.ValidationMode() >= ValidationMode::Basic) { return context_.PopErrorScope(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index a9282a028c803..bf0963f67cf1e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -10,6 +10,10 @@ #include "core/providers/providers.h" #include "core/providers/webgpu/buffer_manager.h" +#if defined(ENABLE_PIX_FOR_WEBGPU_EP) +#include "core/providers/webgpu/webgpu_pix_frame_generator.h" +#endif // ENABLE_PIX_FOR_WEBGPU_EP + struct pthreadpool; namespace onnxruntime { namespace webgpu { @@ -27,18 +31,10 @@ struct CapturedCommandInfo; } // namespace webgpu struct WebGpuExecutionProviderConfig { - WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture, bool enable_pix_capture) - : data_layout{data_layout}, - enable_graph_capture{enable_graph_capture}, - enable_pix_capture{enable_pix_capture} {} - WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default; - WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default; - ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig); - - DataLayout data_layout; - bool enable_graph_capture; - bool enable_pix_capture; - std::vector force_cpu_node_names; + DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default + bool enable_graph_capture{false}; // graph capture feature is disabled by default + bool enable_pix_capture{false}; // PIX capture is disabled by default + std::vector force_cpu_node_names{}; }; class WebGpuExecutionProvider : public IExecutionProvider { @@ -84,6 +80,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; webgpu::BufferManager& BufferManager() const; + AllocatorPtr PrepackAllocator() const { return prepack_allocator_; } private: bool IsGraphCaptureAllowed() const; @@ -100,11 +97,18 @@ class WebGpuExecutionProvider : public IExecutionProvider { const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. int m_current_graph_annotation_id = 0; +#if defined(ENABLE_PIX_FOR_WEBGPU_EP) + std::unique_ptr pix_frame_generator_ = nullptr; +#endif // ENABLE_PIX_FOR_WEBGPU_EP + // Buffer manager specifically for graph capture mode std::unique_ptr graph_buffer_mgr_ = nullptr; // Store captured commands directly in the EP instead of in WebGpuContext std::vector captured_commands_; + + // Allocator for prepacked weights (uses buffers without mapping) + AllocatorPtr prepack_allocator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc index ea38e9415e1fe..8303d2ff4293f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc @@ -34,7 +34,7 @@ Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const { return s; } -Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, +Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { ComputeContextBase context{webgpu_context_, ep_, *this}; @@ -45,8 +45,9 @@ Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr a // Currently, ORT does not allow using prepacked weights in non-CPU EPs. // So we do not pass prepacked_weights to PrePackInternal. // Kernel implementation that supports prepacking should manage its own storage. + // Use the EP's prepack allocator which creates unmapped GPU buffers. - Status s = PrePackInternal(context, tensor, input_idx, alloc, is_packed); + Status s = PrePackInternal(context, tensor, input_idx, ep_.PrepackAllocator(), is_packed); if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 2c57991c6ee35..854b77ba4876b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -44,7 +44,7 @@ class WebGpuKernel : public OpKernel { // @param context The WebGPU compute context base providing access to the execution environment. // @param tensor The constant tensor to potentially pre-process. // @param input_idx The index of this input in the kernel's input list. - // @param alloc The allocator to use for any new tensor allocations. + // @param alloc The allocator to use for any new tensor allocations (prepack allocator). // @param is_packed Output parameter. Set to true if the tensor was pre-packed/processed, // false otherwise. The default implementation sets this to false. // diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index fdd7caa1706f5..cd791e31dcc2f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -14,51 +14,13 @@ #include "core/providers/webgpu/webgpu_provider_options.h" #include "core/providers/webgpu/data_transfer.h" + +using namespace onnxruntime::webgpu; using namespace onnxruntime::webgpu::options; namespace onnxruntime { -// Helper struct that holds configuration parameters for creating a WebGPU context with default settings. -// This is used during lazy initialization of the data transfer to create a context if one doesn't exist. -struct WebGpuContextParams { - webgpu::WebGpuContextConfig context_config; // WebGPU context configuration - webgpu::WebGpuBufferCacheConfig buffer_cache_config; // Buffer cache settings - int backend_type; // Dawn backend type (D3D12, Vulkan, etc.) - bool enable_pix_capture; // Enable PIX GPU capture for debugging -}; - -static WebGpuContextParams GetDefaultWebGpuContextParams() { - WebGpuContextParams params; - params.context_config.context_id = 0; - params.context_config.instance = nullptr; - params.context_config.device = nullptr; - params.context_config.dawn_proc_table = nullptr; - params.context_config.validation_mode = webgpu::ValidationMode::Disabled; - params.context_config.preserve_device = false; - params.context_config.max_storage_buffer_binding_size = 0; - params.context_config.power_preference = static_cast(WGPUPowerPreference_HighPerformance); - - params.buffer_cache_config.storage.mode = webgpu::BufferCacheMode::Bucket; - params.buffer_cache_config.uniform.mode = webgpu::BufferCacheMode::Simple; - params.buffer_cache_config.query_resolve.mode = webgpu::BufferCacheMode::Disabled; - params.buffer_cache_config.default_entry.mode = webgpu::BufferCacheMode::Disabled; - -#ifdef _WIN32 -#if defined(DAWN_ENABLE_D3D12) - params.backend_type = static_cast(WGPUBackendType_D3D12); -#elif defined(DAWN_ENABLE_VULKAN) - params.backend_type = static_cast(WGPUBackendType_Vulkan); -#else - params.backend_type = static_cast(WGPUBackendType_D3D12); -#endif -#else - params.backend_type = 0; -#endif - params.enable_pix_capture = false; - return params; -} - struct WebGpuProviderFactory : IExecutionProviderFactory { - WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config) + WebGpuProviderFactory(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config) : context_id_{context_id}, context_{context}, config_{std::move(webgpu_ep_config)} { } @@ -68,25 +30,17 @@ struct WebGpuProviderFactory : IExecutionProviderFactory { private: int context_id_; - webgpu::WebGpuContext& context_; + WebGpuContext& context_; WebGpuExecutionProviderConfig config_; }; -std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { - // - // STEP.1 - prepare WebGpuExecutionProviderConfig - // - WebGpuExecutionProviderConfig webgpu_ep_config{ - // preferred layout is NHWC by default - DataLayout::NHWC, - // graph capture feature is disabled by default - false, - // enable pix capture feature is diabled by default - false, - }; +namespace { + +WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options) { + WebGpuExecutionProviderConfig webgpu_ep_config{}; - std::string preferred_layout_str; - if (config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { + if (std::string preferred_layout_str; + config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { if (preferred_layout_str == kPreferredLayout_NHWC) { webgpu_ep_config.data_layout = DataLayout::NHWC; } else if (preferred_layout_str == kPreferredLayout_NCHW) { @@ -95,11 +49,9 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( ORT_THROW("Invalid preferred layout: ", preferred_layout_str); } } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_config.data_layout) << " (parsed from \"" - << preferred_layout_str << "\")"; - std::string enable_graph_capture_str; - if (config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { + if (std::string enable_graph_capture_str; + config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { if (enable_graph_capture_str == kEnableGraphCapture_ON) { webgpu_ep_config.enable_graph_capture = true; } else if (enable_graph_capture_str == kEnableGraphCapture_OFF) { @@ -108,13 +60,13 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( ORT_THROW("Invalid enable graph capture: ", enable_graph_capture_str); } } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_config.enable_graph_capture; // parse force CPU node names // The force CPU node names are separated by EOL (\n or \r\n) in the config entry. // each line is a node name that will be forced to run on CPU. - std::string force_cpu_node_names_str; - if (config_options.TryGetConfigEntry(kForceCpuNodeNames, force_cpu_node_names_str)) { + + if (std::string force_cpu_node_names_str; + config_options.TryGetConfigEntry(kForceCpuNodeNames, force_cpu_node_names_str)) { // split the string by EOL (\n or \r\n) std::istringstream ss(force_cpu_node_names_str); std::string line; @@ -127,209 +79,181 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( webgpu_ep_config.force_cpu_node_names.push_back(line); } } + + // enable pix capture + if (std::string enable_pix_capture_str; + config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) { + if (enable_pix_capture_str == kEnablePIXCapture_ON) { + webgpu_ep_config.enable_pix_capture = true; + } else if (enable_pix_capture_str == kEnablePIXCapture_OFF) { + webgpu_ep_config.enable_pix_capture = false; + } else { + ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str); + } + } + + LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_config.data_layout); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_config.enable_graph_capture; LOGS_DEFAULT(VERBOSE) << "WebGPU EP force CPU node count: " << webgpu_ep_config.force_cpu_node_names.size(); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << webgpu_ep_config.enable_pix_capture; + + return webgpu_ep_config; +} + +WebGpuContextConfig ParseWebGpuContextConfig(const ConfigOptions& config_options) { + WebGpuContextConfig config{}; - // - // STEP.2 - prepare WebGpuContextConfig - // - int context_id = 0; - std::string context_id_str; - if (config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { + if (std::string context_id_str; + config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { ORT_ENFORCE(std::errc{} == - std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), context_id).ec); + std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), config.context_id).ec); } - size_t webgpu_instance = 0; - std::string webgpu_instance_str; - if (config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { + if (std::string webgpu_instance_str; + config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { static_assert(sizeof(WGPUInstance) == sizeof(size_t), "WGPUInstance size mismatch"); + size_t webgpu_instance = 0; ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec); + config.instance = reinterpret_cast(webgpu_instance); } - size_t webgpu_device = 0; - std::string webgpu_device_str; - if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { + if (std::string webgpu_device_str; + config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { static_assert(sizeof(WGPUDevice) == sizeof(size_t), "WGPUDevice size mismatch"); + size_t webgpu_device = 0; ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_device_str.data(), webgpu_device_str.data() + webgpu_device_str.size(), webgpu_device).ec); + config.device = reinterpret_cast(webgpu_device); } - size_t dawn_proc_table = 0; - std::string dawn_proc_table_str; - if (config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { + if (std::string dawn_proc_table_str; + config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { + size_t dawn_proc_table = 0; ORT_ENFORCE(std::errc{} == std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); + config.dawn_proc_table = reinterpret_cast(dawn_proc_table); } - webgpu::ValidationMode validation_mode = -#ifndef NDEBUG - webgpu::ValidationMode::Full // for debug build, enable full validation by default -#else - webgpu::ValidationMode::Basic // for release build, enable basic validation by default -#endif // !NDEBUG - ; - std::string validation_mode_str; - if (config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { + if (std::string validation_mode_str; + config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { if (validation_mode_str == kValidationMode_Disabled) { - validation_mode = webgpu::ValidationMode::Disabled; + config.validation_mode = ValidationMode::Disabled; } else if (validation_mode_str == kValidationMode_wgpuOnly) { - validation_mode = webgpu::ValidationMode::WGPUOnly; + config.validation_mode = ValidationMode::WGPUOnly; } else if (validation_mode_str == kValidationMode_basic) { - validation_mode = webgpu::ValidationMode::Basic; + config.validation_mode = ValidationMode::Basic; } else if (validation_mode_str == kValidationMode_full) { - validation_mode = webgpu::ValidationMode::Full; + config.validation_mode = ValidationMode::Full; } else { ORT_THROW("Invalid validation mode: ", validation_mode_str); } } - std::string preserve_device_str; - bool preserve_device = false; - if (config_options.TryGetConfigEntry(kPreserveDevice, preserve_device_str)) { + if (std::string preserve_device_str; + config_options.TryGetConfigEntry(kPreserveDevice, preserve_device_str)) { if (preserve_device_str == kPreserveDevice_ON) { - preserve_device = true; + config.preserve_device = true; } else if (preserve_device_str == kPreserveDevice_OFF) { - preserve_device = false; + config.preserve_device = false; } else { ORT_THROW("Invalid preserve device: ", preserve_device_str); } } - uint64_t max_storage_buffer_binding_size = 0; std::string max_storage_buffer_binding_size_str; if (config_options.TryGetConfigEntry(kMaxStorageBufferBindingSize, max_storage_buffer_binding_size_str)) { ORT_ENFORCE( std::errc{} == std::from_chars( max_storage_buffer_binding_size_str.data(), max_storage_buffer_binding_size_str.data() + max_storage_buffer_binding_size_str.size(), - max_storage_buffer_binding_size) + config.max_storage_buffer_binding_size) .ec, "Invalid maxStorageBufferBindingSize value: ", max_storage_buffer_binding_size_str); } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP max storage buffer binding size: " << max_storage_buffer_binding_size; - // power preference - int power_preference = static_cast(WGPUPowerPreference_HighPerformance); // default - std::string power_preference_str; - if (config_options.TryGetConfigEntry(kPowerPreference, power_preference_str)) { - if (power_preference_str == kPowerPreference_HighPerformance) { - power_preference = static_cast(WGPUPowerPreference_HighPerformance); - } else if (power_preference_str == kPowerPreference_LowPower) { - power_preference = static_cast(WGPUPowerPreference_LowPower); - } else { - ORT_THROW("Invalid power preference: ", power_preference_str); - } - } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP power preference: " << power_preference; - - webgpu::WebGpuContextConfig context_config{ - context_id, - reinterpret_cast(webgpu_instance), - reinterpret_cast(webgpu_device), - reinterpret_cast(dawn_proc_table), - validation_mode, - preserve_device, - max_storage_buffer_binding_size, - power_preference, - }; - - LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << context_id; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUInstance: " << webgpu_instance; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUDevice: " << webgpu_device; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << dawn_proc_table; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << validation_mode; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP PreserveDevice: " << preserve_device; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP PowerPreference: " << power_preference; - - // - // STEP.3 - prepare parameters for WebGPU context initialization. - // - - int backend_type = 0; -#ifdef _WIN32 - // Setup Windows default backend type based on the build configuration -#if defined(DAWN_ENABLE_D3D12) - backend_type = static_cast(WGPUBackendType_D3D12); -#elif defined(DAWN_ENABLE_VULKAN) - backend_type = static_cast(WGPUBackendType_Vulkan); -#endif -#endif - - std::string backend_type_str; - if (config_options.TryGetConfigEntry(kDawnBackendType, backend_type_str)) { - if (backend_type_str == kDawnBackendType_D3D12) { - backend_type = static_cast(WGPUBackendType_D3D12); - } else if (backend_type_str == kDawnBackendType_Vulkan) { - backend_type = static_cast(WGPUBackendType_Vulkan); - } else { - ORT_THROW("Invalid Dawn backend type: ", backend_type_str); - } - } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP Dawn backend type: " << backend_type; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << config.context_id; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUInstance: " << reinterpret_cast(config.instance); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUDevice: " << reinterpret_cast(config.device); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << reinterpret_cast(config.dawn_proc_table); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << config.validation_mode; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP PreserveDevice: " << config.preserve_device; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP max storage buffer binding size: " << config.max_storage_buffer_binding_size; // buffer cache modes auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, - webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { + BufferCacheMode& value) -> void { std::string buffer_cache_mode_str; if (config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { - return webgpu::BufferCacheMode::Disabled; + value = BufferCacheMode::Disabled; } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { - return webgpu::BufferCacheMode::LazyRelease; + value = BufferCacheMode::LazyRelease; } else if (buffer_cache_mode_str == kBufferCacheMode_Simple) { - return webgpu::BufferCacheMode::Simple; + value = BufferCacheMode::Simple; } else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) { - return webgpu::BufferCacheMode::Bucket; + value = BufferCacheMode::Bucket; } else { - ORT_THROW("Invalid buffer cache mode: ", config_entry_str); + ORT_THROW("Invalid buffer cache mode: ", buffer_cache_mode_str); } - } else { - return default_value; } }; - webgpu::WebGpuBufferCacheConfig buffer_cache_config; - - buffer_cache_config.storage.mode = parse_buffer_cache_mode(kStorageBufferCacheMode, - webgpu::BufferCacheMode::Bucket); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << buffer_cache_config.storage.mode; - - buffer_cache_config.uniform.mode = parse_buffer_cache_mode(kUniformBufferCacheMode, - webgpu::BufferCacheMode::Simple); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << buffer_cache_config.uniform.mode; + WebGpuBufferCacheConfig& buffer_cache_config = config.buffer_cache_config; + parse_buffer_cache_mode(kStorageBufferCacheMode, buffer_cache_config.storage.mode); + parse_buffer_cache_mode(kUniformBufferCacheMode, buffer_cache_config.uniform.mode); + parse_buffer_cache_mode(kQueryResolveBufferCacheMode, buffer_cache_config.query_resolve.mode); + parse_buffer_cache_mode(kDefaultBufferCacheMode, buffer_cache_config.default_entry.mode); - buffer_cache_config.query_resolve.mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << buffer_cache_config.query_resolve.mode; - - buffer_cache_config.default_entry.mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << buffer_cache_config.default_entry.mode; + // power preference + if (std::string power_preference_str; + config_options.TryGetConfigEntry(kPowerPreference, power_preference_str)) { + if (power_preference_str == kPowerPreference_HighPerformance) { + config.power_preference = static_cast(WGPUPowerPreference_HighPerformance); + } else if (power_preference_str == kPowerPreference_LowPower) { + config.power_preference = static_cast(WGPUPowerPreference_LowPower); + } else { + ORT_THROW("Invalid power preference: ", power_preference_str); + } + } - bool enable_pix_capture = false; - std::string enable_pix_capture_str; - if (config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) { - if (enable_pix_capture_str == kEnablePIXCapture_ON) { - enable_pix_capture = true; - } else if (enable_pix_capture_str == kEnablePIXCapture_OFF) { - enable_pix_capture = false; + // backend type + if (std::string backend_type_str; + config_options.TryGetConfigEntry(kDawnBackendType, backend_type_str)) { + if (backend_type_str == kDawnBackendType_D3D12) { + config.backend_type = static_cast(WGPUBackendType_D3D12); + } else if (backend_type_str == kDawnBackendType_Vulkan) { + config.backend_type = static_cast(WGPUBackendType_Vulkan); } else { - ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str); + ORT_THROW("Invalid Dawn backend type: ", backend_type_str); } } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << enable_pix_capture; - // - // STEP.4 - start initialization. - // + LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << config.buffer_cache_config.storage.mode; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << config.buffer_cache_config.uniform.mode; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << config.buffer_cache_config.query_resolve.mode; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << config.buffer_cache_config.default_entry.mode; - // Load the Dawn library and create the WebGPU instance. - auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP power preference: " << config.power_preference; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Dawn backend type: " << config.backend_type; - // Create WebGPU device and initialize the context. - context.Initialize(buffer_cache_config, backend_type, enable_pix_capture); + return config; +} + +} // namespace + +std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { + // prepare WebGpuExecutionProviderConfig + WebGpuExecutionProviderConfig webgpu_ep_config = ParseEpConfig(config_options); + + // prepare WebGpuContextConfig + WebGpuContextConfig config = ParseWebGpuContextConfig(config_options); + + // Load the Dawn library and create the WebGPU instance. + auto& context = WebGpuContextFactory::CreateContext(config); // Create WebGPU EP factory. - return std::make_shared(context_id, context, std::move(webgpu_ep_config)); + return std::make_shared(config.context_id, context, std::move(webgpu_ep_config)); } // WebGPU DataTransfer implementation wrapper for the C API with lazy initialization @@ -406,16 +330,17 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { std::lock_guard lock(impl.init_mutex_); if (impl.data_transfer_ == nullptr) { // Always create a new context with context_id 0 - WebGpuContextParams params = GetDefaultWebGpuContextParams(); - params.context_config.context_id = impl.context_id_; - auto* context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config); - context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture); + if (impl.context_id_ != 0) { + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Shared data transfer can only be created for the default device (0)."); + } + + auto& context = WebGpuContextFactory::DefaultContext(); // Create the DataTransfer instance // Note: The DataTransfer holds a const reference to BufferManager. The BufferManager's lifecycle // is managed by the WebGpuContext, which is stored in a static WebGpuContextFactory and persists // for the lifetime of the application, ensuring the reference remains valid. - impl.data_transfer_ = std::make_unique(context_ptr->BufferManager()); + impl.data_transfer_ = std::make_unique(context.BufferManager()); } } @@ -441,15 +366,15 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { } delete p_impl; if (data_transfer_initialized) { - webgpu::WebGpuContextFactory::ReleaseContext(context_id); + WebGpuContextFactory::ReleaseContext(context_id); } } const OrtApi& ort_api; const OrtEpApi& ep_api; - std::unique_ptr data_transfer_; // Lazy-initialized - int context_id_; // Track which context we're using - std::mutex init_mutex_; // Protects lazy initialization + std::unique_ptr data_transfer_; // Lazy-initialized + int context_id_; // Track which context we're using + std::mutex init_mutex_; // Protects lazy initialization }; OrtDataTransferImpl* OrtWebGpuCreateDataTransfer() { diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index e89944394aaec..b0059f87da207 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -582,6 +582,47 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* API_IMPL_END } +ORT_API_STATUS_IMPL(SharedPrePackedWeightCache_StoreWeightData, + _In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, + _In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes, + _In_ size_t num_buffers) { + API_IMPL_BEGIN + if (prepacked_weight_cache == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a valid OrtPrePackedWeightsCache instance"); + } + + if (buffer_data_ptrs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of buffer data pointers"); + } + + if (buffer_data_sizes == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of buffer data sizes"); + } + + if (num_buffers == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one weight data buffer"); + } + + OrtStatus* status = nullptr; + + ORT_TRY { + prepacked_weight_cache->SetBuffers(buffer_data_ptrs, buffer_data_sizes, num_buffers); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + // This API function promises that ORT will take ownership of the data only if it returns successfully. + // If any exception occurred while filling out `prepacked_weight_cache`, we try to release ownership so that + // the caller retains ownership of all of the original data and can delete it. + prepacked_weight_cache->ReleaseAllData(); + status = OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); + }); + } + + return status; + API_IMPL_END +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -636,6 +677,7 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::KernelDef_GetOutputMemType, &OrtExecutionProviderApi::GetTensorDataType, &OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel, + &OrtExecutionProviderApi::SharedPrePackedWeightCache_StoreWeightData, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index b6a7262ec2008..b2abad622c9a6 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -100,4 +100,9 @@ ORT_API_STATUS_IMPL(GetTensorDataType, _In_ ONNXTensorElementDataType elem_type, _Outptr_ const OrtDataType** out); ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def); + +ORT_API_STATUS_IMPL(SharedPrePackedWeightCache_StoreWeightData, + _In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, + _In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes, + _In_ size_t num_buffers); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc index 8dfb0a7ab06b4..fe96bb577d925 100644 --- a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc +++ b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc @@ -3,13 +3,52 @@ #include "core/session/plugin_ep/ep_kernel_registration.h" +#include #include +#include #include +#include #include "core/framework/error_code_helper.h" #include "core/framework/kernel_registry.h" +#include "core/framework/tensor.h" +#include "core/session/allocator_adapters.h" #include "core/session/plugin_ep/ep_api.h" +// +// OrtSharedPrePackedWeightCache +// +OrtSharedPrePackedWeightCache::OrtSharedPrePackedWeightCache(onnxruntime::PrePackedWeights& container, + onnxruntime::AllocatorPtr allocator) + : container_(container), allocator_(std::move(allocator)) {} + +void OrtSharedPrePackedWeightCache::SetBuffers(void** data_ptrs, size_t* data_sizes, size_t num_buffers) { + container_.buffers_.clear(); + container_.buffer_sizes_.clear(); + + container_.buffers_.reserve(num_buffers); + container_.buffer_sizes_.reserve(num_buffers); + + for (size_t i = 0; i < num_buffers; i++) { + auto data_unique_ptr = onnxruntime::IAllocatorUniquePtr(data_ptrs[i], onnxruntime::BufferDeleter(allocator_)); + container_.buffers_.push_back(std::move(data_unique_ptr)); + container_.buffer_sizes_.push_back(data_sizes[i]); + } +} + +bool OrtSharedPrePackedWeightCache::HasData() const noexcept { + return !container_.buffers_.empty(); +} + +void OrtSharedPrePackedWeightCache::ReleaseAllData() noexcept { + for (onnxruntime::IAllocatorUniquePtr& data_unique_ptr : container_.buffers_) { + data_unique_ptr.release(); + } + + container_.buffers_.clear(); + container_.buffer_sizes_.clear(); +} + namespace onnxruntime { /// @@ -17,6 +56,7 @@ namespace onnxruntime { /// class PluginEpOpKernel final : public OpKernel { private: + // Prevents calling constructor directly without having to make it private (required by std::make_unique). struct PrivateTag {}; public: @@ -37,8 +77,108 @@ class PluginEpOpKernel final : public OpKernel { return ToStatusAndRelease(kernel_impl_->Compute(kernel_impl_, reinterpret_cast(ctx))); } + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override { + assert(kernel_impl_ != nullptr); // Should be ensured by PluginEpOpKernel::Create(). + + if (kernel_impl_->ort_version_supported < 24 || kernel_impl_->PrePackWeight == nullptr) { + // OrtKernelImpl does not define a PrePack implementation. + is_packed = false; + return Status::OK(); + } + + // Convert AllocatorPtr to an OrtAllocator* (that wraps the AllocatorPtr) and cache it. + OrtAllocator* ort_allocator = GetPrePackOrtAllocator(alloc); + + // Create a non-owning OrtValue that wraps the const Tensor& with an empty deleter. + // This is passed to OrtKernelImpl::PrePackWeight() as a const OrtValue*. + // The above reasons make the const_cast relatively "safe". + // Note: Documentation for OrtKernelImpl::PrePackWeight disallows caching the OrtValue pointer. + auto empty_tensor_deleter = [](void* /*data*/) -> void { /* do not delete Tensor (not owned) */ }; + const OrtValue ort_value(const_cast(&tensor), DataTypeImpl::GetType(), empty_tensor_deleter); + + // Only allow kernel to store/share pre-packed weights if the weight data will be stored in cpu-accessible memory. + // ORT requires that the data reside in cpu memory to be able to compute the hash of the weight's contents. + // + // If the allocator does not use CPU memory, we pass a NULL OrtSharedPrePackedWeightCache instance to the kernel to + // indicate that storing/sharing is not allowed and the kernel should manage the memory for the pre-packed weight. + std::optional shared_weight_cache; + + if (prepacked_weights != nullptr && alloc->Info().device.UsesCpuMemory()) { + ORT_RETURN_IF(!prepacked_weights->buffers_.empty() || !prepacked_weights->buffer_sizes_.empty(), + "PluginEpOpKernel::PrePack() expected PrePackedWeights instance to be initially empty"); + shared_weight_cache.emplace(OrtSharedPrePackedWeightCache(*prepacked_weights, alloc)); + } + + ORT_RETURN_IF_ERROR(ToStatusAndRelease( + kernel_impl_->PrePackWeight(kernel_impl_, &ort_value, input_idx, + ort_allocator, + shared_weight_cache.has_value() ? &*shared_weight_cache : nullptr, + &is_packed))); + + const bool tried_to_share = shared_weight_cache.has_value() && shared_weight_cache->HasData(); + ORT_RETURN_IF(tried_to_share && !is_packed, "OrtKernelImpl::PrePackWeight() tried to share packed weight data ", + "but did not set the `is_packed` output parameter to true."); + + return Status::OK(); + } + + Status UseSharedPrePackedBuffers_V2(std::vector& buffer_unique_ptrs, + gsl::span buffer_sizes, + int input_idx, /*out*/ bool& used_shared_buffers) override { + assert(kernel_impl_ != nullptr); // Should be ensured by PluginEpOpKernel::Create(). + + if (kernel_impl_->ort_version_supported < 24 || kernel_impl_->SetSharedPrePackedWeight == nullptr) { + // OrtKernelImpl does not define an implementation. The session state, which calls this function, + // generates an error if necessary (i.e., kernel indicated it wanted to share weights but did not define this). + used_shared_buffers = false; + return Status::OK(); + } + + std::vector buffer_data_ptrs; + + buffer_data_ptrs.reserve(buffer_unique_ptrs.size()); + std::transform(buffer_unique_ptrs.begin(), buffer_unique_ptrs.end(), std::back_inserter(buffer_data_ptrs), + [](const BufferUniquePtr& buff) -> const void* { return buff.get(); }); + + ORT_RETURN_IF_ERROR(ToStatusAndRelease( + kernel_impl_->SetSharedPrePackedWeight(kernel_impl_, buffer_data_ptrs.data(), buffer_sizes.data(), + buffer_data_ptrs.size(), input_idx))); + + used_shared_buffers = true; + return Status::OK(); + } + private: + /// + /// Gets the cached OrtAllocator for the given AllocatorPtr passed to PrePack(). + /// + /// + /// + OrtAllocator* GetPrePackOrtAllocator(AllocatorPtr alloc) { + IAllocator* i_allocator = alloc.get(); + + // Try to find an existing OrtAllocator* that wraps the given IAllocator* + for (auto& ort_allocator_wrapper : prepack_ort_allocators_) { + if (ort_allocator_wrapper->GetWrappedIAllocator().get() == i_allocator) { + return ort_allocator_wrapper.get(); + } + } + + // Create a new OrtAllocatorImplWrappingIAllocator + auto ort_allocator_wrapper = std::make_unique(std::move(alloc)); + + prepack_ort_allocators_.push_back(std::move(ort_allocator_wrapper)); + return prepack_ort_allocators_.back().get(); + } + OrtKernelImpl* kernel_impl_ = nullptr; + + // We create and cache a OrtAllocator that wraps each unique IAllocator passed to PrePack(). Need to keep these + // OrtAllocator instances alive because the plugin EP kernel implementation uses the OrtAllocators to allocate + // and free packed weight data. Note: use a vector instead of an unordered_map because this will almost always + // contain only one element and we want to limit the size of this class. + std::vector> prepack_ort_allocators_; }; /*static*/ diff --git a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h index a7fd1759697df..271068059ed72 100644 --- a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h +++ b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.h @@ -4,12 +4,57 @@ #pragma once #include +#include "core/common/inlined_containers_fwd.h" #include "core/session/onnxruntime_c_api.h" +#include "core/framework/allocator.h" #include "core/framework/data_types.h" #include "core/framework/error_code_helper.h" #include "core/framework/kernel_def_builder.h" #include "core/framework/kernel_registry.h" #include "core/framework/op_kernel.h" +#include "core/framework/prepacked_weights.h" + +/// +/// Implementation of the public C API opaque type OrtSharedPrePackedWeightCache used by plugin EP kernels. +/// This wraps and fills out an instance of onnxruntime::PrePackedWeights via the +/// C API SharedPrePackedWeightCache_StoreWeightData. +/// +struct OrtSharedPrePackedWeightCache { + /// + /// Constructs an OrtSharedPrePackedWeightCache that will fill out the provided PrePackedWeights object. + /// + /// The PrePackedWeights container to fill out. + /// The allocator that will be used to free buffers set by the call to SetBuffers(). + OrtSharedPrePackedWeightCache(onnxruntime::PrePackedWeights& container, onnxruntime::AllocatorPtr allocator); + + /// + /// Sets data buffers for the shared weight. Ownership of the buffers is transferred to this class's contained + /// PrePackedWeights instance, which will delete the buffers with `this->allocator_`. + /// The buffer data is required to have been allocated with `this->allocator_`. + /// Refer to OrtKernelImpl::PrePackWeight and OrtEpApi::SharedPrePackedWeightCache_StoreWeightData. + /// + /// + /// + /// + void SetBuffers(void** data_ptrs, size_t* data_sizes, size_t num_buffers); + + /// + /// Returns true if this instance has any weight buffer data. + /// + /// + bool HasData() const noexcept; + + /// + /// Releases all buffer data. + /// Used within OrtEpApi::SharedPrePackedWeightCache_StoreWeightData() if an error occurs and ORT wants to + /// release all data to allow caller to retain ownership of data. + /// + void ReleaseAllData() noexcept; + + private: + onnxruntime::PrePackedWeights& container_; + onnxruntime::AllocatorPtr allocator_; +}; namespace onnxruntime { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h new file mode 100644 index 0000000000000..186a44b5ce1c4 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../plugin_ep_utils.h" + +#include +#include + +// `OrtAllocator` is a C API struct. `BaseAllocator` is a minimal C++ struct which inherits from `OrtAllocator`. +// Notably, `BaseAllocator` has a virtual destructor to enable a derived class to be deleted through a `BaseAllocator` +// pointer. Allocators which need to be deleted through a base class pointer should inherit from `BaseAllocator`. +struct BaseAllocator : OrtAllocator { + virtual ~BaseAllocator() = default; +}; + +using AllocatorUniquePtr = std::unique_ptr; + +struct CustomAllocator : BaseAllocator { + CustomAllocator(const OrtMemoryInfo* mem_info) : memory_info{mem_info} { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + GetStats = nullptr; + AllocOnStream = nullptr; + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* /*this_*/, size_t size) { + return malloc(size); + } + + /// Free a block of memory previously allocated with OrtAllocator::Alloc + static void ORT_API_CALL FreeImpl(struct OrtAllocator* /*this_*/, void* p) { + return free(p); + } + + /// Return a pointer to an ::OrtMemoryInfo that describes this allocator + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const CustomAllocator& impl = *static_cast(this_); + return impl.memory_info; + } + + private: + const OrtMemoryInfo* memory_info; +}; + +using AllocationUniquePtr = std::unique_ptr>; + +inline AllocationUniquePtr AllocateBytes(OrtAllocator* allocator, size_t num_bytes) { + void* p = allocator->Alloc(allocator, num_bytes); + return AllocationUniquePtr(p, [allocator](void* d) { allocator->Free(allocator, d); }); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.cc new file mode 100644 index 0000000000000..dbadc6141c063 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.cc @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_data_transfer.h" + +#include +#include + +/*static*/ +bool ORT_API_CALL ExampleDataTransfer::CanCopyImpl(const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept { + const auto& impl = *static_cast(this_ptr); + bool src_is_our_device = impl.ep_api_.MemoryDevice_AreEqual(src_memory_device, impl.device_mem_info); + bool dst_is_our_device = impl.ep_api_.MemoryDevice_AreEqual(dst_memory_device, impl.device_mem_info); + + if (src_is_our_device && dst_is_our_device) { + return true; + } + + // implementation should check if the copy is possible, which may require checking the device type, the memory type + // and the vendor and device IDs as needed. + OrtMemoryInfoDeviceType src_device_type = impl.ep_api_.MemoryDevice_GetDeviceType(src_memory_device); + OrtMemoryInfoDeviceType dst_device_type = impl.ep_api_.MemoryDevice_GetDeviceType(dst_memory_device); + OrtDeviceMemoryType src_mem_type = impl.ep_api_.MemoryDevice_GetMemoryType(src_memory_device); + OrtDeviceMemoryType dst_mem_type = impl.ep_api_.MemoryDevice_GetMemoryType(dst_memory_device); + + // we can copy to/from CPU or CPU accessible memory + if (src_is_our_device) { + return (dst_device_type == OrtMemoryInfoDeviceType_CPU || dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE); + } + + if (dst_is_our_device) { + return (src_device_type == OrtMemoryInfoDeviceType_CPU || src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE); + } + + return false; +} + +namespace { +void CopyImpl(const void* src_data, void* dst_data, size_t bytes, OrtSyncStream* stream) { + // in our example setup this is really CPU to CPU + + if (stream) { + // EP can do an async copy using the stream. e.g. an NVIDIA EP would provide the stream to cudaMemcpyAsync + } + + if (src_data != dst_data) { + memcpy(dst_data, src_data, bytes); + } +} +} // namespace + +// function to copy one or more tensors. +// implementation can optionally use async copy if a stream is available for the input. +/*static*/ +OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors_ptr, + OrtValue** dst_tensors_ptr, + OrtSyncStream** streams_ptr, + size_t num_tensors) noexcept { + auto& impl = *static_cast(this_ptr); + + auto src_tensors = gsl::make_span(src_tensors_ptr, num_tensors); + auto dst_tensors = gsl::make_span(dst_tensors_ptr, num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + // the implementation for a 'real' EP would be something along these lines. + // See CudaDataTransferImpl in onnxruntime\core\providers\cuda\cuda_provider_factory.cc + const OrtMemoryDevice* src_device = impl.ep_api_.Value_GetMemoryDevice(src_tensors[i]); + const OrtMemoryDevice* dst_device = impl.ep_api_.Value_GetMemoryDevice(dst_tensors[i]); + + OrtMemoryInfoDeviceType src_device_type = impl.ep_api_.MemoryDevice_GetDeviceType(src_device); + OrtMemoryInfoDeviceType dst_device_type = impl.ep_api_.MemoryDevice_GetDeviceType(dst_device); + + // OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); + // OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); + // bool copy_involves_host_accessible_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE || + // dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE; + + const void* src_data = nullptr; + void* dst_data = nullptr; + size_t bytes; + + RETURN_IF_ERROR(impl.ort_api_.GetTensorData(src_tensors[i], &src_data)); + RETURN_IF_ERROR(impl.ort_api_.GetTensorMutableData(dst_tensors[i], &dst_data)); + RETURN_IF_ERROR(impl.ort_api_.GetTensorSizeInBytes(src_tensors[i], &bytes)); + + if (dst_device_type == OrtMemoryInfoDeviceType_GPU) { + if (src_device_type == OrtMemoryInfoDeviceType_GPU) { + // GPU -> GPU + } else { + // CPU -> GPU + } + } else if (src_device_type == OrtMemoryInfoDeviceType_GPU) { + // GPU -> CPU + } else { + // CPU -> CPU. may involve copy a to/from host accessible memory and a synchronize may be required first + } + + // but in our example EP it's simpler as it's really a (fake) CPU to CPU copy + CopyImpl(src_data, dst_data, bytes, streams_ptr ? streams_ptr[i] : nullptr); + } + + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleDataTransfer::ReleaseImpl(OrtDataTransferImpl* /*this_ptr*/) noexcept { + // In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore + // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) + // + // If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here + // delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.h new file mode 100644 index 0000000000000..780fe8e379109 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../plugin_ep_utils.h" + +struct ExampleDataTransfer : OrtDataTransferImpl { + ExampleDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtMemoryDevice* device_mem_info_) + : ort_api_(ort_api), ep_api_(ep_api), device_mem_info{device_mem_info_} { + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; + Release = ReleaseImpl; + } + + static bool ORT_API_CALL CanCopyImpl(const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept; + + // function to copy one or more tensors. + // implementation can optionally use async copy if a stream is available for the input. + static OrtStatus* ORT_API_CALL CopyTensorsImpl(OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors_ptr, + OrtValue** dst_tensors_ptr, + OrtSyncStream** streams_ptr, + size_t num_tensors) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept; + + private: + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + const OrtMemoryDevice* device_mem_info; // device our EP runs on +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc index 6017bf9dd9d1e..a520b02c20cba 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc @@ -8,6 +8,7 @@ #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "ep.h" +#include "ep_allocator.h" #include "ep_kernel_registration.h" #include "../plugin_ep_utils.h" @@ -15,7 +16,9 @@ ExampleKernelEpFactory::ExampleKernelEpFactory(const OrtApi& ort_api, const OrtE const OrtLogger& /*default_logger*/) : OrtEpFactory{}, ort_api_(ort_api), - ep_api_(ep_api) { + ep_api_(ep_api), + default_memory_info_{nullptr}, + readonly_memory_info_{nullptr} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; @@ -34,6 +37,32 @@ ExampleKernelEpFactory::ExampleKernelEpFactory(const OrtApi& ort_api, const OrtE IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + + // Define the default memory info. Allows creating custom OrtAllocators and OrtDataTransferImpls. + // This is not strictly required for cpu-based EPs, like this example EP. However, we define it here + // to serve as an example for non-cpu EPs. + default_memory_info_ = Ort::MemoryInfo{"ExampleKernelEp CPU", + OrtMemoryInfoDeviceType_CPU, + // Use vendor ID 0 for generic allocator (e.g., webGPU) + /* vendor */ 0, + /* device_id */ 0, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator}; + + // create data transfer for the device + const OrtMemoryDevice* device = ep_api.MemoryInfo_GetMemoryDevice(default_memory_info_); + data_transfer_impl_ = std::make_unique(ort_api, ep_api, device); + + // Create read-only allocator for use with initializers. same info as DEFAULT memory apart from the allocator type. + // This is optional. It is only required if the readonly allocator differs from the default device allocator. + // This is not required for this cpu-based example EP, but show it as an example. + readonly_memory_info_ = Ort::MemoryInfo{"ExampleKernelEp CPU readonly", + OrtMemoryInfoDeviceType_CPU, + /*vendor*/ 0, /* device_id */ 0, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtReadOnlyAllocator}; } ExampleKernelEpFactory::~ExampleKernelEpFactory() { @@ -51,7 +80,9 @@ OrtStatus* ExampleKernelEpFactory::GetKernelRegistryForEp(ExampleKernelEp& ep, } if (kernel_registry_ == nullptr) { - void* op_kernel_state = nullptr; // Optional state that is provided to kernels on creation (can be null). + // Optional state that is provided to kernels on creation (can be null). + // We pass the OrtDataTransferImpl created by this factory to allow kernels to copy data between devices. + void* op_kernel_state = static_cast(data_transfer_impl_.get()); const char* ep_name = ep.GetName(static_cast(&ep)); // This statement creates the kernel registry and caches it in the OrtEpFactory instance. @@ -103,7 +134,8 @@ OrtStatus* ORT_API_CALL ExampleKernelEpFactory::GetSupportedDevicesImpl(OrtEpFac for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *hw_devices[i]; - if (factory->ort_api_.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + auto hw_type = factory->ort_api_.HardwareDevice_Type(&device); + if (hw_type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { // these can be returned as nullptr if you have nothing to add. OrtKeyValuePairs* ep_metadata = nullptr; OrtKeyValuePairs* ep_options = nullptr; @@ -129,8 +161,8 @@ OrtStatus* ORT_API_CALL ExampleKernelEpFactory::GetSupportedDevicesImpl(OrtEpFac // register the allocator info required by the EP. // registering OrtMemoryInfo for host accessible memory would be done in an additional call. // OrtReadOnlyAllocator + OrtDeviceMemoryType_DEFAULT allocator for use with initializers is optional. - // RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_.get())); - // RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->readonly_memory_info_.get())); + RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_)); + RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo(ep_device, factory->readonly_memory_info_)); ep_devices[num_ep_devices++] = ep_device; } @@ -167,26 +199,40 @@ void ORT_API_CALL ExampleKernelEpFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr } /*static*/ -OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateAllocatorImpl(OrtEpFactory* /*this_ptr*/, - const OrtMemoryInfo* /*memory_info*/, +OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, const OrtKeyValuePairs* /*allocator_options*/, OrtAllocator** allocator) noexcept { - // Don't support custom allocators in this example for simplicity. A GPU EP would normally support allocators. + auto& factory = *static_cast(this_ptr); *allocator = nullptr; + + bool is_default_allocator = memory_info == factory.default_memory_info_; + bool is_readonly_allocator = memory_info == factory.readonly_memory_info_; + + if (!is_default_allocator && !is_readonly_allocator) { + return factory.ort_api_.CreateStatus(ORT_INVALID_ARGUMENT, + "INTERNAL ERROR! Unknown memory info provided to CreateAllocator. " + "Value did not come directly from an OrtEpDevice returned by this factory."); + } + + // Note: the same allocator handles both default and readonly allocations. A readonly only allocator would + // typically be different. + auto custom_allocator = std::make_unique(memory_info); + *allocator = custom_allocator.release(); return nullptr; } /*static*/ void ORT_API_CALL ExampleKernelEpFactory::ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, - OrtAllocator* /*allocator*/) noexcept { - // Do nothing. + OrtAllocator* allocator) noexcept { + delete static_cast(allocator); } /*static*/ -OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, +OrtStatus* ORT_API_CALL ExampleKernelEpFactory::CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept { - // Don't support data transfer in this example for simplicity. A GPU EP would normally support it. - *data_transfer = nullptr; + auto& factory = *static_cast(this_ptr); + *data_transfer = factory.data_transfer_impl_.get(); return nullptr; } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h index 9ddbeee585115..a2340b8b1499d 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h @@ -9,6 +9,8 @@ #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT +#include "ep_data_transfer.h" + class ExampleKernelEp; /// @@ -75,6 +77,10 @@ class ExampleKernelEpFactory : public OrtEpFactory { const uint32_t vendor_id_{0xB358}; // EP vendor ID const std::string ep_version_{"0.1.0"}; // EP version + Ort::MemoryInfo default_memory_info_; + Ort::MemoryInfo readonly_memory_info_; + std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory + // Cached kernel registry used by all OrtEp instances created by this factory. Refer to OrtEp::GetKernelRegistry. // // Note: If this factory instead created EP instances that each supported different hardware configurations, then diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc deleted file mode 100644 index 30f83e1771dd7..0000000000000 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "base.h" - -BaseKernelImpl::BaseKernelImpl(const OrtKernelInfo* info, void* state) : info_{info}, state_{state} { - ort_version_supported = ORT_API_VERSION; - Compute = ComputeImpl; - Release = ReleaseImpl; -} - -/*static*/ -OrtStatus* ORT_API_CALL BaseKernelImpl::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { - try { - BaseKernelImpl* base_kernel = static_cast(this_ptr); - return base_kernel->DoCompute(kernel_ctx); - } catch (const Ort::Exception& ex) { - Ort::Status status(ex); - return status.release(); - } catch (const std::exception& ex) { - Ort::Status status(ex.what(), ORT_EP_FAIL); - return status.release(); - } -} - -/*static*/ -void ORT_API_CALL BaseKernelImpl::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { - delete static_cast(this_ptr); -} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h deleted file mode 100644 index c4afe1b2e0670..0000000000000 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "../../plugin_ep_utils.h" - -// Base class for kernel implementations. -// -// Note: BaseKernelImpl has virtual functions so care should be taken when casting BaseKernelImpl to a OrtKernelImpl, -// which is a C API struct type. Specifically, a static_cast or implicit cast should be used. A reinterpret_cast -// will result in an invalid object due to the presence of the vtable. -class BaseKernelImpl : public OrtKernelImpl { - public: - BaseKernelImpl(const OrtKernelInfo* info, void* state); - virtual ~BaseKernelImpl() = default; - - static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; - static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; - - private: - // Derived classes implement DoCompute. - // DoCompute is called by BaseKernelImpl::ComputeImpl, which also catches exceptions thrown by DoCompute - // implementations and converts them into OrtStatus*. - virtual OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) = 0; - - protected: - const OrtKernelInfo* info_; - void* state_; // Custom state passed from OrtEp -}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc index 979dc5e9c1303..046ee04f37786 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "mul.h" #include "utils.h" @@ -14,28 +15,71 @@ ONNX_OPERATOR_KERNEL_EX( .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), Mul) -Mul::Mul(const OrtKernelInfo* info, void* state, PrivateTag) : BaseKernelImpl(info, state) {} +Mul::Mul(const OrtKernelInfo* info, void* state, PrivateTag) + : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL + info_{info}, + data_transfer_impl_{reinterpret_cast(state)} { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; + + // Optional functions that are only needed to pre-pack weights. This Mul kernel pre-packs + // input[1] weights as an example (not typically done by an actual implementation of Mul). + PrePackWeight = PrePackWeightImpl; + SetSharedPrePackedWeight = SetSharedPrePackedWeightImpl; +} /*static*/ OrtStatus* Mul::Create(const OrtKernelInfo* info, void* state, - /*out*/ std::unique_ptr& result) { + /*out*/ std::unique_ptr& result) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN // Note: can do basic validation or preprocessing via the OrtKernelInfo APIs. result = std::make_unique(info, state, PrivateTag{}); return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +/*static*/ +void ORT_API_CALL Mul::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); } -OrtStatus* Mul::DoCompute(OrtKernelContext* kernel_ctx) { +/*static*/ +OrtStatus* ORT_API_CALL Mul::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Mul* mul_kernel = static_cast(this_ptr); + static_cast(mul_kernel->info_); // NOTE: Unused in this example. + Ort::KernelContext kernel_context(kernel_ctx); - static_cast(this->state_); // NOTE: Unused in this example. - static_cast(this->info_); // NOTE: Unused in this example. + // Get first input's data. gsl::span input0; - gsl::span input1; std::vector shape0; + RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 0, input0, shape0)); + + // Get second input's data. + // This second input may have been pre-packed if it is a constant weight. + gsl::span input1; std::vector shape1; - RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 0, input0, shape0)); - RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 1, input1, shape1)); + if (mul_kernel->packed_weight_1_info_.has_value()) { + const PackedWeightInfo& packed_weight_info = *mul_kernel->packed_weight_1_info_; + shape1 = packed_weight_info.shape; + + size_t num_elems = 1; + for (auto s : shape1) { + num_elems *= s; + } + + const float* input1_data = packed_weight_info.shared_data != nullptr + ? reinterpret_cast(packed_weight_info.shared_data) + : reinterpret_cast(packed_weight_info.owned_data.get()); + + input1 = gsl::span(input1_data, num_elems); + } else { + RETURN_IF_ERROR(GetValueDataAndShape(kernel_context.GetInput(1), input1, shape1)); + } + RETURN_IF(shape0 != shape1, Ort::GetApi(), "Mul kernel doesn't support broadcasting."); // Checked by GetCapability Ort::UnownedValue output = kernel_context.GetOutput(0, shape0); @@ -46,4 +90,99 @@ OrtStatus* Mul::DoCompute(OrtKernelContext* kernel_ctx) { } return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +/*static*/ +OrtStatus* ORT_API_CALL Mul::PrePackWeightImpl(OrtKernelImpl* this_ptr, const OrtValue* tensor, + int input_index, OrtAllocator* allocator, + OrtSharedPrePackedWeightCache* prepacked_weight_cache, + /*out*/ bool* is_packed) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Mul* mul_kernel = static_cast(this_ptr); + + // This example Mul kernel does not really need to pre-pack mul initializers, but we show it here as an example. + // This implementation just copies original tensor without modification. An actual implementation would, for example, + // transform to an appropriate data layout. + + if (input_index != 1) { + *is_packed = false; + return nullptr; + } + + Ort::ConstValue original_weight(tensor); + auto type_shape_info = original_weight.GetTensorTypeAndShapeInfo(); + size_t num_bytes = original_weight.GetTensorSizeInBytes(); + + PackedWeightInfo weight_info = {}; + weight_info.mem_info = Ort::ConstMemoryInfo(allocator->Info(allocator)); + weight_info.shape = type_shape_info.GetShape(); + weight_info.elem_type = type_shape_info.GetElementType(); + weight_info.num_bytes = num_bytes; + weight_info.owned_data = AllocateBytes(allocator, num_bytes); + + // Note: This Ort::Value does not own the underlying data. + Ort::Value packed_weight = Ort::Value::CreateTensor(weight_info.mem_info, weight_info.owned_data.get(), + weight_info.num_bytes, weight_info.shape.data(), + weight_info.shape.size(), weight_info.elem_type); + + RETURN_IF_ERROR(CopyTensor(*mul_kernel->data_transfer_impl_, original_weight, packed_weight.GetUnowned())); + + const bool sharing_allowed = prepacked_weight_cache != nullptr; + + if (sharing_allowed) { + std::array buffer_data_ptrs = {weight_info.owned_data.get()}; + std::array buffer_data_sizes = {weight_info.num_bytes}; + + Ort::UnownedSharedPrePackedWeightCache weight_cache(prepacked_weight_cache); + + // weight_cache takes ownership of the data. As the API documentation states, this requires that the + // weight data is allocated with the OrtAllocator provided as a parameter to OrtKernelImpl::PrePackWeight. + RETURN_IF_ERROR(weight_cache.StoreWeightData(buffer_data_ptrs.data(), + buffer_data_sizes.data(), + buffer_data_ptrs.size())); + + // IMPORTANT: This kernel no longer owns the packed weight data. + // weight_info.shared_data will be initialized in the call to SetSharedPrePackedWeightImpl. + weight_info.owned_data.release(); + } + + mul_kernel->packed_weight_1_info_ = std::move(weight_info); + *is_packed = true; + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +/*static*/ +OrtStatus* ORT_API_CALL Mul::SetSharedPrePackedWeightImpl(OrtKernelImpl* this_ptr, + const void* const* buffer_data_ptrs, + const size_t* buffer_data_sizes, + size_t num_buffers, int input_index) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Mul* mul_kernel = static_cast(this_ptr); + + if (input_index != 1) { + std::ostringstream oss; + oss << "ExampleKernelEp did not expect a call to OrtKernelImpl::SetSharedPrePackedWeight for input index " + << input_index << " of the Mul kernel."; + return Ort::GetApi().CreateStatus(ORT_EP_FAIL, oss.str().c_str()); + } + + RETURN_IF(num_buffers != 1, Ort::GetApi(), "Invalid number of pre-packed data buffers for Mul kernel's 2nd input"); + RETURN_IF(!mul_kernel->packed_weight_1_info_.has_value(), Ort::GetApi(), + "ERROR! OrtKernelImpl::PrePackWeight should have " + "initialized a valid PackedWeightInfo struct for use in SetSharedPrePackedWeight."); + + // Check that the buffer size is what we expect. + RETURN_IF(buffer_data_sizes[0] != mul_kernel->packed_weight_1_info_->num_bytes, Ort::GetApi(), + "ExampleKernelEp received an unexpected buffer size in a call to OrtKernelImpl::SetSharedPrePackedWeight " + "for the Mul kernel."); + + // Update buffer data pointer because the shared memory could potentially originate from a different + // kernel instance. + mul_kernel->packed_weight_1_info_->shared_data = buffer_data_ptrs[0]; + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h index 882a19a13e23e..f84fda6a8b0ec 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/mul.h @@ -3,17 +3,45 @@ #pragma once -#include "base.h" +#include #include "../../plugin_ep_utils.h" +#include "../ep_allocator.h" -class Mul : public BaseKernelImpl { +class Mul : public OrtKernelImpl { private: struct PrivateTag {}; + struct PackedWeightInfo { + Ort::ConstMemoryInfo mem_info{nullptr}; + std::vector shape; + ONNXTensorElementDataType elem_type; + size_t num_bytes; + + // Only one of the following data fields will be set. + // If pre-packed data is shared with other kernels, `shared_data` will be non-null. Otherwise, this kernel + // sets `owned_data`, whose lifetime it manages. + AllocationUniquePtr owned_data{}; + const void* shared_data{nullptr}; // not owned by this kernel. + }; + public: - static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel); + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; Mul(const OrtKernelInfo* info, void* state, PrivateTag); + // Static functions assigned to the OrtKernelImpl fields: + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL PrePackWeightImpl(OrtKernelImpl* this_ptr, const OrtValue* tensor, + int input_index, OrtAllocator* alloc, + OrtSharedPrePackedWeightCache* prepacked_weight_cache, + /*out*/ bool* is_packed) noexcept; + static OrtStatus* ORT_API_CALL SetSharedPrePackedWeightImpl(OrtKernelImpl* this_ptr, + const void* const* buffer_data_ptrs, + const size_t* buffer_data_sizes, + size_t num_buffers, int input_index) noexcept; + private: - OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) override; + const OrtKernelInfo* info_; + OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp + std::optional packed_weight_1_info_ = std::nullopt; }; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc index 82444b815a1ee..89f52c4b53dc3 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc @@ -19,19 +19,29 @@ ONNX_OPERATOR_KERNEL_EX( .AddInputOutputMutableAlias(0, 0)), Relu) -Relu::Relu(const OrtKernelInfo* info, void* state, PrivateTag) : BaseKernelImpl(info, state) {} +Relu::Relu(const OrtKernelInfo* info, void* /*state*/, PrivateTag) + : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL + info_{info} { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; +} /*static*/ -OrtStatus* Relu::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) { +OrtStatus* Relu::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN Ort::ConstKernelInfo kernel_info(info); kernel = std::make_unique(info, state, PrivateTag{}); return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END } -OrtStatus* Relu::DoCompute(OrtKernelContext* kernel_ctx) { +/*static*/ +OrtStatus* ORT_API_CALL Relu::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Relu* relu_kernel = static_cast(this_ptr); Ort::KernelContext kernel_context(kernel_ctx); - static_cast(this->state_); // NOTE: Unused in this example. - static_cast(this->info_); // NOTE: Unused in this example. + static_cast(relu_kernel->info_); // NOTE: Unused in this example. gsl::span input0; std::vector shape0; @@ -45,4 +55,10 @@ OrtStatus* Relu::DoCompute(OrtKernelContext* kernel_ctx) { } return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +/*static*/ +void ORT_API_CALL Relu::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h index 4f5ba8bc0e77b..cdeb450435c29 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h @@ -3,17 +3,20 @@ #pragma once -#include "base.h" #include "../../plugin_ep_utils.h" -class Relu : public BaseKernelImpl { +class Relu : public OrtKernelImpl { private: struct PrivateTag {}; public: - static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel); + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; Relu(const OrtKernelInfo* info, void* state, PrivateTag); + // Static functions assigned to the OrtKernelImpl fields: + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + private: - OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) override; + const OrtKernelInfo* info_; }; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc index 5311911a8c413..3d6a2527476e8 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc @@ -21,6 +21,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( (Ort::KernelDefBuilder() .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) .AddTypeConstraint("axes", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .SetInputMemType(1, OrtMemTypeCPU) .AddInputOutputAlias(0, 0)), Squeeze) @@ -32,6 +33,7 @@ ONNX_OPERATOR_KERNEL_EX( (Ort::KernelDefBuilder() .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) .AddTypeConstraint("axes", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .SetInputMemType(1, OrtMemTypeCPU) .AddInputOutputAlias(0, 0)), Squeeze) @@ -43,16 +45,26 @@ ONNX_OPERATOR_KERNEL_EX( (Ort::KernelDefBuilder() .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) .AddTypeConstraint("axes", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .SetInputMemType(1, OrtMemTypeCPU) .AddInputOutputAlias(0, 0)), Squeeze) -Squeeze::Squeeze(const OrtKernelInfo* info, void* state, PrivateTag) : BaseKernelImpl(info, state) {} +Squeeze::Squeeze(const OrtKernelInfo* info, void* state, PrivateTag) + : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL + info_{info}, + data_transfer_impl_{reinterpret_cast(state)} { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; +} /*static*/ -OrtStatus* Squeeze::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) { +OrtStatus* Squeeze::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN Ort::ConstKernelInfo kernel_info(info); kernel = std::make_unique(info, state, PrivateTag{}); return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END } static int64_t HandleNegativeAxis(int64_t axis, int64_t tensor_rank) { @@ -84,12 +96,16 @@ static std::vector ComputeOutputShape(gsl::span input_sh return output_shape; } -OrtStatus* Squeeze::DoCompute(OrtKernelContext* kernel_ctx) { - Ort::KernelContext kernel_context(kernel_ctx); - static_cast(this->state_); // NOTE: Unused in this example. +/*static*/ +OrtStatus* ORT_API_CALL Squeeze::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Squeeze* squeeze_kernel = static_cast(this_ptr); + static_cast(squeeze_kernel->info_); // NOTE: Unused in this example. + Ort::KernelContext kernel_context(kernel_ctx); gsl::span input0; std::vector shape0; + Ort::ConstValue input = kernel_context.GetInput(0); RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 0, input0, shape0)); size_t num_inputs = kernel_context.GetInputCount(); @@ -107,15 +123,15 @@ OrtStatus* Squeeze::DoCompute(OrtKernelContext* kernel_ctx) { std::vector output_shape = ComputeOutputShape(shape0, axes); Ort::UnownedValue output = kernel_context.GetOutput(0, output_shape); - float* output_data = output.GetTensorMutableData(); - size_t num_bytes = output.GetTensorSizeInBytes(); - - if (input0.data() != output_data) { // Don't copy if src == dst - // This uses a memcpy because the input and output are both located in the EP's device memory (i.e., cpu memory). - // Normally, an EP would use a OrtDataTransferImpl to generically handle copies where the source and destination - // could be on different devices. - memcpy(output_data, input0.data(), num_bytes); - } + // This kernel aliases the input and output, so a copy is not really necessary. + // CopyTensor() will not do a copy if the source and destination buffers are the same. + RETURN_IF_ERROR(CopyTensor(*squeeze_kernel->data_transfer_impl_, input, output)); return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +/*static*/ +void ORT_API_CALL Squeeze::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h index 9faf91c1d2b3c..d179b95d73f80 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h @@ -3,17 +3,21 @@ #pragma once -#include "base.h" #include "../../plugin_ep_utils.h" -class Squeeze : public BaseKernelImpl { +class Squeeze : public OrtKernelImpl { private: struct PrivateTag {}; public: - static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel); + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; Squeeze(const OrtKernelInfo* info, void* state, PrivateTag); + // Static functions assigned to the OrtKernelImpl fields: + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + private: - OrtStatus* DoCompute(OrtKernelContext* kernel_ctx) override; + const OrtKernelInfo* info_; + OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp }; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h index 615ee3911108a..506392abb6149 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h @@ -18,6 +18,39 @@ inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) { return result; } +/// +/// Copy a tensor using a OrtDataTransferImpl instance. Used by kernel implementations to copy +/// tensors that my reside on different devices. +/// +/// +/// +/// +/// +inline OrtStatus* CopyTensor(OrtDataTransferImpl& data_transfer_impl, + Ort::ConstValue src_tensor, Ort::UnownedValue dst_tensor) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + const OrtMemoryDevice* src_device = Ort::GetEpApi().MemoryInfo_GetMemoryDevice(src_tensor.GetTensorMemoryInfo()); + const OrtMemoryDevice* dst_device = Ort::GetEpApi().MemoryInfo_GetMemoryDevice(dst_tensor.GetTensorMemoryInfo()); + + RETURN_IF(!data_transfer_impl.CanCopy(&data_transfer_impl, src_device, dst_device), Ort::GetApi(), + "OrtDataTransferImpl cannot copy src tensor to dst tensor."); + + auto src_type_shape = src_tensor.GetTensorTypeAndShapeInfo(); + auto dst_type_shape = dst_tensor.GetTensorTypeAndShapeInfo(); + bool same_elem_type = src_type_shape.GetElementType() == dst_type_shape.GetElementType(); + bool same_elem_count = src_type_shape.GetElementCount() == dst_type_shape.GetElementCount(); + RETURN_IF(!same_elem_type || !same_elem_count, Ort::GetApi(), "Cannot copy tensors of different types or size."); + + std::array src_tensors = {src_tensor}; + std::array dst_tensors = {dst_tensor}; + + RETURN_IF_ERROR(data_transfer_impl.CopyTensors(&data_transfer_impl, src_tensors.data(), dst_tensors.data(), + /*streams*/ nullptr, src_tensors.size())); + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + /// /// Contains information to create a kernel: kernel definition, creation function + state. /// diff --git a/onnxruntime/test/autoep/library/plugin_ep_utils.h b/onnxruntime/test/autoep/library/plugin_ep_utils.h index d14186425458f..f7b8dc4d2be0d 100644 --- a/onnxruntime/test/autoep/library/plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/plugin_ep_utils.h @@ -76,6 +76,22 @@ ss << __VA_ARGS__; \ throw std::runtime_error(ss.str()) +#define EXCEPTION_TO_RETURNED_STATUS_BEGIN try { +#define EXCEPTION_TO_RETURNED_STATUS_END \ + } \ + catch (const Ort::Exception& ex) { \ + Ort::Status status(ex); \ + return status.release(); \ + } \ + catch (const std::exception& ex) { \ + Ort::Status status(ex.what(), ORT_EP_FAIL); \ + return status.release(); \ + } \ + catch (...) { \ + Ort::Status status("Unknown exception", ORT_EP_FAIL); \ + return status.release(); \ + } + struct ApiPtrs { const OrtApi& ort_api; const OrtEpApi& ep_api; @@ -161,20 +177,27 @@ inline ONNXTensorElementDataType GetTensorElemDataType() { } template -inline OrtStatus* GetKernelInputDataAndShape(Ort::KernelContext kernel_context, size_t index, - /*out*/ gsl::span& data, - /*out*/ std::vector& shape) { - Ort::ConstValue input = kernel_context.GetInput(index); - auto type_shape = input.GetTensorTypeAndShapeInfo(); +inline OrtStatus* GetValueDataAndShape(Ort::ConstValue value, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) { + auto type_shape = value.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType elem_type = type_shape.GetElementType(); RETURN_IF(elem_type != GetTensorElemDataType(), Ort::GetApi(), "EP expected kernel input of tensor type"); - const T* float_data = input.GetTensorData(); + const T* elem_data = value.GetTensorData(); size_t num_elems = type_shape.GetElementCount(); - data = gsl::span(float_data, num_elems); + data = gsl::span(elem_data, num_elems); shape = type_shape.GetShape(); return nullptr; } + +template +inline OrtStatus* GetKernelInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) { + Ort::ConstValue input = kernel_context.GetInput(index); + return GetValueDataAndShape(input, data, shape); +} diff --git a/onnxruntime/test/common/cuda_op_test_utils.cc b/onnxruntime/test/common/cuda_op_test_utils.cc index bab4e9a60e2ed..fbd9b0a33c7c0 100644 --- a/onnxruntime/test/common/cuda_op_test_utils.cc +++ b/onnxruntime/test/common/cuda_op_test_utils.cc @@ -1,7 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUDA +#include + +#if defined(USE_CUDA) || defined(USE_NV) #include "cuda_runtime_api.h" #endif @@ -13,7 +15,7 @@ int GetCudaArchitecture() { // Usually, we test on a single GPU or multiple GPUs of same architecture, so it's fine to cache the result. static int cuda_arch = -1; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_NV) if (cuda_arch == -1) { int current_device_id = 0; cudaGetDevice(¤t_device_id); @@ -26,6 +28,15 @@ int GetCudaArchitecture() { if (cudaSuccess == cudaGetDeviceProperties(&prop, current_device_id)) { cuda_arch = prop.major * 100 + prop.minor * 10; } + + // Log GPU compute capability + if (cuda_arch == -1) { + std::cout << "WARNING: CUDA is not available or failed to initialize" << std::endl; + } else { + std::cout << "GPU Compute Capability: SM " + << cuda_arch / 100 << "." << (cuda_arch % 100) / 10 + << " (value: " << cuda_arch << ")" << std::endl; + } } #endif diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 04a4c95dd478b..b41231fc7f37a 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -197,10 +197,13 @@ void RunTest2Bits(const TestOptions2Bits& opts) { std::vector> execution_providers; if constexpr (std::is_same::value) { - execution_providers.emplace_back(DefaultCpuExecutionProvider()); #ifdef USE_WEBGPU - execution_providers.push_back(DefaultWebGpuExecutionProvider()); + if (!opts.has_zero_point) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } #endif + // CPU EP should be added last so that other EPs get tested first + execution_providers.emplace_back(DefaultCpuExecutionProvider()); test.ConfigEps(std::move(execution_providers)); test.RunWithConfig(); } diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 85538efbefd28..a87fe8fe30b7b 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -94,12 +94,14 @@ static void RunOneTest( sum_output_data); } - if (cpu_ep != nullptr) { - execution_providers.push_back(DefaultCpuExecutionProvider()); - } + // Add WebGPU EP first so it gets tested before CPU EP + // (ConfigEps runs the first available EP for the operator) if (webgpu_ep != nullptr) { execution_providers.push_back(DefaultWebGpuExecutionProvider()); } + if (cpu_ep != nullptr) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (CudaHasBF16Support() && use_bfloat16) { diff --git a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc index 997ff2869592c..87c9ca09804d2 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc @@ -434,10 +434,12 @@ TEST(GatherOpTest, Gather_axis1_scalar_indices) { TEST(ShrunkenGatherOpTest, ShrunkenGather_PositiveAxis) { std::vector> execution_providers; - execution_providers.emplace_back(DefaultCpuExecutionProvider()); + // Add CUDA EP first so it gets tested before CPU EP + // (ConfigEps runs the first available EP for the operator) #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); #endif + execution_providers.emplace_back(DefaultCpuExecutionProvider()); OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); test.AddAttribute("axis", 0LL); @@ -455,10 +457,12 @@ TEST(ShrunkenGatherOpTest, ShrunkenGather_PositiveAxis) { TEST(ShrunkenGatherOpTest, ShrunkenGather_NegativeAxis) { std::vector> execution_providers; - execution_providers.emplace_back(DefaultCpuExecutionProvider()); + // Add CUDA EP first so it gets tested before CPU EP + // (ConfigEps runs the first available EP for the operator) #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); #endif + execution_providers.emplace_back(DefaultCpuExecutionProvider()); OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); test.AddAttribute("axis", -1LL); @@ -476,10 +480,12 @@ TEST(ShrunkenGatherOpTest, ShrunkenGather_NegativeAxis) { TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidIndicesRank) { std::vector> execution_providers; - execution_providers.emplace_back(DefaultCpuExecutionProvider()); + // Add CUDA EP first so it gets tested before CPU EP + // (ConfigEps runs the first available EP for the operator) #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); #endif + execution_providers.emplace_back(DefaultCpuExecutionProvider()); OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); test.AddAttribute("axis", 0LL); @@ -497,10 +503,12 @@ TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidIndicesRank) { TEST(ShrunkenGatherOpTest, ShrunkenGather_InvalidInputRank) { std::vector> execution_providers; - execution_providers.emplace_back(DefaultCpuExecutionProvider()); + // Add CUDA EP first so it gets tested before CPU EP + // (ConfigEps runs the first available EP for the operator) #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); #endif + execution_providers.emplace_back(DefaultCpuExecutionProvider()); OpTester test("ShrunkenGather", 1, onnxruntime::kMSDomain); test.AddAttribute("axis", 0LL); diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 5ae610a842679..1a987ab4f411a 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -9,6 +9,7 @@ #include "test/util/include/scoped_env_vars.h" #include "test/common/trt_op_test_utils.h" #include "test/common/random_generator.h" +#include "test/common/cuda_op_test_utils.h" #include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" #include @@ -22,6 +23,21 @@ namespace onnxruntime { namespace test { +// Helper function to check if GPU is Blackwell (SM 12.0+) or above +// Returns true if requirement is met +// Returns false if CUDA is unavailable or GPU is below SM 12.0 +static bool IsBlackwellOrAbove() { + constexpr int kBlackwellMinCapability = 1200; // SM 12.0 = 12 * 100 + 0 * 10 + int cuda_arch = GetCudaArchitecture(); + + // Check if CUDA is available + if (cuda_arch == -1) { + return false; + } + + return cuda_arch >= kBlackwellMinCapability; +} + TEST(NvExecutionProviderTest, ContextEmbedAndReload) { PathString model_name = ORT_TSTR("nv_execution_provider_test.onnx"); PathString model_name_ctx = ORT_TSTR("nv_execution_provider_test_ctx.onnx"); @@ -442,6 +458,10 @@ TEST(NvExecutionProviderTest, DataTransfer) { } TEST(NvExecutionProviderTest, FP8CustomOpModel) { + if (!IsBlackwellOrAbove()) { + GTEST_SKIP() << "Test requires SM 12.0+ GPU (Blackwell+)"; + } + PathString model_name = ORT_TSTR("nv_execution_provider_fp8_quantize_dequantize_test.onnx"); clearFileIfExists(model_name); std::string graph_name = "nv_execution_provider_fp8_quantize_dequantize_graph"; @@ -509,6 +529,10 @@ TEST(NvExecutionProviderTest, FP8CustomOpModel) { } TEST(NvExecutionProviderTest, FP4CustomOpModel) { + if (!IsBlackwellOrAbove()) { + GTEST_SKIP() << "Test requires SM 12.0+ GPU (Blackwell+)"; + } + PathString model_name = ORT_TSTR("nv_execution_provider_fp4_dynamic_quantize_test.onnx"); clearFileIfExists(model_name); std::string graph_name = "nv_execution_provider_fp4_dynamic_quantize_graph"; diff --git a/setup.py b/setup.py index c095452fef768..df62fdb78622b 100644 --- a/setup.py +++ b/setup.py @@ -297,8 +297,12 @@ def run(self): else: pass + # qnn links libc++ rather than libstdc++ for its x86_64 dependencies which we currently do not + # support for many_linux. This is not the case for other platforms. + qnn_run_audit = environ.get("AUDITWHEEL_ARCH", "x86_64") != "x86_64" + _bdist_wheel.run(self) - if is_manylinux and not disable_auditwheel_repair and not is_openvino: + if is_manylinux and not disable_auditwheel_repair and not is_openvino and (not is_qnn or qnn_run_audit): assert self.dist_dir is not None file = glob(path.join(self.dist_dir, "*linux*.whl"))[0] logger.info("repairing %s for manylinux1", file)