Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 200 additions & 5 deletions onnxruntime/core/providers/cpu/tensor/grid_sample.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cpu/tensor/grid_sample.h"
#include <type_traits>
#include <vector>

#if defined(MLAS_NEON_INTRINSICS)
#include <arm_neon.h>
#endif

#include "core/providers/cpu/tensor/grid_sample.h"
#include "core/framework/element_type_lists.h"
#include "core/framework/TensorSeq.h"
#include "core/providers/common.h"
Expand Down Expand Up @@ -148,6 +154,181 @@ T GridSample<T>::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w,
return pixel;
}

namespace {

constexpr uint8_t kTopLeftMask = 1u << 0;
constexpr uint8_t kTopRightMask = 1u << 1;
constexpr uint8_t kBottomLeftMask = 1u << 2;
constexpr uint8_t kBottomRightMask = 1u << 3;
constexpr uint8_t kAllNeighborsMask = kTopLeftMask | kTopRightMask | kBottomLeftMask | kBottomRightMask;

template <typename T>
struct BilinearSamplePlan2D {
int64_t x1;
int64_t x2;
int64_t y1;
int64_t y2;
T w11;
T w12;
T w21;
T w22;
uint8_t mask = 0;
};
// PrecomputeBilinearSamplePlan2D, the loop runs across all H_out * W_out points, using the right nx/ny for each (oy, ox) and storing that point’s four indices, four weights, and mask in plans[idx]
template <typename T>
void PrecomputeBilinearSamplePlan2D(const T* grid_data,
int64_t H_out,
int64_t W_out,
int64_t H_in,
int64_t W_in,
std::vector<BilinearSamplePlan2D<T>>& plans) {
const int64_t point_count = H_out * W_out;

for (int64_t idx = 0; idx < point_count; ++idx) {
auto& plan = plans[onnxruntime::narrow<size_t>(idx)];
const T nx = grid_data[idx * 2];
const T ny = grid_data[idx * 2 + 1];
const T x = GsDenormalize<T>(nx, W_in, false);
const T y = GsDenormalize<T>(ny, H_in, false);

const int64_t x1 = static_cast<int64_t>(std::floor(x));
const int64_t y1 = static_cast<int64_t>(std::floor(y));
const int64_t x2 = x1 + 1;
const int64_t y2 = y1 + 1;

const T dx2 = static_cast<T>(x2) - x;
const T dx1 = x - static_cast<T>(x1);
const T dy2 = static_cast<T>(y2) - y;
const T dy1 = y - static_cast<T>(y1);

uint8_t mask = 0;
if (x1 >= 0 && x1 < W_in && y1 >= 0 && y1 < H_in) {
mask |= kTopLeftMask;
}
if (x2 >= 0 && x2 < W_in && y1 >= 0 && y1 < H_in) {
mask |= kTopRightMask;
}
if (x1 >= 0 && x1 < W_in && y2 >= 0 && y2 < H_in) {
mask |= kBottomLeftMask;
}
if (x2 >= 0 && x2 < W_in && y2 >= 0 && y2 < H_in) {
mask |= kBottomRightMask;
}

plan.x1 = x1;
plan.x2 = x2;
plan.y1 = y1;
plan.y2 = y2;
plan.w11 = dy2 * dx2;
plan.w12 = dy2 * dx1;
plan.w21 = dy1 * dx2;
plan.w22 = dy1 * dx1;
plan.mask = mask;
}
}

template <typename T>
void EvaluatePlanForChannel(const T* input_data,
T* output_data,
int64_t W_in,
const BilinearSamplePlan2D<T>* plan_data,
int64_t point_count) {
for (int64_t idx = 0; idx < point_count; ++idx) {
const auto& plan = plan_data[idx];
if (plan.mask == 0) {
output_data[idx] = T{};
continue;
}

#if defined(MLAS_NEON_INTRINSICS)
if constexpr (std::is_same_v<T, float>) {
if (plan.mask == kAllNeighborsMask) {
const float* row1_ptr = input_data + plan.y1 * W_in + plan.x1;
const float* row2_ptr = input_data + plan.y2 * W_in + plan.x1;

float32x2_t row1 = vld1_f32(row1_ptr); // [p11, p12]
float32x2_t row2 = vld1_f32(row2_ptr); // [p21, p22]
float32x4_t neighbors = vcombine_f32(row1, row2);

float32x2_t weights_row1 = vdup_n_f32(plan.w12);
weights_row1 = vset_lane_f32(plan.w11, weights_row1, 0);
float32x2_t weights_row2 = vdup_n_f32(plan.w22);
weights_row2 = vset_lane_f32(plan.w21, weights_row2, 0);
float32x4_t weights = vcombine_f32(weights_row1, weights_row2);

float32x4_t products = vmulq_f32(neighbors, weights);
float32x2_t sum_pairs = vadd_f32(vget_low_f32(products), vget_high_f32(products));
float32x2_t accum = vpadd_f32(sum_pairs, sum_pairs);
output_data[idx] = vget_lane_f32(accum, 0);
continue;
}
}
#endif

T p11 = T{};
T p12 = T{};
T p21 = T{};
T p22 = T{};

if (plan.mask == kAllNeighborsMask) {
const int64_t row1 = plan.y1 * W_in;
const int64_t row2 = plan.y2 * W_in;
p11 = input_data[row1 + plan.x1];
p12 = input_data[row1 + plan.x2];
p21 = input_data[row2 + plan.x1];
p22 = input_data[row2 + plan.x2];
} else {
if (plan.mask & kTopLeftMask) {
p11 = input_data[plan.y1 * W_in + plan.x1];
}
if (plan.mask & kTopRightMask) {
p12 = input_data[plan.y1 * W_in + plan.x2];
}
if (plan.mask & kBottomLeftMask) {
p21 = input_data[plan.y2 * W_in + plan.x1];
}
if (plan.mask & kBottomRightMask) {
p22 = input_data[plan.y2 * W_in + plan.x2];
}
}

output_data[idx] = plan.w11 * p11 + plan.w12 * p12 + plan.w21 * p21 + plan.w22 * p22;
}
}

template <typename T>
void TryRunBilinearZerosFastPath2D(const Tensor& input,
const Tensor& grid,
Tensor& output,
int64_t n,
int64_t C,
int64_t H_in,
int64_t W_in,
int64_t H_out,
int64_t W_out,
concurrency::ThreadPool* tp,
std::vector<BilinearSamplePlan2D<T>>& sampling_plan) {
const int64_t plane_in = H_in * W_in;
const int64_t plane_out = H_out * W_out;
sampling_plan.resize(onnxruntime::narrow<size_t>(plane_out));

const T* grid_data = grid.Data<T>() + n * plane_out * 2;
PrecomputeBilinearSamplePlan2D(grid_data, H_out, W_out, H_in, W_in, sampling_plan);

const T* input_data = input.Data<T>();
T* output_data = output.MutableData<T>();

concurrency::ThreadPool::TrySimpleParallelFor(
tp, onnxruntime::narrow<std::ptrdiff_t>(C),
[&](std::ptrdiff_t c) {
const T* X_data = input_data + (n * C + c) * plane_in;
T* Y_data = output_data + (n * C + c) * plane_out;
EvaluatePlanForChannel(X_data, Y_data, W_in, sampling_plan.data(), plane_out);
});
}

} // namespace

