Skip to content

Commit 2c5e9e6

Browse files
authored
improve batch_knn_search performance (#101)
1 parent ff63d5e commit 2c5e9e6

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/python/kdtree.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ void define_kdtree(py::module& m) {
129129
std::vector<size_t> k_indices(pts.rows(), -1);
130130
std::vector<double> k_sq_dists(pts.rows(), std::numeric_limits<double>::max());
131131

132-
#pragma omp parallel for num_threads(num_threads)
132+
#pragma omp parallel for num_threads(num_threads) schedule(guided, 4)
133133
for (int i = 0; i < pts.rows(); ++i) {
134134
const size_t found = traits::nearest_neighbor_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), &k_indices[i], &k_sq_dists[i]);
135135
if (!found) {
@@ -154,9 +154,9 @@ void define_kdtree(py::module& m) {
154154
155155
Returns
156156
-------
157-
k_indices : numpy.ndarray, shape (n,)
157+
k_indices : numpy.ndarray, shape (n, k)
158158
The indices of the nearest neighbors for each input point. If a neighbor was not found, the index is -1.
159-
k_sq_dists : numpy.ndarray, shape (n,)
159+
k_sq_dists : numpy.ndarray, shape (n, k)
160160
The squared distances to the nearest neighbors for each input point.
161161
)""")
162162

@@ -167,16 +167,21 @@ void define_kdtree(py::module& m) {
167167
throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)");
168168
}
169169

170-
std::vector<std::vector<size_t>> k_indices(pts.rows(), std::vector<size_t>(k, -1));
171-
std::vector<std::vector<double>> k_sq_dists(pts.rows(), std::vector<double>(k, std::numeric_limits<double>::max()));
170+
Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> k_indices(pts.rows(), k);
171+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> k_sq_dists(pts.rows(), k);
172+
k_indices.setConstant(-1);
173+
k_sq_dists.setConstant(std::numeric_limits<double>::max());
172174

173-
#pragma omp parallel for num_threads(num_threads)
175+
#pragma omp parallel for num_threads(num_threads) schedule(guided, 4)
174176
for (int i = 0; i < pts.rows(); ++i) {
175-
const size_t found = traits::knn_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), k, k_indices[i].data(), k_sq_dists[i].data());
177+
size_t* k_indices_begin = k_indices.data() + i * k;
178+
double* k_sq_dists_begin = k_sq_dists.data() + i * k;
179+
180+
const size_t found = traits::knn_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), k, k_indices_begin, k_sq_dists_begin);
176181
if (found < k) {
177182
for (size_t j = found; j < k; ++j) {
178-
k_indices[i][j] = -1;
179-
k_sq_dists[i][j] = std::numeric_limits<double>::max();
183+
k_indices_begin[j] = -1;
184+
k_sq_dists_begin[j] = std::numeric_limits<double>::max();
180185
}
181186
}
182187
}

0 commit comments

Comments
 (0)