|
1 | 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
2 | 2 | // Licensed under the MIT License. |
3 | 3 |
|
4 | | -#include "core/providers/cpu/tensor/grid_sample.h" |
| 4 | +#include <type_traits> |
| 5 | +#include <vector> |
| 6 | + |
| 7 | +#if defined(MLAS_NEON_INTRINSICS) |
| 8 | +#include <arm_neon.h> |
| 9 | +#endif |
5 | 10 |
|
| 11 | +#include "core/providers/cpu/tensor/grid_sample.h" |
6 | 12 | #include "core/framework/element_type_lists.h" |
7 | 13 | #include "core/framework/TensorSeq.h" |
8 | 14 | #include "core/providers/common.h" |
@@ -148,6 +154,181 @@ T GridSample<T>::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, |
148 | 154 | return pixel; |
149 | 155 | } |
150 | 156 |
|
| 157 | +namespace { |
| 158 | + |
| 159 | +constexpr uint8_t kTopLeftMask = 1u << 0; |
| 160 | +constexpr uint8_t kTopRightMask = 1u << 1; |
| 161 | +constexpr uint8_t kBottomLeftMask = 1u << 2; |
| 162 | +constexpr uint8_t kBottomRightMask = 1u << 3; |
| 163 | +constexpr uint8_t kAllNeighborsMask = kTopLeftMask | kTopRightMask | kBottomLeftMask | kBottomRightMask; |
| 164 | + |
| 165 | +template <typename T> |
| 166 | +struct BilinearSamplePlan2D { |
| 167 | + int64_t x1; |
| 168 | + int64_t x2; |
| 169 | + int64_t y1; |
| 170 | + int64_t y2; |
| 171 | + T w11; |
| 172 | + T w12; |
| 173 | + T w21; |
| 174 | + T w22; |
| 175 | + uint8_t mask = 0; |
| 176 | +}; |
| 177 | +// PrecomputeBilinearSamplePlan2D, the loop runs across all H_out * W_out points, using the right nx/ny for each (oy, ox) and storing that point’s four indices, four weights, and mask in plans[idx] |
| 178 | +template <typename T> |
| 179 | +void PrecomputeBilinearSamplePlan2D(const T* grid_data, |
| 180 | + int64_t H_out, |
| 181 | + int64_t W_out, |
| 182 | + int64_t H_in, |
| 183 | + int64_t W_in, |
| 184 | + std::vector<BilinearSamplePlan2D<T>>& plans) { |
| 185 | + const int64_t point_count = H_out * W_out; |
| 186 | + |
| 187 | + for (int64_t idx = 0; idx < point_count; ++idx) { |
| 188 | + auto& plan = plans[idx]; |
| 189 | + const T nx = grid_data[idx * 2]; |
| 190 | + const T ny = grid_data[idx * 2 + 1]; |
| 191 | + const T x = GsDenormalize<T>(nx, W_in, false); |
| 192 | + const T y = GsDenormalize<T>(ny, H_in, false); |
| 193 | + |
| 194 | + const int64_t x1 = static_cast<int64_t>(std::floor(x)); |
| 195 | + const int64_t y1 = static_cast<int64_t>(std::floor(y)); |
| 196 | + const int64_t x2 = x1 + 1; |
| 197 | + const int64_t y2 = y1 + 1; |
| 198 | + |
| 199 | + const T dx2 = static_cast<T>(x2) - x; |
| 200 | + const T dx1 = x - static_cast<T>(x1); |
| 201 | + const T dy2 = static_cast<T>(y2) - y; |
| 202 | + const T dy1 = y - static_cast<T>(y1); |
| 203 | + |
| 204 | + uint8_t mask = 0; |
| 205 | + if (x1 >= 0 && x1 < W_in && y1 >= 0 && y1 < H_in) { |
| 206 | + mask |= kTopLeftMask; |
| 207 | + } |
| 208 | + if (x2 >= 0 && x2 < W_in && y1 >= 0 && y1 < H_in) { |
| 209 | + mask |= kTopRightMask; |
| 210 | + } |
| 211 | + if (x1 >= 0 && x1 < W_in && y2 >= 0 && y2 < H_in) { |
| 212 | + mask |= kBottomLeftMask; |
| 213 | + } |
| 214 | + if (x2 >= 0 && x2 < W_in && y2 >= 0 && y2 < H_in) { |
| 215 | + mask |= kBottomRightMask; |
| 216 | + } |
| 217 | + |
| 218 | + plan.x1 = x1; |
| 219 | + plan.x2 = x2; |
| 220 | + plan.y1 = y1; |
| 221 | + plan.y2 = y2; |
| 222 | + plan.w11 = dy2 * dx2; |
| 223 | + plan.w12 = dy2 * dx1; |
| 224 | + plan.w21 = dy1 * dx2; |
| 225 | + plan.w22 = dy1 * dx1; |
| 226 | + plan.mask = mask; |
| 227 | + } |
| 228 | +} |
| 229 | + |
| 230 | +template <typename T> |
| 231 | +void EvaluatePlanForChannel(const T* input_data, |
| 232 | + T* output_data, |
| 233 | + int64_t W_in, |
| 234 | + const BilinearSamplePlan2D<T>* plan_data, |
| 235 | + int64_t point_count) { |
| 236 | + for (int64_t idx = 0; idx < point_count; ++idx) { |
| 237 | + const auto& plan = plan_data[idx]; |
| 238 | + if (plan.mask == 0) { |
| 239 | + output_data[idx] = T{}; |
| 240 | + continue; |
| 241 | + } |
| 242 | + |
| 243 | +#if defined(MLAS_NEON_INTRINSICS) |
| 244 | + if constexpr (std::is_same_v<T, float>) { |
| 245 | + if (plan.mask == kAllNeighborsMask) { |
| 246 | + const float* row1_ptr = input_data + plan.y1 * W_in + plan.x1; |
| 247 | + const float* row2_ptr = input_data + plan.y2 * W_in + plan.x1; |
| 248 | + |
| 249 | + float32x2_t row1 = vld1_f32(row1_ptr); // [p11, p12] |
| 250 | + float32x2_t row2 = vld1_f32(row2_ptr); // [p21, p22] |
| 251 | + float32x4_t neighbors = vcombine_f32(row1, row2); |
| 252 | + |
| 253 | + float32x2_t weights_row1 = vdup_n_f32(plan.w12); |
| 254 | + weights_row1 = vset_lane_f32(plan.w11, weights_row1, 0); |
| 255 | + float32x2_t weights_row2 = vdup_n_f32(plan.w22); |
| 256 | + weights_row2 = vset_lane_f32(plan.w21, weights_row2, 0); |
| 257 | + float32x4_t weights = vcombine_f32(weights_row1, weights_row2); |
| 258 | + |
| 259 | + float32x4_t products = vmulq_f32(neighbors, weights); |
| 260 | + float32x2_t sum_pairs = vadd_f32(vget_low_f32(products), vget_high_f32(products)); |
| 261 | + float32x2_t accum = vpadd_f32(sum_pairs, sum_pairs); |
| 262 | + output_data[idx] = vget_lane_f32(accum, 0); |
| 263 | + continue; |
| 264 | + } |
| 265 | + } |
| 266 | +#endif |
| 267 | + |
| 268 | + T p11 = T{}; |
| 269 | + T p12 = T{}; |
| 270 | + T p21 = T{}; |
| 271 | + T p22 = T{}; |
| 272 | + |
| 273 | + if (plan.mask == kAllNeighborsMask) { |
| 274 | + const int64_t row1 = plan.y1 * W_in; |
| 275 | + const int64_t row2 = plan.y2 * W_in; |
| 276 | + p11 = input_data[row1 + plan.x1]; |
| 277 | + p12 = input_data[row1 + plan.x2]; |
| 278 | + p21 = input_data[row2 + plan.x1]; |
| 279 | + p22 = input_data[row2 + plan.x2]; |
| 280 | + } else { |
| 281 | + if (plan.mask & kTopLeftMask) { |
| 282 | + p11 = input_data[plan.y1 * W_in + plan.x1]; |
| 283 | + } |
| 284 | + if (plan.mask & kTopRightMask) { |
| 285 | + p12 = input_data[plan.y1 * W_in + plan.x2]; |
| 286 | + } |
| 287 | + if (plan.mask & kBottomLeftMask) { |
| 288 | + p21 = input_data[plan.y2 * W_in + plan.x1]; |
| 289 | + } |
| 290 | + if (plan.mask & kBottomRightMask) { |
| 291 | + p22 = input_data[plan.y2 * W_in + plan.x2]; |
| 292 | + } |
| 293 | + } |
| 294 | + |
| 295 | + output_data[idx] = plan.w11 * p11 + plan.w12 * p12 + plan.w21 * p21 + plan.w22 * p22; |
| 296 | + } |
| 297 | +} |
| 298 | + |
| 299 | +template <typename T> |
| 300 | +void TryRunBilinearZerosFastPath2D(const Tensor& input, |
| 301 | + const Tensor& grid, |
| 302 | + Tensor& output, |
| 303 | + int64_t n, |
| 304 | + int64_t C, |
| 305 | + int64_t H_in, |
| 306 | + int64_t W_in, |
| 307 | + int64_t H_out, |
| 308 | + int64_t W_out, |
| 309 | + concurrency::ThreadPool* tp, |
| 310 | + std::vector<BilinearSamplePlan2D<T>>& sampling_plan) { |
| 311 | + const int64_t plane_in = H_in * W_in; |
| 312 | + const int64_t plane_out = H_out * W_out; |
| 313 | + sampling_plan.resize(plane_out); |
| 314 | + |
| 315 | + const T* grid_data = grid.Data<T>() + n * plane_out * 2; |
| 316 | + PrecomputeBilinearSamplePlan2D(grid_data, H_out, W_out, H_in, W_in, sampling_plan); |
| 317 | + |
| 318 | + const T* input_data = input.Data<T>(); |
| 319 | + T* output_data = output.MutableData<T>(); |
| 320 | + |
| 321 | + concurrency::ThreadPool::TrySimpleParallelFor( |
| 322 | + tp, onnxruntime::narrow<std::ptrdiff_t>(C), |
| 323 | + [&](std::ptrdiff_t c) { |
| 324 | + const T* X_data = input_data + (n * C + c) * plane_in; |
| 325 | + T* Y_data = output_data + (n * C + c) * plane_out; |
| 326 | + EvaluatePlanForChannel(X_data, Y_data, W_in, sampling_plan.data(), plane_out); |
| 327 | + }); |
| 328 | +} |
| 329 | + |
| 330 | +} // namespace |
| 331 | + |
151 | 332 | // When grid sampling, padding is applied before interpolation. |
152 | 333 | // For instance, in bilinear mode and zeros padding-mode, pixel p at actual |
153 | 334 | // image location (-0.5, -0.5) |
@@ -210,13 +391,14 @@ Status GridSample<T>::Compute(OpKernelContext* context) const { |
210 | 391 | T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b |
211 | 392 |
|
212 | 393 | concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; |
213 | | - for (int64_t n = 0; n < N; n++) { |
214 | | - const T* grid_data = grid->Data<T>() + n * (H_out * W_out) * 2; |
| 394 | + |
| 395 | + const auto run_generic_path_for_n = [&](int64_t n_idx) { |
| 396 | + const T* grid_data = grid->Data<T>() + n_idx * (H_out * W_out) * 2; |
215 | 397 | concurrency::ThreadPool::TrySimpleParallelFor( |
216 | 398 | tp, onnxruntime::narrow<std::ptrdiff_t>(C), |
217 | 399 | [&](std::ptrdiff_t c) { |
218 | | - const T* X_data = input->Data<T>() + (n * C + c) * (H_in * W_in); |
219 | | - T* Y_data = Y.MutableData<T>() + (n * C + c) * (H_out * W_out); |
| 400 | + const T* X_data = input->Data<T>() + (n_idx * C + c) * (H_in * W_in); |
| 401 | + T* Y_data = Y.MutableData<T>() + (n_idx * C + c) * (H_out * W_out); |
220 | 402 |
|
221 | 403 | for (int64_t oy = 0; oy < H_out; oy++) { |
222 | 404 | for (int64_t ox = 0; ox < W_out; ox++) { |
@@ -265,6 +447,19 @@ Status GridSample<T>::Compute(OpKernelContext* context) const { |
265 | 447 | } |
266 | 448 | } |
267 | 449 | }); |
| 450 | + }; |
| 451 | + |
| 452 | + const bool can_use_fast_path = (mode_ == Linear && padding_mode_ == Zeros && !align_corners_); |
| 453 | + for (int64_t n = 0; n < N; n++) { |
| 454 | + if (can_use_fast_path) { |
| 455 | + // Choose fast path when all 4 neighbors are within the image and use zero for out-of-boundary neighbors. |
| 456 | + // This fast path can be 2-3x faster than the generic path with boundary check and supports Neon optimization. |
| 457 | + // sampling_plan helps precomputing a separate plan entry per output pixel. |
| 458 | + std::vector<BilinearSamplePlan2D<T>> sampling_plan; |
| 459 | + TryRunBilinearZerosFastPath2D(*input, *grid, Y, n, C, H_in, W_in, H_out, W_out, tp, sampling_plan); |
| 460 | + } else { |
| 461 | + run_generic_path_for_n(n); |
| 462 | + } |
268 | 463 | } |
269 | 464 | } else if (data_dims == 3) { |
270 | 465 | // sample 3d; |
|
0 commit comments