Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/onnxruntime/core/graph/basic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class TensorProto;
class SparseTensorProto;
class TypeProto;
class AttributeProto;
class FunctionProto;
// define types that would come from the ONNX library if we were building against it.
#if defined(ORT_MINIMAL_BUILD)
using OperatorSetVersion = int;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct TypeProto_Sequence;
struct TypeProto;
struct ValueInfoProto;
struct ValueInfoProtos; // RepeatedPtrField
struct FunctionProto;
struct InferenceContext;
class GraphInferencer;
using InferenceFunction = std::function<void(InferenceContext&)>;
Expand All @@ -146,6 +147,7 @@ struct ConfigOptions;
struct DataTransferManager;
struct IndexedSubGraph;
struct IndexedSubGraph_MetaDef;
enum class IndexedSubGraph_SourceOfSchema : uint8_t;
struct KernelCreateInfo;
struct KernelDef;
struct KernelDefBuilder;
Expand Down
62 changes: 62 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ struct ProviderHost {
virtual int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) = 0;
virtual const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const = 0;
virtual ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) = 0;
virtual ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) = 0;

// TensorProto
virtual std::unique_ptr<ONNX_NAMESPACE::TensorProto> TensorProto__construct() = 0;
Expand Down Expand Up @@ -489,6 +490,64 @@ struct ProviderHost {

virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0;

// FunctionProto
virtual std::unique_ptr<ONNX_NAMESPACE::FunctionProto> FunctionProto__construct() = 0;
virtual void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) = 0;
virtual bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) = 0;
virtual bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) = 0;
virtual std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const = 0;
virtual void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& name) = 0;

virtual bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const = 0;
virtual void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& doc_string) = 0;

virtual bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const = 0;
virtual void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& domain) = 0;

virtual const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0;

virtual const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0;

virtual const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0;

virtual const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0;

// ConfigOptions
Expand Down Expand Up @@ -540,6 +599,9 @@ struct ProviderHost {
virtual void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr<IndexedSubGraph_MetaDef>&& meta_def_) = 0;
virtual const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) = 0;

virtual void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) = 0;
virtual IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) = 0;

