diff --git a/onnxruntime/core/providers/cpu/nn/pool.cc b/onnxruntime/core/providers/cpu/nn/pool.cc index d6b9ed693432b..f38b63c0a9663 100644 --- a/onnxruntime/core/providers/cpu/nn/pool.cc +++ b/onnxruntime/core/providers/cpu/nn/pool.cc @@ -289,7 +289,7 @@ Status AveragePoolV19::Compute(OpKernelContext* context) const { RunLoop>(tp, onnxruntime::narrow(total_channels), {X_data, Y_data, x_step, y_step, dilation_h, pooled_height, stride_h(), - height, kernel_shape, pads, pool_attrs_.count_include_pad, p_}); + height, kernel_shape, pads, pool_attrs_.count_include_pad, p_, pool_attrs_.ceil_mode}); break; } @@ -301,7 +301,7 @@ Status AveragePoolV19::Compute(OpKernelContext* context) const { RunLoop>( tp, onnxruntime::narrow(total_channels), {X_data, Y_data, x_step, y_step, dilation_h, dilation_w, pooled_height, pooled_width, stride_h(), - stride_w(), height, width, kernel_shape, pads, pool_attrs_.count_include_pad, p_}); + stride_w(), height, width, kernel_shape, pads, pool_attrs_.count_include_pad, p_, pool_attrs_.ceil_mode}); break; } case 3: { @@ -314,7 +314,7 @@ Status AveragePoolV19::Compute(OpKernelContext* context) const { {X_data, Y_data, x_step, y_step, dilation_h, dilation_w, dilation_d, pooled_height, pooled_width, pooled_depth, stride_h(), stride_w(), stride_d(), height, - width, depth, kernel_shape, pads, pool_attrs_.count_include_pad, p_}); + width, depth, kernel_shape, pads, pool_attrs_.count_include_pad, p_, pool_attrs_.ceil_mode}); break; } default: diff --git a/onnxruntime/core/providers/cpu/nn/pool_functors.h b/onnxruntime/core/providers/cpu/nn/pool_functors.h index 476a9a0338969..fa3155eb70b1f 100644 --- a/onnxruntime/core/providers/cpu/nn/pool_functors.h +++ b/onnxruntime/core/providers/cpu/nn/pool_functors.h @@ -390,6 +390,7 @@ struct AveragePool1DTask final { gsl::span pads; bool count_include_pad; int64_t p; + int64_t ceil_mode; TensorOpCost Cost() { double loop_count = static_cast(pooled_height * kernel_shape[0]); return TensorOpCost{loop_count, loop_count, loop_count}; @@ -406,7 +407,9 @@ struct AveragePool1DTask final { for (int64_t ph = 0; ph < pooled_height; ++ph) { int64_t hstart = ph * stride_h - pads[0]; int64_t hend = hstart + kernel_shape[0] * dilation_h; - hend = std::min(hend, height + pads[1]); + if (ceil_mode) { + hend = std::min(hend, height + pads[1]); + } y_d[ph] = 0; int total_elements = 0; for (int64_t h = hstart; h < hend; h += dilation_h) { @@ -444,6 +447,7 @@ struct AveragePool2DTask final { gsl::span pads; bool count_include_pad; int64_t p; + int64_t ceil_mode; TensorOpCost Cost() { double loop_count = static_cast(pooled_height * pooled_width * kernel_shape[0] * kernel_shape[1]); @@ -462,11 +466,15 @@ struct AveragePool2DTask final { for (int64_t ph = 0; ph < pooled_height; ++ph) { int64_t hstart = ph * stride_h - pads[0]; int64_t hend = hstart + kernel_shape[0] * dilation_h; - hend = std::min(hend, height + pads[1]); + if (ceil_mode) { + hend = std::min(hend, height + pads[1]); + } for (int64_t pw = 0; pw < pooled_width; ++pw) { int64_t wstart = pw * stride_w - pads[1]; int64_t wend = wstart + kernel_shape[1] * dilation_w; - wend = std::min(wend, width + pads[3]); + if (ceil_mode) { + wend = std::min(wend, width + pads[3]); + } const int64_t pool_index = ph * pooled_width + pw; y_d[pool_index] = 0; int total_elements = 0; @@ -515,6 +523,7 @@ struct AveragePool3DTask { gsl::span pads; bool count_include_pad; int64_t p; + int64_t ceil_mode; void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const { for (std::ptrdiff_t c = begin; c < end; ++c) { @@ -535,15 +544,21 @@ struct AveragePool3DTask { for (int64_t ph = 0; ph < pooled_height; ++ph) { int64_t hstart = ph * stride_h - pads[0]; int64_t hend = hstart + kernel_shape[0] * dilation_h; - hend = std::min(hend, height + pads[1]); + if (ceil_mode) { + hend = std::min(hend, height + pads[1]); + } for (int64_t pw = 0; pw < pooled_width; ++pw) { int64_t wstart = pw * stride_w - pads[1]; int64_t wend = wstart + kernel_shape[1] * dilation_w; - wend = std::min(wend, width + pads[3]); + if (ceil_mode) { + wend = std::min(wend, width + pads[3]); + } for (int64_t pd = 0; pd < pooled_depth; ++pd) { int64_t dstart = pd * stride_d - pads[2]; int64_t dend = dstart + kernel_shape[2] * dilation_d; - dend = std::min(dend, depth + pads[5]); + if (ceil_mode) { + dend = std::min(dend, depth + pads[5]); + } const int64_t pool_index = ph * pooled_width * pooled_depth + pw * pooled_depth + pd; y_d[pool_index] = 0; int total_elements = 0;