Skip to content

Commit 2b95e74

Browse files
Versioning for custom op (#18088)
Allow custom ops to have versions. --------- Co-authored-by: Randy Shuai <[email protected]>
1 parent 62c7894 commit 2b95e74

File tree

8 files changed

+148
-37
lines changed

8 files changed

+148
-37
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

+4
Original file line numberDiff line numberDiff line change
@@ -4605,6 +4605,10 @@ struct OrtCustomOp {
46054605
OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
46064606

46074607
OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*);
4608+
4609+
// Get start range
4610+
int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op);
4611+
int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op);
46084612
};
46094613

46104614
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

+13
Original file line numberDiff line numberDiff line change
@@ -2228,6 +2228,8 @@ struct ShapeInferContext {
22282228

22292229
using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&);
22302230

2231+
#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
2232+
22312233
template <typename TOp, typename TKernel, bool WithStatus = false>
22322234
struct CustomOpBase : OrtCustomOp {
22332235
CustomOpBase() {
@@ -2280,6 +2282,14 @@ struct CustomOpBase : OrtCustomOp {
22802282
}
22812283

22822284
SetShapeInferFn<TOp>(0);
2285+
2286+
OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
2287+
return static_cast<const TOp*>(this_)->start_ver_;
2288+
};
2289+
2290+
OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
2291+
return static_cast<const TOp*>(this_)->end_ver_;
2292+
};
22832293
}
22842294

22852295
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
@@ -2348,6 +2358,9 @@ struct CustomOpBase : OrtCustomOp {
23482358
protected:
23492359
// Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
23502360
void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
2361+
2362+
int start_ver_ = 1;
2363+
int end_ver_ = MAX_CUSTOM_OP_END_VER;
23512364
};
23522365

23532366
} // namespace Ort

include/onnxruntime/core/session/onnxruntime_lite_custom_op.h

+43-16
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,11 @@ struct OrtLiteCustomOp : public OrtCustomOp {
773773
PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
774774

775775
OrtLiteCustomOp(const char* op_name,
776-
const char* execution_provider) : op_name_(op_name),
777-
execution_provider_(execution_provider) {
776+
const char* execution_provider,
777+
int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
778+
execution_provider_(execution_provider),
779+
start_ver_(start_ver),
780+
end_ver_(end_ver) {
778781
OrtCustomOp::version = ORT_API_VERSION;
779782

780783
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
@@ -837,13 +840,26 @@ struct OrtLiteCustomOp : public OrtCustomOp {
837840
OrtCustomOp::KernelCompute = {};
838841

839842
OrtCustomOp::InferOutputShapeFn = {};
843+
844+
OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
845+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
846+
return self->start_ver_;
847+
};
848+
849+
OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
850+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
851+
return self->end_ver_;
852+
};
840853
}
841854

842855
const std::string op_name_;
843856
const std::string execution_provider_;
844857

845858
std::vector<ONNXTensorElementDataType> input_types_;
846859
std::vector<ONNXTensorElementDataType> output_types_;
860+
861+
int start_ver_ = 1;
862+
int end_ver_ = MAX_CUSTOM_OP_END_VER;
847863
};
848864

849865
//////////////////////////// OrtLiteCustomFunc ////////////////////////////////
@@ -873,9 +889,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
873889
OrtLiteCustomFunc(const char* op_name,
874890
const char* execution_provider,
875891
ComputeFn compute_fn,
876-
ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider),
877-
compute_fn_(compute_fn),
878-
shape_infer_fn_(shape_infer_fn) {
892+
ShapeInferFn shape_infer_fn = {},
893+
int start_ver = 1,
894+
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
895+
compute_fn_(compute_fn),
896+
shape_infer_fn_(shape_infer_fn) {
879897
ParseArgs<Args...>(input_types_, output_types_);
880898

881899
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
@@ -911,9 +929,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
911929
OrtLiteCustomFunc(const char* op_name,
912930
const char* execution_provider,
913931
ComputeFnReturnStatus compute_fn_return_status,
914-
ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider),
915-
compute_fn_return_status_(compute_fn_return_status),
916-
shape_infer_fn_(shape_infer_fn) {
932+
ShapeInferFn shape_infer_fn = {},
933+
int start_ver = 1,
934+
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
935+
compute_fn_return_status_(compute_fn_return_status),
936+
shape_infer_fn_(shape_infer_fn) {
917937
ParseArgs<Args...>(input_types_, output_types_);
918938

919939
OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
@@ -985,8 +1005,9 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
9851005
};
9861006

9871007
OrtLiteCustomStruct(const char* op_name,
988-
const char* execution_provider) : OrtLiteCustomOp(op_name,
989-
execution_provider) {
1008+
const char* execution_provider,
1009+
int start_ver = 1,
1010+
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) {
9901011
SetCompute(&CustomOp::Compute);
9911012

9921013
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
@@ -1049,25 +1070,31 @@ template <typename... Args>
10491070
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
10501071
const char* execution_provider,
10511072
void (*custom_compute_fn)(Args...),
1052-
Status (*shape_infer_fn)(ShapeInferContext&) = {}) {
1073+
Status (*shape_infer_fn)(ShapeInferContext&) = {},
1074+
int start_ver = 1,
1075+
int end_ver = MAX_CUSTOM_OP_END_VER) {
10531076
using LiteOp = OrtLiteCustomFunc<Args...>;
1054-
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn).release();
1077+
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
10551078
}
10561079

10571080
template <typename... Args>
10581081
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
10591082
const char* execution_provider,
10601083
Status (*custom_compute_fn_v2)(Args...),
1061-
Status (*shape_infer_fn)(ShapeInferContext&) = {}) {
1084+
Status (*shape_infer_fn)(ShapeInferContext&) = {},
1085+
int start_ver = 1,
1086+
int end_ver = MAX_CUSTOM_OP_END_VER) {
10621087
using LiteOp = OrtLiteCustomFunc<Args...>;
1063-
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn).release();
1088+
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
10641089
}
10651090

