diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc index d673fcce223e6..9289f4affe1a3 100644 --- a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc +++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc @@ -1,8 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cpu/tensor/grid_sample.h" +#include +#include + +#if defined(MLAS_NEON_INTRINSICS) +#include +#endif +#include "core/providers/cpu/tensor/grid_sample.h" #include "core/framework/element_type_lists.h" #include "core/framework/TensorSeq.h" #include "core/providers/common.h" @@ -148,6 +154,181 @@ T GridSample::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, return pixel; } +namespace { + +constexpr uint8_t kTopLeftMask = 1u << 0; +constexpr uint8_t kTopRightMask = 1u << 1; +constexpr uint8_t kBottomLeftMask = 1u << 2; +constexpr uint8_t kBottomRightMask = 1u << 3; +constexpr uint8_t kAllNeighborsMask = kTopLeftMask | kTopRightMask | kBottomLeftMask | kBottomRightMask; + +template +struct BilinearSamplePlan2D { + int64_t x1; + int64_t x2; + int64_t y1; + int64_t y2; + T w11; + T w12; + T w21; + T w22; + uint8_t mask = 0; +}; +// PrecomputeBilinearSamplePlan2D, the loop runs across all H_out * W_out points, using the right nx/ny for each (oy, ox) and storing that point’s four indices, four weights, and mask in plans[idx] +template +void PrecomputeBilinearSamplePlan2D(const T* grid_data, + int64_t H_out, + int64_t W_out, + int64_t H_in, + int64_t W_in, + std::vector>& plans) { + const int64_t point_count = H_out * W_out; + + for (int64_t idx = 0; idx < point_count; ++idx) { + auto& plan = plans[onnxruntime::narrow(idx)]; + const T nx = grid_data[idx * 2]; + const T ny = grid_data[idx * 2 + 1]; + const T x = GsDenormalize(nx, W_in, false); + const T y = GsDenormalize(ny, H_in, false); + + const int64_t x1 = static_cast(std::floor(x)); + const int64_t y1 = static_cast(std::floor(y)); + const int64_t x2 = x1 + 1; + const int64_t y2 = y1 + 1; + + const T dx2 = static_cast(x2) - x; + const T dx1 = x - static_cast(x1); + const T dy2 = static_cast(y2) - y; + const T dy1 = y - static_cast(y1); + + uint8_t mask = 0; + if (x1 >= 0 && x1 < W_in && y1 >= 0 && y1 < H_in) { + mask |= kTopLeftMask; + } + if (x2 >= 0 && x2 < W_in && y1 >= 0 && y1 < H_in) { + mask |= kTopRightMask; + } + if (x1 >= 0 && x1 < W_in && y2 >= 0 && y2 < H_in) { + mask |= kBottomLeftMask; + } + if (x2 >= 0 && x2 < W_in && y2 >= 0 && y2 < H_in) { + mask |= kBottomRightMask; + } + + plan.x1 = x1; + plan.x2 = x2; + plan.y1 = y1; + plan.y2 = y2; + plan.w11 = dy2 * dx2; + plan.w12 = dy2 * dx1; + plan.w21 = dy1 * dx2; + plan.w22 = dy1 * dx1; + plan.mask = mask; + } +} + +template +void EvaluatePlanForChannel(const T* input_data, + T* output_data, + int64_t W_in, + const BilinearSamplePlan2D* plan_data, + int64_t point_count) { + for (int64_t idx = 0; idx < point_count; ++idx) { + const auto& plan = plan_data[idx]; + if (plan.mask == 0) { + output_data[idx] = T{}; + continue; + } + +#if defined(MLAS_NEON_INTRINSICS) + if constexpr (std::is_same_v) { + if (plan.mask == kAllNeighborsMask) { + const float* row1_ptr = input_data + plan.y1 * W_in + plan.x1; + const float* row2_ptr = input_data + plan.y2 * W_in + plan.x1; + + float32x2_t row1 = vld1_f32(row1_ptr); // [p11, p12] + float32x2_t row2 = vld1_f32(row2_ptr); // [p21, p22] + float32x4_t neighbors = vcombine_f32(row1, row2); + + float32x2_t weights_row1 = vdup_n_f32(plan.w12); + weights_row1 = vset_lane_f32(plan.w11, weights_row1, 0); + float32x2_t weights_row2 = vdup_n_f32(plan.w22); + weights_row2 = vset_lane_f32(plan.w21, weights_row2, 0); + float32x4_t weights = vcombine_f32(weights_row1, weights_row2); + + float32x4_t products = vmulq_f32(neighbors, weights); + float32x2_t sum_pairs = vadd_f32(vget_low_f32(products), vget_high_f32(products)); + float32x2_t accum = vpadd_f32(sum_pairs, sum_pairs); + output_data[idx] = vget_lane_f32(accum, 0); + continue; + } + } +#endif + + T p11 = T{}; + T p12 = T{}; + T p21 = T{}; + T p22 = T{}; + + if (plan.mask == kAllNeighborsMask) { + const int64_t row1 = plan.y1 * W_in; + const int64_t row2 = plan.y2 * W_in; + p11 = input_data[row1 + plan.x1]; + p12 = input_data[row1 + plan.x2]; + p21 = input_data[row2 + plan.x1]; + p22 = input_data[row2 + plan.x2]; + } else { + if (plan.mask & kTopLeftMask) { + p11 = input_data[plan.y1 * W_in + plan.x1]; + } + if (plan.mask & kTopRightMask) { + p12 = input_data[plan.y1 * W_in + plan.x2]; + } + if (plan.mask & kBottomLeftMask) { + p21 = input_data[plan.y2 * W_in + plan.x1]; + } + if (plan.mask & kBottomRightMask) { + p22 = input_data[plan.y2 * W_in + plan.x2]; + } + } + + output_data[idx] = plan.w11 * p11 + plan.w12 * p12 + plan.w21 * p21 + plan.w22 * p22; + } +} + +template +void TryRunBilinearZerosFastPath2D(const Tensor& input, + const Tensor& grid, + Tensor& output, + int64_t n, + int64_t C, + int64_t H_in, + int64_t W_in, + int64_t H_out, + int64_t W_out, + concurrency::ThreadPool* tp, + std::vector>& sampling_plan) { + const int64_t plane_in = H_in * W_in; + const int64_t plane_out = H_out * W_out; + sampling_plan.resize(onnxruntime::narrow(plane_out)); + + const T* grid_data = grid.Data() + n * plane_out * 2; + PrecomputeBilinearSamplePlan2D(grid_data, H_out, W_out, H_in, W_in, sampling_plan); + + const T* input_data = input.Data(); + T* output_data = output.MutableData(); + + concurrency::ThreadPool::TrySimpleParallelFor( + tp, onnxruntime::narrow(C), + [&](std::ptrdiff_t c) { + const T* X_data = input_data + (n * C + c) * plane_in; + T* Y_data = output_data + (n * C + c) * plane_out; + EvaluatePlanForChannel(X_data, Y_data, W_in, sampling_plan.data(), plane_out); + }); +} + +} // namespace + // When grid sampling, padding is applied before interpolation. // For instance, in bilinear mode and zeros padding-mode, pixel p at actual // image location (-0.5, -0.5) @@ -210,13 +391,14 @@ Status GridSample::Compute(OpKernelContext* context) const { T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; - for (int64_t n = 0; n < N; n++) { - const T* grid_data = grid->Data() + n * (H_out * W_out) * 2; + + const auto run_generic_path_for_n = [&](int64_t n_idx) { + const T* grid_data = grid->Data() + n_idx * (H_out * W_out) * 2; concurrency::ThreadPool::TrySimpleParallelFor( tp, onnxruntime::narrow(C), [&](std::ptrdiff_t c) { - const T* X_data = input->Data() + (n * C + c) * (H_in * W_in); - T* Y_data = Y.MutableData() + (n * C + c) * (H_out * W_out); + const T* X_data = input->Data() + (n_idx * C + c) * (H_in * W_in); + T* Y_data = Y.MutableData() + (n_idx * C + c) * (H_out * W_out); for (int64_t oy = 0; oy < H_out; oy++) { for (int64_t ox = 0; ox < W_out; ox++) { @@ -265,6 +447,19 @@ Status GridSample::Compute(OpKernelContext* context) const { } } }); + }; + + const bool can_use_fast_path = (mode_ == Linear && padding_mode_ == Zeros && !align_corners_); + for (int64_t n = 0; n < N; n++) { + if (can_use_fast_path) { + // Choose fast path when all 4 neighbors are within the image and use zero for out-of-boundary neighbors. + // This fast path can be 2-3x faster than the generic path with boundary check and supports Neon optimization. + // sampling_plan helps precomputing a separate plan entry per output pixel. + std::vector> sampling_plan; + TryRunBilinearZerosFastPath2D(*input, *grid, Y, n, C, H_in, W_in, H_out, W_out, tp, sampling_plan); + } else { + run_generic_path_for_n(n); + } } } else if (data_dims == 3) { // sample 3d; diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 05cfb5c13d689..48311b19147ef 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -727,6 +727,38 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corner RunTests(test, GetExecutionProviders(20)); } +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_linear_zeros_mixed_bounds) { + // Crafts grid points that mix fully in-bounds sampling with cases where either the right, bottom, + // or both neighbors fall outside the source image so zero padding must be applied. This ensures + // the optimized bilinear fast path matches the generic implementation for boundary handling. + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "zeros"; + int64_t align_corners = 0; + std::initializer_list X_shape{1, 1, 2, 2}; + std::initializer_list X_data{TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f)}; + std::initializer_list Grid_shape{1, 2, 2, 2}; + // (nx, ny) pairs: center (in-bounds), right edge (x out), bottom edge (y out), corner (both out) + std::initializer_list Grid_data{ + TypeParam(0.0f), TypeParam(0.0f), // center (all neighbors in bounds) + TypeParam(0.9f), TypeParam(0.0f), // near right edge (right neighbors out of bounds) + TypeParam(0.0f), TypeParam(0.9f), // near bottom edge (bottom neighbors out) + TypeParam(0.9f), TypeParam(0.9f)}; // near bottom-right corner (both right and bottom neighbors out) + std::initializer_list Y_shape{1, 1, 2, 2}; + std::initializer_list Y_data{ + TypeParam(2.5f), // all neighbors in bounds + TypeParam(1.8f), // right neighbors partially out-of-bounds + TypeParam(2.1f), // bottom neighbors partially out-of-bounds + TypeParam(1.44f)}; // both right and bottom neighbors out-of-bounds + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + RunTests(test, GetExecutionProviders(20)); +} + TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear";