Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 209 additions & 8 deletions include/nanoflann.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,19 @@
#include <stdexcept>
#include <unordered_set>
#include <vector>
#include <chrono> // for std::chrono

/** Library version: 0xMmP (M=Major,m=minor,P=patch) */
#define NANOFLANN_VERSION 0x171

/** Returns the current time in seconds since epoch */
inline double getCurrentTime()
{
auto now = std::chrono::high_resolution_clock::now();
auto duration = now.time_since_epoch();
return std::chrono::duration_cast<std::chrono::duration<double>>(duration).count();
}

// Avoid conflicting declaration of min/max macros in Windows headers
#if !defined(NOMINMAX) && \
(defined(_WIN32) || defined(_WIN32_) || defined(WIN32) || defined(_WIN64))
Expand All @@ -77,8 +86,142 @@
#undef None
#endif

// Macros to control exception handling and timeout features
#ifndef NANOFLANN_DISABLE_EXCEPTIONS
#define NANOFLANN_ENABLE_EXCEPTIONS 1
#endif

#ifndef NANOFLANN_DISABLE_TIMEOUT
#define NANOFLANN_ENABLE_TIMEOUT 1
#endif

namespace nanoflann
{
/** @addtogroup exceptions_grp Exception classes
* @{ */

/** Base class for all nanoflann exceptions */
class Exception : public std::exception
{
protected:
std::string message_;

public:
explicit Exception(const std::string& msg) : message_(msg) {}

~Exception() noexcept override = default;

const char* what() const noexcept override
{
return message_.c_str();
}

const std::string& getMessage() const noexcept
{
return message_;
}
};

/** Exception thrown when a null pointer is encountered */
class NullPointerException : public Exception
{
public:
explicit NullPointerException(const std::string& msg) : Exception(msg) {}
};

/** Exception thrown when invalid data is encountered */
class InvalidDataException : public Exception
{
public:
explicit InvalidDataException(const std::string& msg) : Exception(msg) {}
};

/** Exception thrown when memory allocation fails */
class MemoryAllocationException : public Exception
{
public:
explicit MemoryAllocationException(const std::string& msg) : Exception(msg) {}
};

/** Exception thrown when an invalid parameter is provided */
class InvalidParameterException : public Exception
{
public:
explicit InvalidParameterException(const std::string& msg) : Exception(msg) {}
};

/** Exception thrown when a search operation times out */
class SearchTimeoutException : public Exception
{
private:
size_t searchedPoints_;
size_t foundResults_;

public:
SearchTimeoutException(const std::string& msg, size_t searchedPoints, size_t foundResults)
: Exception(msg), searchedPoints_(searchedPoints), foundResults_(foundResults) {}

size_t getSearchedPoints() const noexcept
{
return searchedPoints_;
}

size_t getFoundResults() const noexcept
{
return foundResults_;
}
};

/** Exception thrown when concurrent modification is detected */
class ConcurrentModificationException : public Exception
{
public:
explicit ConcurrentModificationException(const std::string& msg) : Exception(msg) {}
};

/** Exception thrown when invalid point data (NaN or infinity) is encountered */
class InvalidPointDataException : public Exception
{
public:
explicit InvalidPointDataException(const std::string& msg) : Exception(msg) {}
};

/** @} */

/** @addtogroup logging_grp Logging interface
* @{ */

/** Log level enum */
enum class LogLevel
{
DEBUG,
INFO,
WARNING,
ERROR
};

/** Log callback function type */
typedef std::function<void(LogLevel level, const std::string& message)> LogCallback;

/** Set a global log callback function */
inline void setLogCallback(const LogCallback& callback)
{
static LogCallback globalLogCallback;
globalLogCallback = callback;
}

/** Log a message using the global log callback */
inline void logMessage(LogLevel level, const std::string& message)
{
static LogCallback globalLogCallback;
if (globalLogCallback)
{
globalLogCallback(level, message);
}
}

/** @}

/** @addtogroup nanoflann_grp nanoflann C++ library for KD-trees
* @{ */

Expand Down Expand Up @@ -837,6 +980,18 @@ struct SearchParameters
bool sorted; //!< only for radius search, require neighbours sorted by
//!< distance (default: true)
};

/** Search options for KDTreeSingleIndexAdaptor::findNeighbors() with timeout support */
struct SearchParametersEx : public SearchParameters
{
SearchParametersEx(float eps_ = 0, bool sorted_ = true, double timeout_seconds_ = 0)
: SearchParameters(eps_, sorted_), timeout_seconds(timeout_seconds_), start_time(0)
{
}

double timeout_seconds; //!< maximum time allowed for search in seconds (0 = no timeout)
double start_time; //!< start time of the search (internal use)
};
/** @} */

