Skip to content

Commit b35f773

Browse files
committed
fix roi_pool bug.
1 parent 069c88a commit b35f773

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax,
7070
.Attr("spatial_scale_w", spatial_scale)
7171
.Attr("pool_channel", pooled_channel)
7272
.Run();
73-
at::Tensor res = y.contiguous();
73+
at::Tensor result = y.transpose(2, 3).transpose(1, 2);
74+
at::Tensor res = result.contiguous();
7475
grad_input.copy_(res);
7576
}
7677

0 commit comments

Comments
 (0)