// When grid sampling, padding is applied before interpolation.
// For instance, in bilinear mode and zeros padding-mode, pixel p at actual
// image location (-0.5, -0.5)
Expand Down Expand Up @@ -210,13 +391,14 @@ Status GridSample<T>::Compute(OpKernelContext* context) const {
T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b

concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr;
for (int64_t n = 0; n < N; n++) {
const T* grid_data = grid->Data<T>() + n * (H_out * W_out) * 2;

const auto run_generic_path_for_n = [&](int64_t n_idx) {
const T* grid_data = grid->Data<T>() + n_idx * (H_out * W_out) * 2;
concurrency::ThreadPool::TrySimpleParallelFor(
tp, onnxruntime::narrow<std::ptrdiff_t>(C),
[&](std::ptrdiff_t c) {
const T* X_data = input->Data<T>() + (n * C + c) * (H_in * W_in);
T* Y_data = Y.MutableData<T>() + (n * C + c) * (H_out * W_out);
const T* X_data = input->Data<T>() + (n_idx * C + c) * (H_in * W_in);
T* Y_data = Y.MutableData<T>() + (n_idx * C + c) * (H_out * W_out);

for (int64_t oy = 0; oy < H_out; oy++) {
for (int64_t ox = 0; ox < W_out; ox++) {
Expand Down Expand Up @@ -265,6 +447,19 @@ Status GridSample<T>::Compute(OpKernelContext* context) const {
}
}
});
};

const bool can_use_fast_path = (mode_ == Linear && padding_mode_ == Zeros && !align_corners_);
for (int64_t n = 0; n < N; n++) {
if (can_use_fast_path) {
// Choose fast path when all 4 neighbors are within the image and use zero for out-of-boundary neighbors.
// This fast path can be 2-3x faster than the generic path with boundary check and supports Neon optimization.
// sampling_plan helps precomputing a separate plan entry per output pixel.
std::vector<BilinearSamplePlan2D<T>> sampling_plan;
TryRunBilinearZerosFastPath2D(*input, *grid, Y, n, C, H_in, W_in, H_out, W_out, tp, sampling_plan);
} else {
run_generic_path_for_n(n);
}
}
} else if (data_dims == 3) {
// sample 3d;
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,38 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corner
RunTests(test, GetExecutionProviders(20));
}

TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_linear_zeros_mixed_bounds) {
// Crafts grid points that mix fully in-bounds sampling with cases where either the right, bottom,
// or both neighbors fall outside the source image so zero padding must be applied. This ensures
// the optimized bilinear fast path matches the generic implementation for boundary handling.
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "zeros";
int64_t align_corners = 0;
std::initializer_list<int64_t> X_shape{1, 1, 2, 2};
std::initializer_list<TypeParam> X_data{TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f)};
std::initializer_list<int64_t> Grid_shape{1, 2, 2, 2};
// (nx, ny) pairs: center (in-bounds), right edge (x out), bottom edge (y out), corner (both out)
std::initializer_list<TypeParam> Grid_data{
TypeParam(0.0f), TypeParam(0.0f), // center (all neighbors in bounds)
TypeParam(0.9f), TypeParam(0.0f), // near right edge (right neighbors out of bounds)
TypeParam(0.0f), TypeParam(0.9f), // near bottom edge (bottom neighbors out)
TypeParam(0.9f), TypeParam(0.9f)}; // near bottom-right corner (both right and bottom neighbors out)
std::initializer_list<int64_t> Y_shape{1, 1, 2, 2};
std::initializer_list<TypeParam> Y_data{
TypeParam(2.5f), // all neighbors in bounds
TypeParam(1.8f), // right neighbors partially out-of-bounds
TypeParam(2.1f), // bottom neighbors partially out-of-bounds
TypeParam(1.44f)}; // both right and bottom neighbors out-of-bounds
test.AddInput<TypeParam>("X", X_shape, X_data);
test.AddInput<TypeParam>("Grid", Grid_shape, Grid_data);
test.AddAttribute("mode", mode);
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput<TypeParam>("Y", Y_shape, Y_data);
RunTests(test, GetExecutionProviders(20));
}

TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
Expand Down