Skip to content

Add duplicate filtering by document ID in HNSWlib search #623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 100 additions & 71 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ typedef unsigned int linklistsizeint;

template<typename dist_t>
class HierarchicalNSW : public AlgorithmInterface<dist_t> {

private:
std::vector<int> internal_id_to_doc_id_;
private:
int getMetadata(tableint internal_id) const {
if (internal_id >= internal_id_to_doc_id_.size()) {
throw std::runtime_error("Internal ID out of range in getMetadata");
}
return internal_id_to_doc_id_[internal_id];
}
public:
static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
static const unsigned char DELETE_MARK = 0x01;
Expand Down Expand Up @@ -141,6 +151,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
mult_ = 1 / log(1.0 * M_);
revSize_ = 1.0 / mult_;
internal_id_to_doc_id_.reserve(max_elements_);
}


Expand Down Expand Up @@ -332,6 +343,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist);
}
candidate_set.emplace(-dist, ep_id);
vl->mark_seen_doc(getMetadata(ep_id)); // Mark initial document ID
} else {
lowerBound = std::numeric_limits<dist_t>::max();
candidate_set.emplace(-lowerBound, ep_id);
Expand Down Expand Up @@ -361,75 +373,75 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
tableint current_node_id = current_node_pair.second;
int *data = (int *) get_linklist0(current_node_id);
size_t size = getListCount((linklistsizeint*)data);
// bool cur_node_deleted = isMarkedDeleted(current_node_id);
if (collect_metrics) {
metric_hops++;
metric_distance_computations+=size;
metric_distance_computations += size;
}

#ifdef USE_SSE
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
_mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
_mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
_mm_prefetch((char *) (data + 2), _MM_HINT_T0);
#endif
#endif

for (size_t j = 1; j <= size; j++) {
int candidate_id = *(data + j);
// if (candidate_id == 0) continue;
#ifdef USE_SSE
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
_mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
_MM_HINT_T0); ////////////
#endif
_mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
#endif
if (!(visited_array[candidate_id] == visited_array_tag)) {
visited_array[candidate_id] = visited_array_tag;

char *currObj1 = (getDataByInternalId(candidate_id));
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);

bool flag_consider_candidate;
if (!bare_bone_search && stop_condition) {
flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound);
} else {
flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;
}

if (flag_consider_candidate) {
candidate_set.emplace(-dist, candidate_id);
#ifdef USE_SSE
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
offsetLevel0_, ///////////
_MM_HINT_T0); ////////////////////////
#endif

if (bare_bone_search ||
(!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
top_candidates.emplace(dist, candidate_id);
if (!bare_bone_search && stop_condition) {
stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist);
}
}
int doc_id = getMetadata(candidate_id);
if (!vl->is_doc_seen(doc_id)) { // Filter duplicates based on document ID
char *currObj1 = getDataByInternalId(candidate_id);
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);

bool flag_remove_extra = false;
bool flag_consider_candidate;
if (!bare_bone_search && stop_condition) {
flag_remove_extra = stop_condition->should_remove_extra();
flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound);
} else {
flag_remove_extra = top_candidates.size() > ef;
flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;
}
while (flag_remove_extra) {
tableint id = top_candidates.top().second;
top_candidates.pop();

if (flag_consider_candidate) {
candidate_set.emplace(-dist, candidate_id);
#ifdef USE_SSE
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
offsetLevel0_, _MM_HINT_T0);
#endif

if (bare_bone_search ||
(!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
top_candidates.emplace(dist, candidate_id);
vl->mark_seen_doc(doc_id); // Mark document ID as seen
if (!bare_bone_search && stop_condition) {
stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist);
}
}

bool flag_remove_extra = false;
if (!bare_bone_search && stop_condition) {
stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist);
flag_remove_extra = stop_condition->should_remove_extra();
} else {
flag_remove_extra = top_candidates.size() > ef;
}
}
while (flag_remove_extra) {
tableint id = top_candidates.top().second;
top_candidates.pop();
if (!bare_bone_search && stop_condition) {
stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist);
flag_remove_extra = stop_condition->should_remove_extra();
} else {
flag_remove_extra = top_candidates.size() > ef;
}
}

if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
}
}
}
Expand Down Expand Up @@ -956,38 +968,44 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
}

// lock all operations with element by label
std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
// Lock all operations with element by label
std::unique_lock<std::mutex> lock_label(getLabelOpMutex(label));
if (!replace_deleted) {
addPoint(data_point, label, -1);
addPoint(data_point, label, -1); // Call lower-level addPoint
return;
}
// check if there is vacant place

// Check if there is a vacant place
tableint internal_id_replaced;
std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
std::unique_lock<std::mutex> lock_deleted_elements(deleted_elements_lock);
bool is_vacant_place = !deleted_elements.empty();
if (is_vacant_place) {
internal_id_replaced = *deleted_elements.begin();
deleted_elements.erase(internal_id_replaced);
}
lock_deleted_elements.unlock();

