-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathNmsOps.cpp
More file actions
63 lines (54 loc) · 2.35 KB
/
NmsOps.cpp
File metadata and controls
63 lines (54 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// ----------------------------------------------------------------------------
// - Open3D: www.open3d.org -
// ----------------------------------------------------------------------------
// Copyright (c) 2018-2024 www.open3d.org
// SPDX-License-Identifier: MIT
// ----------------------------------------------------------------------------
#include <vector>
#include "open3d/ml/contrib/Nms.h"
#include "open3d/ml/paddle/PaddleHelper.h"
#include "paddle/extension.h"
std::vector<paddle::Tensor> Nms(paddle::Tensor& boxes,
paddle::Tensor& scores,
double nms_overlap_thresh) {
CHECK_TYPE(boxes, phi::DataType::FLOAT32);
CHECK_TYPE(scores, phi::DataType::FLOAT32);
std::vector<int64_t> keep_indices_blob;
if (boxes.is_gpu() || boxes.is_custom_device()) {
#ifdef BUILD_CUDA_MODULE
keep_indices_blob = open3d::ml::contrib::NmsCUDAKernel(
boxes.data<float>(), scores.data<float>(), boxes.shape()[0],
nms_overlap_thresh);
#else
PD_CHECK(false, "Nms was not compiled with CUDA support");
#endif
} else {
keep_indices_blob = open3d::ml::contrib::NmsCPUKernel(
boxes.data<float>(), scores.data<float>(), boxes.shape()[0],
nms_overlap_thresh);
}
paddle::IntArray out_shape(
{static_cast<int64_t>(keep_indices_blob.size())});
paddle::IntArray out_strides({1});
// NOTE: Not pass deleter because data will be free as vector destroy.
if (keep_indices_blob.data()) {
paddle::Tensor temp_keep_indices = paddle::from_blob(
keep_indices_blob.data(), out_shape, out_strides,
phi::DataType::INT64, phi::DataLayout::NCHW, phi::CPUPlace());
paddle::Tensor keep_indices =
temp_keep_indices.copy_to(boxes.place(), false);
return {keep_indices};
} else {
// keep indices is nullptr
return {InitializedEmptyTensor<int64_t>({0}, boxes.place())};
}
}
std::vector<paddle::DataType> NmsInferDtype() {
return {paddle::DataType::INT64};
}
PD_BUILD_OP(open3d_nms)
.Inputs({"boxes", "scores"})
.Outputs({"keep_indices"})
.Attrs({"nms_overlap_thresh: double"})
.SetKernelFn(PD_KERNEL(Nms))
.SetInferDtypeFn(PD_INFER_DTYPE(NmsInferDtype));