Skip to content

Commit c5a339f

Browse files
authored
Merge pull request #272 from algol83/feature/bounding-box
feat: implement findWithinBox
2 parents 81cd02b + fb90f3f commit c5a339f

File tree

2 files changed

+157
-10
lines changed

2 files changed

+157
-10
lines changed

include/nanoflann.hpp

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#include <istream>
5757
#include <limits> // std::numeric_limits
5858
#include <ostream>
59+
#include <stack>
5960
#include <stdexcept>
6061
#include <unordered_set>
6162
#include <vector>
@@ -933,10 +934,7 @@ class PooledAllocator
933934

934935
// use the standard C malloc to allocate memory
935936
void* m = ::malloc(blocksize);
936-
if (!m)
937-
{
938-
throw std::bad_alloc();
939-
}
937+
if (!m) { throw std::bad_alloc(); }
940938

941939
/* Fill first word of new block with pointer to previous block. */
942940
static_cast<void**>(m)[0] = base_;
@@ -1144,7 +1142,8 @@ class KDTreeBaseClass
11441142
*
11451143
* @param left index of the first vector
11461144
* @param right index of the last vector
1147-
* @param bbox bounding box used as input for splitting and output for parent node
1145+
* @param bbox bounding box used as input for splitting and output for
1146+
* parent node
11481147
*/
11491148
NodePtr divideTree(
11501149
Derived& obj, const Offset left, const Offset right, BoundingBox& bbox)
@@ -1217,7 +1216,8 @@ class KDTreeBaseClass
12171216
*
12181217
* @param left index of the first vector
12191218
* @param right index of the last vector
1220-
* @param bbox bounding box used as input for splitting and output for parent node
1219+
* @param bbox bounding box used as input for splitting and output for
1220+
* parent node
12211221
* @param thread_count count of std::async threads
12221222
* @param mutex mutex for mempool allocation
12231223
*/
@@ -1287,7 +1287,7 @@ class KDTreeBaseClass
12871287
BoundingBox left_bbox(bbox);
12881288
left_bbox[cutfeat].high = cutval;
12891289
node->child1 = this->divideTreeConcurrent(
1290-
obj, left, left + idx, left_bbox, thread_count, mutex);
1290+
obj, left, left + idx, left_bbox, thread_count, mutex);
12911291

12921292
if (right_future.valid())
12931293
{
@@ -1730,6 +1730,70 @@ class KDTreeSingleIndexAdaptor
17301730
return result.full();
17311731
}
17321732

1733+
/**
1734+
* Find all points contained within the specified bounding box. Their
1735+
* indices are stored inside the result object.
1736+
*
1737+
* Params:
1738+
* result = the result object in which the indices of the points
1739+
* within the bounding box are stored
1740+
* bbox = the bounding box defining the search region
1741+
*
1742+
* \tparam RESULTSET Should be any ResultSet<DistanceType>
1743+
* \return Number of points found within the bounding box.
1744+
* \sa findNeighbors, knnSearch, radiusSearch
1745+
*
1746+
* \note The search is inclusive - points on the boundary are included.
1747+
*/
1748+
template <typename RESULTSET>
1749+
Size findWithinBox(RESULTSET& result, const BoundingBox& bbox) const
1750+
{
1751+
if (this->size(*this) == 0) return 0;
1752+
if (!Base::root_node_)
1753+
throw std::runtime_error(
1754+
"[nanoflann] findWithinBox() called before building the "
1755+
"index.");
1756+
1757+
std::stack<NodePtr> stack;
1758+
stack.push(Base::root_node_);
1759+
1760+
while (!stack.empty())
1761+
{
1762+
const NodePtr node = stack.top();
1763+
stack.pop();
1764+
1765+
// If this is a leaf node, then do check and return.
1766+
// If they are equal, both pointers are nullptr.
1767+
if (node->child1 == node->child2)
1768+
{
1769+
for (Offset i = node->node_type.lr.left;
1770+
i < node->node_type.lr.right; ++i)
1771+
{
1772+
if (contains(bbox, Base::vAcc_[i]))
1773+
{
1774+
if (!result.addPoint(0, Base::vAcc_[i]))
1775+
{
1776+
// the resultset doesn't want to receive any more
1777+
// points, we're done searching!
1778+
return result.size();
1779+
}
1780+
}
1781+
}
1782+
}
1783+
else
1784+
{
1785+
const int idx = node->node_type.sub.divfeat;
1786+
const auto low_bound = node->node_type.sub.divlow;
1787+
const auto high_bound = node->node_type.sub.divhigh;
1788+
1789+
if (bbox[idx].low <= low_bound) stack.push(node->child1);
1790+
if (bbox[idx].high >= high_bound) stack.push(node->child2);
1791+
}
1792+
}
1793+
1794+
return result.size();
1795+
}
1796+
17331797
/**
17341798
* Find the "num_closest" nearest neighbors to the \a query_point[0:dim-1].
17351799
* Their indices and distances are stored in the provided pointers to
@@ -1831,8 +1895,8 @@ class KDTreeSingleIndexAdaptor
18311895
/** @} */
18321896