// if there is no vacant place then add or update point
// else add point to vacant place
// If no vacant place, add normally; otherwise, replace deleted element
if (!is_vacant_place) {
addPoint(data_point, label, -1);
addPoint(data_point, label, -1); // Call lower-level addPoint
} else {
// we assume that there are no concurrent operations on deleted element
// Replace deleted element
labeltype label_replaced = getExternalLabel(internal_id_replaced);
setExternalLabel(internal_id_replaced, label);

std::unique_lock <std::mutex> lock_table(label_lookup_lock);
std::unique_lock<std::mutex> lock_table(label_lookup_lock);
label_lookup_.erase(label_replaced);
label_lookup_[label] = internal_id_replaced;
lock_table.unlock();

unmarkDeletedInternal(internal_id_replaced);
updatePoint(data_point, internal_id_replaced, 1.0);

// Update internal_id_to_doc_id_ for the replaced element
if (internal_id_replaced >= internal_id_to_doc_id_.size()) {
internal_id_to_doc_id_.resize(internal_id_replaced + 1);
}
internal_id_to_doc_id_[internal_id_replaced] = static_cast<int>(label);
}
}

Expand Down Expand Up @@ -1154,8 +1172,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
tableint cur_c = 0;
{
// Checking if the element with the same label already exists
// if so, updating it *instead* of creating a new element.
std::unique_lock <std::mutex> lock_table(label_lookup_lock);
std::unique_lock<std::mutex> lock_table(label_lookup_lock);
auto search = label_lookup_.find(label);
if (search != label_lookup_.end()) {
tableint existingInternalId = search->second;
Expand All @@ -1171,6 +1188,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
updatePoint(data_point, existingInternalId, 1.0);

// Update internal_id_to_doc_id_ for the existing element
if (existingInternalId >= internal_id_to_doc_id_.size()) {
internal_id_to_doc_id_.resize(existingInternalId + 1);
}
internal_id_to_doc_id_[existingInternalId] = static_cast<int>(label);

return existingInternalId;
}

Expand All @@ -1183,14 +1206,20 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
label_lookup_[label] = cur_c;
}

std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
std::unique_lock<std::mutex> lock_el(link_list_locks_[cur_c]);
int curlevel = getRandomLevel(mult_);
if (level > 0)
curlevel = level;

element_levels_[cur_c] = curlevel;

std::unique_lock <std::mutex> templock(global);
// Populate internal_id_to_doc_id_ for the new element
if (cur_c >= internal_id_to_doc_id_.size()) {
internal_id_to_doc_id_.resize(cur_c + 1);
}
internal_id_to_doc_id_[cur_c] = static_cast<int>(label);

std::unique_lock<std::mutex> templock(global);
int maxlevelcopy = maxlevel_;
if (curlevel <= maxlevelcopy)
templock.unlock();
Expand All @@ -1199,7 +1228,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);

// Initialisation of the data and label
// Initialization of the data and label
memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
memcpy(getDataByInternalId(cur_c), data_point, data_size_);

Expand All @@ -1218,7 +1247,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
while (changed) {
changed = false;
unsigned int *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
std::unique_lock<std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist(currObj, level);
int size = getListCount(data);

Expand All @@ -1244,7 +1273,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
throw std::runtime_error("Level error");

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
currObj, data_point, level);
currObj, data_point, level);
if (epDeleted) {
top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
if (top_candidates.size() > ef_construction_)
Expand All @@ -1266,33 +1295,30 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
return cur_c;
}


std::priority_queue<std::pair<dist_t, labeltype >>
std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
std::priority_queue<std::pair<dist_t, labeltype>> result;
if (cur_element_count == 0) return result;

tableint currObj = enterpoint_node_;
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);

// Traverse higher levels to find the best entry point
for (int level = maxlevel_; level > 0; level--) {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;

data = (unsigned int *) get_linklist(currObj, level);
unsigned int *data = (unsigned int *) get_linklist(currObj, level);
int size = getListCount(data);
metric_hops++;
metric_distance_computations+=size;
metric_distance_computations += size;

tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);

if (d < curdist) {
curdist = d;
currObj = cand;
Expand All @@ -1302,16 +1328,18 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
}

// Base layer search with duplicate filtering handled in searchBaseLayerST
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
bool bare_bone_search = !num_deleted_ && !isIdAllowed;
if (bare_bone_search) {
top_candidates = searchBaseLayerST<true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
currObj, query_data, std::max(ef_, k), isIdAllowed);
} else {
top_candidates = searchBaseLayerST<false>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
currObj, query_data, std::max(ef_, k), isIdAllowed);
}

// Extract top k results
while (top_candidates.size() > k) {
top_candidates.pop();
}
Expand All @@ -1320,6 +1348,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
top_candidates.pop();
}

return result;
}

Expand Down
Loading