Skip to content

Commit c151a35

Browse files
wuzheyi1028momo609
authored andcommitted
fix deformConv and modulatedDeformConv input kernel_size
1 parent 8552434 commit c151a35

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

mmcv/ops/deform_conv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _npu_backward(ctx, grad_output):
5757
grad_input, grad_weight, grad_offset_all, grad_bias = \
5858
torch_npu.npu_deformable_conv2dbk(
5959
input_tensor, grad_output, offset_out, weight, offset_all,
60-
kernel_size=[weight.shape[3], weight.shape[2]],
60+
kernel_size=[weight.shape[2], weight.shape[3]],
6161
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
6262
padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
6363
ctx.padding[1]],

mmcv/ops/modulated_deform_conv.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
5555
conv2d_bias = bias if len(bias) > 0 else None
5656
sort_index_fp, sort_index_bp = \
5757
ModulatedDeformConv2dFunction._calculate_sort_index(
58-
kernel_w, kernel_h, ctx.deform_groups)
58+
kernel_h, kernel_w, ctx.deform_groups)
5959
select_offset = offset.index_select(1, sort_index_fp)
6060
offset_all = torch.cat([select_offset, mask], dim=1)
6161
import torch_npu
@@ -64,7 +64,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
6464
weight,
6565
offset_all,
6666
conv2d_bias,
67-
kernel_size=[kernel_w, kernel_h],
67+
kernel_size=[kernel_h, kernel_w],
6868
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
6969
padding=[
7070
ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]
@@ -87,7 +87,7 @@ def _npu_backward(ctx, grad_output):
8787
grad_input, grad_weight, grad_offset_all, grad_bias = \
8888
torch_npu.npu_deformable_conv2dbk(
8989
input_tensor, grad_output, offset_out, weight, offset_all,
90-
kernel_size=[weight.shape[3], weight.shape[2]],
90+
kernel_size=[weight.shape[2], weight.shape[3]],
9191
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
9292
padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
9393
ctx.padding[1]],

0 commit comments

Comments
 (0)