Skip to content

Commit 2d1ed5b

Browse files
[EP ABI] Add weight pre-packing support to kernel-based plugin EPs (#26754)
### Description - Adds C APIs to support pre-packing of const weights for `OrtKernelImpl` implementations. - APIs optionally support sharing of pre-packed weight data (for cpu-accessible memory). - Updates example kernel (Mul) to use new pre-packing API. Tested by existing unit test: https://github.com/microsoft/onnxruntime/blob/549d7415e26e2b3f86c42f86e135bb746caa37b4/onnxruntime/test/autoep/test_execution.cc#L242-L256 ### Motivation and Context The [previous PR](#26206) added the base APIs that support kernel-based plugin EPs. This PR adds an additional feature that was identified as necessary for the port of WebGPU EP. --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
1 parent 573d395 commit 2d1ed5b

File tree

25 files changed

+1009
-125
lines changed

25 files changed

+1009
-125
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,11 +2104,12 @@ if (onnxruntime_BUILD_SHARED_LIB AND
21042104
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc"
21052105
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h"
21062106
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc"
2107+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h"
2108+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.h"
2109+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.cc"
21072110
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h"
21082111
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc"
21092112
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h"
2110-
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h"
2111-
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc"
21122113
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h"
21132114
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc"
21142115
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h"

include/onnxruntime/core/framework/op_kernel.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55

66
#include "boost/mp11.hpp"
7+
#include <gsl/gsl>
78

89
// It is safe to include the below header even if SHARED_PROVIDER macro is enabled
910
// as it doesn't include any pb headers.
@@ -26,7 +27,6 @@
2627
#include "core/graph/constants.h"
2728
#include "core/graph/graph_viewer.h"
2829
#include "core/graph/onnx_protobuf.h"
29-
#include <gsl/gsl>
3030
namespace onnxruntime {
3131
class OpKernelContext;
3232
}
@@ -105,6 +105,7 @@ class OpKernel {
105105
return Status::OK();
106106
}
107107

108+
// Note: New implementations should override OpKernel::UseSharedPrePackedBuffers_V2 instead.
108109
// Override this function to use provided pre-packed weight.
109110
// Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
110111
// int input_idx,
@@ -130,6 +131,27 @@ class OpKernel {
130131
return Status::OK();
131132
}
132133

134+
/// <summary>
135+
/// Version 2 of OpKernel::UseSharedPrePackedBuffers() that additionally accepts the buffer sizes as a parameter.
136+
/// The default implementation of this function just calls directly to OpKernel::UseSharedPrePackedBuffers()
137+
/// to avoid the need to update all existing kernel-based provider-bridge EPs.
138+
///
139+
/// TODO: Consolidate UseSharedPrePackedBuffers and UseSharedPrePackedBuffers_V2 into a single function,
140+
/// which will require updating kernel-based provider-bridge EPs (cpu, cuda, webgpu).
141+
///
142+
/// </summary>
143+
/// <param name="prepacked_buffers"></param>
144+
/// <param name="prepacked_buffer_sizes"></param>
145+
/// <param name="input_idx"></param>
146+
/// <param name="used_shared_buffers"></param>
147+
/// <returns></returns>
148+
virtual Status UseSharedPrePackedBuffers_V2(std::vector<BufferUniquePtr>& prepacked_buffers,
149+
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
150+
int input_idx,
151+
/*out*/ bool& used_shared_buffers) {
152+
return UseSharedPrePackedBuffers(prepacked_buffers, input_idx, used_shared_buffers);
153+
}
154+
133155
const OrtDevice GetDevice(OrtMemType mem_type) const;
134156
const OpKernelInfo& Info() const {
135157
return *op_kernel_info_;

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3379,5 +3379,34 @@ struct KernelRegistry : detail::Base<OrtKernelRegistry> {
33793379
Status AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func,
33803380
void* kernel_create_func_state);
33813381
};
3382+
3383+
namespace detail {
3384+
template <typename T>
3385+
struct SharedPrePackedWeightCacheImpl : Ort::detail::Base<T> {
3386+
using B = Ort::detail::Base<T>;
3387+
using B::B;
3388+
3389+
//< Wraps SharedPrePackedWeightCache_StoreWeightData
3390+
Status StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes, size_t num_buffers);
3391+
};
3392+
} // namespace detail
3393+
3394+
/** \brief Convenience C++ wrapper class around a ::OrtSharedPrePackedWeightCache instance owned by ORT.
3395+
*
3396+
* An `OrtSharedPrePackedWeightCache*` instance is passed as an argument to OrtKernelImpl::PrePackWeight.
3397+
* Example use:
3398+
* OrtStatus* MyKernel::PrePackWeightImpl(OrtKernelImpl*, ..., OrtSharedPrePackedWeightCache* c_cache, ...) {
3399+
* ...
3400+
* if (c_cache != nullptr) {
3401+
* Ort::UnownedSharedPrePackedWeightCache cpp_cache(c_cache);
3402+
* Ort::Status status = cpp_cache.StoreWeightData(...);
3403+
* }
3404+
* ...
3405+
* }
3406+
*
3407+
* \remarks OrtSharedPrePackedWeightCache is always unowned, but mutable, for EpApi users.
3408+
*/
3409+
using UnownedSharedPrePackedWeightCache =
3410+
detail::SharedPrePackedWeightCacheImpl<Ort::detail::Unowned<OrtSharedPrePackedWeightCache>>;
33823411
} // namespace Ort
33833412
#include "onnxruntime_cxx_inline.h"

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3713,4 +3713,13 @@ inline Status KernelRegistry::AddKernel(const OrtKernelDef* kernel_def, OrtKerne
37133713
void* kernel_create_func_state) {
37143714
return Status{GetEpApi().KernelRegistry_AddKernel(p_, kernel_def, kernel_create_func, kernel_create_func_state)};
37153715
}
3716+
3717+
namespace detail {
3718+
template <typename T>
3719+
inline Status SharedPrePackedWeightCacheImpl<T>::StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes,
3720+
size_t num_buffers) {
3721+
return Status{GetEpApi().SharedPrePackedWeightCache_StoreWeightData(this->p_, buffer_data_ptrs, buffer_sizes,
3722+
num_buffers)};
3723+
}
3724+
} // namespace detail
37163725
} // namespace Ort

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ ORT_RUNTIME_CLASS(KernelRegistry);
2929
ORT_RUNTIME_CLASS(KernelDefBuilder);
3030
ORT_RUNTIME_CLASS(KernelDef);
3131
ORT_RUNTIME_CLASS(DataType); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType
32+
ORT_RUNTIME_CLASS(SharedPrePackedWeightCache);
3233

