Skip to content

Commit 57aafc5

Browse files
Address review comments: Add API for domain, clean up test, clean up docs, etc
1 parent 64fa011 commit 57aafc5

File tree

8 files changed

+155
-96
lines changed

8 files changed

+155
-96
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6915,20 +6915,30 @@ struct OrtApi {
69156915

69166916
/// @}
69176917

6918-
/** \brief Get information about the subgraphs assigned to each EP and the nodes within.
6918+
/** \brief Get information about the subgraphs assigned to each execution provider (EP) and the nodes within.
69196919
*
6920-
* Each returned OrtEpAssignedSubgraph instance contains details of the subgraph/nodes assigned to an execution provider,
6921-
* including the execution provider's name, and the name and operator type for every node. For compiling EPs,
6922-
* a subgraph contains one or more nodes. Alternatively, for EPs that use kernel registration (e.g., CPU EP), each
6923-
* registered kernel for a node is contained in its own subgraph (i.e., a subgraph contains one node).
6920+
* Each returned OrtEpAssignedSubgraph instance contains details of the subgraph/nodes assigned to an execution
6921+
* provider, including the execution provider's name, and the name, domain, and operator type for every node.
69246922
*
6925-
* \note Application must enable the recording of graph partitioning information by enabling the session configuration
6926-
* for the key "session.record_ep_graph_assignment_info". Refer to onnxruntime_session_options_config_keys.h.
6927-
* If the session configuration is not enabled, this function returns an empty result.
6923+
* For compiling execution providers, a single OrtEpAssignedSubgraph instance contains information about the
6924+
* nodes that are fused and compiled within a single subgraph assigned to the execution provider.
69286925
*
6929-
* \param[in] session The OrtSession instance to query.
6930-
* \param[out] ep_subgraphs The OrtEpAssignedSubgraph instances denoting the EP graph partitioning.
6931-
* \param[out] num_ep_subgraphs The number of OrtEpAssignedSubgraph instances returned.
6926+
* For execution providers that use kernel registration (e.g., CPU EP), each node with a registered kernel is
6927+
* contained in its own OrtEpAssignedSubgraph instance.
6928+
*
6929+
* \note The caller must enable the collection of this information by enabling the session
6930+
* configuration entry "session.record_ep_graph_assignment_info" during session creation.
6931+
* Refer to onnxruntime_session_options_config_keys.h. Otherwise, if not enabled, this function returns a
6932+
* status with error code ORT_FAIL.
6933+
*
6934+
* \note The information reported by this function is obtained immediately after running basic optimizations on the
6935+
* original graph if the session optimization level is set to ORT_ENABLE_BASIC or higher. If the session
6936+
* optimization level is set to ORT_DISABLE_ALL, only minimal/required optimizations are run before
6937+
* the information is collected.
6938+
*
6939+
* \param[in] session The OrtSession instance.
6940+
* \param[out] ep_subgraphs Output parameter set to the array of OrtEpAssignedSubgraph instances.
6941+
* \param[out] num_ep_subgraphs Output parameter set to the number of elements in the `ep_subgraphs` array.
69326942
*
69336943
* \snippet{doc} snippets.dox OrtStatus Return Value
69346944
*
@@ -6940,18 +6950,18 @@ struct OrtApi {
69406950

69416951
/** \brief Get the name of the execution provider to which the subgraph was assigned.
69426952
*
6943-
* \param[in] ep_subgraph The OrtEpAssignedSubgraph instance to query.
6944-
* \return The execution provider name.
6953+
* \param[in] ep_subgraph The OrtEpAssignedSubgraph instance.
6954+
* \return The execution provider's name as UTF-8 null-terminated string.
69456955
*
69466956
* \since Version 1.24.
69476957
*/
6948-
const char*(ORT_API_CALL* EpAssignedSubgraph_GetEpName)(_In_ const OrtEpAssignedSubgraph* ep_subgraph);
6958+
ORT_API_T(const char*, EpAssignedSubgraph_GetEpName, _In_ const OrtEpAssignedSubgraph* ep_subgraph);
69496959

6950-
/** \brief Get the list of nodes assigned to an execution provider.
6960+
/** \brief Get the nodes in a subgraph assigned to a specific execution provider.
69516961
*
6952-
* \param[in] ep_subgraph The OrtEpAssignedSubgraph instance to query.
6953-
* \param[out] ep_nodes Output parameter set to the list of OrtEpAssignedNode instances.
6954-
* \param[out] num_ep_nodes Output parameter set to the number of OrtEpAssignedNode instances returned.
6962+
* \param[in] ep_subgraph The OrtEpAssignedSubgraph instance.
6963+
* \param[out] ep_nodes Output parameter set to the array of OrtEpAssignedNode instances.
6964+
* \param[in] num_ep_nodes Output parameter set to the number of OrtEpAssignedNode instance returned.
69556965
*
69566966
* \snippet{doc} snippets.dox OrtStatus Return Value
69576967
*
@@ -6962,21 +6972,30 @@ struct OrtApi {
69626972

69636973
/** \brief Get the name of the node assigned to an execution provider.
69646974
*
6965-
* \param[in] ep_node The OrtEpAssignedNode instance to query.
6966-
* \return The node's name.
6975+
* \param[in] ep_node The OrtEpAssignedNode instance.
6976+
* \return The node's name as a UTF-8 null-terminated string.
6977+
*
6978+
* \since Version 1.24.
6979+
*/
6980+
ORT_API_T(const char*, EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node);
6981+
6982+
/** \brief Get the domain of the node assigned to an execution provider.
6983+
*
6984+
* \param[in] ep_node The OrtEpAssignedNode instance.
6985+
* \return The node's domain as a UTF-8 null-terminated string.
69676986
*
69686987
* \since Version 1.24.
69696988
*/
6970-
const char*(ORT_API_CALL* EpAssignedNode_GetName)(_In_ const OrtEpAssignedNode* ep_node);
6989+
ORT_API_T(const char*, EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node);
69716990

69726991
/** \brief Get the operator type of the node assigned to an execution provider.
69736992
*
6974-
* \param[in] ep_node The OrtEpAssignedNode instance to query.
6975-
* \return The node's operator type.
6993+
* \param[in] ep_node The OrtEpAssignedNode instance.
6994+
* \return The node's operator type as a UTF-8 null-terminated string.
69766995
*
69776996
* \since Version 1.24.
69786997
*/
6979-
const char*(ORT_API_CALL* EpAssignedNode_GetOperatorType)(_In_ const OrtEpAssignedNode* ep_node);
6998+
ORT_API_T(const char*, EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node);
69806999
};
69817000

