Skip to content

Commit 3968ec3

Browse files
Update example BinaryOp class to support Sub instead of Add
1 parent ebe58ed commit 3968ec3

File tree

7 files changed

+115
-19
lines changed

7 files changed

+115
-19
lines changed

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, cons
6868

6969
if (op_type == "Relu" || op_type == "Squeeze") {
7070
candidate_nodes.push_back(node);
71-
} else if (op_type == "Mul" || op_type == "Add") {
71+
} else if (op_type == "Mul" || op_type == "Sub") {
7272
std::vector<Ort::ConstValueInfo> inputs = node.GetInputs();
7373

74-
// Note: ONNX shape inference should ensure Mul/Add has two inputs.
74+
// Note: ONNX shape inference should ensure Mul/Sub has two inputs.
7575
std::optional<std::vector<int64_t>> input_0_shape = GetTensorShape(inputs[0]);
7676
std::optional<std::vector<int64_t>> input_1_shape = GetTensorShape(inputs[1]);
7777

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = {
1111
// Mul version 14
1212
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 14, Mul)>,
1313

14-
// Add version 14
15-
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 14, Add)>,
14+
// Sub version 14
15+
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 14, Sub)>,
1616

1717
// Relu version 14
1818
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 14, Relu)>,

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ ONNX_OPERATOR_KERNEL_EX(
1717
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))),
1818
BinaryOp)
1919

