Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Cos | ai.onnx(7+) | cos | ✓ | ✓ | |
| CumSum | ai.onnx(11-13, 14+) | cumulativeSum | ✓ | ✓ | 'axis' input should be a constant |
| Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | |
| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | | ✓ | |
| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | | ✓ | The shape of x_scale should be a subsample of the shape of input |
| Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode |
| Einsum | ai.onnx(12+) | reshape, transpose, matmul, reduceSum, mul, triangular | ✓ | ✓ | |
| Elu | ai.onnx(7+) | elu | ✓ | ✓ | WebNN CPU backend only supports 'alpha' value is 1.0 |
Expand Down Expand Up @@ -71,7 +71,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Pad | ai.onnx(7-10, 11-12, 13-17, 18, 19-20, 21+) | pad | ✓ | ✓ | modes == 'wrap' is not supported |
| Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow | ✓ | ✓ | |
| PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | ✓ | ✓ | WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) |
| QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear | | ✓ | |
| QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear | | ✓ | The shape of x_scale should be a subsample of the shape of input |
| Reciprocal | ai.onnx(7-12, 13+) | reciprocal | ✓ | ✓ | |
| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | ✓ | ✓ | Input 'axes' if present should be a constant |
| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | ✓ | ✓ | Input 'axes' if present should be a constant |
Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class QDQOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
Expand Down Expand Up @@ -118,6 +120,38 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}

// Operator support related.
bool QDQOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

std::vector<int64_t> input_shape;
std::vector<int64_t> scale_shape;

if (!GetShape(*input_defs[0], input_shape, logger) || !GetShape(*input_defs[1], scale_shape, logger)) {
return false;
}

// WebNN requires the scale_shape to be a subsample of the input_shape.
if (scale_shape.size() > input_shape.size()) {
LOGS(logger, VERBOSE) << "The rank of scale is larger than the rank of input";
return false;
}

for (size_t i = 0; i < scale_shape.size(); ++i) {
auto scale_dim = scale_shape[scale_shape.size() - i - 1];
auto input_dim = input_shape[input_shape.size() - i - 1];
if (input_dim % scale_dim != 0) {
LOGS(logger, VERBOSE) << "The shape of scale is not a subsample of the shape of input";
return false;
}
}

return true;
}

bool QDQOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
Expand Down
Loading