69827001
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,8 +1170,9 @@ struct EpAssignedNodeImpl : Ort::detail::Base<T> {
11701170
using B = Ort::detail::Base<T>;
11711171
using B::B;
11721172

1173-
const char* GetName() const;
1174-
const char* GetOperatorType() const;
1173+
std::string GetName() const;
1174+
std::string GetDomain() const;
1175+
std::string GetOperatorType() const;
11751176
};
11761177
} // namespace detail
11771178

@@ -1186,7 +1187,7 @@ struct EpAssignedSubgraphImpl : Ort::detail::Base<T> {
11861187
using B = Ort::detail::Base<T>;
11871188
using B::B;
11881189

1189-
const char* GetEpName() const;
1190+
std::string GetEpName() const;
11901191
std::vector<ConstEpAssignedNode> GetNodes() const;
11911192
};
11921193
} // namespace detail

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -749,8 +749,8 @@ inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardwar
749749

750750
namespace detail {
751751
template <typename T>
752-
inline const char* EpAssignedSubgraphImpl<T>::GetEpName() const {
753-
return GetApi().EpAssignedSubgraph_GetEpName(this->p_);
752+
inline std::string EpAssignedSubgraphImpl<T>::GetEpName() const {
753+
return std::string(GetApi().EpAssignedSubgraph_GetEpName(this->p_));
754754
}
755755

756756
template <typename T>
@@ -771,13 +771,18 @@ inline std::vector<ConstEpAssignedNode> EpAssignedSubgraphImpl<T>::GetNodes() co
771771
}
772772

773773
template <typename T>
774-
inline const char* EpAssignedNodeImpl<T>::GetName() const {
775-
return GetApi().EpAssignedNode_GetName(this->p_);
774+
inline std::string EpAssignedNodeImpl<T>::GetName() const {
775+
return std::string(GetApi().EpAssignedNode_GetName(this->p_));
776776
}
777777

778778
template <typename T>
779-
inline const char* EpAssignedNodeImpl<T>::GetOperatorType() const {
780-
return GetApi().EpAssignedNode_GetOperatorType(this->p_);
779+
inline std::string EpAssignedNodeImpl<T>::GetDomain() const {
780+
return std::string(GetApi().EpAssignedNode_GetDomain(this->p_));
781+
}
782+
783+
template <typename T>
784+
inline std::string EpAssignedNodeImpl<T>::GetOperatorType() const {
785+
return std::string(GetApi().EpAssignedNode_GetOperatorType(this->p_));
781786
}
782787
} // namespace detail
783788

onnxruntime/core/session/ep_graph_assignment_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
/// </summary>
1414
struct OrtEpAssignedNode {
1515
std::string name;
16+
std::string domain;
1617
std::string op_type;
1718
};
1819

onnxruntime/core/session/inference_session.cc

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,10 +1288,9 @@ common::Status InferenceSession::ApplyUpdates(const OrtModel& model_editor_api_m
12881288
}
12891289

12901290
#if !defined(ORT_MINIMAL_BUILD)
1291-
static void RecordEpGraphPartitionAssignment(std::vector<std::unique_ptr<OrtEpAssignedSubgraph>>& ep_assigned_subgraphs,
1292-
const Graph& graph,
1293-
const ComputeCapability& capability,
1294-
const std::string& ep_name) {
1291+
static std::unique_ptr<OrtEpAssignedSubgraph> CreateEpAssignedSubgraph(const Graph& graph,
1292+
const ComputeCapability& capability,
1293+
const std::string& ep_name) {
12951294
auto assigned_subgraph = std::make_unique<OrtEpAssignedSubgraph>();
12961295
assigned_subgraph->ep_name = ep_name;
12971296

@@ -1300,18 +1299,17 @@ static void RecordEpGraphPartitionAssignment(std::vector<std::unique_ptr<OrtEpAs
13001299
for (NodeIndex node_index : node_indices) {
13011300
const Node* node = graph.GetNode(node_index);
13021301
if (node != nullptr) {
1303-
const std::string& op_type = node->OpType();
1304-
13051302
auto assigned_node = std::make_unique<OrtEpAssignedNode>();
13061303
assigned_node->name = node->Name();
1307-
assigned_node->op_type = op_type;
1304+
assigned_node->domain = node->Domain();
1305+
assigned_node->op_type = node->OpType();
13081306

13091307
assigned_subgraph->nodes.push_back(assigned_node.get());
13101308
assigned_subgraph->nodes_storage.push_back(std::move(assigned_node));
13111309
}
13121310
}
13131311

1314-
ep_assigned_subgraphs.push_back(std::move(assigned_subgraph));
1312+
return assigned_subgraph;
13151313
}
13161314
#endif // !defined(ORT_MINIMAL_BUILD)
13171315

@@ -1329,15 +1327,19 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
13291327
// 8. Repeat steps 5 to 7 depending on the graph optimizations loop level.
13301328
// 9. insert copy nodes (required transformer).
13311329

1332-
OnPartitionAssignmentFunction on_partition_assign_fn;
1330+
OnPartitionAssignmentFunction on_partition_assignment_fn;
13331331
#if !defined(ORT_MINIMAL_BUILD)
1334-
bool record_ep_graph_partitioning =
1332+
bool record_ep_graph_assignment =
13351333
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsRecordEpGraphAssignmentInfo, "0") == "1";
1336-
if (record_ep_graph_partitioning) {
1337-
on_partition_assign_fn = [this](const Graph& graph, const ComputeCapability& assigned_subgraph,
1338-
const std::string& assigned_ep_type) {
1339-
RecordEpGraphPartitionAssignment(this->ep_graph_assignment_info_storage_, graph, assigned_subgraph,
1340-
assigned_ep_type);
1334+
if (record_ep_graph_assignment) {
1335+
on_partition_assignment_fn = [this](const Graph& graph, const ComputeCapability& compute_capability,
1336+
const std::string& ep_name) {
1337+
std::unique_ptr<OrtEpAssignedSubgraph> assigned_subgraph = CreateEpAssignedSubgraph(graph,
1338+
compute_capability,
1339+
ep_name);
1340+
1341+
this->ep_graph_assignment_info_.push_back(assigned_subgraph.get());
1342+
this->ep_graph_assignment_info_storage_.push_back(std::move(assigned_subgraph));
13411343
};
13421344
}
13431345
#endif // !defined(ORT_MINIMAL_BUILD)
@@ -1347,7 +1349,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
13471349
execution_providers_.Get(onnxruntime::kCpuExecutionProvider),
13481350
session_logger_);
13491351
GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry),
1350-
check_load_cancellation_fn_, on_partition_assign_fn);
1352+
check_load_cancellation_fn_, on_partition_assignment_fn);
13511353

13521354
// Run Ahead Of time function inlining
13531355
if (const bool disable_aot_function_inlining =
@@ -1489,14 +1491,6 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
14891491
session_options_.config_options, *session_logger_,
14901492
mode, session_options_.GetEpContextGenerationOptions(), debug_graph_fn));
14911493

