Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN] Add op support validation for decomposed WebNN ops #23370

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
std::unordered_set<const Node*> supported_nodes;

for (const auto& node : graph_viewer.Nodes()) {
bool supported = false;
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node.OpType(), wnn_builder, device_type)) {
supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
}
const bool supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
<< "] index: [" << node.Index()
<< "] name: [" << node.Name()
Expand Down Expand Up @@ -140,7 +136,7 @@ bool AreInputDataTypesSame(const std::string& op_type,
return true;
}

bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) {
bool IsSupportedDataType(const int32_t& onnx_data_type, const emscripten::val& webnn_supported_data_types) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is an immutable simple value being passed by reference? 🤔

auto it = onnx_to_webnn_data_type_map.find(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_data_type));
if (it == onnx_to_webnn_data_type_map.end())
return false;
Expand All @@ -155,7 +151,7 @@ bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& we

// Check if the input or output data type of ONNX node is supported by the WebNN operator.
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const int32_t& onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
Expand All @@ -170,7 +166,7 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,

bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t onnx_data_type,
const int32_t& onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
Expand All @@ -179,6 +175,7 @@ bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now";
return false;
}

if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter ["
<< webnn_input_output_name << "]";
Expand Down
60 changes: 42 additions & 18 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,15 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
// TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops,
// we need to check the support of the decomposed ops.

// Some ONNX ops are supported by decomposed WebNN ops.
static const InlinedHashMap<std::string, std::vector<std::string>> decomposed_op_map = {
{"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}},
{"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "split"}},
{"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
{"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
};
// ONNX op type to WebNN op type mapping.
static const InlinedHashMap<std::string, std::string> op_map = {
{"Abs", "abs"},
{"Add", "add"},
Expand Down Expand Up @@ -247,7 +254,6 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Log", "log"},
{"LpPool", "l2Pool2d"},
{"LSTM", "lstm"},
{"LRN", "averagePool2d"},
{"MatMul", "matmul"},
{"MatMulInteger", "matmulInteger"},
{"Max", "max"},
Expand Down Expand Up @@ -275,17 +281,14 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"RotaryEmbedding", "gather"},
{"ScatterElements", "scatterElements"},
{"ScatterND", "scatterND"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"SimplifiedLayerNormalization", "layerNormalization"},
{"Softplus", "softplus"},
{"Softsign", "softsign"},
{"Sin", "sin"},
{"SkipSimplifiedLayerNormalization", "layerNormalization"},
{"Slice", "slice"},
{"Softmax", "softmax"},
{"Split", "split"},
Expand All @@ -302,16 +305,37 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Xor", "logicalXor"},
};

inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
const WebnnDeviceType device_type) {
auto op_map_entry = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map or
// if the WebNN op has not been implemented in MLGraphBuilder in current browser.
if (op_map_entry == op_map.end() || !wnn_builder[op_map_entry->second].as<bool>()) {
return false;
}
// WebNN op name to its first input name mapping, only record the name that is different from "input".
// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits.
static const InlinedHashMap<std::string, std::string> webnn_op_first_input_name_map = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(consider) All static const initializers add to DLL load time (or rather WASM module initialization), even if never used - death by a thousand paper cuts. Consider populating the std::map on WebNN EP usage and declaring a simple constexpr array here.

Also can the map be std::string_view -> std::string_view since they are constants anyway? The .at() might not working directly if assigning const auto& op_type = node.OpType();, but it should if doing std::string_view op_type = node.OpType(); ... decomposed_op_map.at(op_type) ...

{"add", "a"},
{"concat", "inputs"},
{"div", "a"},
{"equal", "a"},
{"gemm", "a"},
{"greater", "a"},
{"greaterOrEqual", "a"},
{"lesser", "a"},
{"lesserOrEqual", "a"},
{"logicalAnd", "a"},
{"logicalNot", "a"},
{"logicalOr", "a"},
{"logicalXor", "a"},
{"matmul", "a"},
{"max", "a"},
{"min", "a"},
{"mul", "a"},
{"pow", "a"},
{"sub", "a"},
{"where", "condition"},
};

return true;
// Retrieve the first input name of a WebNN op used for validating supported input data types.
// WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
// Special names other than 'input' are recorded in the webnn_op_first_input_name_map.
inline std::string GetWebNNOpFirstInputName(const std::string& webnn_op_type) {
auto it = webnn_op_first_input_name_map.find(webnn_op_type);
return (it != webnn_op_first_input_name_map.end()) ? it->second : "input";
}

inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
Expand Down Expand Up @@ -341,16 +365,16 @@ static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> o
bool AreInputDataTypesSame(const std::string& op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger);
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsSupportedDataType(const int32_t& onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const int32_t& onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger);
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t onnx_data_type,
const int32_t& onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::string webnn_op_type;

Check warning on line 65 in onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc:65: Add #include <string> for string [build/include_what_you_use] [4]
if (!GetWebNNOpType(op_type, webnn_op_type))
return false;

return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
const auto webnn_input_name = GetWebNNOpFirstInputName(op_type);
return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits,
webnn_input_name, "input", logger);
}

bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
Expand Down
24 changes: 0 additions & 24 deletions onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ class CastOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
private:
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -85,25 +80,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}

