Skip to content

Commit a742de6

Browse files
committed
op_common
1 parent 9b0e2e8 commit a742de6

File tree

1 file changed

+41
-35
lines changed

1 file changed

+41
-35
lines changed

onnx/reference/ops/op_pool_common.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def get_output_shape_explicit_padding(
9393

9494
if ceil_mode:
9595
output_spatial_shape[dim] = int(np.ceil(dim_size))
96+
if (output_spatial_shape[dim] - 1) * strides_spatial[
97+
dim
98+
] >= input_spatial_shape[dim] + pads[dim]:
99+
output_spatial_shape[dim] -= 1
96100
else:
97101
output_spatial_shape[dim] = int(np.floor(dim_size))
98102

@@ -152,20 +156,14 @@ def get_output_shape_auto_pad(
152156
return out_shape
153157

154158

155-
def lp_pool(x: np.array, p: int) -> float:
156-
y = 0
157-
for v in np.nditer(x):
158-
y += abs(v) ** p
159-
return y ** (1.0 / p)
160-
161-
162159
def pool(
163160
padded: np.ndarray,
164161
x_shape: Sequence[int],
165162
kernel: Sequence[int],
166163
strides: Sequence[int],
167164
out_shape: Sequence[int],
168165
pooling_type: str,
166+
pads_required: Sequence[int] | None = None,
169167
pads: Sequence[int] | None = None,
170168
dilations: Sequence[int] | None = None,
171169
count_include_pad: int = 0,
@@ -178,6 +176,7 @@ def pool(
178176
strides: the strides
179177
out_shape: the shape of the output tensor
180178
pooling_type: the pooling type, can be "AVG", "LPPOOL", or "MAX"
179+
pads_required: the required padding to make sure the sliding window does not go out-of-bound
181180
pads: the padding in an order of head_pad_1, head_pad_2, ..., tail_pad_1, tail_pad_2, ...
182181
dilations: the dilation
183182
count_include_pad: whether to include the padding in the calculation of average and lp pooling
@@ -187,25 +186,27 @@ def pool(
187186
y = np.zeros([x_shape[0], x_shape[1], *list(out_shape)], dtype=padded.dtype)
188187
if dilations is None:
189188
dilations = np.ones([spatial_size], dtype=np.int64)
189+
if pads_required is None:
190+
pads_required = np.zeros([spatial_size * 2], dtype=np.int64)
191+
elif len(pads_required) == 1:
192+
pads_required = pads_required * spatial_size * 2
190193
if pads is None:
191194
pads = np.zeros([spatial_size * 2], dtype=np.int64)
192195
elif len(pads) == 1:
193196
pads = pads * spatial_size * 2
194197
strides = strides or [1] * spatial_size
195198

196-
def lp_pool_p(x):
197-
return lp_pool(x, p)
198-
199+
# Iterate all the possible sliding windows
199200
for shape in itertools.product(
200-
range(x_shape[0]),
201-
range(x_shape[1]),
201+
range(x_shape[0]), # e.g. dim=0: [0]
202+
range(x_shape[1]), # e.g. dim=1: [0, 1]
202203
*[
203204
range(
204205
int(
205206
(
206207
x_shape[i + 2]
207-
+ pads[i]
208-
+ pads[i + spatial_size]
208+
+ pads_required[i]
209+
+ pads_required[i + spatial_size]
209210
- (1 + (kernel[i] - 1) * dilations[i])
210211
)
211212
/ strides[i]
@@ -216,30 +217,33 @@ def lp_pool_p(x):
216217
],
217218
):
218219
window = padded[shape[0], shape[1]]
219-
window_vals = np.array(
220-
[
221-
window[i]
222-
for i in list(
223-
itertools.product(
224-
*[
225-
range(
226-
strides[i] * shape[i + 2],
227-
strides[i] * shape[i + 2]
228-
+ (1 + (kernel[i] - 1) * dilations[i]),
229-
dilations[i],
230-
)
231-
for i in range(spatial_size)
232-
]
233-
)
220+
elements = []
221+
for i in range(spatial_size):
222+
# NOTE: The if condition is to avoid the case where the window is out of bound
223+
# we need to avoid the pixels that are out of bound being included in the window
224+
elements.extend(
225+
num
226+
for num in range(
227+
strides[i] * shape[i + 2],
228+
strides[i] * shape[i + 2] + (1 + (kernel[i] - 1) * dilations[i]),
229+
dilations[i],
234230
)
235-
]
236-
)
231+
if num < x_shape[i + 2] + pads[i] * 2
232+
)
233+
window_vals = np.array(
234+
[window[indices] for indices in itertools.product(elements)]
235+
)
236+
237237
if pooling_type == "AVG":
238238
f = np.average
239239
elif pooling_type == "MAX":
240240
f = np.max
241241
elif pooling_type == "LPPOOL":
242-
f = lp_pool_p
242+
243+
def lp_pool(x: np.array, p: int = p) -> float:
244+
return np.sum(np.abs(x) ** p) ** (1.0 / p)
245+
246+
f = lp_pool
243247
else:
244248
raise NotImplementedError(
245249
f"Pooling type {pooling_type} does not support. Should be AVG, MAX"
@@ -296,18 +300,19 @@ def _run(
296300
out_shape,
297301
pooling_type,
298302
pads,
303+
pads,
299304
dilations,
300305
count_include_pad,
301306
p,
302307
)
303308
return (y,)
304309
else:
305-
out_shape, pads = get_output_shape_explicit_padding(
310+
out_shape, extra_pads = get_output_shape_explicit_padding(
306311
pads, x_shape[2:], kernel_shape, strides, dilations, ceil_mode
307312
)
308313
# convert pads from [x1_begin, x2_begin,...,x1_end, x2_end,...] to [(x1_begin, x1_end), (x2_begin, x2_end),...]
309-
n_dims = len(pads) // 2
310-
pads_np = [(pads[i], pads[i + n_dims]) for i in range(n_dims)]
314+
n_dims = len(extra_pads) // 2
315+
pads_np = [(extra_pads[i], extra_pads[i + n_dims]) for i in range(n_dims)]
311316
padded = np.pad(
312317
x,
313318
((0, 0), (0, 0), *pads_np),
@@ -321,6 +326,7 @@ def _run(
321326
strides,
322327
out_shape,
323328
pooling_type,
329+
extra_pads,
324330
pads,
325331
dilations,
326332
count_include_pad,

0 commit comments

Comments
 (0)