3334
/** \brief Struct that an EP implements for IDataTransfer to copy between devices it uses and CPU.
3435
*
@@ -308,6 +309,101 @@ struct OrtKernelImpl {
308309
* \since Version 1.24.
309310
*/
310311
ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr);
312+
313+
/** \brief Optional function to pre-pack a constant tensor (i.e., a weight) to the kernel's preferred data layout.
314+
*
315+
* For example, a Conv kernel can define this function to pack input W to the channel-last data layout
316+
* before inference.
317+
*
318+
* Pre-packing can operate in three different modes: no pre-packing mode, sharing mode, and non-sharing mode.
319+
* 1) No pre-packing mode: The kernel can forgo any weight pre-packing for the given `input_index` by setting
320+
* `is_packed` to false and returning a successful OrtStatus. In this mode, the kernel's
321+
* OrtKernelImpl::SetSharedPrePackedWeight() function is not called for that specific
322+
* `input_index`.
323+
* 2) Sharing mode: Sharing is allowed if the `prepacked_weight_cache` argument is not NULL and the EP stores
324+
* weight data in CPU-accessible memory. In this case, the kernel can optionally choose
325+
* to share the packed weight with other kernels that use the same weight
326+
* (compared by content hash). To do so, the kernel must allocate the packed weight with the
327+
* provided `allocator`, then it stores the packed weight data into `prepacked_weight_cache`
328+
* via SharedPrePackedWeightCache_StoreWeightData(), sets `is_packed` to true, and returns a
329+
* successful OrtStatus. ORT will subsequently call OrtKernelImpl::SetSharedPrePackedWeight()
330+
* to provide this kernel with the actual shared weight data, whose memory location could
331+
* differ (i.e., if shared data was allocated by a previously processed kernel).
332+
* 3) Non-sharing mode: In non-sharing mode, the `prepacked_weight_cache` argument is ignored. In this mode,
333+
* the implementation allocates the packed data with the provided `allocator`, sets
334+
* `is_packed` to true, and returns a successful OrtStatus. The kernel is ultimately
335+
* responsible for releasing the packed data for the weight with `allocator`.
336+
* ORT may release the original (unpacked) weight, which must not be accessed in
337+
* OrtKernelImpl::Compute(). Note that in this mode, the kernel's
338+
* OrtKernelImpl::SetSharedPrePackedWeight() function is not called by ORT for that specific
339+
* `input_index`.
340+
*
341+
* \note This function is based on the internal OpKernel::PrePack() virtual function used within ORT.
342+
*
343+
* \param[in] this_ptr The OrtKernelImpl instance.
344+
* \param[in] tensor The OrtValue instance representing the constant tensor (weight). Do not cache in the kernel.
345+
* \param[in] input_index The input index of the tensor in this kernel.
346+
* \param[in] allocator Allocator for allocating the pre-packed data. Its use is required in sharing mode and
347+
* recommended, but not required, in the non-sharing mode. This will be an allocator set by
348+
* the application for the session/environment (e.g., via CreateAndRegisterAllocator[V2]
349+
* or RegisterAllocator), or an allocator on the OrtEpDevice (read-only or default) otherwise.
350+
* The allocator remains valid throughout the lifetime of the OrtKernelImpl instance.
351+
* \param[in] prepacked_weights_cache May be NULL. If not NULL, the kernel may choose to share a packed weight by
352+
* first storing it in the OrtSharedPrePackedWeightCache instance and then
353+
* receiving the actual shared weight data in the call to
354+
* OrtKernelImpl::SetSharedPrePackedWeight(). See the above description for
355+
* "sharing mode".
356+
* \param[out] is_packed Output parameter that the implementation sets to true if the kernel packed the tensor data.
357+
*
358+
* \snippet{doc} snippets.dox OrtStatus Return Value
359+
*
360+
* \note Implementation of this function is optional. If not implemented (set to NULL), ORT assumes the kernel
361+
* does not pre-pack weight data (i.e., `is_packed` defaults to false).
362+
*
363+
* \since Version 1.24.
364+
*/
365+
ORT_API2_STATUS(PrePackWeight, _In_ OrtKernelImpl* this_ptr, _In_ const OrtValue* tensor,
366+
_In_ int input_index, _Inout_ OrtAllocator* allocator,
367+
_In_opt_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, _Out_ bool* is_packed);
368+
369+
/** \brief Optional function that receives data for a shared pre-packed weight from ORT.
370+
*
371+
* ORT calls this function after calling OrtKernelImpl::PrePackWeight for a specific `input_index` if:
372+
* - OrtKernelImpl::PrePackWeight set the output parameter `is_packed` to true.
373+
* - OrtKernelImpl::PrePackWeight stored weight data to share into the provided OrtSharedPrePackedWeightCache
374+
* parameter (`prepacked_weight_cache`) via the API SharedPrePackedWeightCache_StoreWeightData.
375+
*
376+
* Refer to the description of the "sharing-mode" in the documentation for OrtKernelImpl::PrePackWeight().
377+
*
378+
* \note ORT will not call this function for an `input_index` that a previous call to
379+
* OrtKernelImpl::PrePackWeight() did not elect to pre-pack and share.
380+
*
381+
* \note This function is based on the internal OpKernel::UseSharedPrePackedBuffers() virtual function used
382+
* within ORT.
383+
*
384+
* \param[in] this_ptr The OrtKernelImpl instance.
385+
* \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
386+
* single shared weight. The buffers are provided in the same order and with the same
387+
* contents (in a potentially different memory location) as the buffers
388+
* passed into SharedPrePackedWeightCache_StoreWeightData() within the
389+
* OrtKernelImpl::PrePackWeight() call for the same `input_index`.
390+
* \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
391+
* \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
392+
* Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
393+
* \param[in] input_index The input index of the tensor in this kernel. This index identifies the identity of
394+
* the weight.
395+
*
396+
* \snippet{doc} snippets.dox OrtStatus Return Value
397+
*
398+
* \note Implementation of this function is generally optional. It is only required if OrtKernelImpl::PrePack()
399+
* elects to share pre-packed weights.
400+
*
401+
* \since Version 1.24.
402+
*/
403+
ORT_API2_STATUS(SetSharedPrePackedWeight, _In_ OrtKernelImpl* this_ptr,
404+
_In_reads_(num_buffers) const void* const* buffer_data_ptrs,
405+
_In_reads_(num_buffers) const size_t* buffer_data_sizes,
406+
_In_ size_t num_buffers, _In_ int input_index);
311407
};
312408

