Skip to content

Memoize link list offset #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::ofstream output_data_level0_; // output stream for data level 0
std::ofstream output_length_; // output stream for length
std::ofstream output_link_lists_; // output stream for link lists
bool _persist_file_handles_opened = false; // flag to check if file handles are opened
// Used to store the offset of the link list for each element so that we can seek to it during persist
std::vector<size_t> link_list_size_offset_;
size_t link_list_offset_memoized_elements_{0};
// flag to check if file handles are opened
bool _persist_file_handles_opened = false;

HierarchicalNSW(SpaceInterface<dist_t> *s) {
}
Expand Down Expand Up @@ -154,6 +158,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
length_memory_ = (char *) malloc(max_elements_ * sizeof(float));
if (length_memory_ == nullptr)
throw std::runtime_error("Not enough memory");

link_list_size_offset_ = std::vector<size_t>(max_elements_);

cur_element_count = 0;

Expand Down Expand Up @@ -645,6 +651,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
visited_list_pool_ = new VisitedListPool(1, new_max_elements);

element_levels_.resize(new_max_elements);
link_list_size_offset_.resize(new_max_elements);

std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);

Expand Down Expand Up @@ -870,23 +877,56 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
this->output_link_lists_.seekp(0, std::ios::beg);
auto dirty_elements_iter = elements_to_persist_.begin();
// TODO: don't need to iterate over potentially all elements, could store it or memoize
for (size_t i = 0; i < cur_element_count && dirty_elements_iter != elements_to_persist_.end(); i++) {
unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
if (i == *dirty_elements_iter) {
writeBinaryPOD(this->output_link_lists_, linkListSize);
if (linkListSize)
this->output_link_lists_.write(linkLists_[i], linkListSize);
dirty_elements_iter = std::next(dirty_elements_iter);
} else {
this->output_link_lists_.seekp(linkListSize + sizeof(unsigned int), this->output_link_lists_.cur);
}
// for (size_t i = 0; i < cur_element_count && dirty_elements_iter != elements_to_persist_.end(); i++) {
// unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
// if (i == *dirty_elements_iter) {
// writeBinaryPOD(this->output_link_lists_, linkListSize);
// if (linkListSize)
// this->output_link_lists_.write(linkLists_[i], linkListSize);
// dirty_elements_iter = std::next(dirty_elements_iter);
// } else {
// this->output_link_lists_.seekp(linkListSize + sizeof(unsigned int), this->output_link_lists_.cur);
// }
// }
for (const auto& id : elements_to_persist_) {
auto offset = getLinkListOffset(id);
this->output_link_lists_.seekp(offset, this->output_link_lists_.beg);
unsigned int linkListSize = element_levels_[id] > 0 ? size_links_per_element_ * element_levels_[id] : 0;
writeBinaryPOD(this->output_link_lists_, linkListSize);
if (linkListSize)
this->output_link_lists_.write(linkLists_[id], linkListSize);
}
this->output_link_lists_.flush();

// Note: It would make sense to do a fsync here
elements_to_persist_.clear();
}

unsigned int getLinkListOffset(size_t i){
if (i >= cur_element_count || i < 0)
throw std::runtime_error("Index out of bounds");

if (i == 0) {
return 0;
}

// If we have not memoized the offset, we need to calculate it
if (i > link_list_offset_memoized_elements_){
size_t start_idx = link_list_offset_memoized_elements_;
size_t end_idx = i + 1;
for (size_t j = start_idx; j < end_idx; j++){
unsigned int linkListSize = element_levels_[j] > 0 ? size_links_per_element_ * element_levels_[j] : 0;
if (j == 0) {
link_list_size_offset_[j] = linkListSize + sizeof(unsigned int);
} else {
link_list_size_offset_[j] = link_list_size_offset_[j - 1] + linkListSize + sizeof(unsigned int);
}
}
link_list_offset_memoized_elements_ = end_idx;
}
return link_list_size_offset_[i - 1];
}

void loadPersistedIndex(SpaceInterface<dist_t> *s, size_t max_elements_i = 0){
std::ifstream input_header(this->getHeaderLocation(), std::ios::binary);
if (!input_header.is_open())
Expand Down