Skip to content

Commit 75f8480

Browse files
authored
[EP ABI] support Graph_GetModelMetadata (microsoft#25768)
### Description Add a new API `Graph_GetModelMetadata` ### Motivation and Context VitisAI EP would convert ONNX IR to another IR which is suitable for AMD AI compilers. The metadata in a OrtModel contains many important infomation produced by other tools, e.g. Olive. This API potentially used by many other execution providers which need to access the same information.
1 parent f46113d commit 75f8480

File tree

10 files changed

+83
-1
lines changed

10 files changed

+83
-1
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6469,6 +6469,17 @@ struct OrtApi {
64696469
_In_reads_(num_tensors) OrtValue* const* dst_tensors,
64706470
_In_opt_ OrtSyncStream* stream,
64716471
_In_ size_t num_tensors);
6472+
6473+
/** \brief Get ::OrtModelMetadata from an ::OrtGraph
6474+
*
6475+
* \param[in] graph The OrtGraph instance.
6476+
* \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata.
6477+
*
6478+
* \snippet{doc} snippets.dox OrtStatus Return Value
6479+
*
6480+
* \since Version 1.23.
6481+
*/
6482+
ORT_API2_STATUS(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out);
64726483
};
64736484

64746485
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,6 +2834,7 @@ struct GraphImpl : Ort::detail::Base<T> {
28342834
void SetOutputs(std::vector<ValueInfo>& outputs);
28352835
void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value
28362836
void AddNode(Node& node); // Graph takes ownership of Node
2837+
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::Graph_GetModelMetadata
28372838
#endif // !defined(ORT_MINIMAL_BUILD)
28382839
};
28392840
} // namespace detail
@@ -2848,6 +2849,7 @@ struct Graph : detail::GraphImpl<OrtGraph> {
28482849
Graph();
28492850
#endif
28502851
};
2852+
using ConstGraph = detail::GraphImpl<Ort::detail::Unowned<const OrtGraph>>;
28512853

28522854
namespace detail {
28532855
template <typename T>

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2798,6 +2798,13 @@ inline void GraphImpl<OrtGraph>::AddNode(Node& node) {
27982798
ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release()));
27992799
}
28002800

