Skip to content

Commit 2ec5da3

Browse files
authored
Merge pull request #22 from Binary2355/rc4main
npu knn/tnn bugfix
2 parents c34f6d3 + 61e41b3 commit 2ec5da3

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

mmcv/ops/csrc/pytorch/npu/knn_npu.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include "pytorch_npu_helper.hpp"
2+
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
3+
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
4+
5+
using namespace NPU_NAME_SPACE;
6+
using namespace std;
7+
8+
void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
9+
const Tensor new_xyz, Tensor idx, Tensor dist2) {
10+
// transpose known from [B, N, 3] to [B, 3, N]
11+
at::Tensor source = xyz.transpose(2, 1).contiguous();
12+
at::Tensor target = new_xyz.contiguous();
13+
14+
bool is_from_knn = true;
15+
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2);
16+
}
17+
18+
void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
19+
const Tensor new_xyz, Tensor idx, Tensor dist2);
20+
21+
REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu);
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "pytorch_npu_helper.hpp"
2+
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
3+
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
4+
5+
using namespace NPU_NAME_SPACE;
6+
using namespace std;
7+
8+
void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
9+
const Tensor known, Tensor dist2, Tensor idx) {
10+
at::Tensor source = known.contiguous();
11+
at::Tensor target = unknown.contiguous();
12+
13+
bool is_from_knn = false;
14+
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2);
15+
}
16+
17+
void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
18+
const Tensor known, Tensor dist2, Tensor idx);
19+
20+
REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu);

mmcv/ops/knn.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,23 @@ def forward(ctx,
5555
center_xyz_device = center_xyz.get_device()
5656
assert center_xyz_device == xyz.get_device(), \
5757
'center_xyz and xyz should be put on the same device'
58-
if torch.cuda.current_device() != center_xyz_device:
59-
torch.cuda.set_device(center_xyz_device)
58+
if xyz.device.type != 'npu':
59+
if torch.cuda.current_device() != center_xyz_device:
60+
torch.cuda.set_device(center_xyz_device)
6061

6162
B, npoint, _ = center_xyz.shape
6263
N = xyz.shape[1]
6364

65+
if xyz.device.type == 'npu':
66+
dist = center_xyz.new_zeros((B, npoint, N)).float()
67+
ext_module.knn_forward(
68+
xyz, center_xyz, torch.Tensor([]).npu(), dist, b=B, n=N, m=npoint, nsample=k)
69+
dist2, idx = torch.topk(dist, k, dim=2, largest=False, sorted=True)
70+
zeros_idx = torch.zeros(xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu()
71+
idx.where(dist2 >= 1e10, zeros_idx)
72+
idx = idx.transpose(2, 1).contiguous() # [B, k, npoint]
73+
return idx.type(torch.IntTensor)
74+
6475
idx = center_xyz.new_zeros((B, npoint, k)).int()
6576
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
6677

mmcv/ops/three_nn.py

+14
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ def forward(ctx: Any, target: torch.Tensor,
3434

3535
B, N, _ = target.size()
3636
m = source.size(1)
37+
if source.device.type == 'npu':
38+
# strict to fp32
39+
source = source.transpose(2, 1).contiguous()
40+
dtype_ = source.dtype
41+
if dtype_ == torch.float16:
42+
target = target.float()
43+
source = source.float()
44+
dist = target.new_empty(B, N, m)
45+
ext_module.three_nn_forward(target, source, dist, torch.Tensor([]).npu(), b=B, n=N, m=m)
46+
dist2, idx = torch.topk(dist, 3, dim=2, largest=False, sorted=True)
47+
dist2 = torch.sqrt(dist2)
48+
if dtype_ == torch.float16:
49+
dist2 = dist2.half()
50+
return dist2, idx.type(torch.IntTensor)
3751
dist2 = target.new_empty(B, N, 3)
3852
idx = target.new_empty(B, N, 3, dtype=torch.int32)
3953

0 commit comments

Comments
 (0)