Skip to content

Commit 5189ddd

Browse files
committed
fix three_interplote bug.
1 parent c34f6d3 commit 5189ddd

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,21 @@ void three_interpolate_forward_npu(int b, int c, int m, int n,
1212
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
1313
"three_interpolate_forward ascend only support fp32 and fp16.");
1414

15-
auto point_c_trans = points.transpose(1, 2);
16-
15+
auto point_c_trans = points.transpose(1, 2).to(at::kFloat);
16+
auto weight_cast = weight.to(at::kFloat);
17+
auto out_cast = out.to(at::kFloat);
1718
OpCommand cmd;
1819
cmd.Name("ThreeInterpolate")
1920
.Input(point_c_trans)
2021
.Input(idx)
21-
.Input(weight)
22-
.Output(out)
22+
.Input(weight_cast)
23+
.Output(out_cast)
2324
.Run();
2425

25-
auto output = out.view({b, n, c}).transpose(1, 2);
26+
if (originDtype == at::kHalf) {
27+
out_cast = out_cast.to(at::kHalf);
28+
}
29+
auto output = out_cast.view({b, n, c}).transpose(1, 2);
2630
auto res = output.contiguous();
2731
out.copy_(res);
2832
}

0 commit comments

Comments
 (0)