2801+
template <typename T>
2802+
inline ModelMetadata GraphImpl<T>::GetModelMetadata() const {
2803+
OrtModelMetadata* out;
2804+
ThrowOnError(GetApi().Graph_GetModelMetadata(this->p_, &out));
2805+
return ModelMetadata{out};
2806+
}
2807+
28012808
template <>
28022809
inline void ModelImpl<OrtModel>::AddGraph(Graph& graph) {
28032810
// Model takes ownership of `graph`

onnxruntime/core/graph/abi_graph_types.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "core/framework/tensor_external_data_info.h"
1111
#include "core/framework/onnxruntime_typeinfo.h"
1212
#include "core/graph/onnx_protobuf.h"
13+
#include "core/session/inference_session.h"
1314

1415
#define DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(external_type, internal_type, internal_api) \
1516
external_type* ToExternal() { return static_cast<external_type*>(this); } \
@@ -301,6 +302,11 @@ struct OrtGraph {
301302
/// <returns>The graph's name.</returns>
302303
virtual const std::string& GetName() const = 0;
303304

305+
/// <summary>
306+
/// Returns the model's metadata.
307+
/// </summary>
308+
/// <returns>The model metadata.</returns>
309+
virtual std::unique_ptr<onnxruntime::ModelMetadata> GetModelMetadata() const = 0;
304310
/// <summary>
305311
/// Returns the model's path, which is empty if unknown.
306312
/// </summary>

onnxruntime/core/graph/ep_api_types.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "core/framework/onnxruntime_typeinfo.h"
2121
#include "core/graph/graph_viewer.h"
2222
#include "core/graph/graph.h"
23+
#include "core/graph/model.h"
2324

2425
namespace onnxruntime {
2526

@@ -769,6 +770,25 @@ Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer&
769770

770771
const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); }
771772

773+
std::unique_ptr<ModelMetadata> EpGraph::GetModelMetadata() const {
774+
#if !defined(ORT_MINIMAL_BUILD)
775+
const auto& model = graph_viewer_.GetGraph().GetModel();
776+
auto model_metadata = std::make_unique<ModelMetadata>();
777+
778+
model_metadata->producer_name = model.ProducerName();
779+
model_metadata->producer_version = model.ProducerVersion();
780+
model_metadata->description = model.DocString();
781+
model_metadata->graph_description = model.GraphDocString();
782+
model_metadata->domain = model.Domain();
783+
model_metadata->version = model.ModelVersion();
784+
model_metadata->custom_metadata_map = model.MetaData();
785+
model_metadata->graph_name = model.MainGraph().Name();
786+
return model_metadata;
787+
#else
788+
return nullptr;
789+
#endif
790+
}
791+
772792
const ORTCHAR_T* EpGraph::GetModelPath() const {
773793
return graph_viewer_.ModelPath().c_str();
774794
}

onnxruntime/core/graph/ep_api_types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ struct EpGraph : public OrtGraph {
298298
// Returns the graph's name.
299299
const std::string& GetName() const override;
300300

301+
// Returns the graph's metadata
302+
std::unique_ptr<ModelMetadata> GetModelMetadata() const override;
303+
301304
// Returns the model path.
302305
const ORTCHAR_T* GetModelPath() const override;
303306

onnxruntime/core/graph/model_editor_api_types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "core/framework/ort_value.h"
1414
#include "core/graph/abi_graph_types.h"
1515
#include "core/graph/onnx_protobuf.h"
16+
#include "core/session/inference_session.h"
1617

1718
namespace onnxruntime {
1819

@@ -184,6 +185,9 @@ struct ModelEditorGraph : public OrtGraph {
184185

185186
const std::string& GetName() const override { return name; }
186187

188+
std::unique_ptr<ModelMetadata> GetModelMetadata() const override {
189+
return std::make_unique<ModelMetadata>(model_metadata);
190+
}
187191
const ORTCHAR_T* GetModelPath() const override { return model_path.c_str(); }
188192

189193
int64_t GetOnnxIRVersion() const override {
@@ -241,6 +245,7 @@ struct ModelEditorGraph : public OrtGraph {
241245
std::vector<std::unique_ptr<onnxruntime::ModelEditorNode>> nodes;
242246
std::string name = "ModelEditorGraph";
243247
std::filesystem::path model_path;
248+
ModelMetadata model_metadata;
244249
};
245250

246251
} // namespace onnxruntime

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2626,6 +2626,16 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetName, _In_ const OrtGraph* graph, _Outptr_
26262626
API_IMPL_END
26272627
}
26282628

2629+
ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out) {
2630+
API_IMPL_BEGIN
2631+
if (out == nullptr) {
2632+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL");
2633+
}
2634+
*out = reinterpret_cast<OrtModelMetadata*>(graph->GetModelMetadata().release());
2635+
return nullptr;
2636+
API_IMPL_END
2637+
}
2638+
26292639
ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path) {
26302640
API_IMPL_BEGIN
26312641
if (model_path == nullptr) {
@@ -4095,6 +4105,8 @@ static constexpr OrtApi ort_api_1_to_23 = {
40954105
&OrtApis::ReleaseSyncStream,
40964106

40974107
&OrtApis::CopyTensors,
4108+
4109+
&OrtApis::Graph_GetModelMetadata,
40984110
};
40994111

41004112
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.

onnxruntime/core/session/ort_apis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,7 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i
635635

636636
// OrtGraph
637637
ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name);
638+
ORT_API_STATUS_IMPL(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out);
638639
ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path);
639640
ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version);
640641
ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets);

onnxruntime/test/ep_graph/test_ep_graph.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,22 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
914914
const ORTCHAR_T* api_model_path = nullptr;
915915
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetModelPath(&api_graph, &api_model_path));
916916
ASSERT_EQ(PathString(api_model_path), PathString(model_path.c_str()));
917-
917+
// Check the model metadata
918+
Ort::AllocatorWithDefaultOptions default_allocator;
919+
auto ort_cxx_graph = Ort::ConstGraph(&api_graph);
920+
auto ort_cxx_model_metadat = ort_cxx_graph.GetModelMetadata();
921+
auto& model = graph_viewer.GetGraph().GetModel();
922+
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetProducerNameAllocated(default_allocator).get(), model.ProducerName().c_str()), 0);
923+
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphNameAllocated(default_allocator).get(), model.MainGraph().Name().c_str()), 0);
924+
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDomainAllocated(default_allocator).get(), model.Domain().c_str()), 0);
925+
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDescriptionAllocated(default_allocator).get(), model.DocString().c_str()), 0);
926+
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphDescriptionAllocated(default_allocator).get(), model.GraphDocString().c_str()), 0);
927+
ASSERT_EQ(ort_cxx_model_metadat.GetVersion(), model.ModelVersion());
928+
auto model_meta_data = model.MetaData();
929+
for (auto& [k, v] : model_meta_data) {
930+
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.LookupCustomMetadataMapAllocated(k.c_str(), default_allocator).get(), v.c_str()), 0)
931+
<< " key=" << k << "; value=" << v;
932+
}
918933
// Check graph inputs.
919934
const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers();
920935

0 commit comments

Comments
 (0)