Skip to content

Commit f97ca29

Browse files
committed
update pool tests
1 parent 505a942 commit f97ca29

File tree

10 files changed

+24
-17
lines changed

10 files changed

+24
-17
lines changed
Binary file not shown.

onnx/backend/test/data/node/test_lppool_1d_default/test_data_set_0/output_0.pb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
ByJ����?��?aK@��&@���?�q�?̌s?;�)>wU�>�5�>5�?;��?8C?���>���>�?3h�?B�>8.^?�g%@�L$@&sy?TK�?�@i@')�?�@>�P�?k/�?�&�?���>|w@���?V]�>���?/�?˙�?���>GG�?Dv�?/�?X@�,�?�?���?�;�?��?���?qBf?�>k?�z?]�?"�?�S�>Q��>�i�>(?��+?�4?I�4?��U?�h�?~t�>���?�>�?rr?mHh?p�:?��:?��? ��?ܟ?�p:?�h~?��r?��?hӟ>%�?�E�?w�p?�r�?��?Cd
2-
@�^@��?�E�?6`�?tp�?|R�?��?E�z?2~?,O<?
1+
ByJ����?��?aK@��&@���?�q�?͌s?;�)>wU�>�5�>5�?;��?8C?���>���>�?3h�?B�>7.^?�g%@�L$@%sy?TK�?�@i@')�?�@>�P�?k/�?�&�?���>|w@���?V]�>���?/�?˙�?���>GG�?Dv�?/�?X@�,�?�?���?�;�?��?���?qBf?�>k?�z?]�?"�?�S�>P��>�i�>(?��+?�4?I�4?��U?�h�?~t�>���?�>�?rr?nHh?p�:?��:?��? ��?ܟ?�p:?�h~?��r?��?hӟ>%�?�E�?w�p?�r�?��?Cd
2+
@�^@��?�E�?5`�?tp�?|R�?��?E�z?2~?,O<?
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

onnx/reference/ops/op_pool_common.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -218,22 +218,29 @@ def pool(
218218
],
219219
):
220220
window = padded[shape[0], shape[1]]
221-
elements = []
222-
for i in range(spatial_size):
223-
# NOTE: The if condition is to avoid the case where the window is out of bound
224-
# we need to avoid the pixels that are out of bound being included in the window
225-
elements.extend(
226-
num
227-
for num in range(
228-
strides[i] * shape[i + 2],
229-
strides[i] * shape[i + 2] + (1 + (kernel[i] - 1) * dilations[i]),
230-
dilations[i],
221+
window_vals = np.array(
222+
[
223+
window[i]
224+
for i in list(
225+
itertools.product(
226+
*[
227+
[
228+
pixel
229+
for pixel in range(
230+
strides[i] * shape[i + 2],
231+
strides[i] * shape[i + 2]
232+
+ (1 + (kernel[i] - 1) * dilations[i]),
233+
dilations[i],
234+
)
235+
if pixel
236+
< x_shape[i + 2] + pads[i] + pads[spatial_size + i]
237+
]
238+
for i in range(spatial_size)
239+
]
240+
)
231241
)
232-
if num < x_shape[i + 2] + pads[i] + pads[i + spatial_size]
233-
)
234-
window_vals = np.array(
235-
[window[indices] for indices in itertools.product(elements)]
236-
)
242+
]
243+
)
237244

238245
if pooling_type == "AVG":
239246
f = np.average

0 commit comments

Comments
 (0)