// KernelDef
virtual void KernelDef__operator_delete(KernelDef* p) = 0;
virtual int KernelDef__ExecQueueId(const KernelDef* p) = 0;
Expand Down
73 changes: 73 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ struct NodeProto final {
int attribute_size() { return g_host->NodeProto__attribute_size(this); }
const AttributeProto& attribute(int index) const { return g_host->NodeProto__attribute(this, index); }
AttributeProto* mutable_attribute(int index) { return g_host->NodeProto__mutable_attribute(this, index); }
AttributeProto* add_attribute() { return g_host->NodeProto__add_attribute(this); }

NodeProto() = delete;
NodeProto(const NodeProto&) = delete;
Expand Down Expand Up @@ -372,6 +373,69 @@ struct ValueInfoProtos final {

PROVIDER_DISALLOW_ALL(ValueInfoProtos)
};

struct FunctionProto final {
static std::unique_ptr<FunctionProto> Create() { return g_host->FunctionProto__construct(); }
static void operator delete(void* p) { g_host->FunctionProto__operator_delete(reinterpret_cast<FunctionProto*>(p)); }

bool SerializeToString(std::string& string) const { return g_host->FunctionProto__SerializeToString(this, string); }
bool SerializeToOstream(std::ostream& output) const { return g_host->FunctionProto__SerializeToOstream(this, output); }
bool ParseFromString(const std::string& data) { return g_host->FunctionProto__ParseFromString(this, data); }
std::string SerializeAsString() const { return g_host->FunctionProto__SerializeAsString(this); }

bool has_name() const { return g_host->FunctionProto__has_name(this); }
const std::string& name() const { return g_host->FunctionProto__name(this); }
void set_name(const std::string& name) { g_host->FunctionProto__set_name(this, name); }

bool has_doc_string() const { return g_host->FunctionProto__has_doc_string(this); }
const std::string& doc_string() const { return g_host->FunctionProto__doc_string(this); }
void set_doc_string(const std::string& doc_string) { g_host->FunctionProto__set_doc_string(this, doc_string); }

bool has_domain() const { return g_host->FunctionProto__has_domain(this); }
const std::string& domain() const { return g_host->FunctionProto__domain(this); }
void set_domain(const std::string& domain) { g_host->FunctionProto__set_domain(this, domain); }

const std::string& input(int index) const { return g_host->FunctionProto__input(this, index); }
std::string* mutable_input(int index) { return g_host->FunctionProto__mutable_input(this, index); }
int input_size() const { return g_host->FunctionProto__input_size(this); }
void add_input(const std::string& value) { g_host->FunctionProto__add_input(this, value); }

const std::string& output(int index) const { return g_host->FunctionProto__output(this, index); }
std::string* mutable_output(int index) { return g_host->FunctionProto__mutable_output(this, index); }
int output_size() const { return g_host->FunctionProto__output_size(this); }
void add_output(const std::string& value) { g_host->FunctionProto__add_output(this, value); }

const std::string& attribute(int index) const { return g_host->FunctionProto__attribute(this, index); }
std::string* mutable_attribute(int index) { return g_host->FunctionProto__mutable_attribute(this, index); }
int attribute_size() const { return g_host->FunctionProto__attribute_size(this); }
void add_attribute(const std::string& value) { g_host->FunctionProto__add_attribute(this, value); }

const AttributeProto& attribute_proto(int index) const { return g_host->FunctionProto__attribute_proto(this, index); }
AttributeProto* mutable_attribute_proto(int index) { return g_host->FunctionProto__mutable_attribute_proto(this, index); }
int attribute_proto_size() const { return g_host->FunctionProto__attribute_proto_size(this); }
AttributeProto* add_attribute_proto() { return g_host->FunctionProto__add_attribute_proto(this); }

const NodeProto& node(int index) const { return g_host->FunctionProto__node(this, index); }
NodeProto* mutable_node(int index) { return g_host->FunctionProto__mutable_node(this, index); }
int node_size() const { return g_host->FunctionProto__node_size(this); }
NodeProto* add_node() { return g_host->FunctionProto__add_node(this); }

const ValueInfoProto& value_info(int index) const { return g_host->FunctionProto__value_info(this, index); }
ValueInfoProtos* mutable_value_info() { return g_host->FunctionProto__mutable_value_info(this); }
ValueInfoProto* mutable_value_info(int index) { return g_host->FunctionProto__mutable_value_info(this, index); }
int value_info_size() const { return g_host->FunctionProto__value_info_size(this); }
ValueInfoProto* add_value_info() { return g_host->FunctionProto__add_value_info(this); }

const StringStringEntryProto& metadata_props(int index) const { return g_host->FunctionProto__metadata_props(this, index); }
StringStringEntryProtos* mutable_metadata_props() { return g_host->FunctionProto__mutable_metadata_props(this); }
StringStringEntryProto* mutable_metadata_props(int index) { return g_host->FunctionProto__mutable_metadata_props(this, index); }
int metadata_props_size() const { return g_host->FunctionProto__metadata_props_size(this); }
StringStringEntryProto* add_metadata_props() { return g_host->FunctionProto__add_metadata_props(this); }

FunctionProto() = delete;
FunctionProto(const FunctionProto&) = delete;
void operator=(const FunctionProto&) = delete;
};
} // namespace ONNX_NAMESPACE

namespace onnxruntime {
Expand Down Expand Up @@ -449,6 +513,12 @@ struct IndexedSubGraph_MetaDef final {
void operator=(const IndexedSubGraph_MetaDef&) = delete;
};

enum class IndexedSubGraph_SourceOfSchema : uint8_t {
CREATE,
REUSE_OR_CREATE,
EXISTING,
};

struct IndexedSubGraph final {
static std::unique_ptr<IndexedSubGraph> Create() { return g_host->IndexedSubGraph__construct(); }
static void operator delete(void* p) { g_host->IndexedSubGraph__operator_delete(reinterpret_cast<IndexedSubGraph*>(p)); }
Expand All @@ -458,6 +528,9 @@ struct IndexedSubGraph final {
void SetMetaDef(std::unique_ptr<IndexedSubGraph_MetaDef>&& meta_def_) { return g_host->IndexedSubGraph__SetMetaDef(this, std::move(*reinterpret_cast<std::unique_ptr<IndexedSubGraph_MetaDef>*>(&meta_def_))); }
const IndexedSubGraph_MetaDef* GetMetaDef() const { return reinterpret_cast<const IndexedSubGraph_MetaDef*>(g_host->IndexedSubGraph__GetMetaDef(this)); }

void SetSchemaSource(IndexedSubGraph_SourceOfSchema schema_source) { return g_host->IndexedSubGraph__SetSchemaSource(this, schema_source); }
IndexedSubGraph_SourceOfSchema GetSchemaSource() const { return g_host->IndexedSubGraph__GetSchemaSource(this); }

IndexedSubGraph() = delete;
IndexedSubGraph(const IndexedSubGraph&) = delete;
void operator=(const IndexedSubGraph&) = delete;
Expand Down
Loading