Skip to content

Commit 853a41c

Browse files
peishenyanfdwr
authored andcommitted
[WebNN EP] Optimize model partitioning (#23332)
### Description <!-- Describe your changes. --> The old `GetCapability` function of WebNN EP is just a very simple search for groups of nodes that can be handled. This doesn't work well in the following example graph, where A and D could be handled by the EP, but B is between them in the topological order, as you get two single node capabilities. However, it may also be advantageous if C and E could be handled by the EP, since they would be combined with D even though they are not connected. ``` A B C | / | D E | | ``` Therefore, we improve partitioning results by reusing `utils::CreateSupportedPartitions`, which walks the edges for each node that the EP can handle as they are iterated in topological order. This would guarantee that all connected nodes that can be handled are grouped together. Correspondingly, we modify the `webnn::GetSupportedNodes` function to return the supported nodes instead of the group of supported partitions. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Dwayne Robinson <fdwr@hotmail.com>
1 parent 8b7f688 commit 853a41c

File tree

3 files changed

+61
-117
lines changed

3 files changed

+61
-117
lines changed

onnxruntime/core/providers/webnn/builders/helper.cc

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -99,44 +99,30 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
9999
return true;
100100
}
101101

102-
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
103-
const emscripten::val& wnn_builder,
104-
const WebnnDeviceType device_type,
105-
const emscripten::val& wnn_limits,
106-
const logging::Logger& logger) {
107-
std::vector<std::vector<size_t>> supported_node_groups;
108-
std::vector<size_t> supported_node_group;
109-
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
110-
111-
for (size_t i = 0; i < node_indices.size(); i++) {
112-
auto node_idx = node_indices[i];
113-
const auto* node(graph_viewer.GetNode(node_idx));
102+
std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
103+
const emscripten::val& wnn_builder,
104+
const WebnnDeviceType device_type,
105+
const emscripten::val& wnn_limits,
106+
const logging::Logger& logger) {
107+
std::unordered_set<const Node*> supported_nodes;
108+
109+
for (const auto& node : graph_viewer.Nodes()) {
114110
bool supported = false;
115111
// Firstly check if platform supports the WebNN op.
116-
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
117-
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
112+
if (CheckSingleOp(node.OpType(), wnn_builder, device_type)) {
113+
supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
118114
}
119-
120-
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
121-
<< "] index: [" << node_idx
122-
<< "] name: [" << node->Name()
115+
LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
116+
<< "] index: [" << node.Index()
117+
<< "] name: [" << node.Name()
123118
<< "] supported: [" << supported
124119
<< "]";
125120
if (supported) {
126-
supported_node_group.push_back(node_idx);
127-
} else {
128-
if (!supported_node_group.empty()) {
129-
supported_node_groups.push_back(supported_node_group);
130-
supported_node_group.clear();
131-
}
121+
supported_nodes.insert(&node);
132122
}
133123
}
134124

135-
if (!supported_node_group.empty()) {
136-
supported_node_groups.push_back(supported_node_group);
137-
}
138-
139-
return supported_node_groups;
125+
return supported_nodes;
140126
}
141127

142128
bool AreInputDataTypesSame(const std::string& op_type,

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,12 @@ inline bool TensorExists(const ConstPointerContainer<std::vector<NodeArg*>>& def
188188
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
189189
const logging::Logger& logger, bool allow_empty_input = false);
190190

191-
// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
192-
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
193-
const emscripten::val& wnn_builder,
194-
const WebnnDeviceType device_type,
195-
const emscripten::val& wnn_limits,
196-
const logging::Logger& logger);
191+
// Get a set of nodes supported by WebNN EP.
192+
std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
193+
const emscripten::val& wnn_builder,
194+
const WebnnDeviceType device_type,
195+
const emscripten::val& wnn_limits,
196+
const logging::Logger& logger);
197197
// TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops,
198198
// we need to check the support of the decomposed ops.
199199
static const InlinedHashMap<std::string, std::string> op_map = {

onnxruntime/core/providers/webnn/webnn_execution_provider.cc

Lines changed: 40 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
#include "core/common/safeint.h"
1414
#include "core/providers/webnn/allocator.h"
1515
#include "core/providers/webnn/data_transfer.h"
16+
#include "core/providers/partitioning_utils.h"
17+
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
18+
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
1619

1720
#include "builders/model.h"
1821
#include "builders/helper.h"
1922
#include "builders/model_builder.h"
2023

2124
namespace onnxruntime {
2225

26+
constexpr const char* WEBNN = "WEBNN";
27+
2328
WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags)
2429
: IExecutionProvider{
2530
onnxruntime::kWebNNExecutionProvider,
@@ -51,8 +56,6 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {}
5156
std::vector<std::unique_ptr<ComputeCapability>>
5257
WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
5358
const IKernelLookup& /*kernel_registries*/) const {
54-
std::vector<std::unique_ptr<ComputeCapability>> result;
55-
5659
// For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its
5760
// ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for
5861
// identifying the required initializer names and storing into 'meta_def->constant_initializers'.
@@ -64,67 +67,44 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
6467
all_initializers = webnn::CollectAllInitializedTensors(graph_viewer);
6568
}
6669

67-
/*
68-
Very basic search for groups of nodes that can be handled by the EP.
69-
This doesn't work perfectly if you have a scenario like the following where A and D could be handled by the EP
70-
but B is between them in the topological sort as you'll get two single node capabilities. However if can also
71-
be advantageous if C and E could be handled by the EP as they would be combined with D even though not connected.
72-
Not sure how often each of these scenarios happens.
73-
74-
A B C
75-
| / |
76-
D E
77-
| |
78-
79-
Would probably be better to walk the edges for each node the EP can handle as they are iterated in topological order,
80-
accumulating nodes (and saving which ones have been taken) until you run out. This would guarantee all
81-
connected nodes that can be handled are grouped together.
82-
*/
83-
8470
const auto& logger = *GetLogger();
8571

8672
emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
8773
if (!wnn_builder.as<bool>()) {
8874
ORT_THROW("Failed to create WebNN builder.");
8975
}
9076

91-
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);
92-
wnn_builder = emscripten::val::undefined();
77+
// Get all the NodeUnits in the graph_viewer
78+
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
79+
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
9380

