@@ -93,6 +93,10 @@ def get_output_shape_explicit_padding(
9393
9494 if ceil_mode :
9595 output_spatial_shape [dim ] = int (np .ceil (dim_size ))
96+ if (output_spatial_shape [dim ] - 1 ) * strides_spatial [
97+ dim
98+ ] >= input_spatial_shape [dim ] + pads [dim ]:
99+ output_spatial_shape [dim ] -= 1
96100 else :
97101 output_spatial_shape [dim ] = int (np .floor (dim_size ))
98102
@@ -152,20 +156,14 @@ def get_output_shape_auto_pad(
152156 return out_shape
153157
154158
155- def lp_pool (x : np .array , p : int ) -> float :
156- y = 0
157- for v in np .nditer (x ):
158- y += abs (v ) ** p
159- return y ** (1.0 / p )
160-
161-
162159def pool (
163160 padded : np .ndarray ,
164161 x_shape : Sequence [int ],
165162 kernel : Sequence [int ],
166163 strides : Sequence [int ],
167164 out_shape : Sequence [int ],
168165 pooling_type : str ,
166+ pads_required : Sequence [int ] | None = None ,
169167 pads : Sequence [int ] | None = None ,
170168 dilations : Sequence [int ] | None = None ,
171169 count_include_pad : int = 0 ,
@@ -178,6 +176,7 @@ def pool(
178176 strides: the strides
179177 out_shape: the shape of the output tensor
180178 pooling_type: the pooling type, can be "AVG", "LPPOOL", or "MAX"
179+ pads_required: the required padding to make sure the sliding window does not go out-of-bound
181180 pads: the padding in an order of head_pad_1, head_pad_2, ..., tail_pad_1, tail_pad_2, ...
182181 dilations: the dilation
183182 count_include_pad: whether to include the padding in the calculation of average and lp pooling
@@ -187,25 +186,27 @@ def pool(
187186 y = np .zeros ([x_shape [0 ], x_shape [1 ], * list (out_shape )], dtype = padded .dtype )
188187 if dilations is None :
189188 dilations = np .ones ([spatial_size ], dtype = np .int64 )
189+ if pads_required is None :
190+ pads_required = np .zeros ([spatial_size * 2 ], dtype = np .int64 )
191+ elif len (pads_required ) == 1 :
192+ pads_required = pads_required * spatial_size * 2
190193 if pads is None :
191194 pads = np .zeros ([spatial_size * 2 ], dtype = np .int64 )
192195 elif len (pads ) == 1 :
193196 pads = pads * spatial_size * 2
194197 strides = strides or [1 ] * spatial_size
195198
196- def lp_pool_p (x ):
197- return lp_pool (x , p )
198-
199+ # Iterate all the possible sliding windows
199200 for shape in itertools .product (
200- range (x_shape [0 ]),
201- range (x_shape [1 ]),
201+ range (x_shape [0 ]), # e.g. dim=0: [0]
202+ range (x_shape [1 ]), # e.g. dim=1: [0, 1]
202203 * [
203204 range (
204205 int (
205206 (
206207 x_shape [i + 2 ]
207- + pads [i ]
208- + pads [i + spatial_size ]
208+ + pads_required [i ]
209+ + pads_required [i + spatial_size ]
209210 - (1 + (kernel [i ] - 1 ) * dilations [i ])
210211 )
211212 / strides [i ]
@@ -216,30 +217,33 @@ def lp_pool_p(x):
216217 ],
217218 ):
218219 window = padded [shape [0 ], shape [1 ]]
219- window_vals = np .array (
220- [
221- window [i ]
222- for i in list (
223- itertools .product (
224- * [
225- range (
226- strides [i ] * shape [i + 2 ],
227- strides [i ] * shape [i + 2 ]
228- + (1 + (kernel [i ] - 1 ) * dilations [i ]),
229- dilations [i ],
230- )
231- for i in range (spatial_size )
232- ]
233- )
220+ elements = []
221+ for i in range (spatial_size ):
222+ # NOTE: The if condition is to avoid the case where the window is out of bound
223+ # we need to avoid the pixels that are out of bound being included in the window
224+ elements .extend (
225+ num
226+ for num in range (
227+ strides [i ] * shape [i + 2 ],
228+ strides [i ] * shape [i + 2 ] + (1 + (kernel [i ] - 1 ) * dilations [i ]),
229+ dilations [i ],
234230 )
235- ]
236- )
231+ if num < x_shape [i + 2 ] + pads [i ] * 2
232+ )
233+ window_vals = np .array (
234+ [window [indices ] for indices in itertools .product (elements )]
235+ )
236+
237237 if pooling_type == "AVG" :
238238 f = np .average
239239 elif pooling_type == "MAX" :
240240 f = np .max
241241 elif pooling_type == "LPPOOL" :
242- f = lp_pool_p
242+
243+ def lp_pool (x : np .array , p : int = p ) -> float :
244+ return np .sum (np .abs (x ) ** p ) ** (1.0 / p )
245+
246+ f = lp_pool
243247 else :
244248 raise NotImplementedError (
245249 f"Pooling type { pooling_type } does not support. Should be AVG, MAX"
@@ -296,18 +300,19 @@ def _run(
296300 out_shape ,
297301 pooling_type ,
298302 pads ,
303+ pads ,
299304 dilations ,
300305 count_include_pad ,
301306 p ,
302307 )
303308 return (y ,)
304309 else :
305- out_shape , pads = get_output_shape_explicit_padding (
310+ out_shape , extra_pads = get_output_shape_explicit_padding (
306311 pads , x_shape [2 :], kernel_shape , strides , dilations , ceil_mode
307312 )
308313 # convert pads from [x1_begin, x2_begin,...,x1_end, x2_end,...] to [(x1_begin, x1_end), (x2_begin, x2_end),...]
309- n_dims = len (pads ) // 2
310- pads_np = [(pads [i ], pads [i + n_dims ]) for i in range (n_dims )]
314+ n_dims = len (extra_pads ) // 2
315+ pads_np = [(extra_pads [i ], extra_pads [i + n_dims ]) for i in range (n_dims )]
311316 padded = np .pad (
312317 x ,
313318 ((0 , 0 ), (0 , 0 ), * pads_np ),
@@ -321,6 +326,7 @@ def _run(
321326 strides ,
322327 out_shape ,
323328 pooling_type ,
329+ extra_pads ,
324330 pads ,
325331 dilations ,
326332 count_include_pad ,
0 commit comments