Skip to content

Commit d8f0318

Browse files
Add API to get ep graph partitioning info (#26781)
### Description - Adds API functions to get information about the subgraphs/nodes assigned to the EPs in the session. - `Session_GetEpGraphAssignmentInfo`: Returns a list of "subgraphs", each with information about the assigned EP and nodes. - Note: App must enable session configuration `"session.record_ep_graph_assignment_info"` to signal ORT to collect this information. If not enabled, API returns empty results. - `EpAssignedSubgraph_GetEpName`: Returns the name of the EP to which the subgraph is assigned - `EpAssignedSubgraph_GetNodes`: Returns a list of assigned nodes - `EpAssignedNode_GetName`: Returns the assigned node's name - `EpAssignedNode_GetDomain`: Returns the assigned node's domain - `EpAssignedNode_GetOperatorType`: Returns the assigned node's operator type - Also adds C++ and Python bindings #### Structure of returned information The API returns a list of "subgraphs". Each subgraph has the following information: - Subgraph info: - EP name: The name of the execution provider to which this subgraph is assigned. - nodes: Name and operator type of each node. Ex: `[{"multiply", "Mul"}, ...]` Python example program (taken from unit tests): ```python def test_get_graph_provider_assignment_info(self): """ Tests querying for information about the nodes assigned to the CPU EP. """ # Create session options that enables recording EP graph partitioning info. session_options = onnxrt.SessionOptions() session_options.add_session_config_entry("session.record_ep_graph_assignment_info", "1") session = onnxrt.InferenceSession(get_name("add_mul_add.onnx"), sess_options=session_options) # Query session for information on each subgraph assigned to an EP. ep_subgraphs = session.get_provider_graph_assignment_info() # Check that all 3 nodes are assigned to CPU EP (each in its own subgraph) self.assertEqual(len(ep_subgraphs), 3) for ep_subgraph in ep_subgraphs: self.assertEqual(ep_subgraph.ep_name, "CPUExecutionProvider") self.assertEqual(len(ep_subgraph.get_nodes()), 1) # Serialize each node to an identifier (concatenates operator type and node name) node_ids: list[str] = [f"{n.op_type}/{n.name}" for s in ep_subgraphs for n in s.get_nodes()] # Should have 1 Mul and 2 Adds. self.assertEqual(len(node_ids), 3) self.assertIn("Add/add_0", node_ids) self.assertIn("Add/add_1", node_ids) self.assertIn("Mul/mul_0", node_ids) ``` C++ program (taken from unit test): ```c++ // Check the ep graph partitioning (Mul on plugin EP, others on CPU EP). // Model has 3 subgraphs (in no particular order): // - Subgraph 1: Add assigned to CPU EP. // - Subgraph 2: Mul assigned to plugin EP. // - Subgraph 3: Add assigned to CPU EP. std::vector<Ort::ConstEpAssignedSubgraph> ep_subgraphs = session.GetEpGraphAssignmentInfo(); ASSERT_EQ(ep_subgraphs.size(), 3); for (Ort::ConstEpAssignedSubgraph ep_subgraph : ep_subgraphs) { std::string ep_name = ep_subgraph.EpName(); ASSERT_TRUE(ep_name == Utils::example_ep_info.ep_name || ep_name == kCpuExecutionProvider); const std::vector<Ort::ConstEpAssignedNode> ep_nodes = ep_subgraph.GetNodes(); ASSERT_GE(ep_nodes.size(), 1); // All of these subgraphs just have one node. if (ep_name == kCpuExecutionProvider) { std::string op_type = ep_nodes[0].OpType(); std::string node_name = ep_nodes[0].Name(); ASSERT_EQ(op_type, "Add"); ASSERT_TRUE(node_name == "add_0" || node_name == "add_1"); } else { ASSERT_TRUE(ep_name == Utils::example_ep_info.ep_name); std::string op_type = ep_nodes[0].OpType(); std::string node_name = ep_nodes[0].Name(); ASSERT_EQ(op_type, "Mul"); ASSERT_EQ(node_name, "mul_0"); } } ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent bbd3850 commit d8f0318

17 files changed

+687
-7
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ ORT_RUNTIME_CLASS(ExternalResourceImporter); // Capability object for external
337337
ORT_RUNTIME_CLASS(ExternalMemoryHandle); // EP-imported view of shared external allocation
338338
ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared external semaphore
339339
ORT_RUNTIME_CLASS(DeviceEpIncompatibilityDetails);
340+
ORT_RUNTIME_CLASS(EpAssignedSubgraph);
341+
ORT_RUNTIME_CLASS(EpAssignedNode);
340342

341343
#ifdef _MSC_VER
342344
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
@@ -7016,6 +7018,101 @@ struct OrtApi {
70167018
* \since Version 1.24
70177019
*/
70187020
ORT_API2_STATUS(CreateEnvWithOptions, _In_ const OrtEnvCreationOptions* options, _Outptr_ OrtEnv** out);
7021+
7022+
/** \brief Get information about the subgraphs assigned to each execution provider (EP) and the nodes within.
7023+
*
7024+
* Each returned OrtEpAssignedSubgraph instance contains details of the subgraph/nodes assigned to an execution
7025+
* provider, including the execution provider's name, and the name, domain, and operator type for every node.
7026+
*
7027+
* For compiling execution providers, a single OrtEpAssignedSubgraph instance contains information about the
7028+
* nodes that are fused and compiled within a single subgraph assigned to the execution provider.
7029+
*
7030+
* For execution providers that use kernel registration (e.g., CPU EP), each node with a registered kernel is
7031+
* contained in its own OrtEpAssignedSubgraph instance.
7032+
*
7033+
* \note The caller must enable the collection of this information by enabling the session
7034+
* configuration entry "session.record_ep_graph_assignment_info" during session creation.
7035+
* Refer to onnxruntime_session_options_config_keys.h. Otherwise, if not enabled, this function returns a
7036+
* status with error code ORT_FAIL.
7037+
*
7038+
* \note The information reported by this function is obtained immediately after running basic optimizations on the
7039+
* original graph if the session optimization level is set to ORT_ENABLE_BASIC or higher. If the session
7040+
* optimization level is set to ORT_DISABLE_ALL, only minimal/required optimizations are run before
7041+
* the information is collected.
7042+
*
7043+
* \param[in] session The OrtSession instance.
7044+
* \param[out] ep_subgraphs Output parameter set to the array of OrtEpAssignedSubgraph instances.
7045+
* \param[out] num_ep_subgraphs Output parameter set to the number of elements in the `ep_subgraphs` array.
7046+
*
7047+
* \snippet{doc} snippets.dox OrtStatus Return Value
7048+
*
7049+
* \since Version 1.24.
7050+
*/
7051+
ORT_API2_STATUS(Session_GetEpGraphAssignmentInfo, _In_ const OrtSession* session,
7052+
_Outptr_ const OrtEpAssignedSubgraph* const** ep_subgraphs,
7053+
_Out_ size_t* num_ep_subgraphs);
7054+
7055+
/** \brief Get the name of the execution provider to which the subgraph was assigned.
7056+
*
7057+
* \param[in] ep_subgraph The OrtEpAssignedSubgraph instance.
7058+
* \param[out] out Output parameter set to the execution provider's name as a UTF-8 null-terminated string.
7059+
* Owned by the OrtEpAssignedSubgraph instance (do not free).
7060+
*
7061+
* \snippet{doc} snippets.dox OrtStatus Return Value
7062+
*
7063+
* \since Version 1.24.
7064+
*/
7065+
ORT_API2_STATUS(EpAssignedSubgraph_GetEpName, _In_ const OrtEpAssignedSubgraph* ep_subgraph,
7066+
_Outptr_ const char** out);
7067+
7068+
/** \brief Get the nodes in a subgraph assigned to a specific execution provider.
7069+
*
7070+
* \param[in] ep_subgraph The OrtEpAssignedSubgraph instance.
7071+
* \param[out] ep_nodes Output parameter set to the array of OrtEpAssignedNode instances.
7072+
* \param[out] num_ep_nodes Output parameter set to the number of OrtEpAssignedNode instance returned.
7073+
*
7074+
* \snippet{doc} snippets.dox OrtStatus Return Value
7075+
*
7076+
* \since Version 1.24.
7077+
*/
7078+
ORT_API2_STATUS(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgraph* ep_subgraph,
7079+
_Outptr_ const OrtEpAssignedNode* const** ep_nodes, _Out_ size_t* num_ep_nodes);
7080+
7081+
/** \brief Get the name of the node assigned to an execution provider.
7082+
*
7083+
* \param[in] ep_node The OrtEpAssignedNode instance.
7084+
* \param[out] out Output parameter set to the node's name as a UTF-8 null-terminated string.
7085+
* Owned by the OrtEpAssignedNode instance (do not free).
7086+
*
7087+
* \snippet{doc} snippets.dox OrtStatus Return Value
7088+
*
7089+
* \since Version 1.24.
7090+
*/
7091+
ORT_API2_STATUS(EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
7092+
7093+
/** \brief Get the domain of the node assigned to an execution provider.
7094+
*
7095+
* \param[in] ep_node The OrtEpAssignedNode instance.
7096+
* \param[out] out Output parameter set to the node's domain as a UTF-8 null-terminated string.
7097+
* Owned by the OrtEpAssignedNode instance (do not free).
7098+
*
7099+
* \snippet{doc} snippets.dox OrtStatus Return Value
7100+
*
7101+
* \since Version 1.24.
7102+
*/
7103+
ORT_API2_STATUS(EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
7104+
7105+
/** \brief Get the operator type of the node assigned to an execution provider.
7106+
*
7107+
* \param[in] ep_node The OrtEpAssignedNode instance.
7108+
* \param[out] out Output parameter set to the node's operator type as a UTF-8 null-terminated string.
7109+
* Owned by the OrtEpAssignedNode instance (do not free).
7110+
*
7111+
* \snippet{doc} snippets.dox OrtStatus Return Value
7112+
*
7113+
* \since Version 1.24.
7114+
*/
7115+
ORT_API2_STATUS(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
70197116
};
70207117

70217118
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,39 @@ OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(
11641164
const std::vector<ConstEpDevice>& ep_devices,
11651165
const char* compatibility_info);
11661166

1167+
namespace detail {
1168+
template <typename T>
1169+
struct EpAssignedNodeImpl : Ort::detail::Base<T> {
1170+
using B = Ort::detail::Base<T>;
1171+
using B::B;
1172+
1173+
std::string GetName() const;
1174+
std::string GetDomain() const;
1175+
std::string GetOperatorType() const;
1176+
};
1177+
} // namespace detail
1178+
1179+
/** \brief Constant wrapper around ::OrtEpAssignedNode
1180+
* \remarks EpAssignedNode is always read-only for ORT API users.
1181+
*/
1182+
using ConstEpAssignedNode = detail::EpAssignedNodeImpl<Ort::detail::Unowned<const OrtEpAssignedNode>>;
1183+
1184+
namespace detail {
1185+
template <typename T>
1186+
struct EpAssignedSubgraphImpl : Ort::detail::Base<T> {
1187+
using B = Ort::detail::Base<T>;
1188+
using B::B;
1189+
1190+
std::string GetEpName() const;
1191+
std::vector<ConstEpAssignedNode> GetNodes() const;
1192+
};
1193+
} // namespace detail
1194+
1195+
/** \brief Constant wrapper around ::OrtEpAssignedSubgraph
1196+
* \remarks EpAssignedSubgraph is always read-only for ORT API users.
1197+
*/
1198+
using ConstEpAssignedSubgraph = detail::EpAssignedSubgraphImpl<Ort::detail::Unowned<const OrtEpAssignedSubgraph>>;
1199+
11671200
/** \brief The Env (Environment)
11681201
*
11691202
* The Env holds the logging state used by all other objects.
@@ -1665,9 +1698,14 @@ struct ConstSessionImpl : Base<T> {
16651698

16661699
int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain
16671700

1668-
// Will move before checkin if that's the case.
16691701
std::vector<ValueInfo> GetInputs() const;
16701702
std::vector<ValueInfo> GetOutputs() const;
1703+
1704+
/** \brief Returns information on the subgraph/nodes assigned to execution providers in the session.
1705+
*
1706+
* \return A list of ConstEpAssignedSubgraph instances.
1707+
*/
1708+
std::vector<ConstEpAssignedSubgraph> GetEpGraphAssignmentInfo() const;
16711709
};
16721710

16731711
template <typename T>

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,61 @@ inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardwar
747747
ThrowOnError(GetEpApi().CreateEpDevice(&ep_factory, hardware_device, ep_metadata, ep_options, &p_));
748748
}
749749

750+
namespace detail {
751+
template <typename T>
752+
inline std::string EpAssignedSubgraphImpl<T>::GetEpName() const {
753+
const char* ep_name = nullptr;
754+
755+
// Returned null-terminated string will not be null if API function returns successfully.
756+
ThrowOnError(GetApi().EpAssignedSubgraph_GetEpName(this->p_, &ep_name));
757+
return std::string(ep_name);
758+
}
759+
760+
template <typename T>
761+
inline std::vector<ConstEpAssignedNode> EpAssignedSubgraphImpl<T>::GetNodes() const {
762+
size_t num_ep_nodes = 0;
763+
const OrtEpAssignedNode* const* ep_node_ptrs = nullptr;
764+
ThrowOnError(GetApi().EpAssignedSubgraph_GetNodes(this->p_, &ep_node_ptrs, &num_ep_nodes));
765+
766+
std::vector<ConstEpAssignedNode> ep_nodes;
767+
if (num_ep_nodes > 0) {
768+
ep_nodes.reserve(num_ep_nodes);
769+
for (size_t i = 0; i < num_ep_nodes; ++i) {
770+
ep_nodes.emplace_back(ep_node_ptrs[i]);
771+
}
772+
}
773+
774+
return ep_nodes;
775+
}
776+
777+
template <typename T>
778+
inline std::string EpAssignedNodeImpl<T>::GetName() const {
779+
const char* node_name = nullptr;
780+
781+
// Returned null-terminated string will not be null if API function returns successfully.
782+
ThrowOnError(GetApi().EpAssignedNode_GetName(this->p_, &node_name));
783+
return std::string(node_name);
784+
}
785+
786+
template <typename T>
787+
inline std::string EpAssignedNodeImpl<T>::GetDomain() const {
788+
const char* domain = nullptr;
789+
790+
// Returned null-terminated string will not be null if API function returns successfully.
791+
ThrowOnError(GetApi().EpAssignedNode_GetDomain(this->p_, &domain));
792+
return std::string(domain);
793+
}
794+
795+
template <typename T>
796+
inline std::string EpAssignedNodeImpl<T>::GetOperatorType() const {
797+
const char* op_type = nullptr;
798+
799+
// Returned null-terminated string will not be null if API function returns successfully.
800+
ThrowOnError(GetApi().EpAssignedNode_GetOperatorType(this->p_, &op_type));
801+
return std::string(op_type);
802+
}
803+
} // namespace detail
804+
750805
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
751806
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
752807
if (strcmp(logid, "onnxruntime-node") == 0) {
@@ -1756,6 +1811,23 @@ std::vector<ValueInfo> ConstSessionImpl<T>::GetOutputs() const {
17561811
return outputs;
17571812
}
17581813

1814+
template <typename T>
1815+
inline std::vector<ConstEpAssignedSubgraph> ConstSessionImpl<T>::GetEpGraphAssignmentInfo() const {
1816+
size_t num_ep_subgraphs = 0;
1817+
const OrtEpAssignedSubgraph* const* ep_subgraph_ptrs = nullptr;
1818+
ThrowOnError(GetApi().Session_GetEpGraphAssignmentInfo(this->p_, &ep_subgraph_ptrs, &num_ep_subgraphs));
1819+
1820+
std::vector<ConstEpAssignedSubgraph> ep_subgraphs;
1821+
if (num_ep_subgraphs > 0) {
1822+
ep_subgraphs.reserve(num_ep_subgraphs);
1823+
for (size_t i = 0; i < num_ep_subgraphs; ++i) {
1824+
ep_subgraphs.emplace_back(ep_subgraph_ptrs[i]);
1825+
}
1826+
}
1827+
1828+
return ep_subgraphs;
1829+
}
1830+
17591831
template <typename T>
17601832
inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
17611833
const char* const* output_names, size_t output_count) {

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,11 @@ static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel =
421421
// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
422422
// "sustained_high_performance". Default to "default".
423423
static const char* const kOrtEpDynamicOptionsQnnHtpPerformanceMode = "ep.dynamic.qnn_htp_performance_mode";
424+
425+
// Enables the session to record information about the subgraphs/nodes assigned to execution providers.
426+
// When enabled, an application may call Session_GetEpGraphAssignmentInfo() to retrieve the information.
427+
//
428+
// Option values:
429+
// - "0": Recording of EP graph assignment information is disabled. [DEFAULT]
430+
// - "1": Recording of EP graph assignment information is enabled.
431+
static const char* const kOrtSessionOptionsRecordEpGraphAssignmentInfo = "session.record_ep_graph_assignment_info";

onnxruntime/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
OrtArenaCfg, # noqa: F401
3535
OrtCompileApiFlags, # noqa: F401
3636
OrtDeviceMemoryType, # noqa: F401
37+
OrtEpAssignedNode, # noqa: F401
38+
OrtEpAssignedSubgraph, # noqa: F401
3739
OrtEpDevice, # noqa: F401
3840
OrtExecutionProviderDevicePolicy, # noqa: F401
3941
OrtExternalInitializerInfo, # noqa: F401

onnxruntime/core/framework/graph_partitioner.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ struct PartitionParams {
6868
std::reference_wrapper<const layout_transformation::TransformLayoutFunction> transform_layout_function;
6969
std::reference_wrapper<const layout_transformation::DebugGraphFn> debug_graph_fn;
7070
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
71+
std::reference_wrapper<const OnPartitionAssignmentFunction> on_partition_assignment_fn;
7172
};
7273
} // namespace
7374

@@ -426,6 +427,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
426427
const layout_transformation::TransformLayoutFunction& transform_layout_fn,
427428
const layout_transformation::DebugGraphFn& debug_graph_fn,
428429
const CheckLoadCancellationFn& check_load_cancellation_fn,
430+
const OnPartitionAssignmentFunction& on_partition_assignment_fn,
429431
const logging::Logger& logger, IResourceAccountant* resource_accountant,
430432
const GraphOptimizerRegistry& graph_optimizer_registry,
431433
bool disable_model_compile) {
@@ -444,6 +446,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
444446
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
445447
transform_layout_fn, debug_graph_fn,
446448
check_load_cancellation_fn,
449+
on_partition_assignment_fn,
447450
logger, resource_accountant,
448451
graph_optimizer_registry, disable_model_compile));
449452
}
@@ -518,6 +521,12 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
518521