94-
if (node_groups.empty()) {
95-
return result;
96-
}
81+
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
9782

98-
const auto& graph_output_list = graph_viewer.GetOutputs();
99-
InlinedHashSet<const NodeArg*> graph_outputs(graph_output_list.cbegin(), graph_output_list.cend());
83+
const auto supported_nodes = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger);
10084

101-
size_t num_of_supported_nodes = 0;
102-
for (const auto& group : node_groups) {
103-
if (group.empty())
104-
continue;
85+
const auto gen_metadef_name = [&]() {
86+
HashValue model_hash;
87+
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
88+
return MakeString(WEBNN, "_", model_hash, "_", metadef_id);
89+
};
10590

106-
num_of_supported_nodes += group.size();
107-
LOGS(logger, VERBOSE) << "WebNNExecutionProvider::GetCapability, current supported node group size: "
108-
<< group.size();
91+
auto result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
92+
gen_metadef_name, WEBNN, kWebNNExecutionProvider,
93+
&node_unit_map, /*drop_constant_initializers*/ true);
10994

110-
InlinedHashSet<NodeIndex> node_set;
111-
node_set.reserve(group.size());
112-
for (const auto& index : group) {
113-
node_set.insert(index);
114-
}
95+
// Release wnn_builder
96+
wnn_builder = emscripten::val::undefined();
11597