10661091
template <typename CustomOp>
10671092
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1068-
const char* execution_provider) {
1093+
const char* execution_provider,
1094+
int start_ver = 1,
1095+
int end_ver = MAX_CUSTOM_OP_END_VER) {
10691096
using LiteOp = OrtLiteCustomStruct<CustomOp>;
1070-
return std::make_unique<LiteOp>(op_name, execution_provider).release();
1097+
return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
10711098
}
10721099

10731100
} // namespace Custom

onnxruntime/core/session/custom_ops.cc

+19-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#if !defined(ORT_MINIMAL_BUILD)
2626
static constexpr uint32_t min_ort_version_with_optional_io_support = 8;
2727
static constexpr uint32_t min_ort_version_with_variadic_io_support = 14;
28+
static constexpr uint32_t min_ort_version_with_custom_version = 17;
2829
#endif
2930

3031
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
@@ -698,8 +699,19 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust
698699

699700
KernelDefBuilder def_builder;
700701
def_builder.SetName(op->GetName(op))
701-
.SetDomain(domain)
702-
.SinceVersion(1);
702+
.SetDomain(domain);
703+
704+
if (op->version >= min_ort_version_with_custom_version) {
705+
if (op->GetStartVersion && op->GetEndVersion) {
706+
def_builder.SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op));
707+
} else if (op->GetStartVersion) {
708+
def_builder.SinceVersion(op->GetStartVersion(op));
709+
} else {
710+
def_builder.SinceVersion(1);
711+
}
712+
} else {
713+
def_builder.SinceVersion(1);
714+
}
703715

704716
// GetInputMemoryType was introduced in ver 13. This check allows custom ops compiled using older versions
705717
// to work with newer versions (> 12) of the ORT binary.
@@ -820,7 +832,11 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustom
820832
schema.TypeConstraint(output_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
821833
}
822834
schema.SetDomain(domain);
823-
schema.SinceVersion(1);
835+
if (op->version >= min_ort_version_with_custom_version && op->GetStartVersion) {
836+
schema.SinceVersion(op->GetStartVersion(op));
837+
} else {
838+
schema.SinceVersion(1);
839+
}
824840
schema.AllowUncheckedAttributes();
825841