20-
// Defines a kernel creation function for version 14 of Add.
20+
// Defines a kernel creation function for version 14 of Sub.
2121
ONNX_OPERATOR_KERNEL_EX(
22-
Add,
22+
Sub,
2323
kOnnxDomain,
2424
/*version*/ 14, // Equivalent to start_version: 14, end_version: 14 (inclusive)
2525
(Ort::KernelDefBuilder()
@@ -35,7 +35,7 @@ BinaryOp::BinaryOp(Ort::ConstKernelInfo info, void* state, PrivateTag)
3535
Release = ReleaseImpl;
3636

3737
// Optional functions that are only needed to pre-pack weights. This BinaryOp kernel pre-packs
38-
// input[1] weights as an example (not typically done by an actual implementations of Mul/Add).
38+
// input[1] weights as an example (not typically done by an actual implementations of Mul/Sub).
3939
PrePackWeight = PrePackWeightImpl;
4040
SetSharedPrePackedWeight = SetSharedPrePackedWeightImpl;
4141
}
@@ -47,11 +47,11 @@ OrtStatus* BinaryOp::Create(const OrtKernelInfo* info, void* state,
4747
Ort::ConstKernelInfo kernel_info(info);
4848

4949
// Note: can do basic validation or preprocessing via the OrtKernelInfo APIs.
50-
// Here, we check that this BinaryOp class is only instantiated for an onnx Mul or Add operator.
50+
// Here, we check that this BinaryOp class is only instantiated for an onnx Mul or Sub operator.
5151
std::string op_domain = kernel_info.GetOperatorDomain();
5252
std::string op_type = kernel_info.GetOperatorType();
5353

54-
if ((!op_domain.empty() && op_domain != "ai.onnx") || (op_type != "Add" && op_type != "Mul")) {
54+
if ((!op_domain.empty() && op_domain != "ai.onnx") || (op_type != "Sub" && op_type != "Mul")) {
5555
std::ostringstream oss;
5656
oss << "ExampleKernelEp's BinaryOp class does not support operator with domain '" << op_domain << "' and "
5757
<< " type '" << op_type << "'.";
@@ -110,9 +110,9 @@ OrtStatus* ORT_API_CALL BinaryOp::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernel
110110
float* output_data = output.GetTensorMutableData<float>();
111111

112112
std::string op_type = binary_op_kernel->info_.GetOperatorType();
113-
if (op_type == "Add") {
113+
if (op_type == "Sub") {
114114
for (size_t i = 0; i < input0.size(); ++i) {
115-
output_data[i] = input0[i] + input1[i];
115+
output_data[i] = input0[i] - input1[i];
116116
}
117117
} else {
118118
assert(op_type == "Mul"); // Checked by BinaryOp::Create

onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
/// <summary>
1111
/// An OrtKernelImpl class for binary element-wise operations.
12-
/// Only Add and Mul are supported currently.
12+
/// Only Sub and Mul are supported currently.
1313
/// </summary>
1414
class BinaryOp : public OrtKernelImpl {
1515
private:

onnxruntime/test/autoep/test_execution.cc

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,40 @@ void RunSqueezeMulReluModel(const Ort::SessionOptions& session_options) {
7878
EXPECT_THAT(output_span, ::testing::ElementsAre(4, 0, 24, 0, 0, 84));
7979
}
8080

81-
void RunAddMulAddModel(const Ort::SessionOptions& session_options) {
82-
// This model has Add -> Mul -> Add. The example plugin EP only supports Mul.
81+
void RunSubMulSubModel(const Ort::SessionOptions& session_options) {
82+
// This model has Sub -> Mul -> Sub: (A - B) * B - A
83+
// The example plugin EP supports all ops.
84+
Ort::Session session(*ort_env, ORT_TSTR("testdata/sub_mul_sub.onnx"), session_options);
85+
86+
// Create inputs
87+
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
88+
std::vector<int64_t> shape = {3, 2};
89+
90+
std::vector<float> a_data{1, 2, 3, 4, 5, 6};
91+
std::vector<float> b_data{2, 3, 4, 5, 6, 7};
92+
93+
std::vector<Ort::Value> ort_inputs{};
94+
ort_inputs.emplace_back(
95+
Ort::Value::CreateTensor<float>(memory_info, a_data.data(), a_data.size(), shape.data(), shape.size()));
96+
ort_inputs.emplace_back(
97+
Ort::Value::CreateTensor<float>(memory_info, b_data.data(), b_data.size(), shape.data(), shape.size()));
98+
99+
std::array ort_input_names{"A", "B"};
100+
101+
// Run session and get outputs
102+
std::array output_names{"C"};
103+
std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(),
104+
ort_inputs.size(), output_names.data(), output_names.size());
105+
106+
// Check expected output values
107+
Ort::Value& ort_output = ort_outputs[0];
108+
const float* output_data = ort_output.GetTensorData<float>();
109+
gsl::span<const float> output_span(output_data, 6);
110+
EXPECT_THAT(output_span, ::testing::ElementsAre(-3, -5, -7, -9, -11, -13));
111+
}
112+
113+
void RunPartiallySupportedModelWithPluginEp(const Ort::SessionOptions& session_options) {
114+
// This model has Add -> Mul -> Add. The example plugin EP supports Mul but not Add.
83115
Ort::Session session(*ort_env, ORT_TSTR("testdata/add_mul_add.onnx"), session_options);
84116

85117
// Create inputs
@@ -150,7 +182,7 @@ TEST(OrtEpLibrary, PluginEp_AppendV2_PartiallySupportedModelInference) {
150182
std::unordered_map<std::string, std::string> ep_options;
151183
session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
152184

153-
RunAddMulAddModel(session_options);
185+
RunPartiallySupportedModelWithPluginEp(session_options);
154186
}
155187

156188
// Generate an EPContext model with a plugin EP.
@@ -298,26 +330,26 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) {
298330
ASSERT_NO_FATAL_FAILURE(RunSqueezeMulReluModel(session_options));
299331
}
300332

301-
// Run model with add, mul, add.
333+
// Run model with sub, mul, sub.
302334
// No sharing of pre-packed weights.
303335
{
304336
Ort::SessionOptions session_options;
305337
std::unordered_map<std::string, std::string> ep_options;
306338

307339
session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP
308340
session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
309-
ASSERT_NO_FATAL_FAILURE(RunAddMulAddModel(session_options));
341+
ASSERT_NO_FATAL_FAILURE(RunSubMulSubModel(session_options));
310342
}
311343

312-
// Run model with add, mul, add.
344+
// Run model with sub, mul, sub.
313345
// Enable sharing of pre-packed weights.
314346
{
315347
std::unordered_map<std::string, std::string> ep_options = {{"enable_prepack_weight_sharing", "1"}};
316348
Ort::SessionOptions session_options;
317349

318350
session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP
319351
session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
320-
ASSERT_NO_FATAL_FAILURE(RunAddMulAddModel(session_options));
352+
ASSERT_NO_FATAL_FAILURE(RunSubMulSubModel(session_options));
321353
}
322354
}
323355
} // namespace test
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
 :�
2+

3+
A
4+
B
5+
sub_outputsub_0"Sub
6+
'
7+
8+
sub_output
9+
B
10+
mul_outputmul_0"Mul
11+

12+
13+
mul_output
14+
ACsub_1"Sub
15+
Main_graphZ
16+
A
17+

18+

19+
Z
20+
B
21+

22+

23+
b
24+
C
25+

26+

27+
B
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from onnx import TensorProto, checker, helper, save
2+
3+
# (A - B) * B - A
4+
graph_proto = helper.make_graph(
5+
nodes=[
6+
helper.make_node(
7+
"Sub",
8+
inputs=["A", "B"],
9+
outputs=["sub_output"],
10+
name="sub_0",
11+
),
12+
helper.make_node(
13+
"Mul",
14+
inputs=["sub_output", "B"],
15+
outputs=["mul_output"],
16+
name="mul_0",
17+
),
18+
helper.make_node(
19+
"Sub",
20+
inputs=["mul_output", "A"],
21+
outputs=["C"],
22+
name="sub_1",
23+
),
24+
],
25+
name="Main_graph",
26+
inputs=[
27+
helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 2]),
28+
helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]),
29+
],
30+
outputs=[
31+
helper.make_tensor_value_info("C", TensorProto.FLOAT, [3, 2]),
32+
],
33+
)
34+
35+
model = helper.make_model(graph_proto)
36+
checker.check_model(model, True)
37+
save(model, "sub_mul_sub.onnx")

0 commit comments

Comments
 (0)