@@ -29,6 +29,7 @@ ORT_RUNTIME_CLASS(KernelRegistry);
2929ORT_RUNTIME_CLASS (KernelDefBuilder );
3030ORT_RUNTIME_CLASS (KernelDef );
3131ORT_RUNTIME_CLASS (DataType ); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType
32+ ORT_RUNTIME_CLASS (SharedPrePackedWeightCache );
3233
3334/** \brief Struct that an EP implements for IDataTransfer to copy between devices it uses and CPU.
3435 *
@@ -308,6 +309,101 @@ struct OrtKernelImpl {
308309 * \since Version 1.24.
309310 */
310311 ORT_API_T (void , Release , _In_ OrtKernelImpl * this_ptr );
312+
313+ /** \brief Optional function to pre-pack a constant tensor (i.e., a weight) to the kernel's preferred data layout.
314+ *
315+ * For example, a Conv kernel can define this function to pack input W to the channel-last data layout
316+ * before inference.
317+ *
318+ * Pre-packing can operate in three different modes: no pre-packing mode, sharing mode, and non-sharing mode.
319+ * 1) No pre-packing mode: The kernel can forgo any weight pre-packing for the given `input_index` by setting
320+ * `is_packed` to false and returning a successful OrtStatus. In this mode, the kernel's
321+ * OrtKernelImpl::SetSharedPrePackedWeight() function is not called for that specific
322+ * `input_index`.
323+ * 2) Sharing mode: Sharing is allowed if the `prepacked_weight_cache` argument is not NULL and the EP stores
324+ * weight data in CPU-accessible memory. In this case, the kernel can optionally choose
325+ * to share the packed weight with other kernels that use the same weight
326+ * (compared by content hash). To do so, the kernel must allocate the packed weight with the
327+ * provided `allocator`, then it stores the packed weight data into `prepacked_weight_cache`
328+ * via SharedPrePackedWeightCache_StoreWeightData(), sets `is_packed` to true, and returns a
329+ * successful OrtStatus. ORT will subsequently call OrtKernelImpl::SetSharedPrePackedWeight()
330+ * to provide this kernel with the actual shared weight data, whose memory location could
331+ * differ (i.e., if shared data was allocated by a previously processed kernel).
332+ * 3) Non-sharing mode: In non-sharing mode, the `prepacked_weight_cache` argument is ignored. In this mode,
333+ * the implementation allocates the packed data with the provided `allocator`, sets
334+ * `is_packed` to true, and returns a successful OrtStatus. The kernel is ultimately
335+ * responsible for releasing the packed data for the weight with `allocator`.
336+ * ORT may release the original (unpacked) weight, which must not be accessed in
337+ * OrtKernelImpl::Compute(). Note that in this mode, the kernel's
338+ * OrtKernelImpl::SetSharedPrePackedWeight() function is not called by ORT for that specific
339+ * `input_index`.
340+ *
341+ * \note This function is based on the internal OpKernel::PrePack() virtual function used within ORT.
342+ *
343+ * \param[in] this_ptr The OrtKernelImpl instance.
344+ * \param[in] tensor The OrtValue instance representing the constant tensor (weight). Do not cache in the kernel.
345+ * \param[in] input_index The input index of the tensor in this kernel.
346+ * \param[in] allocator Allocator for allocating the pre-packed data. Its use is required in sharing mode and
347+ * recommended, but not required, in the non-sharing mode. This will be an allocator set by
348+ * the application for the session/environment (e.g., via CreateAndRegisterAllocator[V2]
349+ * or RegisterAllocator), or an allocator on the OrtEpDevice (read-only or default) otherwise.
350+ * The allocator remains valid throughout the lifetime of the OrtKernelImpl instance.
351+ * \param[in] prepacked_weights_cache May be NULL. If not NULL, the kernel may choose to share a packed weight by
352+ * first storing it in the OrtSharedPrePackedWeightCache instance and then
353+ * receiving the actual shared weight data in the call to
354+ * OrtKernelImpl::SetSharedPrePackedWeight(). See the above description for
355+ * "sharing mode".
356+ * \param[out] is_packed Output parameter that the implementation sets to true if the kernel packed the tensor data.
357+ *
358+ * \snippet{doc} snippets.dox OrtStatus Return Value
359+ *
360+ * \note Implementation of this function is optional. If not implemented (set to NULL), ORT assumes the kernel
361+ * does not pre-pack weight data (i.e., `is_packed` defaults to false).
362+ *
363+ * \since Version 1.24.
364+ */
365+ ORT_API2_STATUS (PrePackWeight , _In_ OrtKernelImpl * this_ptr , _In_ const OrtValue * tensor ,
366+ _In_ int input_index , _Inout_ OrtAllocator * allocator ,
367+ _In_opt_ OrtSharedPrePackedWeightCache * prepacked_weight_cache , _Out_ bool * is_packed );
368+
369+ /** \brief Optional function that receives data for a shared pre-packed weight from ORT.
370+ *
371+ * ORT calls this function after calling OrtKernelImpl::PrePackWeight for a specific `input_index` if:
372+ * - OrtKernelImpl::PrePackWeight set the output parameter `is_packed` to true.
373+ * - OrtKernelImpl::PrePackWeight stored weight data to share into the provided OrtSharedPrePackedWeightCache
374+ * parameter (`prepacked_weight_cache`) via the API SharedPrePackedWeightCache_StoreWeightData.
375+ *
376+ * Refer to the description of the "sharing-mode" in the documentation for OrtKernelImpl::PrePackWeight().
377+ *
378+ * \note ORT will not call this function for an `input_index` that a previous call to
379+ * OrtKernelImpl::PrePackWeight() did not elect to pre-pack and share.
380+ *
381+ * \note This function is based on the internal OpKernel::UseSharedPrePackedBuffers() virtual function used
382+ * within ORT.
383+ *
384+ * \param[in] this_ptr The OrtKernelImpl instance.
385+ * \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
386+ * single shared weight. The buffers are provided in the same order and with the same
387+ * contents (in a potentially different memory location) as the buffers
388+ * passed into SharedPrePackedWeightCache_StoreWeightData() within the
389+ * OrtKernelImpl::PrePackWeight() call for the same `input_index`.
390+ * \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
391+ * \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
392+ * Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
393+ * \param[in] input_index The input index of the tensor in this kernel. This index identifies the identity of
394+ * the weight.
395+ *
396+ * \snippet{doc} snippets.dox OrtStatus Return Value
397+ *
398+ * \note Implementation of this function is generally optional. It is only required if OrtKernelImpl::PrePack()
399+ * elects to share pre-packed weights.
400+ *
401+ * \since Version 1.24.
402+ */
403+ ORT_API2_STATUS (SetSharedPrePackedWeight , _In_ OrtKernelImpl * this_ptr ,
404+ _In_reads_ (num_buffers ) const void * const * buffer_data_ptrs ,
405+ _In_reads_ (num_buffers ) const size_t * buffer_data_sizes ,
406+ _In_ size_t num_buffers , _In_ int input_index );
311407};
312408
313409/** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel.
@@ -846,6 +942,35 @@ struct OrtEpApi {
846942 */
847943 ORT_API2_STATUS (EpGraphSupportInfo_LookUpKernel , _In_ OrtEpGraphSupportInfo * graph_support_info ,
848944 _In_ const OrtNode * node , _Outptr_result_maybenull_ const OrtKernelDef * * out_kernel_def );
945+
946+ /** \brief Sets one or more data buffers that collectively hold the pre-packed data for a single shared weight.
947+ *
948+ * \note Used within the implementation of OrtKernelImpl::PrePackWeight() when the kernel wants to share pre-packed
949+ * weight data with other kernels. The buffer data MUST be allocated with the OrtAllocator provided to
950+ * OrtKernelImpl::PrePack.
951+ *
952+ * \note Ownership of weight data transfers to the OrtSharedPrePackedWeightCache instance on success.
953+ * If this function returns an error status, the caller retains ownership of the weight data.
954+ *
955+ * \note Subsequent calls with the same OrtSharedPrePackedWeightCache instance release and replace the old data.
956+ *
957+ * \param[in] this_ptr The OrtKernelImpl instance.
958+ * \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
959+ * single shared weight. Note that sometimes a single weight may have multiple pre-packed
960+ * buffers and it is up to the kernel implementation to determine how to split the data
961+ * into multiple buffers (if desired).
962+ * \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
963+ * \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
964+ * Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
965+ *
966+ * \snippet{doc} snippets.dox OrtStatus Return Value
967+ *
968+ * \since Version 1.24.
969+ */
970+ ORT_API2_STATUS (SharedPrePackedWeightCache_StoreWeightData ,
971+ _In_ OrtSharedPrePackedWeightCache * prepacked_weight_cache ,
972+ _In_reads_ (num_buffers ) void * * buffer_data_ptrs , _In_reads_ (num_buffers ) size_t * buffer_data_sizes ,
973+ _In_ size_t num_buffers );
849974};
850975
851976/**
0 commit comments