18331897
public:
1834-
/** Make sure the auxiliary list \a vind has the same size than the current
1835-
* dataset, and re-generate if size has changed. */
1898+
/** Make sure the auxiliary list \a vind has the same size than the
1899+
* current dataset, and re-generate if size has changed. */
18361900
void init_vind()
18371901
{
18381902
// Create a permutable array of indices to the input vectors.
@@ -1875,6 +1939,17 @@ class KDTreeSingleIndexAdaptor
18751939
}
18761940
}
18771941

1942+
bool contains(const BoundingBox& bbox, IndexType idx) const
1943+
{
1944+
const auto dims = (DIM > 0 ? DIM : Base::dim_);
1945+
for (Dimension i = 0; i < dims; ++i)
1946+
{
1947+
const auto point = this->dataset_.kdtree_get_pt(idx, i);
1948+
if (point < bbox[i].low || point > bbox[i].high) return false;
1949+
}
1950+
return true;
1951+
}
1952+
18781953
/**
18791954
* Performs an exact search in the tree starting from a node.
18801955
* \tparam RESULTSET Should be any ResultSet<DistanceType>
@@ -1897,7 +1972,7 @@ class KDTreeSingleIndexAdaptor
18971972
{
18981973
const IndexType accessor = Base::vAcc_[i]; // reorder... : i;
18991974
DistanceType dist = distance_.evalMetric(
1900-
vec, accessor, (DIM > 0 ? DIM : Base::dim_));
1975+
vec, accessor, (DIM > 0 ? DIM : Base::dim_));
19011976
if (dist < worst_dist)
19021977
{
19031978
if (!result_set.addPoint(dist, Base::vAcc_[i]))

tests/test_main.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,61 @@ void L2_vs_bruteforce_test(
224224
}
225225
}
226226

227+
template <typename NUM>
228+
void box_L2_vs_bruteforce_test(const size_t nSamples, const size_t DIM)
229+
{
230+
std::vector<std::vector<NUM>> samples;
231+
232+
const NUM max_range = NUM(20.0);
233+
234+
// Generate points:
235+
generateRandomPointCloud(samples, nSamples, DIM, max_range);
236+
237+
typedef KDTreeVectorOfVectorsAdaptor<std::vector<std::vector<NUM>>, NUM>
238+
my_kd_tree_t;
239+
240+
// Query box:
241+
typename my_kd_tree_t::index_t::BoundingBox query_box(DIM);
242+
for (size_t d = 0; d < DIM; d++)
243+
{
244+
query_box[d].low =
245+
static_cast<NUM>(max_range * (rand() % 1000) / (1000.0));
246+
query_box[d].high =
247+
static_cast<NUM>(max_range * (rand() % 1000) / (1000.0));
248+
if (query_box[d].low > query_box[d].high)
249+
std::swap(query_box[d].low, query_box[d].high);
250+
}
251+
252+
// construct a kd-tree index:
253+
// Dimensionality set at run-time (default: L2)
254+
// ------------------------------------------------------------
255+
my_kd_tree_t mat_index(DIM /*dim*/, samples, 10 /* max leaf */);
256+
257+
// do a knn search
258+
std::vector<size_t> ret_indexes(nSamples);
259+
std::vector<NUM> out_dists_sqr(nSamples);
260+
261+
nanoflann::KNNResultSet<NUM> resultSet(nSamples);
262+
263+
resultSet.init(&ret_indexes[0], &out_dists_sqr[0]);
264+
const auto nFound = mat_index.index->findWithinBox(resultSet, query_box);
265+
266+
// Brute force:
267+
std::set<size_t /*idx*/> bf_nn;
268+
for (size_t i = 0; i < nSamples; i++)
269+
{
270+
if (mat_index.index->contains(query_box, i)) bf_nn.insert(i);
271+
}
272+
273+
// Compare:
274+
EXPECT_EQ(bf_nn.size(), nFound);
275+
276+
for (size_t i = 0; i < nFound; ++i)
277+
{
278+
EXPECT_TRUE(bf_nn.find(ret_indexes[i]) != bf_nn.end());
279+
}
280+
}
281+
227282
template <typename NUM>
228283
void rknn_L2_vs_bruteforce_test(
229284
const size_t nSamples, const size_t DIM, const size_t numToSearch,
@@ -666,6 +721,23 @@ TEST(kdtree, L2_vs_bruteforce)
666721
}
667722
}
668723

724+
TEST(kdtree, box_L2_vs_bruteforce)
725+
{
726+
srand(static_cast<unsigned int>(time(nullptr)));
727+
for (int i = 0; i < 500; i++)
728+
{
729+
box_L2_vs_bruteforce_test<float>(10, 2);
730+
731+
box_L2_vs_bruteforce_test<float>(100, 2);
732+
box_L2_vs_bruteforce_test<float>(100, 3);
733+
box_L2_vs_bruteforce_test<float>(100, 7);
734+
735+
box_L2_vs_bruteforce_test<double>(100, 2);
736+
box_L2_vs_bruteforce_test<double>(100, 3);
737+
box_L2_vs_bruteforce_test<double>(100, 7);
738+
}
739+
}
740+
669741
TEST(kdtree, L2_vs_bruteforce_rknn)
670742
{
671743
srand(static_cast<unsigned int>(time(nullptr)));

0 commit comments

Comments
 (0)