@@ -60,16 +60,17 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
60
60
emscripten::val options = emscripten::val::object ();
61
61
options.set (" label" , node.Name ());
62
62
NodeAttrHelper helper (node);
63
+ const bool is_nhwc = model_builder.GetPreferredLayout () == DataLayout::NHWC;
63
64
64
- const auto kernel_shape = helper.Get (" kernel_shape" , std::vector<int32_t >{0 , 0 });
65
+ const auto onnx_kernel_shape = helper.Get (" kernel_shape" , std::vector<int32_t >{0 , 0 });
65
66
if (!is_global_pooling) {
66
- options.set (" windowDimensions" , emscripten::val::array (kernel_shape ));
67
+ options.set (" windowDimensions" , emscripten::val::array (onnx_kernel_shape ));
67
68
}
68
69
const auto strides = helper.Get (" strides" , std::vector<int32_t >{1 , 1 });
69
70
options.set (" strides" , emscripten::val::array (strides));
70
71
const auto dilations = helper.Get (" dilations" , std::vector<int32_t >{1 , 1 });
71
72
options.set (" dilations" , emscripten::val::array (dilations));
72
- if (model_builder. GetPreferredLayout () == DataLayout::NHWC ) {
73
+ if (is_nhwc ) {
73
74
options.set (" layout" , emscripten::val (" nhwc" ));
74
75
} else {
75
76
options.set (" layout" , emscripten::val (" nchw" ));
@@ -78,7 +79,6 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
78
79
// Add Padding.
79
80
// Usually using autopadding is more efficient than using explicit padding.
80
81
// Try to see if we can map explicit padding to auto padding.
81
- const auto onnx_kernel_shape = helper.Get (" kernel_shape" , std::vector<int64_t >{0 , 0 });
82
82
const auto onnx_strides = helper.Get (" strides" , std::vector<int64_t >{1 , 1 });
83
83
const auto onnx_pads = helper.Get (" pads" , std::vector<int64_t >{0 , 0 , 0 , 0 });
84
84
auto pads = helper.Get (" pads" , std::vector<uint32_t >{0 , 0 , 0 , 0 });
@@ -93,7 +93,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
93
93
helper.Get (" dilations" , std::vector<int64_t >{1 , 1 }),
94
94
auto_pad_type,
95
95
pads_out,
96
- model_builder. GetPreferredLayout () == DataLayout::NCHW ));
96
+ !is_nhwc ));
97
97
pads = GetNarrowedIntfromInt64<uint32_t >(pads_out);
98
98
}
99
99
// Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width],
@@ -105,6 +105,23 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
105
105
options.set (" roundingType" , ceil_mode == 0 ? emscripten::val (" floor" )
106
106
: emscripten::val (" ceil" ));
107
107
108
+ // WebNN doesn't support AveragePool with count_include_pad == 1, emulate it by pad + averagePool2d.
109
+ if (op_type == " AveragePool" && helper.Get (" count_include_pad" , 0 ) == 1 ) {
110
+ std::vector<uint32_t > beginning_padding{0 , 0 , pads[0 ], pads[1 ]};
111
+ std::vector<uint32_t > ending_padding{0 , 0 , pads[2 ], pads[3 ]};
112
+ // Unset padding option, because we will use pad op instead.
113
+ options.set (" padding" , emscripten::val::array (std::vector<uint32_t >{0 , 0 , 0 , 0 }));
114
+ if (is_nhwc) {
115
+ beginning_padding = {0 , pads[0 ], pads[1 ], 0 };
116
+ ending_padding = {0 , pads[2 ], pads[3 ], 0 };
117
+ }
118
+
119
+ emscripten::val pad_options = emscripten::val::object ();
120
+ pad_options.set (" label" , node.Name () + " _pad" );
121
+ input = model_builder.GetBuilder ().call <emscripten::val>(" pad" , input, emscripten::val::array (beginning_padding),
122
+ emscripten::val::array (ending_padding), pad_options);
123
+ }
124
+
108
125
emscripten::val output = model_builder.GetBuilder ().call <emscripten::val>(webnn_op_name.c_str (), input, options);
109
126
model_builder.AddOperand (node.OutputDefs ()[0 ]->Name (), std::move (output));
110
127
return Status::OK ();
@@ -138,13 +155,6 @@ bool PoolOpBuilder::IsOpSupportedImpl(const GraphViewer&,
138
155
}
139
156
}
140
157
141
- if (op_type == " AveragePool" ) {
142
- if (helper.Get (" count_include_pad" , 0 ) != 0 ) {
143
- LOGS (logger, VERBOSE) << " AveragePool only supports count_include_pad == 0" ;
144
- return false ;
145
- }
146
- }
147
-
148
158
if (op_type == " MaxPool" ) {
149
159
if (helper.Get (" storage_order" , 0 ) == 1 ) {
150
160
LOGS (logger, VERBOSE) << " MaxPool storage_order == 1 is not supported" ;
0 commit comments