Skip to content

Commit f9a45ad

Browse files
committed
GridSample operator performance improvement on bilinear interpolation mode
Signed-off-by: melkap01 <melike.kaptan@arm.com>
1 parent 0df5dbc commit f9a45ad

File tree

2 files changed

+232
-5
lines changed

2 files changed

+232
-5
lines changed

onnxruntime/core/providers/cpu/tensor/grid_sample.cc

Lines changed: 200 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#include "core/providers/cpu/tensor/grid_sample.h"
4+
#include <type_traits>
5+
#include <vector>
6+
7+
#if defined(MLAS_NEON_INTRINSICS)
8+
#include <arm_neon.h>
9+
#endif
510

11+
#include "core/providers/cpu/tensor/grid_sample.h"
612
#include "core/framework/element_type_lists.h"
713
#include "core/framework/TensorSeq.h"
814
#include "core/providers/common.h"
@@ -148,6 +154,181 @@ T GridSample<T>::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w,
148154
return pixel;
149155
}
150156

157+
namespace {
158+
159+
constexpr uint8_t kTopLeftMask = 1u << 0;
160+
constexpr uint8_t kTopRightMask = 1u << 1;
161+
constexpr uint8_t kBottomLeftMask = 1u << 2;
162+
constexpr uint8_t kBottomRightMask = 1u << 3;
163+
constexpr uint8_t kAllNeighborsMask = kTopLeftMask | kTopRightMask | kBottomLeftMask | kBottomRightMask;
164+
165+
template <typename T>
166+
struct BilinearSamplePlan2D {
167+
int64_t x1;
168+
int64_t x2;
169+
int64_t y1;
170+
int64_t y2;
171+
T w11;
172+
T w12;
173+
T w21;
174+
T w22;
175+
uint8_t mask = 0;
176+
};
177+
// PrecomputeBilinearSamplePlan2D, the loop runs across all H_out * W_out points, using the right nx/ny for each (oy, ox) and storing that point’s four indices, four weights, and mask in plans[idx]
178+
template <typename T>
179+
void PrecomputeBilinearSamplePlan2D(const T* grid_data,
180+
int64_t H_out,
181+
int64_t W_out,
182+
int64_t H_in,
183+
int64_t W_in,
184+
std::vector<BilinearSamplePlan2D<T>>& plans) {
185+
const int64_t point_count = H_out * W_out;
186+
187+
for (int64_t idx = 0; idx < point_count; ++idx) {
188+
auto& plan = plans[idx];
189+
const T nx = grid_data[idx * 2];
190+
const T ny = grid_data[idx * 2 + 1];
191+
const T x = GsDenormalize<T>(nx, W_in, false);
192+
const T y = GsDenormalize<T>(ny, H_in, false);
193+
194+
const int64_t x1 = static_cast<int64_t>(std::floor(x));
195+
const int64_t y1 = static_cast<int64_t>(std::floor(y));
196+
const int64_t x2 = x1 + 1;
197+
const int64_t y2 = y1 + 1;
198+
199+
const T dx2 = static_cast<T>(x2) - x;
200+
const T dx1 = x - static_cast<T>(x1);
201+
const T dy2 = static_cast<T>(y2) - y;
202+
const T dy1 = y - static_cast<T>(y1);
203+
204+
uint8_t mask = 0;
205+
if (x1 >= 0 && x1 < W_in && y1 >= 0 && y1 < H_in) {
206+
mask |= kTopLeftMask;
207+
}
208+
if (x2 >= 0 && x2 < W_in && y1 >= 0 && y1 < H_in) {
209+
mask |= kTopRightMask;
210+
}
211+
if (x1 >= 0 && x1 < W_in && y2 >= 0 && y2 < H_in) {
212+
mask |= kBottomLeftMask;
213+
}
214+
if (x2 >= 0 && x2 < W_in && y2 >= 0 && y2 < H_in) {
215+
mask |= kBottomRightMask;
216+
}
217+
218+
plan.x1 = x1;
219+
plan.x2 = x2;
220+
plan.y1 = y1;
221+
plan.y2 = y2;
222+
plan.w11 = dy2 * dx2;
223+
plan.w12 = dy2 * dx1;
224+
plan.w21 = dy1 * dx2;
225+
plan.w22 = dy1 * dx1;
226+
plan.mask = mask;
227+
}
228+
}
229+
230+
template <typename T>
231+
void EvaluatePlanForChannel(const T* input_data,
232+
T* output_data,
233+
int64_t W_in,
234+
const BilinearSamplePlan2D<T>* plan_data,
235+
int64_t point_count) {
236+
for (int64_t idx = 0; idx < point_count; ++idx) {
237+
const auto& plan = plan_data[idx];
238+
if (plan.mask == 0) {
239+
output_data[idx] = T{};
240+
continue;
241+
}
242+
243+
#if defined(MLAS_NEON_INTRINSICS)
244+
if constexpr (std::is_same_v<T, float>) {
245+
if (plan.mask == kAllNeighborsMask) {
246+
const float* row1_ptr = input_data + plan.y1 * W_in + plan.x1;
247+
const float* row2_ptr = input_data + plan.y2 * W_in + plan.x1;
248+
249+
float32x2_t row1 = vld1_f32(row1_ptr); // [p11, p12]
250+
float32x2_t row2 = vld1_f32(row2_ptr); // [p21, p22]
251+
float32x4_t neighbors = vcombine_f32(row1, row2);
252+
253+
float32x2_t weights_row1 = vdup_n_f32(plan.w12);
254+
weights_row1 = vset_lane_f32(plan.w11, weights_row1, 0);
255+
float32x2_t weights_row2 = vdup_n_f32(plan.w22);
256+
weights_row2 = vset_lane_f32(plan.w21, weights_row2, 0);
257+
float32x4_t weights = vcombine_f32(weights_row1, weights_row2);
258+
259+
float32x4_t products = vmulq_f32(neighbors, weights);
260+
float32x2_t sum_pairs = vadd_f32(vget_low_f32(products), vget_high_f32(products));
261+
float32x2_t accum = vpadd_f32(sum_pairs, sum_pairs);
262+
output_data[idx] = vget_lane_f32(accum, 0);
263+
continue;
264+
}
265+
}
266+
#endif
267+
268+
T p11 = T{};
269+
T p12 = T{};
270+
T p21 = T{};
271+
T p22 = T{};
272+
273+
if (plan.mask == kAllNeighborsMask) {
274+
const int64_t row1 = plan.y1 * W_in;
275+
const int64_t row2 = plan.y2 * W_in;
276+
p11 = input_data[row1 + plan.x1];
277+
p12 = input_data[row1 + plan.x2];
278+
p21 = input_data[row2 + plan.x1];
279+
p22 = input_data[row2 + plan.x2];
280+
} else {
281+
if (plan.mask & kTopLeftMask) {
282+
p11 = input_data[plan.y1 * W_in + plan.x1];
283+
}
284+
if (plan.mask & kTopRightMask) {
285+
p12 = input_data[plan.y1 * W_in + plan.x2];
286+
}
287+
if (plan.mask & kBottomLeftMask) {
288+
p21 = input_data[plan.y2 * W_in + plan.x1];
289+
}
290+
if (plan.mask & kBottomRightMask) {
291+
p22 = input_data[plan.y2 * W_in + plan.x2];
292+
}
293+
}
294+
295+
output_data[idx] = plan.w11 * p11 + plan.w12 * p12 + plan.w21 * p21 + plan.w22 * p22;
296+
}
297+
}
298+
299+
template <typename T>
300+
void TryRunBilinearZerosFastPath2D(const Tensor& input,
301+
const Tensor& grid,
302+
Tensor& output,
303+
int64_t n,
304+
int64_t C,
305+
int64_t H_in,
306+
int64_t W_in,
307+
int64_t H_out,
308+
int64_t W_out,
309+
concurrency::ThreadPool* tp,
310+
std::vector<BilinearSamplePlan2D<T>>& sampling_plan) {
311+
const int64_t plane_in = H_in * W_in;
312+
const int64_t plane_out = H_out * W_out;
313+
sampling_plan.resize(plane_out);
314+
315+
const T* grid_data = grid.Data<T>() + n * plane_out * 2;
316+
PrecomputeBilinearSamplePlan2D(grid_data, H_out, W_out, H_in, W_in, sampling_plan);
317+
318+
const T* input_data = input.Data<T>();
319+
T* output_data = output.MutableData<T>();
320+
321+
concurrency::ThreadPool::TrySimpleParallelFor(
322+
tp, onnxruntime::narrow<std::ptrdiff_t>(C),
323+
[&](std::ptrdiff_t c) {
324+
const T* X_data = input_data + (n * C + c) * plane_in;
325+
T* Y_data = output_data + (n * C + c) * plane_out;
326+
EvaluatePlanForChannel(X_data, Y_data, W_in, sampling_plan.data(), plane_out);
327+
});
328+
}
329+
330+
} // namespace
331+
151332
// When grid sampling, padding is applied before interpolation.
152333
// For instance, in bilinear mode and zeros padding-mode, pixel p at actual
153334
// image location (-0.5, -0.5)
@@ -210,13 +391,14 @@ Status GridSample<T>::Compute(OpKernelContext* context) const {
210391
T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b
211392

212393
concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr;
213-
for (int64_t n = 0; n < N; n++) {
214-
const T* grid_data = grid->Data<T>() + n * (H_out * W_out) * 2;
394+
395+
const auto run_generic_path_for_n = [&](int64_t n_idx) {
396+
const T* grid_data = grid->Data<T>() + n_idx * (H_out * W_out) * 2;
215397
concurrency::ThreadPool::TrySimpleParallelFor(
216398
tp, onnxruntime::narrow<std::ptrdiff_t>(C),
217399
[&](std::ptrdiff_t c) {
218-
const T* X_data = input->Data<T>() + (n * C + c) * (H_in * W_in);
219-
T* Y_data = Y.MutableData<T>() + (n * C + c) * (H_out * W_out);
400+
const T* X_data = input->Data<T>() + (n_idx * C + c) * (H_in * W_in);
401+
T* Y_data = Y.MutableData<T>() + (n_idx * C + c) * (H_out * W_out);
220402

221403
for (int64_t oy = 0; oy < H_out; oy++) {
222404
for (int64_t ox = 0; ox < W_out; ox++) {
@@ -265,6 +447,19 @@ Status GridSample<T>::Compute(OpKernelContext* context) const {
265447
}
266448
}
267449
});
450+
};
451+
452+
const bool can_use_fast_path = (mode_ == Linear && padding_mode_ == Zeros && !align_corners_);
453+
for (int64_t n = 0; n < N; n++) {
454+
if (can_use_fast_path) {
455+
// Choose fast path when all 4 neighbors are within the image and use zero for out-of-boundary neighbors.
456+
// This fast path can be 2-3x faster than the generic path with boundary check and supports Neon optimization.
457+
// sampling_plan helps precomputing a separate plan entry per output pixel.
458+
std::vector<BilinearSamplePlan2D<T>> sampling_plan;
459+
TryRunBilinearZerosFastPath2D(*input, *grid, Y, n, C, H_in, W_in, H_out, W_out, tp, sampling_plan);
460+
} else {
461+
run_generic_path_for_n(n);
462+
}
268463
}
269464
} else if (data_dims == 3) {
270465
// sample 3d;

onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,38 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corner
727727
RunTests(test, GetExecutionProviders(20));
728728
}
729729

730+
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_linear_zeros_mixed_bounds) {
731+
// Crafts grid points that mix fully in-bounds sampling with cases where either the right, bottom,
732+
// or both neighbors fall outside the source image so zero padding must be applied. This ensures
733+
// the optimized bilinear fast path matches the generic implementation for boundary handling.
734+
OpTester test("GridSample", 20);
735+
std::string mode = "linear";
736+
std::string padding_mode = "zeros";
737+
int64_t align_corners = 0;
738+
std::initializer_list<int64_t> X_shape{1, 1, 2, 2};
739+
std::initializer_list<TypeParam> X_data{TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f)};
740+
std::initializer_list<int64_t> Grid_shape{1, 2, 2, 2};
741+
// (nx, ny) pairs: center (in-bounds), right edge (x out), bottom edge (y out), corner (both out)
742+
std::initializer_list<TypeParam> Grid_data{
743+
TypeParam(0.0f), TypeParam(0.0f), // center (all neighbors in bounds)
744+
TypeParam(0.9f), TypeParam(0.0f), // near right edge (right neighbors out of bounds)
745+
TypeParam(0.0f), TypeParam(0.9f), // near bottom edge (bottom neighbors out)
746+
TypeParam(0.9f), TypeParam(0.9f)}; // near bottom-right corner (both right and bottom neighbors out)
747+
std::initializer_list<int64_t> Y_shape{1, 1, 2, 2};
748+
std::initializer_list<TypeParam> Y_data{
749+
TypeParam(2.5f), // all neighbors in bounds
750+
TypeParam(1.8f), // right neighbors partially out-of-bounds
751+
TypeParam(2.1f), // bottom neighbors partially out-of-bounds
752+
TypeParam(1.44f)}; // both right and bottom neighbors out-of-bounds
753+
test.AddInput<TypeParam>("X", X_shape, X_data);
754+
test.AddInput<TypeParam>("Grid", Grid_shape, Grid_data);
755+
test.AddAttribute("mode", mode);
756+
test.AddAttribute("padding_mode", padding_mode);
757+
test.AddAttribute("align_corners", align_corners);
758+
test.AddOutput<TypeParam>("Y", Y_shape, Y_data);
759+
RunTests(test, GetExecutionProviders(20));
760+
}
761+
730762
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
731763
OpTester test("GridSample", 20);
732764
std::string mode = "linear";

0 commit comments

Comments
 (0)