@@ -29,6 +29,7 @@ ORT_RUNTIME_CLASS(KernelRegistry);
2929ORT_RUNTIME_CLASS (KernelDefBuilder );
3030ORT_RUNTIME_CLASS (KernelDef );
3131ORT_RUNTIME_CLASS (DataType ); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType
32+ ORT_RUNTIME_CLASS (SharedPrePackedWeightCache );
3233
3334/** \brief Struct that an EP implements for IDataTransfer to copy between devices it uses and CPU.
3435 *
@@ -308,6 +309,98 @@ struct OrtKernelImpl {
308309 * \since Version 1.24.
309310 */
310311 ORT_API_T (void , Release , _In_ OrtKernelImpl * this_ptr );
312+
313+ /** \brief Optional function to pre-pack a constant tensor (i.e., a weight) to the kernel's preferred data layout.
314+ *
315+ * For example, a Conv kernel can define this function to pack input W to the channel-last data layout
316+ * before inference.
317+ *
318+ * Pre-packing can operate in three different modes: no pre-packing mode, sharing mode, and non-sharing mode.
319+ * 1) No pre-packing mode: The kernel can forgo any weight pre-packing for the given `input_index` by setting
320+ * `is_packed` to false and returning a successful OrtStatus. In this mode, the kernel's
321+ * OrtKernelImpl::SetSharedPrePackedWeight() function is not called for that specific
322+ * `input_index`.
323+ * 2) Sharing mode: Sharing is allowed if the `prepacked_weight_cache` argument is not NULL and the EP stores
324+ * weight data in CPU-accessible memory. In this case, the kernel can optionally choose
325+ * to share the packed weight with other kernels that use the same weight
326+ * (compared by content hash). To do so, the kernel must allocate the packed weight with the
327+ * provided `allocator`, then it stores the packed weight data into `prepacked_weight_cache`
328+ * via SharedPrePackedWeightCache_StoreWeightData(), sets `is_packed` to true, and returns a
329+ * successful OrtStatus. ORT will subsequently call OrtKernelImpl::SetSharedPrePackedWeight()
330+ * to provide this kernel with the actual shared weight data, whose memory location could
331+ * differ (i.e., if shared data was allocated by a previously processed kernel).
332+ * 3) Non-sharing mode: In non-sharing mode, the `prepacked_weight_cache` argument is ignored. In this mode,
333+ * the implementation allocates the packed data with the provided `allocator`, sets
334+ * `is_packed` to true, and returns a successful OrtStatus. The kernel is ultimately
335+ * responsible for releasing the packed data for the weight with `allocator`.
336+ * ORT may release the original (unpacked) weight, which must not be accessed in
337+ * OrtKernelImpl::Compute(). Note that in this mode, the kernel's
338+ * OrtKernelImpl::SetSharedPrePackedWeight() function is not called by ORT for that specific
339+ * `input_index`.
340+ *
341+ * \note This function is based on the internal OpKernel::PrePack() virtual function used within ORT.
342+ *
343+ * \param[in] this_ptr The OrtKernelImpl instance.
344+ * \param[in] tensor The OrtValue instance representing the constant tensor (weight). Do not cache in the kernel.
345+ * \param[in] input_index The input index of the tensor in this kernel.
346+ * \param[in] allocator Allocator for allocating the pre-packed data. Its use is required in sharing mode and
347+ * recommended, but not required, in the non-sharing mode. This will be an allocator set by
348+ * the application for the session/environment (e.g., via CreateAndRegisterAllocator[V2]
349+ * or RegisterAllocator), or an allocator on the OrtEpDevice (read-only or default) otherwise.
350+ * The allocator remains valid throughout the lifetime of the OrtKernelImpl instance.
351+ * \param[in] prepacked_weights_cache May be NULL. If not NULL, the kernel may choose to share a packed weight by
352+ * first storing it in the OrtSharedPrePackedWeightCache instance and then
353+ * receiving the actual shared weight data in the call to
354+ * OrtKernelImpl::SetSharedPrePackedWeight(). See the above description for
355+ * "sharing mode".
356+ * \param[out] is_packed Output parameter that the implementation sets to true if the kernel packed the tensor data.
357+ *
358+ * \snippet{doc} snippets.dox OrtStatus Return Value
359+ *
360+ * \note Implementation of this function is optional. If not implemented (set to NULL), ORT assumes the kernel
361+ * does not pre-pack weight data (i.e., `is_packed` defaults to false).
362+ *
363+ * \since Version 1.24.
364+ */
365+ ORT_API2_STATUS (PrePackWeight , _In_ OrtKernelImpl * this_ptr , _In_ const OrtValue * tensor ,
366+ _In_ int input_index , _Inout_ OrtAllocator * allocator ,
367+ _In_opt_ OrtSharedPrePackedWeightCache * prepacked_weight_cache , _Out_ bool * is_packed );
368+
369+ /** \brief Optional function that receives data for a shared pre-packed weight from ORT.
370+ *
371+ * This function is called after a prior call to OrtKernelImpl::PrePackWeight for a specific `input_index` set
372+ * `is_packed` to true and stored weight data (to share) into the provided OrtSharedPrePackedWeightCache instance.
373+ * Refer to the description of the "sharing-mode" in the documentation for OrtKernelImpl::PrePackWeight().
374+ *
375+ * \note ORT will not call this function for an `input_index` that a previous call to
376+ * OrtKernelImpl::PrePackWeight() did not elect to pre-pack and share.
377+ *
378+ * \note This function is based on the internal OpKernel::UseSharedPrePackedBuffers() virtual function used
379+ * within ORT.
380+ *
381+ * \param[in] this_ptr The OrtKernelImpl instance.
382+ * \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
383+ * single shared weight. The buffers are provided in the same order and with the same
384+ * contents (in a potentially different memory location) as the buffers
385+ * passed into SharedPrePackedWeightCache_StoreWeightData() within the
386+ * OrtKernelImpl::PrePackWeight() call for the same `input_index`.
387+ * \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
388+ * \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
389+ * Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
390+ * \param[in] input_index The input index of the tensor in this kernel. This index identifies the identity of
391+ * the weight.
392+ *
393+ * \snippet{doc} snippets.dox OrtStatus Return Value
394+ *
395+ * \note Implementation of this function is generally optional. It is only required if OrtKernelImpl::PrePack()
396+ * elects to share pre-packed weights.
397+ *
398+ * \since Version 1.24.
399+ */
400+ ORT_API2_STATUS (SetSharedPrePackedWeight , _In_ OrtKernelImpl * this_ptr ,
401+ _In_reads_ (num_buffers ) const void * const * buffer_data_ptrs ,
402+ _In_reads_ (num_buffers ) const size_t * buffer_data_sizes ,
403+ _In_ size_t num_buffers , _In_ int input_index );
311404};
312405
313406/** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel.
@@ -847,6 +940,35 @@ struct OrtEpApi {
847940 ORT_API2_STATUS (EpGraphSupportInfo_LookUpKernel , _In_ OrtEpGraphSupportInfo * graph_support_info ,
848941 _In_ const OrtNode * node , _Outptr_result_maybenull_ const OrtKernelDef * * out_kernel_def );
849942
943+ /** \brief Sets one or more data buffers that collectively hold the pre-packed data for a single shared weight.
944+ *
945+ * \note Used within the implementation of OrtKernelImpl::PrePackWeight() when the kernel wants to share pre-packed
946+ * weight data with other kernels. The buffer data MUST be allocated with the OrtAllocator provided to
947+ * OrtKernelImpl::PrePack.
948+ *
949+ * \note Ownership of weight data transfers to the OrtSharedPrePackedWeightCache instance on success.
950+ * If this function returns an error status, the caller retains ownership of the weight data.
951+ *
952+ * \note Subsequent calls with the same OrtSharedPrePackedWeightCache instance release and replace the old data.
953+ *
954+ * \param[in] this_ptr The OrtKernelImpl instance.
955+ * \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
956+ * single shared weight. Note that sometimes a single weight may have multiple pre-packed
957+ * buffers and it is up to the kernel implementation to determine how to split the data
958+ * into multiple buffers (if desired).
959+ * \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
960+ * \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
961+ * Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
962+ *
963+ * \snippet{doc} snippets.dox OrtStatus Return Value
964+ *
965+ * \since Version 1.24.
966+ */
967+ ORT_API2_STATUS (SharedPrePackedWeightCache_StoreWeightData ,
968+ _In_ OrtSharedPrePackedWeightCache * prepacked_weight_cache ,
969+ _In_reads_ (num_buffers ) void * * buffer_data_ptrs , _In_reads_ (num_buffers ) size_t * buffer_data_sizes ,
970+ _In_ size_t num_buffers );
971+
850972 /** \brief Get the OrtEp instance to which the node is assigned from the OrtKernelInfo.
851973 *
852974 * \note Used within OrtKernelImpl implementations to obtain a reference to the OrtEp.
0 commit comments