@@ -55,7 +55,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
55
55
conv2d_bias = bias if len (bias ) > 0 else None
56
56
sort_index_fp , sort_index_bp = \
57
57
ModulatedDeformConv2dFunction ._calculate_sort_index (
58
- kernel_w , kernel_h , ctx .deform_groups )
58
+ kernel_h , kernel_w , ctx .deform_groups )
59
59
select_offset = offset .index_select (1 , sort_index_fp )
60
60
offset_all = torch .cat ([select_offset , mask ], dim = 1 )
61
61
import torch_npu
@@ -64,7 +64,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
64
64
weight ,
65
65
offset_all ,
66
66
conv2d_bias ,
67
- kernel_size = [kernel_w , kernel_h ],
67
+ kernel_size = [kernel_h , kernel_w ],
68
68
stride = [1 , 1 , ctx .stride [0 ], ctx .stride [1 ]],
69
69
padding = [
70
70
ctx .padding [0 ], ctx .padding [0 ], ctx .padding [1 ], ctx .padding [1 ]
@@ -87,7 +87,7 @@ def _npu_backward(ctx, grad_output):
87
87
grad_input , grad_weight , grad_offset_all , grad_bias = \
88
88
torch_npu .npu_deformable_conv2dbk (
89
89
input_tensor , grad_output , offset_out , weight , offset_all ,
90
- kernel_size = [weight .shape [3 ], weight .shape [2 ]],
90
+ kernel_size = [weight .shape [2 ], weight .shape [3 ]],
91
91
stride = [1 , 1 , ctx .stride [0 ], ctx .stride [1 ]],
92
92
padding = [ctx .padding [0 ], ctx .padding [0 ], ctx .padding [1 ],
93
93
ctx .padding [1 ]],
0 commit comments