-
Notifications
You must be signed in to change notification settings - Fork 152
/
Copy pathknn_cpu.cpp
116 lines (93 loc) · 3.68 KB
/
knn_cpu.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include "knn_cpu.h"
#include "utils.h"
#include "utils/KDTreeVectorOfVectorsAdaptor.h"
#include "utils/nanoflann.hpp"
using torch::indexing::Slice;
using torch::indexing::None;
std::tuple<torch::Tensor, torch::Tensor>
knn_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
int64_t num_workers) {
CHECK_CPU(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CPU(y);
CHECK_INPUT(y.dim() == 2);
if (ptr_x.has_value()) {
CHECK_CPU(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
}
if (ptr_y.has_value()) {
CHECK_CPU(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
}
std::vector<size_t> out_vec = std::vector<size_t>();
torch::Tensor out_vec_dist_sqr = torch::empty({y.size(0) * k}, y.options());
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
// See: nanoflann/examples/vector_of_vectors_example.cpp
auto x_data = x.data_ptr<scalar_t>();
auto y_data = y.data_ptr<scalar_t>();
auto out_vec_dist_sqr_data = out_vec_dist_sqr.data_ptr<scalar_t>();
typedef std::vector<std::vector<scalar_t>> vec_t;
if (!ptr_x.has_value()) { // Single example.
vec_t pts(x.size(0));
for (int64_t i = 0; i < x.size(0); i++) {
pts[i].resize(x.size(1));
for (int64_t j = 0; j < x.size(1); j++) {
pts[i][j] = x_data[i * x.size(1) + j];
}
}
typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;
my_kd_tree_t mat_index(x.size(1), pts, 10);
mat_index.index->buildIndex();
std::vector<size_t> ret_index(k);
std::vector<scalar_t> out_dist_sqr(k);
for (int64_t i = 0; i < y.size(0); i++) {
size_t num_matches = mat_index.index->knnSearch(
y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]);
for (size_t j = 0; j < num_matches; j++) {
out_vec_dist_sqr_data[out_vec.size() / 2] = out_dist_sqr[j];
out_vec.push_back(ret_index[j]);
out_vec.push_back(i);
}
}
} else { // Batch-wise.
auto ptr_x_data = ptr_x.value().data_ptr<int64_t>();
auto ptr_y_data = ptr_y.value().data_ptr<int64_t>();
for (int64_t b = 0; b < ptr_x.value().size(0) - 1; b++) {
auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1];
auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1];
if (x_start == x_end || y_start == y_end)
continue;
vec_t pts(x_end - x_start);
for (int64_t i = 0; i < x_end - x_start; i++) {
pts[i].resize(x.size(1));
for (int64_t j = 0; j < x.size(1); j++) {
pts[i][j] = x_data[(i + x_start) * x.size(1) + j];
}
}
typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;
my_kd_tree_t mat_index(x.size(1), pts, 10);
mat_index.index->buildIndex();
std::vector<size_t> ret_index(k);
std::vector<scalar_t> out_dist_sqr(k);
for (int64_t i = y_start; i < y_end; i++) {
size_t num_matches = mat_index.index->knnSearch(
y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]);
for (size_t j = 0; j < num_matches; j++) {
out_vec_dist_sqr_data[out_vec.size() / 2] = out_dist_sqr[j];
out_vec.push_back(x_start + ret_index[j]);
out_vec.push_back(i);
}
}
}
}
});
const int64_t size = out_vec.size() / 2;
auto out = torch::from_blob(out_vec.data(), {size, 2},
x.options().dtype(torch::kLong));
return std::make_tuple(
out.t().index_select(0, torch::tensor({1, 0})),
out_vec_dist_sqr.index({Slice(None, size)})
);
}