1313#include " core/common/safeint.h"
1414#include " core/providers/webnn/allocator.h"
1515#include " core/providers/webnn/data_transfer.h"
16+ #include " core/providers/partitioning_utils.h"
17+ #include " core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
18+ #include " core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
1619
1720#include " builders/model.h"
1821#include " builders/helper.h"
1922#include " builders/model_builder.h"
2023
2124namespace onnxruntime {
2225
26+ constexpr const char * WEBNN = " WEBNN" ;
27+
2328WebNNExecutionProvider::WebNNExecutionProvider (const std::string& webnn_device_flags)
2429 : IExecutionProvider{
2530 onnxruntime::kWebNNExecutionProvider ,
@@ -51,8 +56,6 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {}
5156std::vector<std::unique_ptr<ComputeCapability>>
5257WebNNExecutionProvider::GetCapability (const onnxruntime::GraphViewer& graph_viewer,
5358 const IKernelLookup& /* kernel_registries*/ ) const {
54- std::vector<std::unique_ptr<ComputeCapability>> result;
55-
5659 // For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its
5760 // ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for
5861 // identifying the required initializer names and storing into 'meta_def->constant_initializers'.
@@ -64,67 +67,44 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
6467 all_initializers = webnn::CollectAllInitializedTensors (graph_viewer);
6568 }
6669
67- /*
68- Very basic search for groups of nodes that can be handled by the EP.
69- This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP
70- but B is between them in the topological sort as you'll get two single node capabilities. However if can also
71- be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected.
72- Not sure how often each of these scenarios happens.
73-
74- A B C
75- | / |
76- D E
77- | |
78-
79- Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order,
80- accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all
81- connected nodes that can be handled are grouped together.
82- */
83-
8470 const auto & logger = *GetLogger ();
8571
8672 emscripten::val wnn_builder = emscripten::val::global (" MLGraphBuilder" ).new_ (wnn_context_);
8773 if (!wnn_builder.as <bool >()) {
8874 ORT_THROW (" Failed to create WebNN builder." );
8975 }
9076
91- const auto node_groups = webnn::GetSupportedNodes (graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);
92- wnn_builder = emscripten::val::undefined ();
77+ // Get all the NodeUnits in the graph_viewer
78+ std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
79+ std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
9380
94- if (node_groups.empty ()) {
95- return result;
96- }
81+ std::tie (node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits (graph_viewer, logger);
9782
98- const auto & graph_output_list = graph_viewer.GetOutputs ();
99- InlinedHashSet<const NodeArg*> graph_outputs (graph_output_list.cbegin (), graph_output_list.cend ());
83+ const auto supported_nodes = webnn::GetSupportedNodes (graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);
10084
101- size_t num_of_supported_nodes = 0 ;
102- for (const auto & group : node_groups) {
103- if (group.empty ())
104- continue ;
85+ const auto gen_metadef_name = [&]() {
86+ HashValue model_hash;
87+ int metadef_id = metadef_id_generator_.GenerateId (graph_viewer, model_hash);
88+ return MakeString (WEBNN, " _" , model_hash, " _" , metadef_id);
89+ };
10590
106- num_of_supported_nodes += group. size ();
107- LOGS (logger, VERBOSE) << " WebNNExecutionProvider::GetCapability, current supported node group size: "
108- << group. size ( );
91+ auto result = utils::CreateSupportedPartitions (graph_viewer, supported_nodes, {},
92+ gen_metadef_name, WEBNN, kWebNNExecutionProvider ,
93+ &node_unit_map, /* drop_constant_initializers */ true );
10994
110- InlinedHashSet<NodeIndex> node_set;
111- node_set.reserve (group.size ());
112- for (const auto & index : group) {
113- node_set.insert (index);
114- }
95+ // Release wnn_builder
96+ wnn_builder = emscripten::val::undefined ();
11597
116- std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
98+ const auto & graph_output_list = graph_viewer.GetOutputs ();
99+ InlinedHashSet<const NodeArg*> graph_outputs (graph_output_list.cbegin (), graph_output_list.cend ());
100+
101+ for (auto & capability : result) {
102+ auto & sub_graph = capability->sub_graph ;
103+ if (sub_graph->nodes .empty ())
104+ continue ;
117105
118106 std::vector<std::string> subgraph_initializers;
119- InlinedHashSet<const NodeArg*> node_outputs;
120- InlinedHashSet<const NodeArg*> subgraph_inputs;
121- InlinedHashSet<const NodeArg*> subgraph_outputs;
122- std::vector<const NodeArg*> ordered_subgraph_inputs;
123- // Output should be unique. It may be produced as graph output and subgraph output.
124- InlinedHashSet<const NodeArg*> ordered_subgraph_outputs;
125-
126- for (const auto & index : group) {
127- sub_graph->nodes .push_back (index);
107+ for (const auto & index : sub_graph->nodes ) {
128108 const auto * node = graph_viewer.GetNode (index);
129109
130110 for (const auto * input : node->InputDefs ()) {
@@ -136,39 +116,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
136116 if (is_subgraph && Contains (all_initializers, input->Name ())) {
137117 subgraph_initializers.push_back (input->Name ());
138118 }
139- // If the node input was not produced by this subgraph, add it to the subgraph inputs.
140- if (node_outputs.count (input) == 0 ) {
141- if (subgraph_inputs.count (input) == 0 ) {
142- subgraph_inputs.insert (input);
143- ordered_subgraph_inputs.push_back (input);
144- }
145- }
146- }
147-
148- const auto & output_defs = node->OutputDefs ();
149- for (const auto * output_def : output_defs) {
150- node_outputs.insert (output_def);
151- // if output is overall graph output we need to produce it.
152- if (graph_outputs.count (output_def) != 0 ) {
153- ordered_subgraph_outputs.insert (output_def);
154- }
155- }
156-
157- // if output connects to a node not in this subgraph we need to produce it.
158- for (auto it = node->OutputEdgesBegin (), end = node->OutputEdgesEnd (); it != end; ++it) {
159- if (node_set.count (it->GetNode ().Index ()) == 0 ) {
160- const auto * output_def = output_defs[it->GetSrcArgIndex ()];
161- if (subgraph_outputs.count (output_def) == 0 ) {
162- subgraph_outputs.insert (output_def);
163- ordered_subgraph_outputs.insert (output_def);
164- }
165- }
166119 }
167120 }
168121
169122 // Assign inputs and outputs to subgraph's meta_def.
170123 uint64_t model_hash;
171124 int metadef_id = metadef_id_generator_.GenerateId (graph_viewer, model_hash);
125+ const auto meta_def_old = sub_graph->GetMetaDef ();
172126 auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
173127 meta_def->name = " WEBNN_" + std::to_string (model_hash) + " _" + std::to_string (metadef_id);
174128 meta_def->domain = kMSDomain ;
@@ -181,20 +135,24 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
181135 }
182136 }
183137
184- for (const auto & input : ordered_subgraph_inputs ) {
185- meta_def->inputs .push_back (input-> Name () );
138+ for (const auto & input : meta_def_old-> inputs ) {
139+ meta_def->inputs .push_back (input);
186140 }
187141
188- for (const auto & output : ordered_subgraph_outputs ) {
189- meta_def->outputs .push_back (output-> Name () );
142+ for (const auto & output : meta_def_old-> outputs ) {
143+ meta_def->outputs .push_back (output);
190144 }
191145
192146 sub_graph->SetMetaDef (std::move (meta_def));
193-
194- result.push_back (std::make_unique<ComputeCapability>(std::move (sub_graph)));
195147 }
196148
197- auto num_of_partitions = result.size ();
149+ const auto num_of_partitions = result.size ();
150+ const auto num_of_supported_nodes = std::accumulate (
151+ result.begin (), result.end (), size_t {0 },
152+ [](const auto & acc, const auto & partition) -> size_t {
153+ return acc + (partition && partition->sub_graph ? partition->sub_graph ->nodes .size () : 0 );
154+ });
155+
198156 const auto summary_msg = MakeString (
199157 " WebNNExecutionProvider::GetCapability," ,
200158 " number of partitions supported by WebNN: " , num_of_partitions,
0 commit comments