/** @addtogroup memalloc_grp Memory allocation
Expand Down Expand Up @@ -1723,7 +1878,16 @@ class KDTreeSingleIndexAdaptor
auto zero = static_cast<typename RESULTSET::DistanceType>(0);
assign(dists, (DIM > 0 ? DIM : Base::dim_), zero);
DistanceType dist = this->computeInitialDistances(*this, vec, dists);
searchLevel(result, vec, Base::root_node_, dist, dists, epsError);

// Check if searchParams is actually a SearchParametersEx
const SearchParametersEx* searchParamsEx = dynamic_cast<const SearchParametersEx*>(&searchParams);
double startTime = 0;
if (searchParamsEx && searchParamsEx->timeout_seconds > 0)
{
startTime = getCurrentTime();
}

searchLevel(result, vec, Base::root_node_, dist, dists, epsError, startTime, searchParamsEx ? searchParamsEx->timeout_seconds : 0);

if (searchParams.sorted) result.sort();

Expand Down Expand Up @@ -1885,8 +2049,15 @@ class KDTreeSingleIndexAdaptor
bool searchLevel(
RESULTSET& result_set, const ElementType* vec, const NodePtr node,
DistanceType mindist, distance_vector_t& dists,
const float epsError) const
const float epsError, const double startTime = 0, const double timeoutSeconds = 0) const
{
// Check for timeout
if (timeoutSeconds > 0 && getCurrentTime() - startTime > timeoutSeconds)
{
// Timeout reached, stop searching
return false;
}

// If this is a leaf node, then do check and return.
// If they are equal, both pointers are nullptr.
if (node->child1 == node->child2)
Expand All @@ -1895,6 +2066,13 @@ class KDTreeSingleIndexAdaptor
for (Offset i = node->node_type.lr.left;
i < node->node_type.lr.right; ++i)
{
// Check for timeout in loop
if (timeoutSeconds > 0 && getCurrentTime() - startTime > timeoutSeconds)
{
// Timeout reached, stop searching
return false;
}

const IndexType accessor = Base::vAcc_[i]; // reorder... : i;
DistanceType dist = distance_.evalMetric(
vec, accessor, (DIM > 0 ? DIM : Base::dim_));
Expand Down Expand Up @@ -1936,7 +2114,7 @@ class KDTreeSingleIndexAdaptor
}

/* Call recursively to search next level down. */
if (!searchLevel(result_set, vec, bestChild, mindist, dists, epsError))
if (!searchLevel(result_set, vec, bestChild, mindist, dists, epsError, startTime, timeoutSeconds))
{
// the resultset doesn't want to receive any more points, we're done
// searching!
Expand All @@ -1949,7 +2127,7 @@ class KDTreeSingleIndexAdaptor
if (mindist * epsError <= result_set.worstDist())
{
if (!searchLevel(
result_set, vec, otherChild, mindist, dists, epsError))
result_set, vec, otherChild, mindist, dists, epsError, startTime, timeoutSeconds))
{
// the resultset doesn't want to receive any more points, we're
// done searching!
Expand Down Expand Up @@ -2194,7 +2372,16 @@ class KDTreeSingleIndexDynamicAdaptor_
dists, (DIM > 0 ? DIM : Base::dim_),
static_cast<typename distance_vector_t::value_type>(0));
DistanceType dist = this->computeInitialDistances(*this, vec, dists);
searchLevel(result, vec, Base::root_node_, dist, dists, epsError);

// Check if searchParams is actually a SearchParametersEx
const SearchParametersEx* searchParamsEx = dynamic_cast<const SearchParametersEx*>(&searchParams);
double startTime = 0;
if (searchParamsEx && searchParamsEx->timeout_seconds > 0)
{
startTime = getCurrentTime();
}

searchLevel(result, vec, Base::root_node_, dist, dists, epsError, startTime, searchParamsEx ? searchParamsEx->timeout_seconds : 0);
return result.full();
}

Expand Down Expand Up @@ -2313,8 +2500,15 @@ class KDTreeSingleIndexDynamicAdaptor_
void searchLevel(
RESULTSET& result_set, const ElementType* vec, const NodePtr node,
DistanceType mindist, distance_vector_t& dists,
const float epsError) const
const float epsError, const double startTime = 0, const double timeoutSeconds = 0) const
{
// Check for timeout
if (timeoutSeconds > 0 && getCurrentTime() - startTime > timeoutSeconds)
{
// Timeout reached, stop searching
return;
}

// If this is a leaf node, then do check and return.
// If they are equal, both pointers are nullptr.
if (node->child1 == node->child2)
Expand All @@ -2323,6 +2517,13 @@ class KDTreeSingleIndexDynamicAdaptor_
for (Offset i = node->node_type.lr.left;
i < node->node_type.lr.right; ++i)
{
// Check for timeout in loop
if (timeoutSeconds > 0 && getCurrentTime() - startTime > timeoutSeconds)
{
// Timeout reached, stop searching
return;
}

const IndexType index = Base::vAcc_[i]; // reorder... : i;
if (treeIndex_[index] == -1) continue;
DistanceType dist = distance_.evalMetric(
Expand Down Expand Up @@ -2368,14 +2569,14 @@ class KDTreeSingleIndexDynamicAdaptor_
}

/* Call recursively to search next level down. */
searchLevel(result_set, vec, bestChild, mindist, dists, epsError);
searchLevel(result_set, vec, bestChild, mindist, dists, epsError, startTime, timeoutSeconds);

DistanceType dst = dists[idx];
mindist = mindist + cut_dist - dst;
dists[idx] = cut_dist;
if (mindist * epsError <= result_set.worstDist())
{
searchLevel(result_set, vec, otherChild, mindist, dists, epsError);
searchLevel(result_set, vec, otherChild, mindist, dists, epsError, startTime, timeoutSeconds);
}
dists[idx] = dst;
}
Expand Down