519522
Node* n = nullptr;
520523
if (sub_graph_available_for_assignment) {
524+
if (on_partition_assignment_fn) {
525+
// Call custom function provided by owner of GraphPartitioner whenever a subgraph is assigned to an EP.
526+
// This can be used, for example, to collect partitioning information.
527+
on_partition_assignment_fn(graph, *capability, type);
528+
}
529+
521530
n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
522531
}
523532

@@ -1018,6 +1027,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
10181027
auto& fused_node_unique_id = partition_params.fused_node_unique_id.get();
10191028
const auto& transform_layout_function = partition_params.transform_layout_function;
10201029
const CheckLoadCancellationFn& check_load_cancellation_fn = partition_params.check_load_cancellation_fn;
1030+
const OnPartitionAssignmentFunction& on_partition_assignment_fn = partition_params.on_partition_assignment_fn;
10211031

10221032
do {
10231033
// process full graph with each EP
@@ -1034,6 +1044,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
10341044
transform_layout_function,
10351045
partition_params.debug_graph_fn,
10361046
check_load_cancellation_fn,
1047+
on_partition_assignment_fn,
10371048
logger, resource_accountant, graph_optimizer_registry,
10381049
disable_model_compile));
10391050
}
@@ -1280,7 +1291,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
12801291
std::ref(*fused_kernel_registry),
12811292
std::ref(fused_node_unique_id),
12821293
std::cref(transform_layout_function),
1283-
std::cref(debug_graph_fn)};
1294+
std::cref(debug_graph_fn),
1295+
std::cref(on_partition_assignment_fn_)};
12841296

