Skip to content

Commit 12b8394

Browse files
Merge main and fix conflicts
2 parents 3323330 + c60472d commit 12b8394

File tree

25 files changed

+1003
-125
lines changed

25 files changed

+1003
-125
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,11 +2104,12 @@ if (onnxruntime_BUILD_SHARED_LIB AND
21042104
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc"
21052105
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h"
21062106
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc"
2107+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h"
2108+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.h"
2109+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.cc"
21072110
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h"
21082111
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc"
21092112
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h"
2110-
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h"
2111-
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc"
21122113
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h"
21132114
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc"
21142115
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h"

include/onnxruntime/core/framework/op_kernel.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55

66
#include "boost/mp11.hpp"
7+
#include <gsl/gsl>
78

89
// It is safe to include the below header even if SHARED_PROVIDER macro is enabled
910
// as it doesn't include any pb headers.
@@ -26,7 +27,6 @@
2627
#include "core/graph/constants.h"
2728
#include "core/graph/graph_viewer.h"
2829
#include "core/graph/onnx_protobuf.h"
29-
#include <gsl/gsl>
3030
namespace onnxruntime {
3131
class OpKernelContext;
3232
}
@@ -105,6 +105,7 @@ class OpKernel {
105105
return Status::OK();
106106
}
107107

108+
// Note: New implementations should override OpKernel::UseSharedPrePackedBuffers_V2 instead.
108109
// Override this function to use provided pre-packed weight.
109110
// Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
110111
// int input_idx,
@@ -130,6 +131,27 @@ class OpKernel {
130131
return Status::OK();
131132
}
132133

134+
/// <summary>
135+
/// Version 2 of OpKernel::UseSharedPrePackedBuffers() that additionally accepts the buffer sizes as a parameter.
136+
/// The default implementation of this function just calls directly to OpKernel::UseSharedPrePackedBuffers()
137+
/// to avoid the need to update all existing kernel-based provider-bridge EPs.
138+
///
139+
/// TODO: Consolidate UseSharedPrePackedBuffers and UseSharedPrePackedBuffers_V2 into a single function,
140+
/// which will require updating kernel-based provider-bridge EPs (cpu, cuda, webgpu).
141+
///
142+
/// </summary>
143+
/// <param name="prepacked_buffers"></param>
144+
/// <param name="prepacked_buffer_sizes"></param>
145+
/// <param name="input_idx"></param>
146+
/// <param name="used_shared_buffers"></param>
147+
/// <returns></returns>
148+
virtual Status UseSharedPrePackedBuffers_V2(std::vector<BufferUniquePtr>& prepacked_buffers,
149+
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
150+
int input_idx,
151+
/*out*/ bool& used_shared_buffers) {
152+
return UseSharedPrePackedBuffers(prepacked_buffers, input_idx, used_shared_buffers);
153+
}
154+
133155
const OrtDevice GetDevice(OrtMemType mem_type) const;
134156
const OpKernelInfo& Info() const {
135157
return *op_kernel_info_;

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3386,5 +3386,34 @@ struct KernelRegistry : detail::Base<OrtKernelRegistry> {
33863386
Status AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func,
33873387
void* kernel_create_func_state);
33883388
};
3389+
3390+
namespace detail {
3391+
template <typename T>
3392+
struct SharedPrePackedWeightCacheImpl : Ort::detail::Base<T> {
3393+
using B = Ort::detail::Base<T>;
3394+
using B::B;
3395+
3396+
//< Wraps SharedPrePackedWeightCache_StoreWeightData
3397+
Status StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes, size_t num_buffers);
3398+
};
3399+
} // namespace detail
3400+
3401+
/** \brief Convenience C++ wrapper class around a ::OrtSharedPrePackedWeightCache instance owned by ORT.
3402+
*
3403+
* An `OrtSharedPrePackedWeightCache*` instance is passed as an argument to OrtKernelImpl::PrePackWeight.
3404+
* Example use:
3405+
* OrtStatus* MyKernel::PrePackWeightImpl(OrtKernelImpl*, ..., OrtSharedPrePackedWeightCache* c_cache, ...) {
3406+
* ...
3407+
* if (c_cache != nullptr) {
3408+
* Ort::UnownedSharedPrePackedWeightCache cpp_cache(c_cache);
3409+
* Ort::Status status = cpp_cache.StoreWeightData(...);
3410+
* }
3411+
* ...
3412+
* }
3413+
*
3414+
* \remarks OrtSharedPrePackedWeightCache is always unowned, but mutable, for EpApi users.
3415+
*/
3416+
using UnownedSharedPrePackedWeightCache =
3417+
detail::SharedPrePackedWeightCacheImpl<Ort::detail::Unowned<OrtSharedPrePackedWeightCache>>;
33893418
} // namespace Ort
33903419
#include "onnxruntime_cxx_inline.h"

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3742,4 +3742,13 @@ inline Status KernelRegistry::AddKernel(const OrtKernelDef* kernel_def, OrtKerne
37423742
void* kernel_create_func_state) {
37433743
return Status{GetEpApi().KernelRegistry_AddKernel(p_, kernel_def, kernel_create_func, kernel_create_func_state)};
37443744
}
3745+
3746+
namespace detail {
3747+
template <typename T>
3748+
inline Status SharedPrePackedWeightCacheImpl<T>::StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes,
3749+
size_t num_buffers) {
3750+
return Status{GetEpApi().SharedPrePackedWeightCache_StoreWeightData(this->p_, buffer_data_ptrs, buffer_sizes,
3751+
num_buffers)};
3752+
}
3753+
} // namespace detail
37453754
} // namespace Ort

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ ORT_RUNTIME_CLASS(KernelRegistry);
2929
ORT_RUNTIME_CLASS(KernelDefBuilder);
3030
ORT_RUNTIME_CLASS(KernelDef);
3131
ORT_RUNTIME_CLASS(DataType); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType
32+
ORT_RUNTIME_CLASS(SharedPrePackedWeightCache);
3233