313409
/** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel.
@@ -846,6 +942,35 @@ struct OrtEpApi {
846942
*/
847943
ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info,
848944
_In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def);
945+
946+
/** \brief Sets one or more data buffers that collectively hold the pre-packed data for a single shared weight.
947+
*
948+
* \note Used within the implementation of OrtKernelImpl::PrePackWeight() when the kernel wants to share pre-packed
949+
* weight data with other kernels. The buffer data MUST be allocated with the OrtAllocator provided to
950+
* OrtKernelImpl::PrePack.
951+
*
952+
* \note Ownership of weight data transfers to the OrtSharedPrePackedWeightCache instance on success.
953+
* If this function returns an error status, the caller retains ownership of the weight data.
954+
*
955+
* \note Subsequent calls with the same OrtSharedPrePackedWeightCache instance release and replace the old data.
956+
*
957+
* \param[in] this_ptr The OrtKernelImpl instance.
958+
* \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
959+
* single shared weight. Note that sometimes a single weight may have multiple pre-packed
960+
* buffers and it is up to the kernel implementation to determine how to split the data
961+
* into multiple buffers (if desired).
962+
* \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
963+
* \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
964+
* Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
965+
*
966+
* \snippet{doc} snippets.dox OrtStatus Return Value
967+
*
968+
* \since Version 1.24.
969+
*/
970+
ORT_API2_STATUS(SharedPrePackedWeightCache_StoreWeightData,
971+
_In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache,
972+
_In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes,
973+
_In_ size_t num_buffers);
849974
};
850975