// Operator support related.
bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_type;

if (!GetType(*input_defs[0], input_type, logger))
return false;

if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger))
return false;

NodeAttrHelper helper(node);
// Check cast to type.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger);
}

void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<CastOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
Expand Down
45 changes: 45 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class LRNOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
Expand Down Expand Up @@ -142,6 +146,47 @@ bool LRNOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}

bool LRNOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_type = 0;
if (!GetType(*input_defs[0], input_type, logger)) {
return false;
}

// Check if the input data type is supported by each decomposed WebNN op.
// Decomposed ops include: "add", "averagePool2d", "div", "mul", "pad", "pow" and "transpose".
for (const auto& webnn_op_type : decomposed_op_map.at(op_type)) {
const auto webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type);
if (!IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits, webnn_input_name, "X", logger)) {
return false;
}
}

return true;
}

bool LRNOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
int32_t output_type = 0;
if (!GetType(*output_defs[0], output_type, logger)) {
return false;
}

// Check if the output data type is supported by every decomposed WebNN op.
for (const auto& webnn_op_type : decomposed_op_map.at(op_type)) {
if (!IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, output_type, wnn_limits, "output", "Y", logger)) {
return false;
}
}

return true;
}

void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<LRNOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class NormalizationOpBuilder : public BaseOpBuilder {
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
Expand Down Expand Up @@ -305,7 +307,44 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet&
return false;
}

return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
if (op_type == "SimplifiedLayerNormalization" || op_type == "SkipSimplifiedLayerNormalization") {
// SkipSimplifiedLayerNormalization and SimplifiedLayerNormalization are supported by decomposed WebNN ops.
// Check if the input data type is supported by each decomposed WebNN op.
// Decomposed ops include: "add", "div", "mul", "pow", "reduceMean" and "sqrt".
for (const auto& webnn_op_type : decomposed_op_map.at(op_type)) {
const auto webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type);
if (!IsDataTypeSupportedByWebNNOp(
op_type, webnn_op_type, input0_type, wnn_limits, webnn_input_name, "input", logger)) {
return false;
}
}
return true;
} else {
return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
}
}

bool NormalizationOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
int32_t output_type = 0;
if (!GetType(*output_defs[0], output_type, logger)) {
return false;
}

if (op_type == "SimplifiedLayerNormalization" || op_type == "SkipSimplifiedLayerNormalization") {
// Check if the output data type is supported by every decomposed WebNN op.
for (const auto& webnn_op_type : decomposed_op_map.at(op_type)) {
if (!IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, output_type, wnn_limits, "output", "output", logger)) {
return false;
}
}
return true;
} else {
return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger);
}
}

void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
Expand Down
Loading
Loading