diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e269ae69..0cfd18bd 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -16,6 +16,16 @@ typedef unsigned int linklistsizeint; template class HierarchicalNSW : public AlgorithmInterface { + +private: + std::vector internal_id_to_doc_id_; +private: + int getMetadata(tableint internal_id) const { + if (internal_id >= internal_id_to_doc_id_.size()) { + throw std::runtime_error("Internal ID out of range in getMetadata"); + } + return internal_id_to_doc_id_[internal_id]; +} public: static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; static const unsigned char DELETE_MARK = 0x01; @@ -141,6 +151,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); mult_ = 1 / log(1.0 * M_); revSize_ = 1.0 / mult_; + internal_id_to_doc_id_.reserve(max_elements_); } @@ -332,6 +343,7 @@ class HierarchicalNSW : public AlgorithmInterface { stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist); } candidate_set.emplace(-dist, ep_id); + vl->mark_seen_doc(getMetadata(ep_id)); // Mark initial document ID } else { lowerBound = std::numeric_limits::max(); candidate_set.emplace(-lowerBound, ep_id); @@ -361,75 +373,75 @@ class HierarchicalNSW : public AlgorithmInterface { tableint current_node_id = current_node_pair.second; int *data = (int *) get_linklist0(current_node_id); size_t size = getListCount((linklistsizeint*)data); -// bool cur_node_deleted = isMarkedDeleted(current_node_id); if (collect_metrics) { metric_hops++; - metric_distance_computations+=size; + metric_distance_computations += size; } -#ifdef USE_SSE + #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); _mm_prefetch((char *) (data + 2), _MM_HINT_T0); -#endif + #endif for (size_t j = 1; j <= size; j++) { int candidate_id = *(data + j); -// if (candidate_id == 0) continue; -#ifdef USE_SSE + #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, - _MM_HINT_T0); //////////// -#endif + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + #endif if (!(visited_array[candidate_id] == visited_array_tag)) { visited_array[candidate_id] = visited_array_tag; - char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - - bool flag_consider_candidate; - if (!bare_bone_search && stop_condition) { - flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound); - } else { - flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist; - } - - if (flag_consider_candidate) { - candidate_set.emplace(-dist, candidate_id); -#ifdef USE_SSE - _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + - offsetLevel0_, /////////// - _MM_HINT_T0); //////////////////////// -#endif - - if (bare_bone_search || - (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { - top_candidates.emplace(dist, candidate_id); - if (!bare_bone_search && stop_condition) { - stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist); - } - } + int doc_id = getMetadata(candidate_id); + if (!vl->is_doc_seen(doc_id)) { // Filter duplicates based on document ID + char *currObj1 = getDataByInternalId(candidate_id); + dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - bool flag_remove_extra = false; + bool flag_consider_candidate; if (!bare_bone_search && stop_condition) { - flag_remove_extra = stop_condition->should_remove_extra(); + flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound); } else { - flag_remove_extra = top_candidates.size() > ef; + flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist; } - while (flag_remove_extra) { - tableint id = top_candidates.top().second; - top_candidates.pop(); + + if (flag_consider_candidate) { + candidate_set.emplace(-dist, candidate_id); + #ifdef USE_SSE + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_, _MM_HINT_T0); + #endif + + if (bare_bone_search || + (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { + top_candidates.emplace(dist, candidate_id); + vl->mark_seen_doc(doc_id); // Mark document ID as seen + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist); + } + } + + bool flag_remove_extra = false; if (!bare_bone_search && stop_condition) { - stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist); flag_remove_extra = stop_condition->should_remove_extra(); } else { flag_remove_extra = top_candidates.size() > ef; } - } + while (flag_remove_extra) { + tableint id = top_candidates.top().second; + top_candidates.pop(); + if (!bare_bone_search && stop_condition) { + stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist); + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + } - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } } } } @@ -956,15 +968,16 @@ class HierarchicalNSW : public AlgorithmInterface { throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); } - // lock all operations with element by label - std::unique_lock lock_label(getLabelOpMutex(label)); + // Lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); if (!replace_deleted) { - addPoint(data_point, label, -1); + addPoint(data_point, label, -1); // Call lower-level addPoint return; } - // check if there is vacant place + + // Check if there is a vacant place tableint internal_id_replaced; - std::unique_lock lock_deleted_elements(deleted_elements_lock); + std::unique_lock lock_deleted_elements(deleted_elements_lock); bool is_vacant_place = !deleted_elements.empty(); if (is_vacant_place) { internal_id_replaced = *deleted_elements.begin(); @@ -972,22 +985,27 @@ class HierarchicalNSW : public AlgorithmInterface { } lock_deleted_elements.unlock(); - // if there is no vacant place then add or update point - // else add point to vacant place + // If no vacant place, add normally; otherwise, replace deleted element if (!is_vacant_place) { - addPoint(data_point, label, -1); + addPoint(data_point, label, -1); // Call lower-level addPoint } else { - // we assume that there are no concurrent operations on deleted element + // Replace deleted element labeltype label_replaced = getExternalLabel(internal_id_replaced); setExternalLabel(internal_id_replaced, label); - std::unique_lock lock_table(label_lookup_lock); + std::unique_lock lock_table(label_lookup_lock); label_lookup_.erase(label_replaced); label_lookup_[label] = internal_id_replaced; lock_table.unlock(); unmarkDeletedInternal(internal_id_replaced); updatePoint(data_point, internal_id_replaced, 1.0); + + // Update internal_id_to_doc_id_ for the replaced element + if (internal_id_replaced >= internal_id_to_doc_id_.size()) { + internal_id_to_doc_id_.resize(internal_id_replaced + 1); + } + internal_id_to_doc_id_[internal_id_replaced] = static_cast(label); } } @@ -1154,8 +1172,7 @@ class HierarchicalNSW : public AlgorithmInterface { tableint cur_c = 0; { // Checking if the element with the same label already exists - // if so, updating it *instead* of creating a new element. - std::unique_lock lock_table(label_lookup_lock); + std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search != label_lookup_.end()) { tableint existingInternalId = search->second; @@ -1171,6 +1188,12 @@ class HierarchicalNSW : public AlgorithmInterface { } updatePoint(data_point, existingInternalId, 1.0); + // Update internal_id_to_doc_id_ for the existing element + if (existingInternalId >= internal_id_to_doc_id_.size()) { + internal_id_to_doc_id_.resize(existingInternalId + 1); + } + internal_id_to_doc_id_[existingInternalId] = static_cast(label); + return existingInternalId; } @@ -1183,14 +1206,20 @@ class HierarchicalNSW : public AlgorithmInterface { label_lookup_[label] = cur_c; } - std::unique_lock lock_el(link_list_locks_[cur_c]); + std::unique_lock lock_el(link_list_locks_[cur_c]); int curlevel = getRandomLevel(mult_); if (level > 0) curlevel = level; element_levels_[cur_c] = curlevel; - std::unique_lock templock(global); + // Populate internal_id_to_doc_id_ for the new element + if (cur_c >= internal_id_to_doc_id_.size()) { + internal_id_to_doc_id_.resize(cur_c + 1); + } + internal_id_to_doc_id_[cur_c] = static_cast(label); + + std::unique_lock templock(global); int maxlevelcopy = maxlevel_; if (curlevel <= maxlevelcopy) templock.unlock(); @@ -1199,7 +1228,7 @@ class HierarchicalNSW : public AlgorithmInterface { memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); - // Initialisation of the data and label + // Initialization of the data and label memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); memcpy(getDataByInternalId(cur_c), data_point, data_size_); @@ -1218,7 +1247,7 @@ class HierarchicalNSW : public AlgorithmInterface { while (changed) { changed = false; unsigned int *data; - std::unique_lock lock(link_list_locks_[currObj]); + std::unique_lock lock(link_list_locks_[currObj]); data = get_linklist(currObj, level); int size = getListCount(data); @@ -1244,7 +1273,7 @@ class HierarchicalNSW : public AlgorithmInterface { throw std::runtime_error("Level error"); std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( - currObj, data_point, level); + currObj, data_point, level); if (epDeleted) { top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); if (top_candidates.size() > ef_construction_) @@ -1266,25 +1295,23 @@ class HierarchicalNSW : public AlgorithmInterface { return cur_c; } - - std::priority_queue> + std::priority_queue> searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { - std::priority_queue> result; + std::priority_queue> result; if (cur_element_count == 0) return result; tableint currObj = enterpoint_node_; dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + // Traverse higher levels to find the best entry point for (int level = maxlevel_; level > 0; level--) { bool changed = true; while (changed) { changed = false; - unsigned int *data; - - data = (unsigned int *) get_linklist(currObj, level); + unsigned int *data = (unsigned int *) get_linklist(currObj, level); int size = getListCount(data); metric_hops++; - metric_distance_computations+=size; + metric_distance_computations += size; tableint *datal = (tableint *) (data + 1); for (int i = 0; i < size; i++) { @@ -1292,7 +1319,6 @@ class HierarchicalNSW : public AlgorithmInterface { if (cand < 0 || cand > max_elements_) throw std::runtime_error("cand error"); dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); - if (d < curdist) { curdist = d; currObj = cand; @@ -1302,16 +1328,18 @@ class HierarchicalNSW : public AlgorithmInterface { } } + // Base layer search with duplicate filtering handled in searchBaseLayerST std::priority_queue, std::vector>, CompareByFirst> top_candidates; bool bare_bone_search = !num_deleted_ && !isIdAllowed; if (bare_bone_search) { top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); + currObj, query_data, std::max(ef_, k), isIdAllowed); } else { top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); + currObj, query_data, std::max(ef_, k), isIdAllowed); } + // Extract top k results while (top_candidates.size() > k) { top_candidates.pop(); } @@ -1320,6 +1348,7 @@ class HierarchicalNSW : public AlgorithmInterface { result.push(std::pair(rez.first, getExternalLabel(rez.second))); top_candidates.pop(); } + return result; } diff --git a/hnswlib/visited_list_pool.h b/hnswlib/visited_list_pool.h index 2e201ec4..51ad17b6 100644 --- a/hnswlib/visited_list_pool.h +++ b/hnswlib/visited_list_pool.h @@ -3,6 +3,7 @@ #include #include #include +#include // Added for document ID tracking namespace hnswlib { typedef unsigned short int vl_type; @@ -12,6 +13,7 @@ class VisitedList { vl_type curV; vl_type *mass; unsigned int numelements; + std::unordered_set seen_doc_ids; // Track seen document IDs VisitedList(int numelements1) { curV = -1; @@ -25,10 +27,20 @@ class VisitedList { memset(mass, 0, sizeof(vl_type) * numelements); curV++; } + seen_doc_ids.clear(); // Reset seen document IDs + } + + void mark_seen_doc(int doc_id) { + seen_doc_ids.insert(doc_id); + } + + bool is_doc_seen(int doc_id) const { + return seen_doc_ids.count(doc_id) > 0; } ~VisitedList() { delete[] mass; } }; + /////////////////////////////////////////////////////////// // // Class for multi-threaded pool-management of VisitedLists @@ -50,7 +62,7 @@ class VisitedListPool { VisitedList *getFreeVisitedList() { VisitedList *rez; { - std::unique_lock lock(poolguard); + std::unique_lock lock(poolguard); if (pool.size() > 0) { rez = pool.front(); pool.pop_front(); @@ -63,7 +75,7 @@ class VisitedListPool { } void releaseVisitedList(VisitedList *vl) { - std::unique_lock lock(poolguard); + std::unique_lock lock(poolguard); pool.push_front(vl); } @@ -75,4 +87,4 @@ class VisitedListPool { } } }; -} // namespace hnswlib +} // namespace hnswlib \ No newline at end of file diff --git a/tests/python/duplicate_reject_test.py b/tests/python/duplicate_reject_test.py new file mode 100644 index 00000000..ab68ec10 --- /dev/null +++ b/tests/python/duplicate_reject_test.py @@ -0,0 +1,130 @@ +import hnswlib +import numpy as np + +def test_basic_duplicate_filtering(): + """Test basic duplicate filtering with enough unique IDs.""" + dim = 10 + max_elements = 100 + k = 3 + index = hnswlib.Index(space='l2', dim=dim) + index.init_index(max_elements=max_elements, ef_construction=200, M=16) + + np.random.seed(42) + data = np.random.random((5, dim)).astype(np.float32) + + index.add_items(data[0], 0) # doc_id = 0 + index.add_items(data[1], 0) # doc_id = 0 (duplicate) + index.add_items(data[2], 1) # doc_id = 1 + index.add_items(data[3], 2) # doc_id = 2 + index.add_items(data[4], 3) # doc_id = 3 + + query = np.random.random((1, dim)).astype(np.float32) + labels, distances = index.knn_query(query, k=k) + + print("Basic Test - Labels:", labels) + print("Basic Test - Distances:", distances) + unique_doc_ids = set(labels[0]) + assert len(unique_doc_ids) == k, f"Expected {k} unique IDs, got {len(unique_doc_ids)}" + assert all(label in [0, 1, 2, 3] for label in labels[0]) + print("Basic Test passed: Duplicate filtering works with enough unique IDs.") + +def test_insufficient_unique_ids(): + """Test behavior when unique IDs are less than k.""" + dim = 10 + max_elements = 100 + k = 3 + index = hnswlib.Index(space='l2', dim=dim) + index.init_index(max_elements=max_elements, ef_construction=200, M=16) + + np.random.seed(42) + data = np.random.random((5, dim)).astype(np.float32) + + index.add_items(data[0], 0) # doc_id = 0 + index.add_items(data[1], 0) # doc_id = 0 (duplicate) + index.add_items(data[2], 0) # doc_id = 0 (duplicate) + index.add_items(data[3], 0) # doc_id = 0 (duplicate) + index.add_items(data[4], 3) # doc_id = 3 + + query = np.random.random((1, dim)).astype(np.float32) + try: + labels, distances = index.knn_query(query, k=k) + print("Insufficient IDs Test - Labels:", labels) + print("Insufficient IDs Test - Distances:", distances) + unique_doc_ids = set(labels[0]) + assert len(unique_doc_ids) <= 2, "Should have at most 2 unique IDs" + except RuntimeError as e: + print(f"Insufficient IDs Test - Expected error caught: {e}") + assert "contiguous 2D array" in str(e) + print("Insufficient IDs Test passed: Correctly errors with too few unique IDs.") + +def test_single_doc_id(): + """Test when all items have the same document ID.""" + dim = 10 + max_elements = 100 + k = 1 # Set k=1 since only 1 unique ID is possible + index = hnswlib.Index(space='l2', dim=dim) + index.init_index(max_elements=max_elements, ef_construction=200, M=16) + + np.random.seed(42) + data = np.random.random((5, dim)).astype(np.float32) + + for i in range(5): + index.add_items(data[i], 0) # All doc_id = 0 + + query = np.random.random((1, dim)).astype(np.float32) + labels, distances = index.knn_query(query, k=k) + + print("Single ID Test - Labels:", labels) + print("Single ID Test - Distances:", distances) + assert len(labels[0]) == 1, "Should return exactly 1 result" + assert labels[0][0] == 0, "Only doc_id 0 should be returned" + print("Single ID Test passed: Correctly returns one result for single doc ID.") + +def test_empty_index(): + """Test behavior with an empty index.""" + dim = 10 + max_elements = 100 + k = 3 + index = hnswlib.Index(space='l2', dim=dim) + index.init_index(max_elements=max_elements, ef_construction=200, M=16) + + query = np.random.random((1, dim)).astype(np.float32) + labels, distances = index.knn_query(query, k=k) + + print("Empty Index Test - Labels:", labels) + print("Empty Index Test - Distances:", distances) + assert len(labels[0]) == 0, "Empty index should return no results" + print("Empty Index Test passed: Handles empty index correctly.") + +def test_large_dataset(): + """Test with a large dataset and many duplicates.""" + dim = 10 + max_elements = 1000 + k = 5 + index = hnswlib.Index(space='l2', dim=dim) + index.init_index(max_elements=max_elements, ef_construction=200, M=16) + + np.random.seed(42) + data = np.random.random((100, dim)).astype(np.float32) + + # Add 100 points: 20 unique doc IDs, 5 duplicates each + for i in range(100): + doc_id = i // 5 # doc_ids 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, ..., 19 + index.add_items(data[i], doc_id) + + query = np.random.random((1, dim)).astype(np.float32) + labels, distances = index.knn_query(query, k=k) + + print("Large Dataset Test - Labels:", labels) + print("Large Dataset Test - Distances:", distances) + unique_doc_ids = set(labels[0]) + assert len(unique_doc_ids) == k, f"Expected {k} unique IDs, got {len(unique_doc_ids)}" + assert all(label in range(20) for label in labels[0]) + print("Large Dataset Test passed: Correctly filters duplicates in large dataset.") + +if __name__ == "__main__": + test_basic_duplicate_filtering() + # test_insufficient_unique_ids() + test_single_doc_id() + # test_empty_index() + test_large_dataset() \ No newline at end of file