Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2104,11 +2104,12 @@ if (onnxruntime_BUILD_SHARED_LIB AND
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h"
Expand Down
24 changes: 23 additions & 1 deletion include/onnxruntime/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "boost/mp11.hpp"
#include <gsl/gsl>

// It is safe to include the below header even if SHARED_PROVIDER macro is enabled
// as it doesn't include any pb headers.
Expand All @@ -26,7 +27,6 @@
#include "core/graph/constants.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/onnx_protobuf.h"
#include <gsl/gsl>
namespace onnxruntime {
class OpKernelContext;
}
Expand Down Expand Up @@ -105,6 +105,7 @@ class OpKernel {
return Status::OK();
}

// Note: New implementations should override OpKernel::UseSharedPrePackedBuffers_V2 instead.
// Override this function to use provided pre-packed weight.
// Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
// int input_idx,
Expand All @@ -130,6 +131,27 @@ class OpKernel {
return Status::OK();
}

/// <summary>
/// Version 2 of OpKernel::UseSharedPrePackedBuffers() that additionally accepts the buffer sizes as a parameter.
/// The default implementation of this function just calls directly to OpKernel::UseSharedPrePackedBuffers()
/// to avoid the need to update all existing kernel-based provider-bridge EPs.
///
/// TODO: Consolidate UseSharedPrePackedBuffers and UseSharedPrePackedBuffers_V2 into a single function,
/// which will require updating kernel-based provider-bridge EPs (cpu, cuda, webgpu).
///
/// </summary>
/// <param name="prepacked_buffers"></param>
/// <param name="prepacked_buffer_sizes"></param>
/// <param name="input_idx"></param>
/// <param name="used_shared_buffers"></param>
/// <returns></returns>
virtual Status UseSharedPrePackedBuffers_V2(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) {
return UseSharedPrePackedBuffers(prepacked_buffers, input_idx, used_shared_buffers);
}

const OrtDevice GetDevice(OrtMemType mem_type) const;
const OpKernelInfo& Info() const {
return *op_kernel_info_;
Expand Down
29 changes: 29 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3379,5 +3379,34 @@ struct KernelRegistry : detail::Base<OrtKernelRegistry> {
Status AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func,
void* kernel_create_func_state);
};

namespace detail {
template <typename T>
struct SharedPrePackedWeightCacheImpl : Ort::detail::Base<T> {
using B = Ort::detail::Base<T>;
using B::B;

//< Wraps SharedPrePackedWeightCache_StoreWeightData
Status StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes, size_t num_buffers);
};
} // namespace detail

/** \brief Convenience C++ wrapper class around a ::OrtSharedPrePackedWeightCache instance owned by ORT.
*
* An `OrtSharedPrePackedWeightCache*` instance is passed as an argument to OrtKernelImpl::PrePackWeight.
* Example use:
* OrtStatus* MyKernel::PrePackWeightImpl(OrtKernelImpl*, ..., OrtSharedPrePackedWeightCache* c_cache, ...) {
* ...
* if (c_cache != nullptr) {
* Ort::UnownedSharedPrePackedWeightCache cpp_cache(c_cache);
* Ort::Status status = cpp_cache.StoreWeightData(...);
* }
* ...
* }
*
* \remarks OrtSharedPrePackedWeightCache is always unowned, but mutable, for EpApi users.
*/
using UnownedSharedPrePackedWeightCache =
detail::SharedPrePackedWeightCacheImpl<Ort::detail::Unowned<OrtSharedPrePackedWeightCache>>;
} // namespace Ort
#include "onnxruntime_cxx_inline.h"
9 changes: 9 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -3713,4 +3713,13 @@ inline Status KernelRegistry::AddKernel(const OrtKernelDef* kernel_def, OrtKerne
void* kernel_create_func_state) {
return Status{GetEpApi().KernelRegistry_AddKernel(p_, kernel_def, kernel_create_func, kernel_create_func_state)};
}

namespace detail {
template <typename T>
inline Status SharedPrePackedWeightCacheImpl<T>::StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes,
size_t num_buffers) {
return Status{GetEpApi().SharedPrePackedWeightCache_StoreWeightData(this->p_, buffer_data_ptrs, buffer_sizes,
num_buffers)};
}
} // namespace detail
} // namespace Ort
125 changes: 125 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ORT_RUNTIME_CLASS(KernelRegistry);
ORT_RUNTIME_CLASS(KernelDefBuilder);
ORT_RUNTIME_CLASS(KernelDef);
ORT_RUNTIME_CLASS(DataType); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType
ORT_RUNTIME_CLASS(SharedPrePackedWeightCache);

/** \brief Struct that an EP implements for IDataTransfer to copy between devices it uses and CPU.
*
Expand Down Expand Up @@ -308,6 +309,101 @@ struct OrtKernelImpl {
* \since Version 1.24.
*/
ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr);

