diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 3b4ea8c0..beeb84a1 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -83,7 +83,11 @@ class HierarchicalNSW : public AlgorithmInterface { std::ofstream output_data_level0_; // output stream for data level 0 std::ofstream output_length_; // output stream for length std::ofstream output_link_lists_; // output stream for link lists - bool _persist_file_handles_opened = false; // flag to check if file handles are opened + // Used to store the offset of the link list for each element so that we can seek to it during persist + std::vector link_list_size_offset_; + size_t link_list_offset_memoized_elements_{0}; + // flag to check if file handles are opened + bool _persist_file_handles_opened = false; HierarchicalNSW(SpaceInterface *s) { } @@ -154,6 +158,8 @@ class HierarchicalNSW : public AlgorithmInterface { length_memory_ = (char *) malloc(max_elements_ * sizeof(float)); if (length_memory_ == nullptr) throw std::runtime_error("Not enough memory"); + + link_list_size_offset_ = std::vector(max_elements_); cur_element_count = 0; @@ -645,6 +651,7 @@ class HierarchicalNSW : public AlgorithmInterface { visited_list_pool_ = new VisitedListPool(1, new_max_elements); element_levels_.resize(new_max_elements); + link_list_size_offset_.resize(new_max_elements); std::vector(new_max_elements).swap(link_list_locks_); @@ -870,16 +877,24 @@ class HierarchicalNSW : public AlgorithmInterface { this->output_link_lists_.seekp(0, std::ios::beg); auto dirty_elements_iter = elements_to_persist_.begin(); // TODO: don't need to iterate over potentially all elements, could store it or memoize - for (size_t i = 0; i < cur_element_count && dirty_elements_iter != elements_to_persist_.end(); i++) { - unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; - if (i == *dirty_elements_iter) { - writeBinaryPOD(this->output_link_lists_, linkListSize); - if (linkListSize) - this->output_link_lists_.write(linkLists_[i], linkListSize); - dirty_elements_iter = std::next(dirty_elements_iter); - } else { - this->output_link_lists_.seekp(linkListSize + sizeof(unsigned int), this->output_link_lists_.cur); - } + // for (size_t i = 0; i < cur_element_count && dirty_elements_iter != elements_to_persist_.end(); i++) { + // unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + // if (i == *dirty_elements_iter) { + // writeBinaryPOD(this->output_link_lists_, linkListSize); + // if (linkListSize) + // this->output_link_lists_.write(linkLists_[i], linkListSize); + // dirty_elements_iter = std::next(dirty_elements_iter); + // } else { + // this->output_link_lists_.seekp(linkListSize + sizeof(unsigned int), this->output_link_lists_.cur); + // } + // } + for (const auto& id : elements_to_persist_) { + auto offset = getLinkListOffset(id); + this->output_link_lists_.seekp(offset, this->output_link_lists_.beg); + unsigned int linkListSize = element_levels_[id] > 0 ? size_links_per_element_ * element_levels_[id] : 0; + writeBinaryPOD(this->output_link_lists_, linkListSize); + if (linkListSize) + this->output_link_lists_.write(linkLists_[id], linkListSize); } this->output_link_lists_.flush(); @@ -887,6 +902,31 @@ class HierarchicalNSW : public AlgorithmInterface { elements_to_persist_.clear(); } + unsigned int getLinkListOffset(size_t i){ + if (i >= cur_element_count || i < 0) + throw std::runtime_error("Index out of bounds"); + + if (i == 0) { + return 0; + } + + // If we have not memoized the offset, we need to calculate it + if (i > link_list_offset_memoized_elements_){ + size_t start_idx = link_list_offset_memoized_elements_; + size_t end_idx = i + 1; + for (size_t j = start_idx; j < end_idx; j++){ + unsigned int linkListSize = element_levels_[j] > 0 ? size_links_per_element_ * element_levels_[j] : 0; + if (j == 0) { + link_list_size_offset_[j] = linkListSize + sizeof(unsigned int); + } else { + link_list_size_offset_[j] = link_list_size_offset_[j - 1] + linkListSize + sizeof(unsigned int); + } + } + link_list_offset_memoized_elements_ = end_idx; + } + return link_list_size_offset_[i - 1]; + } + void loadPersistedIndex(SpaceInterface *s, size_t max_elements_i = 0){ std::ifstream input_header(this->getHeaderLocation(), std::ios::binary); if (!input_header.is_open())