diff --git a/include/onnxruntime/core/graph/basic_types.h b/include/onnxruntime/core/graph/basic_types.h index 36984d0405bbd..67c7db7979f81 100644 --- a/include/onnxruntime/core/graph/basic_types.h +++ b/include/onnxruntime/core/graph/basic_types.h @@ -19,6 +19,7 @@ class TensorProto; class SparseTensorProto; class TypeProto; class AttributeProto; +class FunctionProto; // define types that would come from the ONNX library if we were building against it. #if defined(ORT_MINIMAL_BUILD) using OperatorSetVersion = int; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 7cdfb0ffc19f2..0bb11ec74fe32 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -120,6 +120,7 @@ struct TypeProto_Sequence; struct TypeProto; struct ValueInfoProto; struct ValueInfoProtos; // RepeatedPtrField +struct FunctionProto; struct InferenceContext; class GraphInferencer; using InferenceFunction = std::function; @@ -146,6 +147,7 @@ struct ConfigOptions; struct DataTransferManager; struct IndexedSubGraph; struct IndexedSubGraph_MetaDef; +enum class IndexedSubGraph_SourceOfSchema : uint8_t; struct KernelCreateInfo; struct KernelDef; struct KernelDefBuilder; diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index cc3b13f696a96..de741f752abae 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -421,6 +421,7 @@ struct ProviderHost { virtual int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) = 0; virtual const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const = 0; virtual ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) = 0; + virtual ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) = 0; // TensorProto virtual std::unique_ptr TensorProto__construct() = 0; @@ -489,6 +490,64 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0; + // FunctionProto + virtual std::unique_ptr FunctionProto__construct() = 0; + virtual void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) = 0; + virtual bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) = 0; + virtual bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) = 0; + virtual std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const = 0; + virtual void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& name) = 0; + + virtual bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const = 0; + virtual void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& doc_string) = 0; + + virtual bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const = 0; + virtual void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& domain) = 0; + + virtual const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0; + + virtual const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0; + + virtual const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0; + + virtual const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0; + + virtual const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) = 0; + virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0; // ConfigOptions @@ -540,6 +599,9 @@ struct ProviderHost { virtual void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) = 0; virtual const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) = 0; + virtual void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) = 0; + virtual IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) = 0; + // KernelDef virtual void KernelDef__operator_delete(KernelDef* p) = 0; virtual int KernelDef__ExecQueueId(const KernelDef* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index fd2540b42a3db..5569ac0dcba6f 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -190,6 +190,7 @@ struct NodeProto final { int attribute_size() { return g_host->NodeProto__attribute_size(this); } const AttributeProto& attribute(int index) const { return g_host->NodeProto__attribute(this, index); } AttributeProto* mutable_attribute(int index) { return g_host->NodeProto__mutable_attribute(this, index); } + AttributeProto* add_attribute() { return g_host->NodeProto__add_attribute(this); } NodeProto() = delete; NodeProto(const NodeProto&) = delete; @@ -372,6 +373,69 @@ struct ValueInfoProtos final { PROVIDER_DISALLOW_ALL(ValueInfoProtos) }; + +struct FunctionProto final { + static std::unique_ptr Create() { return g_host->FunctionProto__construct(); } + static void operator delete(void* p) { g_host->FunctionProto__operator_delete(reinterpret_cast(p)); } + + bool SerializeToString(std::string& string) const { return g_host->FunctionProto__SerializeToString(this, string); } + bool SerializeToOstream(std::ostream& output) const { return g_host->FunctionProto__SerializeToOstream(this, output); } + bool ParseFromString(const std::string& data) { return g_host->FunctionProto__ParseFromString(this, data); } + std::string SerializeAsString() const { return g_host->FunctionProto__SerializeAsString(this); } + + bool has_name() const { return g_host->FunctionProto__has_name(this); } + const std::string& name() const { return g_host->FunctionProto__name(this); } + void set_name(const std::string& name) { g_host->FunctionProto__set_name(this, name); } + + bool has_doc_string() const { return g_host->FunctionProto__has_doc_string(this); } + const std::string& doc_string() const { return g_host->FunctionProto__doc_string(this); } + void set_doc_string(const std::string& doc_string) { g_host->FunctionProto__set_doc_string(this, doc_string); } + + bool has_domain() const { return g_host->FunctionProto__has_domain(this); } + const std::string& domain() const { return g_host->FunctionProto__domain(this); } + void set_domain(const std::string& domain) { g_host->FunctionProto__set_domain(this, domain); } + + const std::string& input(int index) const { return g_host->FunctionProto__input(this, index); } + std::string* mutable_input(int index) { return g_host->FunctionProto__mutable_input(this, index); } + int input_size() const { return g_host->FunctionProto__input_size(this); } + void add_input(const std::string& value) { g_host->FunctionProto__add_input(this, value); } + + const std::string& output(int index) const { return g_host->FunctionProto__output(this, index); } + std::string* mutable_output(int index) { return g_host->FunctionProto__mutable_output(this, index); } + int output_size() const { return g_host->FunctionProto__output_size(this); } + void add_output(const std::string& value) { g_host->FunctionProto__add_output(this, value); } + + const std::string& attribute(int index) const { return g_host->FunctionProto__attribute(this, index); } + std::string* mutable_attribute(int index) { return g_host->FunctionProto__mutable_attribute(this, index); } + int attribute_size() const { return g_host->FunctionProto__attribute_size(this); } + void add_attribute(const std::string& value) { g_host->FunctionProto__add_attribute(this, value); } + + const AttributeProto& attribute_proto(int index) const { return g_host->FunctionProto__attribute_proto(this, index); } + AttributeProto* mutable_attribute_proto(int index) { return g_host->FunctionProto__mutable_attribute_proto(this, index); } + int attribute_proto_size() const { return g_host->FunctionProto__attribute_proto_size(this); } + AttributeProto* add_attribute_proto() { return g_host->FunctionProto__add_attribute_proto(this); } + + const NodeProto& node(int index) const { return g_host->FunctionProto__node(this, index); } + NodeProto* mutable_node(int index) { return g_host->FunctionProto__mutable_node(this, index); } + int node_size() const { return g_host->FunctionProto__node_size(this); } + NodeProto* add_node() { return g_host->FunctionProto__add_node(this); } + + const ValueInfoProto& value_info(int index) const { return g_host->FunctionProto__value_info(this, index); } + ValueInfoProtos* mutable_value_info() { return g_host->FunctionProto__mutable_value_info(this); } + ValueInfoProto* mutable_value_info(int index) { return g_host->FunctionProto__mutable_value_info(this, index); } + int value_info_size() const { return g_host->FunctionProto__value_info_size(this); } + ValueInfoProto* add_value_info() { return g_host->FunctionProto__add_value_info(this); } + + const StringStringEntryProto& metadata_props(int index) const { return g_host->FunctionProto__metadata_props(this, index); } + StringStringEntryProtos* mutable_metadata_props() { return g_host->FunctionProto__mutable_metadata_props(this); } + StringStringEntryProto* mutable_metadata_props(int index) { return g_host->FunctionProto__mutable_metadata_props(this, index); } + int metadata_props_size() const { return g_host->FunctionProto__metadata_props_size(this); } + StringStringEntryProto* add_metadata_props() { return g_host->FunctionProto__add_metadata_props(this); } + + FunctionProto() = delete; + FunctionProto(const FunctionProto&) = delete; + void operator=(const FunctionProto&) = delete; +}; } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -449,6 +513,12 @@ struct IndexedSubGraph_MetaDef final { void operator=(const IndexedSubGraph_MetaDef&) = delete; }; +enum class IndexedSubGraph_SourceOfSchema : uint8_t { + CREATE, + REUSE_OR_CREATE, + EXISTING, +}; + struct IndexedSubGraph final { static std::unique_ptr Create() { return g_host->IndexedSubGraph__construct(); } static void operator delete(void* p) { g_host->IndexedSubGraph__operator_delete(reinterpret_cast(p)); } @@ -458,6 +528,9 @@ struct IndexedSubGraph final { void SetMetaDef(std::unique_ptr&& meta_def_) { return g_host->IndexedSubGraph__SetMetaDef(this, std::move(*reinterpret_cast*>(&meta_def_))); } const IndexedSubGraph_MetaDef* GetMetaDef() const { return reinterpret_cast(g_host->IndexedSubGraph__GetMetaDef(this)); } + void SetSchemaSource(IndexedSubGraph_SourceOfSchema schema_source) { return g_host->IndexedSubGraph__SetSchemaSource(this, schema_source); } + IndexedSubGraph_SourceOfSchema GetSchemaSource() const { return g_host->IndexedSubGraph__GetSchemaSource(this); } + IndexedSubGraph() = delete; IndexedSubGraph(const IndexedSubGraph&) = delete; void operator=(const IndexedSubGraph&) = delete; diff --git a/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc new file mode 100644 index 0000000000000..2094967cd34f8 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc @@ -0,0 +1,345 @@ +// Standard headers/libs. +#include +#include +#include + +#include "ep_context_utils.h" + +namespace fs = std::filesystem; + +namespace onnxruntime { + +std::unique_ptr ConvertIndexedSubGraphToFunctionProto( + const IndexedSubGraph& sub_graph, const Graph& parent_graph) { + auto p_func_proto = ONNX_NAMESPACE::FunctionProto::Create(); + auto* p_meta_def = const_cast(sub_graph.GetMetaDef()); + if (p_meta_def) { + p_func_proto->set_name(p_meta_def->name()); + p_func_proto->set_domain(p_meta_def->domain()); + for (const auto& input : p_meta_def->inputs()) { + p_func_proto->add_input(input); + } + auto* p_metadata_props_0 = p_func_proto->add_metadata_props(); + *(p_metadata_props_0->mutable_key()) = "meta_def_inputs_size"; + *(p_metadata_props_0->mutable_value()) = std::to_string(p_meta_def->inputs().size()); + for (const auto& output : p_meta_def->outputs()) { + p_func_proto->add_output(output); + } + // XXX: SerDes with different fields. + for (const auto& initializer : p_meta_def->constant_initializers()) { + p_func_proto->add_input(initializer); + } + // XXX: SerDes with different numbers of fields. + for (const auto& attr_pair : p_meta_def->attributes()) { + p_func_proto->add_attribute(attr_pair.first); + auto* p_attr_proto = p_func_proto->add_attribute_proto(); + *p_attr_proto = attr_pair.second; + } + p_func_proto->set_doc_string(p_meta_def->doc_string()); + // "since_version" + auto* p_metadata_props_1 = p_func_proto->add_metadata_props(); + *(p_metadata_props_1->mutable_key()) = "meta_def_since_version"; + *(p_metadata_props_1->mutable_value()) = std::to_string(p_meta_def->since_version()); + // "status" + auto* p_metadata_props_2 = p_func_proto->add_metadata_props(); + *(p_metadata_props_2->mutable_key()) = "meta_def_status"; + *(p_metadata_props_2->mutable_value()) = std::to_string(static_cast(p_meta_def->status())); + // TODO: `MetaDef::type_and_shape_inference_function`. + } + auto p_parent_graph_proto = parent_graph.ToGraphProto(); + for (auto node_index : const_cast(sub_graph).Nodes()) { + auto* p_node_proto = p_parent_graph_proto->mutable_node(node_index); + auto* p_attr_proto = p_node_proto->add_attribute(); + p_attr_proto->set_name("parent_graph_node_index"); + p_attr_proto->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_proto->set_i(node_index); + *(p_func_proto->add_node()) = *p_node_proto; + } +#if 0 + // Alternative. + for (const auto node_index : sub_graph.Nodes()) { + const auto* p_node = parent_graph.GetNode(node_index); + auto p_node_proto = ONNX_NAMESPACE::NodeProto::Create(); + // XXX + p_node->ToProto(*p_node_proto, true); + p_attr_proto->set_name("parent_graph_node_index"); + p_attr_proto->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_proto->set_i(node_index); + *(p_func_proto.add_node()) = *p_node_proto; + } +#endif + auto* p_metadata_props_3 = p_func_proto->add_metadata_props(); + *p_metadata_props_3->mutable_key() = "schema_source"; + *p_metadata_props_3->mutable_value() = std::to_string(static_cast(sub_graph.GetSchemaSource())); + return p_func_proto; +} + +std::unique_ptr ConvertFunctionProtoToIndexedSubGraph( + const std::unique_ptr& p_func_proto) { + auto p_isg = IndexedSubGraph::Create(); + // "meta_def_inputs_size" (optional) and "schema_source". + int func_metadata_props_size = p_func_proto->metadata_props_size(); + // Precisely, func_metadata_props_size == 4, which implies + // `IndexedSubGraph::meta_def_` is not null and `IndexedSubGraph::nodes` > 1. + if (func_metadata_props_size > 1) { + auto& prop0 = const_cast(p_func_proto->metadata_props(0)); + int isg_meta_def_inputs_size = std::stoi(*(prop0.mutable_value())); + auto p_meta_def = IndexedSubGraph_MetaDef::Create(); + p_meta_def->name() = p_func_proto->name(); + p_meta_def->domain() = p_func_proto->domain(); + auto& prop1 = const_cast(p_func_proto->metadata_props(1)); + p_meta_def->since_version() = std::stoi(*(prop1.mutable_value())); + auto& prop2 = const_cast(p_func_proto->metadata_props(2)); + p_meta_def->status() = static_cast(std::stoi(*(prop2.mutable_value()))); + auto& meta_def_inputs = p_meta_def->inputs(); + for (int i = 0; i < isg_meta_def_inputs_size; i++) { + meta_def_inputs.push_back(p_func_proto->input(i)); + } + auto& meta_def_outputs = p_meta_def->outputs(); + for (int i = 0, l = p_func_proto->output_size(); i < l; i++) { + meta_def_outputs.push_back(p_func_proto->output(i)); + } + auto& meta_def_initializers = p_meta_def->constant_initializers(); + for (int i = isg_meta_def_inputs_size, l = p_func_proto->input_size(); i < l; i++) { + meta_def_initializers.push_back(p_func_proto->input(i)); + } + auto& meta_def_attrs = p_meta_def->attributes(); + for (int i = 0, l = p_func_proto->attribute_size(); i < l; i++) { + meta_def_attrs.emplace(p_func_proto->attribute(i), p_func_proto->attribute_proto(i)); + } + p_meta_def->doc_string() = p_func_proto->doc_string(); + // TODO: `IndexedSubGraph::type_and_shape_inference_function`. + p_isg->SetMetaDef(std::move(p_meta_def)); + } + auto& isg_nodes = p_isg->Nodes(); + for (int i = 0, l = p_func_proto->node_size(); i < l; i++) { + const auto& node_proto = p_func_proto->node(i); + isg_nodes.push_back(node_proto.attribute(const_cast(node_proto).attribute_size() - 1).i()); + } + auto schema_source = static_cast( + std::stoi(*(const_cast(p_func_proto->metadata_props(func_metadata_props_size - 1)).mutable_value()))); + p_isg->SetSchemaSource(schema_source); + return p_isg; +} + +std::string SerializeCapabilities( + const std::vector>& capability_ptrs, + const Graph& graph) { + std::stringstream ss; + for (const auto& p : capability_ptrs) { + auto& p_subgraph = p->SubGraph(); + auto p_func_proto = ConvertIndexedSubGraphToFunctionProto(*p_subgraph, graph); + std::string func_proto_buf; + p_func_proto->SerializeToString(func_proto_buf); + size_t buf_len = func_proto_buf.length(); + ss.write(reinterpret_cast(&buf_len), sizeof(buf_len)); + ss.write(func_proto_buf.data(), buf_len); + } + if (!ss.good()) { + ORT_THROW("Serialization stream bad"); + } + return ss.str(); +} + +void DeserializeCapabilities(const std::string& ser_capabilities, + std::vector>& capability_ptrs) { + std::istringstream ss(ser_capabilities); + while (!ss.eof()) { + size_t buf_len; + ss.read(reinterpret_cast(&buf_len), sizeof(buf_len)); + std::string buf(buf_len, '\0'); + ss.read(&buf[0], buf_len); + auto p_func_proto = ONNX_NAMESPACE::FunctionProto::Create(); + p_func_proto->ParseFromString(buf); + auto p_subgraph = ConvertFunctionProtoToIndexedSubGraph(p_func_proto); + capability_ptrs.push_back(ComputeCapability::Create(std::move(p_subgraph))); + } +} + +std::unique_ptr CreateEPContexModel( + const GraphViewer& graph_viewer, + const std::string& serialized_ctx_cache, + const std::string& ctx_cache_file_loc, + const int64_t embed_mode, + const logging::Logger* p_logger) { + // Create a new graph/model, reusing the graph name, + // the op-domain-to-opset-version map, + // and the op schema registry of the current graph. + auto& ep_ctx_graph = graph_viewer.CreateModel(*p_logger)->MainGraph(); + + std::vector input_node_arg_ptrs; + // XXX: vs `GraphViewer::GetInputsIncludingInitializers()`. + for (const auto* p_node_arg : graph_viewer.GetInputs()) { + auto& temp_node_arg = ep_ctx_graph.GetOrCreateNodeArg( + p_node_arg->Name(), p_node_arg->TypeAsProto()); + input_node_arg_ptrs.push_back(&temp_node_arg); + } + std::vector output_node_arg_ptrs; + for (const auto* p_node_arg : graph_viewer.GetOutputs()) { + auto& temp_node_arg = ep_ctx_graph.GetOrCreateNodeArg(p_node_arg->Name(), p_node_arg->TypeAsProto()); + output_node_arg_ptrs.push_back(&temp_node_arg); + } + + // Attr "embed_mode". + auto p_attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_0->set_name(kEmbedModeAttr); + // p_attr_0->set_type(onnx::AttributeProto_AttributeType_INT); + p_attr_0->set_type(ONNX_NAMESPACE::AttributeProto::INT); + p_attr_0->set_i(embed_mode); + // Attr "ep_cache_context". + auto p_attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_1->set_name(kEPCacheContextAttr); + // p_attr_1->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_1->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_1->set_s(embed_mode == 0 ? ctx_cache_file_loc : serialized_ctx_cache); + // Attr "source". + auto p_attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); + p_attr_2->set_name(kSourceAttr); + // p_attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); + p_attr_2->set_type(ONNX_NAMESPACE::AttributeProto::STRING); + p_attr_2->set_s(kVitisAIExecutionProvider); + + auto p_node_attrs = NodeAttributes::Create(); + constexpr int num_attrs = 3; + p_node_attrs->reserve(num_attrs); + p_node_attrs->emplace(kEmbedModeAttr, *p_attr_0); + p_node_attrs->emplace(kEPCacheContextAttr, *p_attr_1); + p_node_attrs->emplace(kSourceAttr, *p_attr_2); + + ep_ctx_graph.AddNode(kEPContextOp, kEPContextOp, "", input_node_arg_ptrs, output_node_arg_ptrs, p_node_attrs.get(), kEPContextOpDomain); + ORT_ENFORCE(ep_ctx_graph.Resolve().IsOK()); + auto p_ep_ctx_graph_viewer = ep_ctx_graph.CreateGraphViewer(); + auto p_ep_ctx_model = p_ep_ctx_graph_viewer->CreateModel(*p_logger); + auto p_ep_ctx_model_proto = p_ep_ctx_model->ToProto(); + p_ep_ctx_model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + p_ep_ctx_graph_viewer->ToProto(*(p_ep_ctx_model_proto->mutable_graph()), true, true); + + return p_ep_ctx_model; +} + +void DumpEPContextModel( + const std::unique_ptr& p_model, const std::string& ep_ctx_model_file_loc) { + std::fstream dump_stream(ep_ctx_model_file_loc, std::ios::out | std::ios::trunc | std::ios::binary); + p_model->ToProto()->SerializeToOstream(dump_stream); + LOGS_DEFAULT(VERBOSE) << "[VitisAI EP] Dumped " << ep_ctx_model_file_loc; +} + +bool ValidateEPContextNode(const Graph& graph) { + // TODO: Support for multi-node EP context model. + assert(graph.Nodes().size() == 1); + auto* p_node = graph.GetNode(0); + assert(p_node->OpType() == kEPContextOp); + auto& attrs = p_node->GetAttributes(); + assert(attrs.count(kEmbedModeAttr) > 0); + assert(attrs.count(kEPCacheContextAttr) > 0); + assert(attrs.count(kSourceAttr) > 0); + (void)attrs; + return true; +} + +std::string RetrieveEPContextCache(const Graph& graph) { + if (!ValidateEPContextNode(graph)) { + ORT_THROW("Invalid EP context model for Vitis AI"); + } + // TODO: Support for multi-node EP context model. + auto* p_node = graph.GetNode(0); + const auto& attrs = p_node->GetAttributes(); + int64_t embed_mode = attrs.at(kEmbedModeAttr).i(); + const std::string& ep_ctx_cache = attrs.at(kEPCacheContextAttr).s(); + if (embed_mode) { + return ep_ctx_cache; + } + fs::path ep_ctx_file_loc(ep_ctx_cache); + // TODO: Validaion of the file location to make sure security is met. + if (!fs::exists(ep_ctx_file_loc) || !fs::is_regular_file(ep_ctx_file_loc)) { + ORT_THROW("File for EP context cache is missing"); + } + std::ifstream ifs(ep_ctx_cache, std::ios::binary | std::ios::in); + if (!ifs.is_open()) { + ORT_THROW("Exception opening EP context cache file"); + } + ifs.seekg(0, ifs.end); + int cache_len = ifs.tellg(); + ifs.seekg(0, ifs.beg); + char* buf = new char[cache_len]; + ifs.read(buf, cache_len); + if (!ifs.good()) { + ifs.close(); + ORT_THROW("Exception reading EP context cache file"); + } + ifs.close(); + std::string cache_payload(buf); + delete[] buf; + return cache_payload; +} + +bool GraphHasEPContextNode(const GraphViewer& graph_viewer) { + for (size_t i = 0, l = static_cast(graph_viewer.MaxNodeIndex()); i < l; i++) { + auto* p_node = graph_viewer.GetNode(i); + if (p_node != nullptr && p_node->OpType() == kEPContextOp) { + const auto& attrs = p_node->GetAttributes(); + if (attrs.count(kSourceAttr) > 0 && attrs.at(kSourceAttr).s() == kVitisAIExecutionProvider) { + return true; + } + } + } + return false; +} + +bool FusedGraphHasEPContextNode( + const std::vector& fused_nodes_and_graphs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + bool has_node = GraphHasEPContextNode(fused_node_graph.filtered_graph); + if (has_node) { + return true; + } + } + return false; +} + +const Path& GetTopLevelModelPath(const GraphViewer& graph_viewer) { + const auto& graph = graph_viewer.GetGraph(); + const Graph* p_graph = &graph; + while (p_graph->IsSubgraph()) { + p_graph = p_graph->ParentGraph(); + } + return p_graph->ModelPath(); +} + +bool GetEPContextModelFileLocation( + const std::string& ep_ctx_model_path_cfg, + const PathString& model_path_str, + bool is_ep_ctx_model, + PathString& ep_ctx_model_file_loc) { + // if (!ep_ctx_model_file_loc.empty()) { + // return true; + // } + if (!ep_ctx_model_path_cfg.empty()) { + ep_ctx_model_file_loc = ToPathString(ep_ctx_model_path_cfg); + } else if (!model_path_str.empty()) { + if (is_ep_ctx_model) { + ep_ctx_model_file_loc = model_path_str; + } else { + ep_ctx_model_file_loc = + ToPathString(fs::path(model_path_str).stem().string() + "_ctx.onnx"); + } + } + return !ep_ctx_model_file_loc.empty() && fs::exists(ep_ctx_model_file_loc) && fs::is_regular_file(ep_ctx_model_file_loc); +} + +// The file for EP context binary is in the same folder as the EP context model file. +PathString GetEPContextCacheFileLocation( + const PathString& ep_ctx_model_file_loc, const PathString& model_path_str) { + if (!ep_ctx_model_file_loc.empty()) { + fs::path ep_ctx_model_fs_path(ep_ctx_model_file_loc); + auto ep_ctx_cache_fs_path = + ep_ctx_model_fs_path.replace_extension(fs::path("__ep_ctx_cache.bin")); + return ToPathString(ep_ctx_cache_fs_path.string()); + } + fs::path model_fs_path(model_path_str); + auto ep_ctx_cache_fs_path = + model_fs_path.replace_extension(fs::path("__ep_ctx_cache.bin")); + return ToPathString(ep_ctx_cache_fs_path.string()); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/include/ep_context_utils.h b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h new file mode 100644 index 0000000000000..1e3bf8e53ba3f --- /dev/null +++ b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h @@ -0,0 +1,56 @@ +#pragma once + +// Standard headers/libs. +#include +#include +#include + +// 1st-party headers/libs. +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { + +static constexpr const char* kEPContextOp = "EPContext"; +static constexpr const char* kMainContextAttr = "main_context"; +static constexpr const char* kEPCacheContextAttr = "ep_cache_context"; +static constexpr const char* kEmbedModeAttr = "embed_mode"; +static constexpr const char* kPartitionNameAttr = "partition_name"; +static constexpr const char* kSourceAttr = "source"; +static constexpr const char* kEPSDKVersionAttr = "ep_sdk_version"; +static constexpr const char* kEPContextOpDomain = "com.microsoft"; + +std::unique_ptr +ConvertIndexedSubGraphToFunctionProto(const IndexedSubGraph&, const Graph&); + +std::unique_ptr ConvertFunctionProtoToIndexedSubGraph( + const std::unique_ptr&); + +std::string SerializeCapabilities( + const std::vector>&, const Graph&); + +void DeserializeCapabilities( + const std::string&, std::vector>&); + +std::unique_ptr CreateEPContexModel(const GraphViewer&, const std::string&, + const std::string&, const int64_t, const logging::Logger*); + +void DumpEPContextModel(std::unique_ptr&, const std::string&); + +bool ValidateEPContextNode(const Graph&); + +std::string RetrieveEPContextCache(const Graph&); + +bool GraphHasEPContextNode(const GraphViewer&); + +bool FusedGraphHasEPContextNode( + const std::vector&); + +const Path& GetTopLevelModelPath(const GraphViewer&); + +bool GetEPContextModelFileLocation( + const std::string&, const PathString&, bool, PathString&); + +// The file for EP context binary is in the same folder as the EP context model file. +PathString GetEPContextCacheFileLocation(const PathString&, const PathString&); + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 6fc09f3495aa1..26036bb34d13c 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -2,22 +2,80 @@ // Licensed under the MIT License. #include "vitisai_execution_provider.h" +// Standard headers/libs. #include #include #include +#include #include "vaip/capability.h" #include "vaip/global_api.h" +#include "ep_context_utils.h" using namespace ONNX_NAMESPACE; +namespace fs = std::filesystem; + namespace onnxruntime { constexpr const char* VITISAI = "VITISAI"; VitisAIExecutionProvider::VitisAIExecutionProvider( const ProviderOptions& info) + // const ProviderOptions& info, const SessionOptions* p_sess_opts) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { CreateKernelRegistry(); + +#if 0 + if (p_sess_opts) { + ep_ctx_enabled_ = p_sess_opts->config_options.GetConfigOrDefault( + kOrtSessionOptionEpContextEnable, "0") == "1"; + std::string embed_mode = p_sess_opts->config_options.GetConfigOrDefault( + kOrtSessionOptionEpContextEmbedMode, "1"); + if ("1" == embed_mode) { + ep_ctx_embed_mode_ = true; + } else if ("0" == embed_mode) { + ep_ctx_embed_mode_ = false; + } else { + LOGS_DEFAULT(VERBOSE) << "Invalid ep.context_embed_mode: " << embed_mode << " only 0 or 1 allowed. Set to 1."; + ep_ctx_embed_mode_ = true; + } + ep_ctx_model_path_cfg_ = p_sess_opts->config_options.GetConfigOrDefault( + kOrtSessionOptionEpContextFilePath, ""); + } else { +#endif + auto it = info_.find("ep_context_enable"); + ep_ctx_enabled_ = it != info_.end() && it->second == "1"; + it = info_.find("ep_context_embed_mode"); + ep_ctx_embed_mode_ = it != info_.end() && it->second != "0"; + // ep_ctx_embed_mode_ = it == info_.end() || it->second != "0"; + it = info_.find("ep_context_file_path"); + ep_ctx_model_path_cfg_ = it == info_.end() ? "" : it->second; +#if 0 + } +#endif + LOGS_DEFAULT(VERBOSE) << "EP Context cache enabled: " << ep_ctx_enabled_; + LOGS_DEFAULT(VERBOSE) << "EP context cache embed mode: " << ep_ctx_embed_mode_; + LOGS_DEFAULT(VERBOSE) << "User specified EP context cache path: " << ep_ctx_model_path_cfg_; +} + +#if 0 +VitisAIExecutionProvider::~VitisAIExecutionProvider() { + // TODO: EP context related sources. +} +#endif + +void VitisAIExecutionProvider::LoadEPContexModelFromFile() const { + // XXX: should "p_ep_ctx_model_" be checked or not? + if (!p_ep_ctx_model_ && !ep_ctx_model_file_loc_.empty()) { + auto p_model_proto = ONNX_NAMESPACE::ModelProto::Create(); + auto status = Model::Load(ep_ctx_model_file_loc_, *p_model_proto); + if (!status.IsOK()) { + ORT_THROW("Loading EP context model failed from ", ep_ctx_model_file_loc_); + } + auto& logger = logging::LoggingManager::DefaultLogger(); + p_ep_ctx_model_ = Model::Create(std::move(*p_model_proto), ep_ctx_model_file_loc_, nullptr, logger); + LOGS_DEFAULT(VERBOSE) << "Loaded EP context model from: " << ep_ctx_model_file_loc_; + } } void VitisAIExecutionProvider::CreateKernelRegistry() { @@ -30,9 +88,73 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return get_kernel_registry_vitisaiep(); } +// This method is called after both `GetComputeCapabilityOps()` and `Compile()`. +// This timing is required to work with both compliation-based EPs and non-compilation-based EPs. +const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() const { + InlinedVector ep_context_node_ptrs; + // All preconditions are supposed to have happened. + if (p_ep_ctx_model_) { + auto& graph = p_ep_ctx_model_->MainGraph(); + for (const auto* p_node : graph.Nodes()) { + ep_context_node_ptrs.push_back(p_node); + } + } + return ep_context_node_ptrs; +} + +// Create EP context model and dump it for future use. +// This implementation here is only working for non-compilation-based EPs. +void VitisAIExecutionProvider::FulfillEPContextEnablement( + const std::vector>& capability_ptrs, + const onnxruntime::GraphViewer& graph_viewer) const { + auto& logger = logging::LoggingManager::DefaultLogger(); + auto model_path_str = GetTopLevelModelPath(graph_viewer).ToPathString(); + auto ep_ctx_payload = SerializeCapabilities(capability_ptrs, graph_viewer.GetGraph()); + if (!ep_ctx_embed_mode_) { + GetEPContextModelFileLocation(ep_ctx_model_path_cfg_, model_path_str, false, ep_ctx_model_file_loc_); + auto ep_ctx_cache_path_str = GetEPContextCacheFileLocation(ep_ctx_model_file_loc_, model_path_str); + std::ofstream ep_ctx_cache_ofs(ep_ctx_cache_path_str.c_str(), std::ios::trunc | std::ios::binary); + if (!ep_ctx_cache_ofs.is_open()) { + ORT_THROW("Failed to open a file to write EP context cache: ", ep_ctx_cache_path_str.c_str()); + } + ep_ctx_cache_ofs.write(ep_ctx_payload.c_str(), ep_ctx_payload.length()); + if (!ep_ctx_cache_ofs.good()) { + ep_ctx_cache_ofs.close(); + ORT_THROW("Exception writing EP context cache file: ", ep_ctx_cache_path_str.c_str()); + } + ep_ctx_cache_ofs.close(); + p_ep_ctx_model_ = CreateEPContexModel(graph_viewer, ep_ctx_payload, ep_ctx_cache_path_str, 0, &logger); + } else { + p_ep_ctx_model_ = CreateEPContexModel(graph_viewer, ep_ctx_payload, "", 1, &logger); + } + DumpEPContextModel(p_ep_ctx_model_, ep_ctx_model_file_loc_); +} + std::vector> VitisAIExecutionProvider::GetCapability( - const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { - if (graph.IsSubgraph()) { + const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + bool is_ep_ctx_model = GraphHasEPContextNode(graph_viewer); + if (is_ep_ctx_model) { + auto ep_ctx_payload = RetrieveEPContextCache(graph_viewer.GetGraph()); + std::vector> capability_ptrs; + DeserializeCapabilities(ep_ctx_payload, capability_ptrs); + return capability_ptrs; + } else { + // FIXME: Will it make sense to do this? + // One of the potential problems is the existing EP-context model file may be stale. + auto model_path_str = GetTopLevelModelPath(graph_viewer).ToPathString(); + if (GetEPContextModelFileLocation( + ep_ctx_model_path_cfg_, model_path_str, false, ep_ctx_model_file_loc_)) { + LOGS_DEFAULT(WARNING) << "The inference session was created with a normal ONNX model " + << "but a model file with EP context cache exists at " << ep_ctx_model_file_loc_.c_str(); + LoadEPContexModelFromFile(); + auto ep_ctx_payload = RetrieveEPContextCache(p_ep_ctx_model_->MainGraph()); + std::vector> capability_ptrs; + DeserializeCapabilities(ep_ctx_payload, capability_ptrs); + return capability_ptrs; + } + } + + if (graph_viewer.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; } @@ -40,13 +162,16 @@ std::vector> VitisAIExecutionProvider::GetCap // Only compiling a model once is currently supported return {}; } - execution_providers_ = std::make_unique(compile_onnx_model(graph, *GetLogger(), info_)); - auto result = vaip::GetComputeCapabilityOps(graph, execution_providers_.get(), vitisai_optypes_); + execution_providers_ = std::make_unique(compile_onnx_model(graph_viewer, *GetLogger(), info_)); + auto result = vaip::GetComputeCapabilityOps(graph_viewer, execution_providers_.get(), vitisai_optypes_); size_t index = 0u; for (auto& ep : **execution_providers_) { - result.emplace_back(vaip::XirSubgraphToComputeCapability1(graph, ep.get(), index)); + result.emplace_back(vaip::XirSubgraphToComputeCapability1(graph_viewer, ep.get(), index)); index = index + 1; } + if (ep_ctx_enabled_) { + FulfillEPContextEnablement(result, graph_viewer); + } return result; } diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 186427be4fab2..2c852b9853a36 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -3,14 +3,18 @@ #pragma once +// Standard headers/libs. #include #include #include #include #include +// 1st-party headers/libs. +// #include "core/framework/session_options.h" #include "core/providers/shared_library/provider_api.h" #include "core/session/onnxruntime_c_api.h" +#include "core/common/inlined_containers_fwd.h" // we cannot include vaip/vaip.hpp here because header file referred by // onnxruntime_pybind_state_common.cc @@ -24,9 +28,11 @@ namespace onnxruntime { class VitisAIExecutionProvider : public IExecutionProvider { public: explicit VitisAIExecutionProvider(const ProviderOptions& info); + // explicit VitisAIExecutionProvider(const ProviderOptions& info, + // const SessionOptions* p_sess_opts = nullptr); ~VitisAIExecutionProvider() = default; - std::vector> GetCapability(const onnxruntime::GraphViewer& graph, + std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return 0; } @@ -35,6 +41,10 @@ class VitisAIExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; + // This method is called after both `GetComputeCapabilityOps()` and `Compile()`. + // This timing is required to work with both compliation-based EPs and non-compilation-based EPs. + const InlinedVector GetEpContextNodes() const override; + private: void CreateKernelRegistry(); using my_ep_t = vaip_core::DllSafe>>; @@ -45,6 +55,21 @@ class VitisAIExecutionProvider : public IExecutionProvider { std::vector custom_op_domains_; std::shared_ptr registry_; std::set vitisai_optypes_; + // EP context related. + bool ep_ctx_enabled_ = false; + bool ep_ctx_embed_mode_ = true; + std::string ep_ctx_model_path_cfg_{""}; + mutable PathString ep_ctx_model_file_loc_{""}; + // FIXME: This might not be needed. + mutable std::unique_ptr p_ep_ctx_model_; + // It might need to be called before loading + // the EP context model that is compiled AOT/offline. + void LoadEPContexModelFromFile() const; + // Create EP context model and dump it for future use. + // This implementation here is only working for non-compilation-based EPs. + void FulfillEPContextEnablement( + const std::vector>&, + const onnxruntime::GraphViewer&) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc old mode 100755 new mode 100644 diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d18b3ac40d489..68cab2d86518d 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -27,6 +27,7 @@ #include "core/session/inference_session.h" #include "core/session/abi_session_options_impl.h" #include "core/session/ort_apis.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/provider_bridge_ort.h" #include "core/util/math.h" #include "core/framework/sparse_utils.h" @@ -66,10 +67,12 @@ using StringStringEntryProtos = google::protobuf::RepeatedPtrField; using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField; using ValueInfoProtos = google::protobuf::RepeatedPtrField; +using FunctionProtos = google::protobuf::RepeatedPtrField; } // namespace ONNX_NAMESPACE namespace onnxruntime { using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; +using IndexedSubGraph_SourceOfSchema = IndexedSubGraph::SourceOfSchema; } // namespace onnxruntime #include "core/common/cpuid_info.h" @@ -526,6 +529,7 @@ struct ProviderHostImpl : ProviderHost { int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) override { return p->attribute_size(); } const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const override { return p->attribute(index); } ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) override { return p->mutable_attribute(index); } + ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) override { return p->add_attribute(); } // TensorProto (wrapped) std::unique_ptr TensorProto__construct() override { return std::make_unique(); } @@ -600,6 +604,64 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; } + // FunctionProto (wrapped) + std::unique_ptr FunctionProto__construct() override { return std::make_unique(); } + void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) override { delete p; } + + bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) override { return p->SerializeToString(&string); } + bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) override { return p->SerializeToOstream(&output); } + bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) override { return p->ParseFromString(data); } + std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) override { return p->SerializeAsString(); } + + bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_name(); } + const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->name(); } + void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const std::string& name) override { p->set_name(name); } + + bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_doc_string(); } + const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->doc_string(); } + void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const std::string& doc_string) override { p->set_doc_string(doc_string); } + + bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) override { return p->has_domain(); } + const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const override { return p->domain(); } + void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const std::string& domain) override { p->set_domain(domain); } + + const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->input(index); } + std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_input(index); } + int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->input_size(); } + void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_input(value); } + + const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->output(index); } + std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_output(index); } + int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->output_size(); } + void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_output(value); } + + const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->attribute(index); } + std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_attribute(index); } + int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->attribute_size(); } + void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) override { p->add_attribute(value); } + + const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->attribute_proto(index); } + ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_attribute_proto(index); } + int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->attribute_proto_size(); } + ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_attribute_proto(); } + + const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->node(index); } + ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_node(index); } + int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->node_size(); } + ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_node(); } + + const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->value_info(index); } + ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_value_info(index); } + ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) override { return p->mutable_value_info(); } + int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->value_info_size(); } + ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_value_info(); } + + const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->metadata_props(index); } + ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) override { return p->mutable_metadata_props(index); } + ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->mutable_metadata_props(); } + int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->metadata_props_size(); } + ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_metadata_props(); } + static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { auto* shape = ctx.getAttribute("shape"); auto* data_type = ctx.getAttribute("data_type"); @@ -734,9 +796,12 @@ struct ProviderHostImpl : ProviderHost { std::vector& IndexedSubGraph__Nodes(IndexedSubGraph* p) override { return p->nodes; } - void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) override { return p->SetMetaDef(std::move(meta_def_)); } + void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr&& meta_def_) override { p->SetMetaDef(std::move(meta_def_)); } const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) override { return p->GetMetaDef(); } + void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) override { p->schema_source = schema_source; } + IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) override { return p->schema_source; } + // KernelDef (wrapped) void KernelDef__operator_delete(KernelDef* p) override { delete p; } void KernelDef__SinceVersion(const KernelDef* p, int* start, int* end) override { return p->SinceVersion(start, end); } @@ -2753,6 +2818,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ provider_options[provider_options_keys[i]] = provider_options_values[i]; } + // EP context related session config options. + provider_options["ep_context_enable"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0"); + provider_options["ep_context_embed_mode"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); + provider_options["ep_context_file_path"] = options->value.config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library");