3334
/** \brief Struct that an EP implements for IDataTransfer to copy between devices it uses and CPU.
3435
*
@@ -308,6 +309,98 @@ struct OrtKernelImpl {
308309
* \since Version 1.24.
309310
*/
310311
ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr);
312+
313+
/** \brief Optional function to pre-pack a constant tensor (i.e., a weight) to the kernel's preferred data layout.
314+
*
315+
* For example, a Conv kernel can define this function to pack input W to the channel-last data layout
316+
* before inference.
317+
*
318+
* Pre-packing can operate in three different modes: no pre-packing mode, sharing mode, and non-sharing mode.
319+
* 1) No pre-packing mode: The kernel can forgo any weight pre-packing for the given `input_index` by setting
320+
* `is_packed` to false and returning a successful OrtStatus. In this mode, the kernel's
321+
* OrtKernelImpl::SetSharedPrePackedWeight() function is not called for that specific
322+
* `input_index`.
323+
* 2) Sharing mode: Sharing is allowed if the `prepacked_weight_cache` argument is not NULL and the EP stores
324+
* weight data in CPU-accessible memory. In this case, the kernel can optionally choose
325+
* to share the packed weight with other kernels that use the same weight
326+
* (compared by content hash). To do so, the kernel must allocate the packed weight with the
327+
* provided `allocator`, then it stores the packed weight data into `prepacked_weight_cache`
328+
* via SharedPrePackedWeightCache_StoreWeightData(), sets `is_packed` to true, and returns a
329+
* successful OrtStatus. ORT will subsequently call OrtKernelImpl::SetSharedPrePackedWeight()
330+
* to provide this kernel with the actual shared weight data, whose memory location could
331+
* differ (i.e., if shared data was allocated by a previously processed kernel).
332+
* 3) Non-sharing mode: In non-sharing mode, the `prepacked_weight_cache` argument is ignored. In this mode,
333+
* the implementation allocates the packed data with the provided `allocator`, sets
334+
* `is_packed` to true, and returns a successful OrtStatus. The kernel is ultimately
335+
* responsible for releasing the packed data for the weight with `allocator`.
336+
* ORT may release the original (unpacked) weight, which must not be accessed in
337+
* OrtKernelImpl::Compute(). Note that in this mode, the kernel's
338+
* OrtKernelImpl::SetSharedPrePackedWeight() function is not called by ORT for that specific
339+
* `input_index`.
340+
*
341+
* \note This function is based on the internal OpKernel::PrePack() virtual function used within ORT.
342+
*
343+
* \param[in] this_ptr The OrtKernelImpl instance.
344+
* \param[in] tensor The OrtValue instance representing the constant tensor (weight). Do not cache in the kernel.
345+
* \param[in] input_index The input index of the tensor in this kernel.
346+
* \param[in] allocator Allocator for allocating the pre-packed data. Its use is required in sharing mode and
347+
* recommended, but not required, in the non-sharing mode. This will be an allocator set by
348+
* the application for the session/environment (e.g., via CreateAndRegisterAllocator[V2]
349+
* or RegisterAllocator), or an allocator on the OrtEpDevice (read-only or default) otherwise.
350+
* The allocator remains valid throughout the lifetime of the OrtKernelImpl instance.
351+
* \param[in] prepacked_weights_cache May be NULL. If not NULL, the kernel may choose to share a packed weight by
352+
* first storing it in the OrtSharedPrePackedWeightCache instance and then
353+
* receiving the actual shared weight data in the call to
354+
* OrtKernelImpl::SetSharedPrePackedWeight(). See the above description for
355+
* "sharing mode".
356+
* \param[out] is_packed Output parameter that the implementation sets to true if the kernel packed the tensor data.
357+
*
358+
* \snippet{doc} snippets.dox OrtStatus Return Value
359+
*
360+
* \note Implementation of this function is optional. If not implemented (set to NULL), ORT assumes the kernel
361+
* does not pre-pack weight data (i.e., `is_packed` defaults to false).
362+
*
363+
* \since Version 1.24.
364+
*/
365+
ORT_API2_STATUS(PrePackWeight, _In_ OrtKernelImpl* this_ptr, _In_ const OrtValue* tensor,
366+
_In_ int input_index, _Inout_ OrtAllocator* allocator,
367+
_In_opt_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, _Out_ bool* is_packed);
368+
369+
/** \brief Optional function that receives data for a shared pre-packed weight from ORT.
370+
*
371+
* This function is called after a prior call to OrtKernelImpl::PrePackWeight for a specific `input_index` set
372+
* `is_packed` to true and stored weight data (to share) into the provided OrtSharedPrePackedWeightCache instance.
373+
* Refer to the description of the "sharing-mode" in the documentation for OrtKernelImpl::PrePackWeight().
374+
*
375+
* \note ORT will not call this function for an `input_index` that a previous call to
376+
* OrtKernelImpl::PrePackWeight() did not elect to pre-pack and share.
377+
*
378+
* \note This function is based on the internal OpKernel::UseSharedPrePackedBuffers() virtual function used
379+
* within ORT.
380+
*
381+
* \param[in] this_ptr The OrtKernelImpl instance.
382+
* \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
383+
* single shared weight. The buffers are provided in the same order and with the same
384+
* contents (in a potentially different memory location) as the buffers
385+
* passed into SharedPrePackedWeightCache_StoreWeightData() within the
386+
* OrtKernelImpl::PrePackWeight() call for the same `input_index`.
387+
* \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
388+
* \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
389+
* Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
390+
* \param[in] input_index The input index of the tensor in this kernel. This index identifies the identity of
391+
* the weight.
392+
*
393+
* \snippet{doc} snippets.dox OrtStatus Return Value
394+
*
395+
* \note Implementation of this function is generally optional. It is only required if OrtKernelImpl::PrePack()
396+
* elects to share pre-packed weights.
397+
*
398+
* \since Version 1.24.
399+
*/
400+
ORT_API2_STATUS(SetSharedPrePackedWeight, _In_ OrtKernelImpl* this_ptr,
401+
_In_reads_(num_buffers) const void* const* buffer_data_ptrs,
402+
_In_reads_(num_buffers) const size_t* buffer_data_sizes,
403+
_In_ size_t num_buffers, _In_ int input_index);
311404
};
312405

