|
| 1 | +#include "pytorch_npu_helper.hpp" |
| 2 | + |
| 3 | +using namespace NPU_NAME_SPACE; |
| 4 | +using namespace std; |
| 5 | + |
| 6 | +void border_align_forward_impl(const Tensor &input, const Tensor &boxes, |
| 7 | + Tensor output, Tensor argmax_idx, |
| 8 | + const int pool_size); |
| 9 | + |
| 10 | +void border_align_forward_npu(const Tensor &input, const Tensor &boxes, |
| 11 | + Tensor output, Tensor argmax_idx, |
| 12 | + const int pool_size) { |
| 13 | + TORCH_CHECK(input.size(0) == boxes.size(0), |
| 14 | + "The batch sizes of feature map and rois must be the same."); |
| 15 | + TORCH_CHECK(input.size(1) % 4 == 0, |
| 16 | + "The number of channels must be divisible by 4."); |
| 17 | + TORCH_CHECK(pool_size >= 2, "The pool size should be larger than 2."); |
| 18 | + int32_t batch_size = input.size(0); |
| 19 | + int32_t channels = input.size(1); |
| 20 | + int32_t height = input.size(2); |
| 21 | + int32_t width = input.size(3); |
| 22 | + at::Tensor feature_map = input.permute({0, 2, 3, 1}).contiguous(); |
| 23 | + at::Tensor rois_map = boxes.contiguous(); |
| 24 | + at::Tensor temp_tensor = at::zeros( |
| 25 | + {batch_size, height * width, pool_size + 1, channels}, input.options()); |
| 26 | + EXEC_NPU_CMD(aclnnBorderAlign, feature_map, rois_map, pool_size, temp_tensor); |
| 27 | + auto max_result = temp_tensor.max(-2); |
| 28 | + at::Tensor output_ = std::get<0>(max_result).to(at::kFloat); |
| 29 | + output_ = output_.reshape({batch_size, height * width, 4, channels / 4}) |
| 30 | + .permute({0, 3, 1, 2}) |
| 31 | + .contiguous(); |
| 32 | + output.copy_(output_); |
| 33 | + at::Tensor argmax_idx_ = std::get<1>(max_result).to(at::kInt); |
| 34 | + argmax_idx_ = |
| 35 | + argmax_idx_.reshape({batch_size, height * width, 4, channels / 4}) |
| 36 | + .permute({0, 3, 1, 2}) |
| 37 | + .contiguous(); |
| 38 | + argmax_idx.copy_(argmax_idx_); |
| 39 | +} |
| 40 | +REGISTER_NPU_IMPL(border_align_forward_impl, border_align_forward_npu); |
| 41 | + |
| 42 | + |
| 43 | +void border_align_backward_impl(const Tensor &grad_output, const Tensor &boxes, |
| 44 | + const Tensor &argmax_idx, Tensor grad_input, |
| 45 | + const int pool_size); |
| 46 | + |
| 47 | +void border_align_backward_npu(const Tensor &grad_output, const Tensor &boxes, |
| 48 | + const Tensor &argmax_idx, Tensor grad_input, |
| 49 | + const int pool_size) { |
| 50 | + TORCH_CHECK(grad_output.dim() == 4, |
| 51 | + "grad_out.dim() must be 4, but got: ", grad_output.dim()); |
| 52 | + TORCH_CHECK(boxes.dim() == 3, "idx.dim() must be 3, but got: ", boxes.dim()); |
| 53 | + TORCH_CHECK(argmax_idx.dim() == 4, |
| 54 | + "argmax_idx.dim() must be 4, but got: ", argmax_idx.dim()); |
| 55 | + |
| 56 | + int32_t batch_size = grad_output.size(0); |
| 57 | + int32_t feat_channels = grad_output.size(1) * 4; |
| 58 | + int32_t channels = grad_output.size(1); |
| 59 | + int32_t box_size = boxes.size(1); |
| 60 | + int32_t height = grad_input.size(2); |
| 61 | + int32_t width = grad_input.size(3); |
| 62 | + |
| 63 | + EXEC_NPU_CMD(aclnnBorderAlignGrad, grad_output, boxes, argmax_idx, channels, |
| 64 | + box_size, height, width, pool_size, batch_size, grad_input); |
| 65 | +} |
| 66 | +REGISTER_NPU_IMPL(border_align_backward_impl, border_align_backward_npu); |
0 commit comments