12851297
#else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
12861298

@@ -1290,6 +1302,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
12901302
PartitionParams partition_params{
12911303
std::ref(graph),
12921304
std::cref(check_load_cancellation_fn),
1305+
std::cref(on_partition_assignment_fn_),
12931306
};
12941307

12951308
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

onnxruntime/core/framework/graph_partitioner.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ namespace epctx {
2020
struct ModelGenOptions;
2121
}
2222

23+
// OnPartitionAssignmentFunction is called by GraphPartitioner when a subgraph is assigned to
24+
// an execution provider. Can be used to collect partitioning information.
25+
using OnPartitionAssignmentFunction = std::function<void(const Graph& graph,
26+
const ComputeCapability& assigned_subgraph,
27+
const std::string& assigned_ep_type)>;
28+
2329
class GraphPartitioner {
2430
public:
2531
enum class Mode {
@@ -40,11 +46,13 @@ class GraphPartitioner {
4046
GraphPartitioner(KernelRegistryManager& kernel_registry_mgr,
4147
const ExecutionProviders& providers,
4248
std::unique_ptr<GraphOptimizerRegistry> graph_optimizer_registry,
43-
CheckLoadCancellationFn check_load_cancellation_fn)
49+
CheckLoadCancellationFn check_load_cancellation_fn,
50+
OnPartitionAssignmentFunction on_partition_assignment_fn = {})
4451
: kernel_registry_mgr_(kernel_registry_mgr),
4552
providers_(providers),
4653
graph_optimizer_registry_(std::move(graph_optimizer_registry)),
47-
check_load_cancellation_fn_(std::move(check_load_cancellation_fn)) {
54+
check_load_cancellation_fn_(std::move(check_load_cancellation_fn)),
55+
on_partition_assignment_fn_(std::move(on_partition_assignment_fn)) {
4856
}
4957

5058
// Run partitioning.
@@ -89,6 +97,7 @@ class GraphPartitioner {
8997
const ExecutionProviders& providers_;
9098
std::unique_ptr<GraphOptimizerRegistry> graph_optimizer_registry_;
9199
CheckLoadCancellationFn check_load_cancellation_fn_;
100+
OnPartitionAssignmentFunction on_partition_assignment_fn_;
92101
};
93102

94103
} // namespace onnxruntime

0 commit comments

Comments
 (0)