diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e269ae69..64cc929c 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -14,6 +14,31 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; +struct Stats { + size_t nodes = 0; + size_t edges = 0; + size_t allocated_bytes = 0; + size_t max_edges = 0; +}; + +struct InternalParameters { + size_t max_elements = 0; + size_t num_deleted = 0; + size_t M = 0; + size_t maxM = 0; + size_t maxM0 = 0; + size_t ef_construction = 0; + size_t ef = 0; + double mult = 0; + size_t maxlevel = 0; + + size_t size_data_per_element = 0; + size_t size_links_per_element = 0; + size_t size_links_level0 = 0; + + size_t bytes_per_vector = 0; +}; + template class HierarchicalNSW : public AlgorithmInterface { public: @@ -51,6 +76,7 @@ class HierarchicalNSW : public AlgorithmInterface { char **linkLists_{nullptr}; std::vector element_levels_; // keeps level of each element + // Size of each vector in bytes. size_t data_size_{0}; DISTFUNC fstdistfunc_; @@ -92,7 +118,8 @@ class HierarchicalNSW : public AlgorithmInterface { size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100, - bool allow_replace_deleted = false) + bool allow_replace_deleted = false, + size_t ef = 10) : label_op_locks_(MAX_LABEL_OPERATION_LOCKS), link_list_locks_(max_elements), element_levels_(max_elements), @@ -112,7 +139,7 @@ class HierarchicalNSW : public AlgorithmInterface { maxM_ = M_; maxM0_ = M_ * 2; ef_construction_ = std::max(ef_construction, M_); - ef_ = 10; + ef_ = ef; level_generator_.seed(random_seed); update_probability_generator_.seed(random_seed + 1); @@ -322,7 +349,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if (bare_bone_search || + if (bare_bone_search || (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) { char* ep_data = getDataByInternalId(ep_id); dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_); @@ -403,7 +430,7 @@ class HierarchicalNSW : public AlgorithmInterface { _MM_HINT_T0); //////////////////////// #endif - if (bare_bone_search || + if (bare_bone_search || (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { top_candidates.emplace(dist, candidate_id); if (!bare_bone_search && stop_condition) { @@ -682,7 +709,7 @@ class HierarchicalNSW : public AlgorithmInterface { return size; } - void saveIndex(const std::string &location) { + void saveIndex(const std::string &location) override { std::ofstream output(location, std::ios::binary); std::streampos position; @@ -826,7 +853,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector getDataByLabel(labeltype label) const { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); - + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { @@ -888,7 +915,7 @@ class HierarchicalNSW : public AlgorithmInterface { /* * Removes the deleted mark of the node, does NOT really change the current graph. - * + * * Note: the method is not safe to use when replacement of deleted elements is enabled, * because elements marked as deleted can be completely removed by addPoint */ @@ -951,7 +978,7 @@ class HierarchicalNSW : public AlgorithmInterface { * Adds point. Updates the point if it is already in the index. * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point */ - void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) { + void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) override { if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); } @@ -1268,7 +1295,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue> - searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const override { std::priority_queue> result; if (cur_element_count == 0) return result; @@ -1377,6 +1404,27 @@ class HierarchicalNSW : public AlgorithmInterface { return result; } + int getMaxLevel() const override { + return maxlevel_; + } + + InternalParameters getInternalParameters() { + InternalParameters params; + params.max_elements = max_elements_; + params.size_data_per_element = size_data_per_element_; + params.size_links_per_element = size_links_per_element_; + params.num_deleted = num_deleted_; + params.M = M_; + params.maxM = maxM_; + params.maxM0 = maxM0_; + params.ef_construction = ef_construction_; + params.ef = ef_; + params.mult = mult_; + params.maxlevel = maxlevel_; + params.size_links_level0 = size_links_level0_; + params.bytes_per_vector = data_size_; + return params; + } void checkIntegrity() { int connections_checked = 0; @@ -1408,5 +1456,56 @@ class HierarchicalNSW : public AlgorithmInterface { } std::cout << "integrity ok, checked " << connections_checked << " connections\n"; } + + // Populate index statistics in the Stats array. + Stats getStats(Stats* stats_per_level, int max_level) const { + if (max_level < 0) { + return {}; + } + max_level = std::min(maxlevel_, max_level); + size_t node_head_bytes = size_links_level0_ + sizeof(labeltype); // Node header size + + // Iterate through all the elements + auto num_elements = cur_element_count.load(std::memory_order_seq_cst); + for (size_t i = 0; i != num_elements; ++i) { + tableint internal_id = static_cast(i); + auto node_level = element_levels_[internal_id]; + if (node_level < 0) { + // This should not happen in practice. + continue; + } + + // Base level (0) + stats_per_level[0].nodes++; + stats_per_level[0].edges += getListCount(get_linklist_at_level(internal_id, 0)); + stats_per_level[0].allocated_bytes += node_head_bytes; + + size_t max_level_for_node = std::min( + static_cast(node_level), static_cast(max_level)); + for (size_t l = 1; l <= max_level_for_node; ++l) { + stats_per_level[l].nodes++; + stats_per_level[l].edges += getListCount(get_linklist_at_level(internal_id, l)); + stats_per_level[l].allocated_bytes += size_links_per_element_; + } + } + + // Compute max_edges based on the node count at each level + stats_per_level[0].max_edges = stats_per_level[0].nodes * maxM0_; + for (int l = 1; l <= max_level; ++l) { + stats_per_level[l].max_edges = stats_per_level[l].nodes * maxM_; + } + + // Aggregate stats across all levels + Stats result{}; + for (auto l = 0; l <= max_level; ++l) { + result.nodes += stats_per_level[l].nodes; + result.edges += stats_per_level[l].edges; + result.allocated_bytes += stats_per_level[l].allocated_bytes; + result.max_edges += stats_per_level[l].max_edges; + } + + return result; + } + }; } // namespace hnswlib diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 7ccfbba5..9d9b7d32 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -196,6 +196,9 @@ class AlgorithmInterface { searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; virtual void saveIndex(const std::string &location) = 0; + + virtual int getMaxLevel() const = 0; + virtual ~AlgorithmInterface(){ } };