diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index bef00170..40d3c47c 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace hnswlib { typedef unsigned int tableint; @@ -77,11 +78,12 @@ class HierarchicalNSW : public AlgorithmInterface { HierarchicalNSW( SpaceInterface *s, const std::string &location, + bool parallel_load = false, bool nmslib = false, size_t max_elements = 0, bool allow_replace_deleted = false) : allow_replace_deleted_(allow_replace_deleted) { - loadIndex(location, s, max_elements); + loadIndex(location, s, max_elements, parallel_load); } @@ -627,7 +629,7 @@ class HierarchicalNSW : public AlgorithmInterface { } - void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0, bool parallel_load = false) { std::ifstream input(location, std::ios::binary); if (!input.is_open()) @@ -690,7 +692,14 @@ class HierarchicalNSW : public AlgorithmInterface { data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); if (data_level0_memory_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + size_t level0_size = cur_element_count * size_data_per_element_; + if (parallel_load) { + parallelLoadLevel0(location, pos); + input.seekg(pos + level0_size); + } else { + input.read(data_level0_memory_, level0_size); + } size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); @@ -1267,5 +1276,38 @@ class HierarchicalNSW : public AlgorithmInterface { } std::cout << "integrity ok, checked " << connections_checked << " connections\n"; } + +protected: + + + void loadLevel0Chunk(const std::string &location, const size_t level0_start_pos, size_t offset, size_t size) { + std::ifstream in_file(location, std::ios::binary); + + if (!in_file.is_open()) + throw std::runtime_error("Cannot open file"); + + in_file.seekg(level0_start_pos + offset); + in_file.read(data_level0_memory_ + offset, size); + in_file.close(); + } + + + void parallelLoadLevel0(const std::string &location, std::streampos level0_start_pos) { + size_t level0_size = cur_element_count * size_data_per_element_; + size_t num_cpu_threads = std::thread::hardware_concurrency(); + size_t part_sz = level0_size / num_cpu_threads; + + std::vector threads; + size_t offset = 0; + for (size_t i = 0; i < num_cpu_threads - 1; i++) { + threads.emplace_back(&HierarchicalNSW::loadLevel0Chunk, this, location, level0_start_pos, offset, part_sz); + offset += part_sz; + } + + loadLevel0Chunk(location, level0_start_pos, offset, part_sz + level0_size % num_cpu_threads); + + for (auto& thread : threads) + thread.join(); + } }; } // namespace hnswlib