826842
if (op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn) {

onnxruntime/test/shared_lib/test_inference.cc

+16
Original file line numberDiff line numberDiff line change
@@ -3323,6 +3323,22 @@ TEST(LiteCustomOpTest, CustomFunc) {
33233323
ASSERT_TRUE(floats_output[1] == 16);
33243324
}
33253325

3326+
TEST(LiteCustomOpTest, CustomFuncOpsetMismatch) {
3327+
Ort::SessionOptions session_options;
3328+
session_options.SetIntraOpNumThreads(1);
3329+
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
3330+
session_options.SetLogSeverityLevel(0);
3331+
#if defined(_WIN32)
3332+
session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll"));
3333+
#elif defined(__APPLE__)
3334+
session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib"));
3335+
#else
3336+
session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so"));
3337+
#endif
3338+
3339+
EXPECT_THROW(Ort::Session(*ort_env, TSTR("testdata/fuse_select_filter_opset_8.onnx"), session_options), std::exception);
3340+
}
3341+
33263342
struct Merge {
33273343
Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {
33283344
int64_t reverse;

onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc

+21-16
Original file line numberDiff line numberDiff line change
@@ -94,23 +94,28 @@ void Select(const Ort::Custom::Span<int32_t>& indices_in,
9494
}
9595
}
9696

97-
void Filter(const Ort::Custom::Tensor<float>& floats_in,
98-
Ort::Custom::Tensor<float>& floats_out) {
99-
const float* in = floats_in.Data();
100-
auto in_len = floats_in.NumberOfElement();
97+
struct Filter {
98+
Filter(const OrtApi*, const OrtKernelInfo*) {}
99+
Ort::Status Compute(const Ort::Custom::Tensor<float>& floats_in,
100+
Ort::Custom::Tensor<float>& floats_out) {
101+
const float* in = floats_in.Data();
102+
auto in_len = floats_in.NumberOfElement();
103+
104+
std::vector<float> filter_floats;
105+
for (int64_t i = 0; i < in_len; ++i) {
106+
if (in[i] > 1.f) {
107+
filter_floats.push_back(in[i]);
108+
}
109+
}
101110

102-
std::vector<float> filter_floats;
103-
for (int64_t i = 0; i < in_len; ++i) {
104-
if (in[i] > 1.f) {
105-
filter_floats.push_back(in[i]);
111+
float* out = static_cast<float*>(floats_out.Allocate({static_cast<int64_t>(filter_floats.size())}));
112+
for (size_t j = 0; j < filter_floats.size(); ++j) {
113+
out[j] = filter_floats[j];
106114
}
107-
}
108115

109-
float* out = static_cast<float*>(floats_out.Allocate({static_cast<int64_t>(filter_floats.size())}));
110-
for (size_t j = 0; j < filter_floats.size(); ++j) {
111-
out[j] = filter_floats[j];
116+
return Ort::Status{nullptr};
112117
}
113-
}
118+
};
114119

115120
void Box(const Ort::Custom::Tensor<float>* float_in_1,
116121
const Ort::Custom::Tensor<float>* float_in_2,
@@ -293,9 +298,9 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
293298
static const std::unique_ptr<OrtLiteCustomOp> c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)};
294299
static const std::unique_ptr<OrtLiteCustomOp> c_MulTopOpFloat{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop<float>)};
295300
static const std::unique_ptr<OrtLiteCustomOp> c_MulTopOpInt32{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop<int32_t>)};
296-
static const std::unique_ptr<OrtLiteCustomOp> c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse)};
301+
static const std::unique_ptr<OrtLiteCustomOp> c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse, {}, 10, 12)};
297302
static const std::unique_ptr<OrtLiteCustomOp> c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)};
298-
static const std::unique_ptr<OrtLiteCustomOp> c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
303+
static const std::unique_ptr<OrtLiteCustomOp> c_Filter{Ort::Custom::CreateLiteCustomOp<Filter>("Filter", "CPUExecutionProvider", 15, 17)};
299304
static const std::unique_ptr<OrtLiteCustomOp> c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)};
300305
static const std::unique_ptr<OrtLiteCustomOp> c_CopyTensorArrayAllVariadic{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayAllVariadic", "CPUExecutionProvider", CopyTensorArrayAllVariadic<float>)};
301306
static const std::unique_ptr<OrtLiteCustomOp> c_CopyTensorArrayCombined{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayCombined", "CPUExecutionProvider", CopyTensorArrayCombined<float>)};
@@ -314,7 +319,7 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
314319
domain.Add(c_MulTopOpInt32.get());
315320
domain.Add(c_Fuse.get());
316321
domain.Add(c_Select.get());
317-
domain.Add(c_Fill.get());
322+
domain.Add(c_Filter.get());
318323
domain.Add(c_Box.get());
319324
domain.Add(c_CopyTensorArrayAllVariadic.get());
320325
domain.Add(c_CopyTensorArrayCombined.get());

onnxruntime/test/testdata/fuse_select_filter.onnx

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
:�
1+
 :�
22
P
33
vector_1
44
vector_2
@@ -25,4 +25,5 @@ N
2525
���������b&
2626
vector_filtered
2727

28-
���������B
28+
���������B
29+
v2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
 :�
2+
P
3+
vector_1
4+
vector_2
5+
alpha vector_fused fuse_node"Fuse*
6+
fuse_algo�:v2
7+
4
8+
indicesindices_selected select_node"Select:v2
9+
N
10+
vector_fused
11+
indices_selectedvector_gathered gather_node"GatherElements
12+
;
13+
vector_gatheredvector_filtered filter_node"Filter:v2graphZ
14+
vector_1
15+

16+
���������Z
17+
vector_2
18+

19+
���������Z
20+
alpha
21+

22+
���������Z
23+
indices
24+

25+
���������b&
26+
vector_filtered
27+

28+
���������B
29+
v2

0 commit comments

Comments
 (0)