Skip to content

Commit 96fe212

Browse files
Add Shape and Reshape kernels to example kernel EP
1 parent 838b17a commit 96fe212

File tree

6 files changed

+309
-0
lines changed

6 files changed

+309
-0
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,6 +2108,10 @@ if (onnxruntime_BUILD_SHARED_LIB AND
21082108
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h"
21092109
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc"
21102110
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h"
2111+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.h"
2112+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.cc"
2113+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.h"
2114+
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc"
21112115
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h"
21122116
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc"
21132117
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h"

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,18 @@ static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = {
1919
BuildKernelCreateInfo<class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 21, 22, Squeeze)>,
2020
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 23, Squeeze)>,
2121
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 24, Squeeze)>,
22+
23+
// Support Shape 21, 23, and 24.
24+
// Note: end versions are inclusive.
25+
BuildKernelCreateInfo<class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 21, 22, Shape)>,
26+
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 23, Shape)>,
27+
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 24, Shape)>,
28+
29+
// Support Reshape 21, 23, and 24.
30+
// Note: end versions are inclusive.
31+
BuildKernelCreateInfo<class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 21, 22, Reshape)>,
32+
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 23, Reshape)>,
33+
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 24, Reshape)>,
2234
};
2335

2436
size_t GetNumKernels() {
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "reshape.h"
5+
6+
#include <gsl/span>
7+
#include <vector>
8+
#include "utils.h"
9+
10+
// ONNX Reshape version 21
11+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
12+
Reshape,
13+
kOnnxDomain,
14+
/*start_version*/ 21, /*end_version (inclusive)*/ 22,
15+
(Ort::KernelDefBuilder()
16+
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
17+
.AddTypeConstraint("shape", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64))
18+
.AddInputOutputAlias(0, 0)
19+
.SetInputMemType(1, OrtMemTypeCPU)),
20+
Reshape)
21+
22+
// ONNX Reshape version 23
23+
ONNX_OPERATOR_KERNEL_EX(
24+
Reshape,
25+
kOnnxDomain,
26+
/*version*/ 23, // Equivalent to start_version: 23, end_version: 23
27+
(Ort::KernelDefBuilder()
28+
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
29+
.AddTypeConstraint("shape", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64))
30+
.AddInputOutputAlias(0, 0)
31+
.SetInputMemType(1, OrtMemTypeCPU)),
32+
Reshape)
33+
34+
// ONNX Reshape version 24
35+
ONNX_OPERATOR_KERNEL_EX(
36+
Reshape,
37+
kOnnxDomain,
38+
/*version*/ 24, // Equivalent start_version: 24, end_version: 24
39+
(Ort::KernelDefBuilder()
40+
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
41+
.AddTypeConstraint("shape", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64))
42+
.AddInputOutputAlias(0, 0)
43+
.SetInputMemType(1, OrtMemTypeCPU)),
44+
Reshape)
45+
46+
Reshape::Reshape(const OrtKernelInfo* info, void* state, bool allow_zero, PrivateTag)
47+
: OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL
48+
info_{info},
49+
data_transfer_impl_{reinterpret_cast<OrtDataTransferImpl*>(state)},
50+
allow_zero_{allow_zero} {
51+
ort_version_supported = ORT_API_VERSION;
52+
Compute = ComputeImpl;
53+
Release = ReleaseImpl;
54+
}
55+
56+
/*static*/
57+
OrtStatus* Reshape::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr<Reshape>& kernel) noexcept {
58+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
59+
Ort::ConstKernelInfo kernel_info(info);
60+
bool allow_zero = kernel_info.GetAttribute<int64_t>("allowzero") == 1;
61+
62+
kernel = std::make_unique<Reshape>(info, state, allow_zero, PrivateTag{});
63+
return nullptr;
64+
EXCEPTION_TO_RETURNED_STATUS_END
65+
}
66+
67+
// Computes the requested shape for the reshape operation.
68+
// Implementation is based on ReshapeHelper in onnxruntime/core/providers/cpu/tensor/reshape_helper.h
69+
static OrtStatus* GetRequestedShape(gsl::span<const int64_t> input_shape, bool allow_zero,
70+
/*out*/ std::vector<int64_t>& requested_shape) {
71+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
72+
const OrtApi& ort_api = Ort::GetApi();
73+
74+
int64_t num_input_elems = 1;
75+
for (auto dim_val : input_shape) {
76+
num_input_elems *= dim_val;
77+
}
78+
RETURN_IF(num_input_elems == -1, ort_api, "Input tensor must not have dynamic (-1) dimensions.");
79+
80+
size_t num_dims = requested_shape.size();
81+
int64_t unknown_dim = -1;
82+
int64_t size = 1;
83+
84+
for (size_t i = 0; i < num_dims; i++) {
85+
RETURN_IF(requested_shape[i] < -1, ort_api, "A dimension cannot be less than -1");
86+
87+
if (requested_shape[i] == -1) {
88+
RETURN_IF(unknown_dim != -1, ort_api, "At most one dimension can be -1");
89+
unknown_dim = static_cast<int64_t>(i);
90+
} else {
91+
if (!allow_zero && requested_shape[i] == 0) {
92+
RETURN_IF(i >= input_shape.size(), ort_api,
93+
"The dimension with value zero exceeds the dimension size of the input");
94+
requested_shape[i] = input_shape[i];
95+
}
96+
97+
size *= requested_shape[i];
98+
}
99+
}
100+
101+
if (unknown_dim != -1) {
102+
// Calculate unknown dimension.
103+
RETURN_IF(size == 0 || (num_input_elems % size) != 0, ort_api,
104+
"The input cannot be reshaped to the requested shape");
105+
requested_shape[unknown_dim] = num_input_elems / size;
106+
} else {
107+
// Check if the output shape is valid.
108+
RETURN_IF(num_input_elems != size, ort_api, "The input cannot be reshaped to the requested shape");
109+
}
110+
111+
return nullptr;
112+
EXCEPTION_TO_RETURNED_STATUS_END
113+
}
114+
115+
/*static*/
116+
OrtStatus* ORT_API_CALL Reshape::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
117+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
118+
Reshape* reshape_kernel = static_cast<Reshape*>(this_ptr);
119+
static_cast<void>(reshape_kernel->info_); // NOTE: Unused in this example.
120+
121+
Ort::KernelContext kernel_context(kernel_ctx);
122+
123+
// Input[0] has the data to reshape.
124+
Ort::ConstValue input = kernel_context.GetInput(0);
125+
auto type_shape_info = input.GetTensorTypeAndShapeInfo();
126+
std::vector<int64_t> input_shape = type_shape_info.GetShape();
127+
128+
// Input[1] has the requested shape for the reshape operation.
129+
Ort::ConstValue shape_input = kernel_context.GetInput(1);
130+
gsl::span<const int64_t> shape_input_data;
131+
std::vector<int64_t> final_shape;
132+
133+
RETURN_IF_ERROR(GetValueDataAndShape(shape_input, shape_input_data, final_shape));
134+
RETURN_IF(final_shape.size() != 1, Ort::GetApi(), "A shape tensor must have one dimension");
135+
RETURN_IF_ERROR(GetRequestedShape(input_shape, reshape_kernel->allow_zero_, final_shape));
136+
137+
Ort::UnownedValue output = kernel_context.GetOutput(0, final_shape);
138+
139+
// This kernel aliases the input and output, so a copy is not really necessary.
140+
// CopyTensor() will not do a copy if the source and destination buffers are the same.
141+
RETURN_IF_ERROR(CopyTensor(*reshape_kernel->data_transfer_impl_, input, output));
142+
return nullptr;
143+
EXCEPTION_TO_RETURNED_STATUS_END
144+
}
145+
146+
/*static*/
147+
void ORT_API_CALL Reshape::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept {
148+
delete static_cast<Reshape*>(this_ptr);
149+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "../../plugin_ep_utils.h"
7+
8+
class Reshape : public OrtKernelImpl {
9+
private:
10+
struct PrivateTag {};
11+
12+
public:
13+
static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr<Reshape>& kernel) noexcept;
14+
Reshape(const OrtKernelInfo* info, void* state, bool allow_zero, PrivateTag);
15+
16+
// Static functions assigned to the OrtKernelImpl fields:
17+
static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept;
18+
static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept;
19+
20+
private:
21+
const OrtKernelInfo* info_;
22+
OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp
23+
bool allow_zero_;
24+
};
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "shape.h"
5+
6+
#include <vector>
7+
#include "utils.h"
8+
9+
// ONNX Shape version 21
10+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
11+
Shape,
12+
kOnnxDomain,
13+
/*start_version*/ 21, /*end_version (inclusive)*/ 22,
14+
(Ort::KernelDefBuilder()
15+
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
16+
.AddTypeConstraint("T1", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64))
17+
.SetOutputMemType(0, OrtMemTypeCPU)),
18+
Shape)
19+
20+
// ONNX Shape version 23
21+
ONNX_OPERATOR_KERNEL_EX(
22+
Shape,
23+
kOnnxDomain,
24+
/*version*/ 23, // Equivalent to start_version: 23, end_version: 23
25+
(Ort::KernelDefBuilder()
26+
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
27+
.AddTypeConstraint("T1", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64))
28+
.SetOutputMemType(0, OrtMemTypeCPU)),
29+
Shape)
30+
31+
// ONNX Shape version 24
32+
ONNX_OPERATOR_KERNEL_EX(
33+
Shape,
34+
kOnnxDomain,
35+
/*version*/ 24, // Equivalent start_version: 24, end_version: 24
36+
(Ort::KernelDefBuilder()
37+
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
38+
.AddTypeConstraint("T1", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64))
39+
.SetOutputMemType(0, OrtMemTypeCPU)),
40+
Shape)
41+
42+
Shape::Shape(const OrtKernelInfo* info, void* state, PrivateTag)
43+
: OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL
44+
info_{info},
45+
data_transfer_impl_{reinterpret_cast<OrtDataTransferImpl*>(state)} {
46+
ort_version_supported = ORT_API_VERSION;
47+
Compute = ComputeImpl;
48+
Release = ReleaseImpl;
49+
}
50+
51+
/*static*/
52+
OrtStatus* Shape::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr<Shape>& kernel) noexcept {
53+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
54+
Ort::ConstKernelInfo kernel_info(info);
55+
56+
int64_t start = kernel_info.GetAttribute<int64_t>("start");
57+
int64_t end = 0;
58+
Ort::Status status{Ort::GetApi().KernelInfoGetAttribute_int64(info, "end", &end)};
59+
60+
// This example kernel does not support shape slicing.
61+
RETURN_IF(start != 0 || status.IsOK(), Ort::GetApi(),
62+
"Example Shape kernel does not support non-default start/end attributes");
63+
64+
kernel = std::make_unique<Shape>(info, state, PrivateTag{});
65+
return nullptr;
66+
EXCEPTION_TO_RETURNED_STATUS_END
67+
}
68+
69+
/*static*/
70+
OrtStatus* ORT_API_CALL Shape::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
71+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
72+
Shape* shape_kernel = static_cast<Shape*>(this_ptr);
73+
static_cast<void>(shape_kernel->info_); // NOTE: Unused in this example.
74+
static_cast<void>(shape_kernel->data_transfer_impl_); // NOTE: Unused in this example.
75+
76+
Ort::KernelContext kernel_context(kernel_ctx);
77+
78+
Ort::ConstValue input = kernel_context.GetInput(0);
79+
auto type_shape_info = input.GetTensorTypeAndShapeInfo();
80+
std::vector<int64_t> input_shape = type_shape_info.GetShape();
81+
82+
std::vector<int64_t> output_shape = {static_cast<int64_t>(input_shape.size())};
83+
Ort::UnownedValue output = kernel_context.GetOutput(0, output_shape);
84+
int64_t* output_data = output.GetTensorMutableData<int64_t>();
85+
86+
for (size_t i = 0; i < input_shape.size(); i++) {
87+
output_data[i] = input_shape[i];
88+
}
89+
90+
return nullptr;
91+
EXCEPTION_TO_RETURNED_STATUS_END
92+
}
93+
94+
/*static*/
95+
void ORT_API_CALL Shape::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept {
96+
delete static_cast<Shape*>(this_ptr);
97+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "../../plugin_ep_utils.h"
7+
8+
class Shape : public OrtKernelImpl {
9+
private:
10+
struct PrivateTag {};
11+
12+
public:
13+
static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr<Shape>& kernel) noexcept;
14+
Shape(const OrtKernelInfo* info, void* state, PrivateTag);
15+
16+
// Static functions assigned to the OrtKernelImpl fields:
17+
static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept;
18+
static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept;
19+
20+
private:
21+
const OrtKernelInfo* info_;
22+
OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp
23+
};

0 commit comments

Comments
 (0)