851976
/**

onnxruntime/core/framework/session_state.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,16 +421,23 @@ void SessionState::CleanInitializedTensorsFromGraph() {
421421
static Status KernelUseSharedPrePackedBuffers(OpKernel& kernel, int input_idx,
422422
const PrePackedWeights& prepacked_weights,
423423
const std::string& node_name) {
424+
const size_t num_buffers = prepacked_weights.buffers_.size();
425+
assert(prepacked_weights.buffer_sizes_.size() == num_buffers);
426+
424427
std::vector<BufferUniquePtr> shared_prepacked_buffers;
425-
shared_prepacked_buffers.reserve(4); // Unlikely to see more than 4 prepacked buffers per initializer
428+
std::vector<size_t> shared_prepacked_buffer_sizes;
429+
shared_prepacked_buffers.reserve(num_buffers);
430+
shared_prepacked_buffer_sizes.reserve(num_buffers);
426431

427-
for (const auto& prepacked_buffer : prepacked_weights.buffers_) {
432+
for (size_t i = 0; i < num_buffers; i++) {
428433
// BufferDeleter is nullptr because the kernel should not delete the shared buffer - it can only use it
429-
shared_prepacked_buffers.emplace_back(prepacked_buffer.get(), BufferDeleter(nullptr));
434+
shared_prepacked_buffers.emplace_back(prepacked_weights.buffers_[i].get(), BufferDeleter(nullptr));
435+
shared_prepacked_buffer_sizes.push_back(prepacked_weights.buffer_sizes_[i]);
430436
}
431437

432438
bool used_shared_buffers = false;
433-
ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers(shared_prepacked_buffers, input_idx, used_shared_buffers));
439+
ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers_V2(shared_prepacked_buffers, shared_prepacked_buffer_sizes,
440+
input_idx, used_shared_buffers));
434441

435442
// BUG CHECK: Ensure that the kernel used the provided shared buffers
436443
// Mostly a debug check to ensure that the kernel has an overridden implementation of the

onnxruntime/core/session/plugin_ep/ep_api.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,47 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo*
582582
API_IMPL_END
583583
}
584584

585+
ORT_API_STATUS_IMPL(SharedPrePackedWeightCache_StoreWeightData,
586+
_In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache,
587+
_In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes,
588+
_In_ size_t num_buffers) {
589+
API_IMPL_BEGIN
590+
if (prepacked_weight_cache == nullptr) {
591+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
592+
"Must specify a valid OrtPrePackedWeightsCache instance");
593+
}
594+
595+
if (buffer_data_ptrs == nullptr) {
596+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of buffer data pointers");
597+
}
598+
599+
if (buffer_data_sizes == nullptr) {
600+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of buffer data sizes");
601+
}
602+
603+
if (num_buffers == 0) {
604+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one weight data buffer");
605+
}
606+
607+
OrtStatus* status = nullptr;
608+
609+
ORT_TRY {
610+
prepacked_weight_cache->SetBuffers(buffer_data_ptrs, buffer_data_sizes, num_buffers);
611+
}
612+
ORT_CATCH(const std::exception& ex) {
613+
ORT_HANDLE_EXCEPTION([&]() {
614+
// This API function promises that ORT will take ownership of the data only if it returns successfully.
615+
// If any exception occurred while filling out `prepacked_weight_cache`, we try to release ownership so that
616+
// the caller retains ownership of all of the original data and can delete it.
617+
prepacked_weight_cache->ReleaseAllData();
618+
status = OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what());
619+
});
620+
}
621+
622+
return status;
623+
API_IMPL_END
624+
}
625+
585626
static constexpr OrtEpApi ort_ep_api = {
586627
// NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end,
587628
// and no functions can be removed (the implementation needs to change to return an error).
@@ -636,6 +677,7 @@ static constexpr OrtEpApi ort_ep_api = {
636677
&OrtExecutionProviderApi::KernelDef_GetOutputMemType,
637678
&OrtExecutionProviderApi::GetTensorDataType,
638679
&OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel,
680+
&OrtExecutionProviderApi::SharedPrePackedWeightCache_StoreWeightData,
639681
};
640682

641683
// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned

0 commit comments

Comments
 (0)