116-
std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
98+
const auto& graph_output_list = graph_viewer.GetOutputs();
99+
InlinedHashSet<const NodeArg*> graph_outputs(graph_output_list.cbegin(), graph_output_list.cend());
100+
101+
for (auto& capability : result) {
102+
auto& sub_graph = capability->sub_graph;
103+
if (sub_graph->nodes.empty())
104+
continue;
117105

118106
std::vector<std::string> subgraph_initializers;
119-
InlinedHashSet<const NodeArg*> node_outputs;
120-
InlinedHashSet<const NodeArg*> subgraph_inputs;
121-
InlinedHashSet<const NodeArg*> subgraph_outputs;
122-
std::vector<const NodeArg*> ordered_subgraph_inputs;
123-
// Output should be unique. It may be produced as graph output and subgraph output.
124-
InlinedHashSet<const NodeArg*> ordered_subgraph_outputs;
125-
126-
for (const auto& index : group) {
127-
sub_graph->nodes.push_back(index);
107+
for (const auto& index : sub_graph->nodes) {
128108
const auto* node = graph_viewer.GetNode(index);
129109

130110
for (const auto* input : node->InputDefs()) {
@@ -136,39 +116,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
136116
if (is_subgraph && Contains(all_initializers, input->Name())) {
137117
subgraph_initializers.push_back(input->Name());
138118
}
139-
// If the node input was not produced by this subgraph, add it to the subgraph inputs.
140-
if (node_outputs.count(input) == 0) {
141-
if (subgraph_inputs.count(input) == 0) {
142-
subgraph_inputs.insert(input);
143-
ordered_subgraph_inputs.push_back(input);
144-
}
145-
}
146-
}
147-
148-
const auto& output_defs = node->OutputDefs();
149-
for (const auto* output_def : output_defs) {
150-
node_outputs.insert(output_def);
151-
// if output is overall graph output we need to produce it.
152-
if (graph_outputs.count(output_def) != 0) {
153-
ordered_subgraph_outputs.insert(output_def);
154-
}
155-
}
156-
157-
// if output connects to a node not in this subgraph we need to produce it.
158-
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
159-
if (node_set.count(it->GetNode().Index()) == 0) {
160-
const auto* output_def = output_defs[it->GetSrcArgIndex()];
161-
if (subgraph_outputs.count(output_def) == 0) {
162-
subgraph_outputs.insert(output_def);
163-
ordered_subgraph_outputs.insert(output_def);
164-
}
165-
}
166119
}
167120
}
168121

169122
// Assign inputs and outputs to subgraph's meta_def.
170123
uint64_t model_hash;
171124
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
125+
const auto meta_def_old = sub_graph->GetMetaDef();
172126
auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
173127
meta_def->name = "WEBNN_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id);
174128
meta_def->domain = kMSDomain;
@@ -181,20 +135,24 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
181135
}
182136
}
183137

184-
for (const auto& input : ordered_subgraph_inputs) {
185-
meta_def->inputs.push_back(input->Name());
138+
for (const auto& input : meta_def_old->inputs) {
139+
meta_def->inputs.push_back(input);
186140
}
187141

188-
for (const auto& output : ordered_subgraph_outputs) {
189-
meta_def->outputs.push_back(output->Name());
142+
for (const auto& output : meta_def_old->outputs) {
143+
meta_def->outputs.push_back(output);
190144
}
191145

192146
sub_graph->SetMetaDef(std::move(meta_def));
193-
194-
result.push_back(std::make_unique<ComputeCapability>(std::move(sub_graph)));
195147
}
196148

197-
auto num_of_partitions = result.size();
149+
const auto num_of_partitions = result.size();
150+
const auto num_of_supported_nodes = std::accumulate(
151+
result.begin(), result.end(), size_t{0},
152+
[](const auto& acc, const auto& partition) -> size_t {
153+
return acc + (partition && partition->sub_graph ? partition->sub_graph->nodes.size() : 0);
154+
});
155+
198156
const auto summary_msg = MakeString(
199157
"WebNNExecutionProvider::GetCapability,",
200158
" number of partitions supported by WebNN: ", num_of_partitions,

0 commit comments

Comments
 (0)