313406
/** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel.
@@ -847,6 +940,35 @@ struct OrtEpApi {
847940
ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info,
848941
_In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def);
849942

943+
/** \brief Sets one or more data buffers that collectively hold the pre-packed data for a single shared weight.
944+
*
945+
* \note Used within the implementation of OrtKernelImpl::PrePackWeight() when the kernel wants to share pre-packed
946+
* weight data with other kernels. The buffer data MUST be allocated with the OrtAllocator provided to
947+
* OrtKernelImpl::PrePack.
948+
*
949+
* \note Ownership of weight data transfers to the OrtSharedPrePackedWeightCache instance on success.
950+
* If this function returns an error status, the caller retains ownership of the weight data.
951+
*
952+
* \note Subsequent calls with the same OrtSharedPrePackedWeightCache instance release and replace the old data.
953+
*
954+
* \param[in] this_ptr The OrtKernelImpl instance.
955+
* \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
956+
* single shared weight. Note that sometimes a single weight may have multiple pre-packed
957+
* buffers and it is up to the kernel implementation to determine how to split the data
958+
* into multiple buffers (if desired).
959+
* \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
960+
* \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
961+
* Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
962+
*
963+
* \snippet{doc} snippets.dox OrtStatus Return Value
964+
*
965+
* \since Version 1.24.
966+
*/
967+
ORT_API2_STATUS(SharedPrePackedWeightCache_StoreWeightData,
968+
_In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache,
969+
_In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes,
970+
_In_ size_t num_buffers);
971+
850972
/** \brief Get the OrtEp instance to which the node is assigned from the OrtKernelInfo.
851973
*
852974
* \note Used within OrtKernelImpl implementations to obtain a reference to the OrtEp.

onnxruntime/core/framework/session_state.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,16 +421,23 @@ void SessionState::CleanInitializedTensorsFromGraph() {
421421
static Status KernelUseSharedPrePackedBuffers(OpKernel& kernel, int input_idx,
422422
const PrePackedWeights& prepacked_weights,
423423
const std::string& node_name) {
424+
const size_t num_buffers = prepacked_weights.buffers_.size();
425+
assert(prepacked_weights.buffer_sizes_.size() == num_buffers);
426+
424427
std::vector<BufferUniquePtr> shared_prepacked_buffers;
425-
shared_prepacked_buffers.reserve(4); // Unlikely to see more than 4 prepacked buffers per initializer
428+
std::vector<size_t> shared_prepacked_buffer_sizes;
429+
shared_prepacked_buffers.reserve(num_buffers);
430+
shared_prepacked_buffer_sizes.reserve(num_buffers);
426431

427-
for (const auto& prepacked_buffer : prepacked_weights.buffers_) {
432+
for (size_t i = 0; i < num_buffers; i++) {
428433
// BufferDeleter is nullptr because the kernel should not delete the shared buffer - it can only use it
429-
shared_prepacked_buffers.emplace_back(prepacked_buffer.get(), BufferDeleter(nullptr));
434+
shared_prepacked_buffers.emplace_back(prepacked_weights.buffers_[i].get(), BufferDeleter(nullptr));
435+
shared_prepacked_buffer_sizes.push_back(prepacked_weights.buffer_sizes_[i]);
430436
}
431437

432438
bool used_shared_buffers = false;
433-
ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers(shared_prepacked_buffers, input_idx, used_shared_buffers));
439+
ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers_V2(shared_prepacked_buffers, shared_prepacked_buffer_sizes,
440+
input_idx, used_shared_buffers));
434441

435442
// BUG CHECK: Ensure that the kernel used the provided shared buffers
436443
// Mostly a debug check to ensure that the kernel has an overridden implementation of the

onnxruntime/core/session/plugin_ep/ep_api.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,47 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo*
582582
API_IMPL_END
583583
}
584584

585+
ORT_API_STATUS_IMPL(SharedPrePackedWeightCache_StoreWeightData,
586+
_In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache,
587+
_In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes,
588+
_In_ size_t num_buffers) {
589+
API_IMPL_BEGIN
590+
if (prepacked_weight_cache == nullptr) {
591+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
592+
"Must specify a valid OrtPrePackedWeightsCache instance");
593+
}
594+
595+
if (buffer_data_ptrs == nullptr) {
596+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of buffer data pointers");
597+
}
598+
599+
if (buffer_data_sizes == nullptr) {
600+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of buffer data sizes");
601+
}
602+
603+
if (num_buffers == 0) {
604+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one weight data buffer");
605+
}
606+
607+
OrtStatus* status = nullptr;
608+
609+
ORT_TRY {
610+
prepacked_weight_cache->SetBuffers(buffer_data_ptrs, buffer_data_sizes, num_buffers);
611+
}
612+
ORT_CATCH(const std::exception& ex) {
613+
ORT_HANDLE_EXCEPTION([&]() {
614+
// This API function promises that ORT will take ownership of the data only if it returns successfully.
615+
// If any exception occurred while filling out `prepacked_weight_cache`, we try to release ownership so that
616+
// the caller retains ownership of all of the original data and can delete it.
617+
prepacked_weight_cache->ReleaseAllData();
618+
status = OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what());
619+
});
620+
}
621+
622+
return status;
623+
API_IMPL_END
624+
}
625+
585626
ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep) {
586627
API_IMPL_BEGIN
587628
if (ep == nullptr) {
@@ -663,6 +704,7 @@ static constexpr OrtEpApi ort_ep_api = {
663704
&OrtExecutionProviderApi::KernelDef_GetOutputMemType,
664705
&OrtExecutionProviderApi::GetTensorDataType,
665706
&OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel,
707+
&OrtExecutionProviderApi::SharedPrePackedWeightCache_StoreWeightData,
666708
&OrtExecutionProviderApi::KernelInfo_GetEp,
667709
};
668710

0 commit comments

Comments
 (0)