@@ -773,8 +773,11 @@ struct OrtLiteCustomOp : public OrtCustomOp {
773
773
PARSE_ARGS (Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
774
774
775
775
OrtLiteCustomOp (const char * op_name,
776
- const char * execution_provider) : op_name_(op_name),
777
- execution_provider_ (execution_provider) {
776
+ const char * execution_provider,
777
+ int start_ver = 1 , int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
778
+ execution_provider_ (execution_provider),
779
+ start_ver_(start_ver),
780
+ end_ver_(end_ver) {
778
781
OrtCustomOp::version = ORT_API_VERSION;
779
782
780
783
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast <const OrtLiteCustomOp*>(op)->op_name_ .c_str (); };
@@ -837,13 +840,26 @@ struct OrtLiteCustomOp : public OrtCustomOp {
837
840
OrtCustomOp::KernelCompute = {};
838
841
839
842
OrtCustomOp::InferOutputShapeFn = {};
843
+
844
+ OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
845
+ auto self = reinterpret_cast <const OrtLiteCustomOp*>(op);
846
+ return self->start_ver_ ;
847
+ };
848
+
849
+ OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
850
+ auto self = reinterpret_cast <const OrtLiteCustomOp*>(op);
851
+ return self->end_ver_ ;
852
+ };
840
853
}
841
854
842
855
const std::string op_name_;
843
856
const std::string execution_provider_;
844
857
845
858
std::vector<ONNXTensorElementDataType> input_types_;
846
859
std::vector<ONNXTensorElementDataType> output_types_;
860
+
861
+ int start_ver_ = 1 ;
862
+ int end_ver_ = MAX_CUSTOM_OP_END_VER;
847
863
};
848
864
849
865
// ////////////////////////// OrtLiteCustomFunc ////////////////////////////////
@@ -873,9 +889,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
873
889
OrtLiteCustomFunc (const char * op_name,
874
890
const char * execution_provider,
875
891
ComputeFn compute_fn,
876
- ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider),
877
- compute_fn_ (compute_fn),
878
- shape_infer_fn_(shape_infer_fn) {
892
+ ShapeInferFn shape_infer_fn = {},
893
+ int start_ver = 1 ,
894
+ int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
895
+ compute_fn_ (compute_fn),
896
+ shape_infer_fn_(shape_infer_fn) {
879
897
ParseArgs<Args...>(input_types_, output_types_);
880
898
881
899
OrtCustomOp::KernelCompute = [](void * op_kernel, OrtKernelContext* context) {
@@ -911,9 +929,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
911
929
OrtLiteCustomFunc (const char * op_name,
912
930
const char * execution_provider,
913
931
ComputeFnReturnStatus compute_fn_return_status,
914
- ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider),
915
- compute_fn_return_status_(compute_fn_return_status),
916
- shape_infer_fn_(shape_infer_fn) {
932
+ ShapeInferFn shape_infer_fn = {},
933
+ int start_ver = 1 ,
934
+ int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
935
+ compute_fn_return_status_(compute_fn_return_status),
936
+ shape_infer_fn_(shape_infer_fn) {
917
937
ParseArgs<Args...>(input_types_, output_types_);
918
938
919
939
OrtCustomOp::KernelComputeV2 = [](void * op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
@@ -985,8 +1005,9 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
985
1005
};
986
1006
987
1007
OrtLiteCustomStruct (const char * op_name,
988
- const char * execution_provider) : OrtLiteCustomOp(op_name,
989
- execution_provider) {
1008
+ const char * execution_provider,
1009
+ int start_ver = 1 ,
1010
+ int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) {
990
1011
SetCompute (&CustomOp::Compute);
991
1012
992
1013
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
@@ -1049,25 +1070,31 @@ template <typename... Args>
1049
1070
OrtLiteCustomOp* CreateLiteCustomOp (const char * op_name,
1050
1071
const char * execution_provider,
1051
1072
void (*custom_compute_fn)(Args...),
1052
- Status (*shape_infer_fn)(ShapeInferContext&) = {}) {
1073
+ Status (*shape_infer_fn)(ShapeInferContext&) = {},
1074
+ int start_ver = 1 ,
1075
+ int end_ver = MAX_CUSTOM_OP_END_VER) {
1053
1076
using LiteOp = OrtLiteCustomFunc<Args...>;
1054
- return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn).release ();
1077
+ return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver ).release ();
1055
1078
}
1056
1079
1057
1080
template <typename ... Args>
1058
1081
OrtLiteCustomOp* CreateLiteCustomOp (const char * op_name,
1059
1082
const char * execution_provider,
1060
1083
Status (*custom_compute_fn_v2)(Args...),
1061
- Status (*shape_infer_fn)(ShapeInferContext&) = {}) {
1084
+ Status (*shape_infer_fn)(ShapeInferContext&) = {},
1085
+ int start_ver = 1 ,
1086
+ int end_ver = MAX_CUSTOM_OP_END_VER) {
1062
1087
using LiteOp = OrtLiteCustomFunc<Args...>;
1063
- return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn).release ();
1088
+ return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver ).release ();
1064
1089
}
1065
1090
1066
1091
template <typename CustomOp>
1067
1092
OrtLiteCustomOp* CreateLiteCustomOp (const char * op_name,
1068
- const char * execution_provider) {
1093
+ const char * execution_provider,
1094
+ int start_ver = 1 ,
1095
+ int end_ver = MAX_CUSTOM_OP_END_VER) {
1069
1096
using LiteOp = OrtLiteCustomStruct<CustomOp>;
1070
- return std::make_unique<LiteOp>(op_name, execution_provider).release ();
1097
+ return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver ).release ();
1071
1098
}
1072
1099
1073
1100
} // namespace Custom
0 commit comments