@@ -129,7 +129,7 @@ void define_kdtree(py::module& m) {
129129 std::vector<size_t > k_indices (pts.rows (), -1 );
130130 std::vector<double > k_sq_dists (pts.rows (), std::numeric_limits<double >::max ());
131131
132- #pragma omp parallel for num_threads(num_threads)
132+ #pragma omp parallel for num_threads(num_threads) schedule(guided, 4)
133133 for (int i = 0 ; i < pts.rows (); ++i) {
134134 const size_t found = traits::nearest_neighbor_search (kdtree, Eigen::Vector4d (pts (i, 0 ), pts (i, 1 ), pts (i, 2 ), 1.0 ), &k_indices[i], &k_sq_dists[i]);
135135 if (!found) {
@@ -154,9 +154,9 @@ void define_kdtree(py::module& m) {
154154
155155 Returns
156156 -------
157- k_indices : numpy.ndarray, shape (n,)
157+ k_indices : numpy.ndarray, shape (n, k )
158158 The indices of the nearest neighbors for each input point. If a neighbor was not found, the index is -1.
159- k_sq_dists : numpy.ndarray, shape (n,)
159+ k_sq_dists : numpy.ndarray, shape (n, k )
160160 The squared distances to the nearest neighbors for each input point.
161161 )""" )
162162
@@ -167,16 +167,21 @@ void define_kdtree(py::module& m) {
167167 throw std::invalid_argument (" pts must have shape (n, 3) or (n, 4)" );
168168 }
169169
170- std::vector<std::vector<size_t >> k_indices (pts.rows (), std::vector<size_t >(k, -1 ));
171- std::vector<std::vector<double >> k_sq_dists (pts.rows (), std::vector<double >(k, std::numeric_limits<double >::max ()));
170+ Eigen::Matrix<size_t , Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> k_indices (pts.rows (), k);
171+ Eigen::Matrix<double , Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> k_sq_dists (pts.rows (), k);
172+ k_indices.setConstant (-1 );
173+ k_sq_dists.setConstant (std::numeric_limits<double >::max ());
172174
173- #pragma omp parallel for num_threads(num_threads)
175+ #pragma omp parallel for num_threads(num_threads) schedule(guided, 4)
174176 for (int i = 0 ; i < pts.rows (); ++i) {
175- const size_t found = traits::knn_search (kdtree, Eigen::Vector4d (pts (i, 0 ), pts (i, 1 ), pts (i, 2 ), 1.0 ), k, k_indices[i].data (), k_sq_dists[i].data ());
177+ size_t * k_indices_begin = k_indices.data () + i * k;
178+ double * k_sq_dists_begin = k_sq_dists.data () + i * k;
179+
180+ const size_t found = traits::knn_search (kdtree, Eigen::Vector4d (pts (i, 0 ), pts (i, 1 ), pts (i, 2 ), 1.0 ), k, k_indices_begin, k_sq_dists_begin);
176181 if (found < k) {
177182 for (size_t j = found; j < k; ++j) {
178- k_indices[i] [j] = -1 ;
179- k_sq_dists[i] [j] = std::numeric_limits<double >::max ();
183+ k_indices_begin [j] = -1 ;
184+ k_sq_dists_begin [j] = std::numeric_limits<double >::max ();
180185 }
181186 }
182187 }
0 commit comments