1492-
#if !defined(ORT_MINIMAL_BUILD)
1493-
if (record_ep_graph_partitioning) {
1494-
for (std::unique_ptr<OrtEpAssignedSubgraph>& ep_subgraph : ep_graph_assignment_info_storage_) {
1495-
ep_graph_assignment_info_.push_back(ep_subgraph.get());
1496-
}
1497-
}
1498-
#endif // !defined(ORT_MINIMAL_BUILD)
1499-
15001494
// Get graph optimizations loop level from session config, if not present, set to default value of 1 as per
15011495
// the definition of kOrtSessionOptionsGraphOptimizationsLoopLevel.
15021496
unsigned int graph_optimizations_loop_level = static_cast<unsigned int>(std::stoi(

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "core/session/lora_adapters.h"
4949
#include "core/session/model_editor_api.h"
5050
#include "core/session/onnxruntime_c_api.h"
51+
#include "core/session/onnxruntime_session_options_config_keys.h"
5152
#include "core/session/ort_apis.h"
5253
#include "core/session/ort_env.h"
5354
#include "core/session/utils.h"
@@ -878,6 +879,18 @@ ORT_API_STATUS_IMPL(OrtApis::Session_GetEpGraphAssignmentInfo, _In_ const OrtSes
878879
_Out_ size_t* num_ep_subgraphs) {
879880
API_IMPL_BEGIN
880881
#if !defined(ORT_MINIMAL_BUILD)
882+
const auto* inference_session = reinterpret_cast<const onnxruntime::InferenceSession*>(session);
883+
const auto& session_options = inference_session->GetSessionOptions();
884+
bool is_enabled =
885+
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsRecordEpGraphAssignmentInfo, "0") == "1";
886+
887+
if (!is_enabled) {
888+
std::ostringstream oss;
889+
oss << "Session configuration entry '" << kOrtSessionOptionsRecordEpGraphAssignmentInfo
890+
<< "' must be set to \"1\" to retrieve EP graph assignment information.";
891+
return OrtApis::CreateStatus(ORT_FAIL, oss.str().c_str());
892+
}
893+
881894
if (ep_subgraphs == nullptr) {
882895
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'ep_subgraphs' argument is null");
883896
}
@@ -886,7 +899,6 @@ ORT_API_STATUS_IMPL(OrtApis::Session_GetEpGraphAssignmentInfo, _In_ const OrtSes
886899
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_ep_subgraphs' argument is null");
887900
}
888901

889-
auto inference_session = reinterpret_cast<const onnxruntime::InferenceSession*>(session);
890902
const std::vector<const OrtEpAssignedSubgraph*>& ep_assignment_info = inference_session->GetEpGraphAssignmentInfo();
891903

892904
*ep_subgraphs = ep_assignment_info.data();
@@ -896,7 +908,7 @@ ORT_API_STATUS_IMPL(OrtApis::Session_GetEpGraphAssignmentInfo, _In_ const OrtSes
896908
ORT_UNUSED_PARAMETER(session);
897909
ORT_UNUSED_PARAMETER(ep_subgraphs);
898910
ORT_UNUSED_PARAMETER(num_ep_subgraphs);
899-
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph partitioning information is not supported in this build");
911+
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph assignment information is not supported in this build");
900912
#endif // !defined(ORT_MINIMAL_BUILD)
901913
API_IMPL_END
902914
}
@@ -906,7 +918,7 @@ ORT_API(const char*, OrtApis::EpAssignedSubgraph_GetEpName, _In_ const OrtEpAssi
906918
return ep_subgraph->ep_name.c_str();
907919
#else
908920
ORT_UNUSED_PARAMETER(ep_subgraph);
909-
fprintf(stderr, "EP graph partitioning information is not supported in this build\n");
921+
fprintf(stderr, "EP graph assignment information is not supported in this build\n");
910922
return nullptr;
911923
#endif // !defined(ORT_MINIMAL_BUILD)
912924
}
@@ -916,11 +928,15 @@ ORT_API_STATUS_IMPL(OrtApis::EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssign
916928
API_IMPL_BEGIN
917929
#if !defined(ORT_MINIMAL_BUILD)
918930
if (ep_nodes == nullptr) {
919-
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'ep_nodes' argument is null");
931+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
932+
"EpAssignedSubgraph_GetNodes requires a valid (non-null) `ep_nodes` output parameter "
933+
"into which to store the pointer to the node array.");
920934
}
921935

922936
if (num_ep_nodes == nullptr) {
923-
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_ep_nodes' argument is null");
937+
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
938+
"EpAssignedSubgraph_GetNodes requires a valid (non-null) `num_ep_nodes` "
939+
"output parameter into which to store the number of nodes.");
924940
}
925941

926942
*ep_nodes = ep_subgraph->nodes.data();
@@ -930,7 +946,7 @@ ORT_API_STATUS_IMPL(OrtApis::EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssign
930946
ORT_UNUSED_PARAMETER(ep_subgraph);
931947
ORT_UNUSED_PARAMETER(ep_nodes);
932948
ORT_UNUSED_PARAMETER(num_ep_nodes);
933-
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph partitioning information is not supported in this build");
949+
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph assignment information is not supported in this build");
934950
#endif // !defined(ORT_MINIMAL_BUILD)
935951
API_IMPL_END
936952
}
@@ -940,7 +956,17 @@ ORT_API(const char*, OrtApis::EpAssignedNode_GetName, _In_ const OrtEpAssignedNo
940956
return ep_node->name.c_str();
941957
#else
942958
ORT_UNUSED_PARAMETER(ep_node);
943-
fprintf(stderr, "EP graph partitioning information is not supported in this build\n");
959+
fprintf(stderr, "EP graph assignment information is not supported in this build\n");
960+
return nullptr;
961+
#endif // !defined(ORT_MINIMAL_BUILD)
962+
}
963+
964+
ORT_API(const char*, OrtApis::EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node) {
965+
#if !defined(ORT_MINIMAL_BUILD)
966+
return ep_node->domain.c_str();
967+
#else
968+
ORT_UNUSED_PARAMETER(ep_node);
969+
fprintf(stderr, "EP graph assignment information is not supported in this build\n");
944970
return nullptr;
945971
#endif // !defined(ORT_MINIMAL_BUILD)
946972
}
@@ -950,7 +976,7 @@ ORT_API(const char*, OrtApis::EpAssignedNode_GetOperatorType, _In_ const OrtEpAs
950976
return ep_node->op_type.c_str();
951977
#else
952978
ORT_UNUSED_PARAMETER(ep_node);
953-
fprintf(stderr, "EP graph partitioning information is not supported in this build\n");
979+
fprintf(stderr, "EP graph assignment information is not supported in this build\n");
954980
return nullptr;
955981
#endif // !defined(ORT_MINIMAL_BUILD)
956982
}
@@ -4386,6 +4412,7 @@ static constexpr OrtApi ort_api_1_to_24 = {
43864412
&OrtApis::EpAssignedSubgraph_GetEpName,
43874413
&OrtApis::EpAssignedSubgraph_GetNodes,
43884414
&OrtApis::EpAssignedNode_GetName,
4415+
&OrtApis::EpAssignedNode_GetDomain,
43894416
&OrtApis::EpAssignedNode_GetOperatorType,
43904417
};
43914418

onnxruntime/core/session/ort_apis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,5 +790,6 @@ ORT_API(const char*, EpAssignedSubgraph_GetEpName, _In_ const OrtEpAssignedSubgr
790790
ORT_API_STATUS_IMPL(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgraph* ep_subgraph,
791791
_Outptr_ const OrtEpAssignedNode* const** ep_nodes, _Out_ size_t* num_ep_nodes);
792792
ORT_API(const char*, EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node);
793+
ORT_API(const char*, EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node);
793794
ORT_API(const char*, EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node);
794795
} // namespace OrtApis

0 commit comments

Comments
 (0)