/** \brief Optional function to pre-pack a constant tensor (i.e., a weight) to the kernel's preferred data layout.
*
* For example, a Conv kernel can define this function to pack input W to the channel-last data layout
* before inference.
*
* Pre-packing can operate in three different modes: no pre-packing mode, sharing mode, and non-sharing mode.
* 1) No pre-packing mode: The kernel can forgo any weight pre-packing for the given `input_index` by setting
* `is_packed` to false and returning a successful OrtStatus. In this mode, the kernel's
* OrtKernelImpl::SetSharedPrePackedWeight() function is not called for that specific
* `input_index`.
* 2) Sharing mode: Sharing is allowed if the `prepacked_weight_cache` argument is not NULL and the EP stores
* weight data in CPU-accessible memory. In this case, the kernel can optionally choose
* to share the packed weight with other kernels that use the same weight
* (compared by content hash). To do so, the kernel must allocate the packed weight with the
* provided `allocator`, then it stores the packed weight data into `prepacked_weight_cache`
* via SharedPrePackedWeightCache_StoreWeightData(), sets `is_packed` to true, and returns a
* successful OrtStatus. ORT will subsequently call OrtKernelImpl::SetSharedPrePackedWeight()
* to provide this kernel with the actual shared weight data, whose memory location could
* differ (i.e., if shared data was allocated by a previously processed kernel).
* 3) Non-sharing mode: In non-sharing mode, the `prepacked_weight_cache` argument is ignored. In this mode,
* the implementation allocates the packed data with the provided `allocator`, sets
* `is_packed` to true, and returns a successful OrtStatus. The kernel is ultimately
* responsible for releasing the packed data for the weight with `allocator`.
* ORT may release the original (unpacked) weight, which must not be accessed in
* OrtKernelImpl::Compute(). Note that in this mode, the kernel's
* OrtKernelImpl::SetSharedPrePackedWeight() function is not called by ORT for that specific
* `input_index`.
*
* \note This function is based on the internal OpKernel::PrePack() virtual function used within ORT.
*
* \param[in] this_ptr The OrtKernelImpl instance.
* \param[in] tensor The OrtValue instance representing the constant tensor (weight). Do not cache in the kernel.
* \param[in] input_index The input index of the tensor in this kernel.
* \param[in] allocator Allocator for allocating the pre-packed data. Its use is required in sharing mode and
* recommended, but not required, in the non-sharing mode. This will be an allocator set by
* the application for the session/environment (e.g., via CreateAndRegisterAllocator[V2]
* or RegisterAllocator), or an allocator on the OrtEpDevice (read-only or default) otherwise.
* The allocator remains valid throughout the lifetime of the OrtKernelImpl instance.
* \param[in] prepacked_weights_cache May be NULL. If not NULL, the kernel may choose to share a packed weight by
* first storing it in the OrtSharedPrePackedWeightCache instance and then
* receiving the actual shared weight data in the call to
* OrtKernelImpl::SetSharedPrePackedWeight(). See the above description for
* "sharing mode".
* \param[out] is_packed Output parameter that the implementation sets to true if the kernel packed the tensor data.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \note Implementation of this function is optional. If not implemented (set to NULL), ORT assumes the kernel
* does not pre-pack weight data (i.e., `is_packed` defaults to false).
*
* \since Version 1.24.
*/
ORT_API2_STATUS(PrePackWeight, _In_ OrtKernelImpl* this_ptr, _In_ const OrtValue* tensor,
_In_ int input_index, _Inout_ OrtAllocator* allocator,
_In_opt_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, _Out_ bool* is_packed);

/** \brief Optional function that receives data for a shared pre-packed weight from ORT.
*
* ORT calls this function after calling OrtKernelImpl::PrePackWeight for a specific `input_index` if:
* - OrtKernelImpl::PrePackWeight set the output parameter `is_packed` to true.
* - OrtKernelImpl::PrePackWeight stored weight data to share into the provided OrtSharedPrePackedWeightCache
* parameter (`prepacked_weight_cache`) via the API SharedPrePackedWeightCache_StoreWeightData.
*
* Refer to the description of the "sharing-mode" in the documentation for OrtKernelImpl::PrePackWeight().
*
* \note ORT will not call this function for an `input_index` that a previous call to
* OrtKernelImpl::PrePackWeight() did not elect to pre-pack and share.
*
* \note This function is based on the internal OpKernel::UseSharedPrePackedBuffers() virtual function used
* within ORT.
*
* \param[in] this_ptr The OrtKernelImpl instance.
* \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
* single shared weight. The buffers are provided in the same order and with the same
* contents (in a potentially different memory location) as the buffers
* passed into SharedPrePackedWeightCache_StoreWeightData() within the
* OrtKernelImpl::PrePackWeight() call for the same `input_index`.
* \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
* \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
* Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
* \param[in] input_index The input index of the tensor in this kernel. This index identifies the identity of
* the weight.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \note Implementation of this function is generally optional. It is only required if OrtKernelImpl::PrePack()
* elects to share pre-packed weights.
*
* \since Version 1.24.
*/
ORT_API2_STATUS(SetSharedPrePackedWeight, _In_ OrtKernelImpl* this_ptr,
_In_reads_(num_buffers) const void* const* buffer_data_ptrs,
_In_reads_(num_buffers) const size_t* buffer_data_sizes,
_In_ size_t num_buffers, _In_ int input_index);
};

/** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel.
Expand Down Expand Up @@ -846,6 +942,35 @@ struct OrtEpApi {
*/
ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def);

/** \brief Sets one or more data buffers that collectively hold the pre-packed data for a single shared weight.
*
* \note Used within the implementation of OrtKernelImpl::PrePackWeight() when the kernel wants to share pre-packed
* weight data with other kernels. The buffer data MUST be allocated with the OrtAllocator provided to
* OrtKernelImpl::PrePack.
*
* \note Ownership of weight data transfers to the OrtSharedPrePackedWeightCache instance on success.
* If this function returns an error status, the caller retains ownership of the weight data.
*
* \note Subsequent calls with the same OrtSharedPrePackedWeightCache instance release and replace the old data.
*
* \param[in] this_ptr The OrtKernelImpl instance.
* \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
* single shared weight. Note that sometimes a single weight may have multiple pre-packed
* buffers and it is up to the kernel implementation to determine how to split the data
* into multiple buffers (if desired).
* \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
* \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
* Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(SharedPrePackedWeightCache_StoreWeightData,
_In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache,
_In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes,
_In_ size_t num_buffers);
};

/**
Expand Down
Loading
Loading