From 43a8b9c9cdcd050c2a41456f12398da81d84bd6c Mon Sep 17 00:00:00 2001 From: LHT129 Date: Tue, 20 May 2025 11:05:35 +0800 Subject: [PATCH 01/42] speed up ci compile (#703) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- .circleci/fresh_ci_cache.commit | 2 +- .github/workflows/asan_build_and_test.yml | 40 +++++++++++++++++++++-- .github/workflows/coverage.yml | 20 +++++++++++- .github/workflows/tsan_build_and_test.yml | 23 +++++++++++-- scripts/change_mtime.sh | 7 ++++ 5 files changed, 86 insertions(+), 6 deletions(-) create mode 100644 scripts/change_mtime.sh diff --git a/.circleci/fresh_ci_cache.commit b/.circleci/fresh_ci_cache.commit index 09d2dfaaf..5c8c4dacd 100644 --- a/.circleci/fresh_ci_cache.commit +++ b/.circleci/fresh_ci_cache.commit @@ -1 +1 @@ -fc8c1a3aad5ce0fd4704b8a7bbdf6255d2b79cd6 \ No newline at end of file +cce585e1f18168e5a98e5f3f4d08785cd4e120b2 diff --git a/.github/workflows/asan_build_and_test.yml b/.github/workflows/asan_build_and_test.yml index 9584cbe8c..d50644f76 100644 --- a/.github/workflows/asan_build_and_test.yml +++ b/.github/workflows/asan_build_and_test.yml @@ -21,11 +21,25 @@ jobs: - name: Free Disk Space run: rm -rf /useless/hostedtoolcache - uses: actions/checkout@v4 + with: + fetch-depth: '0' + - name: Change Time + run: | + git config --global --add safe.directory /__w/vsag/vsag + sh -x ./scripts/change_mtime.sh - name: Load Cache uses: actions/cache@v4 with: - path: ./build/ + path: ./build.tar.gz key: build-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }} + - name: Extract + run: | + if [ -f build.tar.gz ]; then + tar -xzvf build.tar.gz; + rm -rf build.tar.gz; + else + mkdir -p build; + fi - name: Make Asan run: export CMAKE_GENERATOR="Ninja"; make asan - name: Save Test @@ -36,6 +50,10 @@ jobs: compression-level: 1 retention-days: 1 overwrite: 'true' + - name: Compress + run: | + sudo apt install pigz -y; + tar -cf - ./build | pigz -p 4 > build.tar.gz build_asan_aarch64: name: Asan Build Aarch64 @@ -47,13 +65,27 @@ jobs: - name: Free Disk Space run: rm -rf /opt/hostedtoolcache - uses: actions/checkout@v4 + with: + fetch-depth: '0' + - name: Change Time + run: | + git config --global --add safe.directory /__w/vsag/vsag + sh -x ./scripts/change_mtime.sh - name: Prepare Env run: sudo bash ./scripts/deps/install_deps_ubuntu.sh - name: Load Cache uses: actions/cache@v4 with: - path: ./build/ + path: ./build.tar.gz key: build-aarch-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }} + - name: Extract + run: | + if [ -f build.tar.gz ]; then + tar -xzvf build.tar.gz; + rm -rf build.tar.gz; + else + mkdir -p build; + fi - name: Make Asan run: export CMAKE_GENERATOR="Ninja"; make asan - name: Save Test @@ -64,6 +96,10 @@ jobs: compression-level: 1 retention-days: 1 overwrite: 'true' + - name: Compress + run: | + sudo apt install pigz -y; + tar -cf - ./build | pigz -p 4 > build.tar.gz test_asan_x86: name: Test X86 diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index e0155653e..e686bfaf7 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -17,6 +17,12 @@ jobs: image: vsaglib/vsag:ci-x86 steps: - uses: actions/checkout@v4 + with: + fetch-depth: '0' + - name: Change Time + run: | + git config --global --add safe.directory /__w/vsag/vsag + sh -x ./scripts/change_mtime.sh - name: Install lcov run: | apt update @@ -27,8 +33,16 @@ jobs: - name: Load Cache uses: actions/cache@v4 with: - path: ./build/ + path: ./build.tar.gz key: build-cov-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }} + - name: Extract + run: | + if [ -f build.tar.gz ]; then + tar -xzvf build.tar.gz; + rm -rf build.tar.gz; + else + mkdir -p build; + fi - name: Compile with Coverage Flags run: export CMAKE_GENERATOR="Ninja"; make cov - name: Run Test @@ -48,3 +62,7 @@ jobs: flags: cpp token: ${{ secrets.CODECOV_TOKEN }} verbose: true + - name: Compress + run: | + sudo apt install pigz -y; + tar -cf - ./build | pigz -p 4 > build.tar.gz \ No newline at end of file diff --git a/.github/workflows/tsan_build_and_test.yml b/.github/workflows/tsan_build_and_test.yml index e461c2c6c..8ea51d1a6 100644 --- a/.github/workflows/tsan_build_and_test.yml +++ b/.github/workflows/tsan_build_and_test.yml @@ -17,11 +17,25 @@ jobs: image: vsaglib/vsag:ci-x86 steps: - uses: actions/checkout@v4 + with: + fetch-depth: '0' + - name: Change Time + run: | + git config --global --add safe.directory /__w/vsag/vsag + sh -x ./scripts/change_mtime.sh - name: Load Cache uses: actions/cache@v4 with: - path: ./build/ - key: build-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }}-tsan + path: ./build.tar.gz + key: build-tsan-${{ hashFiles('./CMakeLists.txt') }}-${{ hashFiles('./.circleci/fresh_ci_cache.commit') }} + - name: Extract + run: | + if [ -f build.tar.gz ]; then + tar -xzvf build.tar.gz; + rm -rf build.tar.gz; + else + mkdir -p build; + fi - name: Make Tsan run: export CMAKE_GENERATOR="Ninja"; make tsan - name: Save Test @@ -32,6 +46,11 @@ jobs: compression-level: 1 retention-days: 1 overwrite: 'true' + - name: Compress + run: | + sudo apt install pigz -y; + tar -cf - ./build | pigz -p 4 > build.tar.gz + test_tsan: name: Run TSAN Tests needs: build_tsan diff --git a/scripts/change_mtime.sh b/scripts/change_mtime.sh new file mode 100644 index 000000000..fdd0ca949 --- /dev/null +++ b/scripts/change_mtime.sh @@ -0,0 +1,7 @@ +commit=$(git rev-parse HEAD) +git ls-files | while read -r file; do + time=$(git log -1 --pretty=%cd --date=iso -- "$file") + if [ -n "$time" ]; then + touch -d "$time" "$file" + fi +done \ No newline at end of file From a28b5c76f1b3132bc9108cbc36c6fc7bafb0b555 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Tue, 20 May 2025 11:12:58 +0800 Subject: [PATCH 02/42] kmeans train for huge K and data count (#717) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- src/impl/kmeans_cluster.cpp | 149 ++++++++++-------- src/impl/kmeans_cluster.h | 10 ++ .../product_quantizer_test.cpp | 4 +- src/quantization/quantizer_test.h | 2 +- 4 files changed, 92 insertions(+), 73 deletions(-) diff --git a/src/impl/kmeans_cluster.cpp b/src/impl/kmeans_cluster.cpp index 7c92caea0..e380c8ff7 100644 --- a/src/impl/kmeans_cluster.cpp +++ b/src/impl/kmeans_cluster.cpp @@ -42,7 +42,6 @@ KMeansCluster::~KMeansCluster() { Vector KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { - // Allocate space for centroids if (k_centroids_ != nullptr) { allocator_->Deallocate(k_centroids_); k_centroids_ = nullptr; @@ -50,7 +49,6 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { uint64_t size = static_cast(k) * static_cast(dim_) * sizeof(float); k_centroids_ = static_cast(allocator_->Allocate(size)); - // Initialize centroids randomly std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution dis(0, count - 1); @@ -61,78 +59,21 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { } } + Vector labels(count, -1, this->allocator_); + std::vector mutexes(k); + std::vector> futures; + constexpr uint64_t query_count_bs = 65536; ByteBuffer y_sqr_buffer(static_cast(k) * sizeof(float), allocator_); - ByteBuffer distances_buffer(static_cast(k) * count * sizeof(float), allocator_); + ByteBuffer distances_buffer(static_cast(k) * query_count_bs * sizeof(float), + allocator_); auto* y_sqr = reinterpret_cast(y_sqr_buffer.data); auto* distances = reinterpret_cast(distances_buffer.data); - - Vector labels(count, -1, this->allocator_); - bool have_empty = false; - std::vector mutexes(k); for (int it = 0; it < iter; ++it) { - std::atomic has_converged = true; - std::vector> futures; - - auto compute_ip_func = [&](uint64_t start, uint64_t end) -> void { - for (uint64_t i = start; i < end; ++i) { - y_sqr[i] = FP32ComputeIP(k_centroids_ + i * dim_, k_centroids_ + i * dim_, dim_); - } - }; - auto bs = 1024; - for (uint64_t i = 0; i < static_cast(k); i += bs) { - futures.emplace_back(thread_pool_->GeneralEnqueue( - compute_ip_func, i, std::min(i + bs, static_cast(k)))); - } - for (auto& future : futures) { - future.wait(); - } - futures.clear(); - - cblas_sgemm(CblasColMajor, - CblasTrans, - CblasNoTrans, - static_cast(k), - static_cast(count), - dim_, - -2.0F, - k_centroids_, - dim_, - datas, - dim_, - 0.0F, - distances, - static_cast(k)); + this->find_nearest_one_with_blas(datas, count, k, query_count_bs, y_sqr, distances, labels); + constexpr uint64_t bs = 1024; - // Assign labels to each data point - auto assign_labels_func = [&](uint64_t start, uint64_t end) { - omp_set_num_threads(1); - for (uint64_t i = start; i < end; ++i) { - cblas_saxpy(static_cast(k), 1.0, y_sqr, 1, distances + i * k, 1); - auto* min_elem = std::min_element(distances + i * k, distances + i * k + k); - auto min_index = std::distance(distances + i * k, min_elem); - if (min_index != labels[i]) { - labels[i] = static_cast(min_index); - has_converged.store(false); - } - } - }; - for (uint64_t i = 0; i < count; i += bs) { - futures.emplace_back( - thread_pool_->GeneralEnqueue(assign_labels_func, i, std::min(i + bs, count))); - } - for (auto& future : futures) { - future.wait(); - } - futures.clear(); - - if (has_converged.load() and not have_empty) { - break; - } - - // Update centroids Vector counts(k, 0, allocator_); Vector new_centroids(static_cast(k) * dim_, 0.0F, allocator_); - have_empty = false; auto update_centroids_func = [&](uint64_t start, uint64_t end) { omp_set_num_threads(1); @@ -165,12 +106,10 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { 1.0F / static_cast(counts[j]), new_centroids.data() + j * static_cast(dim_), 1); - // Copy new centroids to k_centroids_ std::copy(new_centroids.data() + j * static_cast(dim_), new_centroids.data() + (j + 1) * static_cast(dim_), k_centroids_ + j * static_cast(dim_)); } else { - have_empty = true; auto index = dis(gen); for (int s = 0; s < dim_; ++s) { k_centroids_[j * dim_ + s] = datas[index * dim_ + s]; @@ -181,4 +120,76 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { return labels; } +void +KMeansCluster::find_nearest_one_with_blas(const float* query, + const uint64_t query_count, + const uint64_t k, + const uint64_t query_count_bs, + float* y_sqr, + float* distances, + Vector& labels) { + if (k_centroids_ == nullptr) { + return; + } + auto& thread_pool = this->thread_pool_; + auto bs = 1024; + std::vector> futures; + + auto wait_futures_and_clear = [&]() { + for (auto& future : futures) { + future.wait(); + } + futures.clear(); + }; + + auto compute_ip_func = [&](uint64_t start, uint64_t end) -> void { + for (uint64_t i = start; i < end; ++i) { + y_sqr[i] = FP32ComputeIP(k_centroids_ + i * dim_, k_centroids_ + i * dim_, dim_); + } + }; + for (uint64_t i = 0; i < static_cast(k); i += bs) { + futures.emplace_back(thread_pool->GeneralEnqueue( + compute_ip_func, i, std::min(i + bs, static_cast(k)))); + } + wait_futures_and_clear(); + + for (uint64_t i = 0; i < query_count; i += query_count_bs) { + auto end = std::min(i + query_count_bs, query_count); + auto cur_query_count = end - i; + auto* cur_label = labels.data() + i; + + cblas_sgemm(CblasColMajor, + CblasTrans, + CblasNoTrans, + static_cast(k), + static_cast(cur_query_count), + dim_, + -2.0F, + k_centroids_, + dim_, + query + i * dim_, + dim_, + 0.0F, + distances, + static_cast(k)); + + auto assign_labels_func = [&](uint64_t start, uint64_t end) -> void { + omp_set_num_threads(1); + for (uint64_t i = start; i < end; ++i) { + cblas_saxpy(static_cast(k), 1.0, y_sqr, 1, distances + i * k, 1); + auto* min_elem = std::min_element(distances + i * k, distances + i * k + k); + auto min_index = std::distance(distances + i * k, min_elem); + if (min_index != cur_label[i]) { + cur_label[i] = static_cast(min_index); + } + } + }; + for (uint64_t j = 0; j < cur_query_count; j += bs) { + futures.emplace_back(thread_pool->GeneralEnqueue( + assign_labels_func, j, std::min(j + bs, cur_query_count))); + } + wait_futures_and_clear(); + } +} + } // namespace vsag diff --git a/src/impl/kmeans_cluster.h b/src/impl/kmeans_cluster.h index 192118c19..d376c3fee 100644 --- a/src/impl/kmeans_cluster.h +++ b/src/impl/kmeans_cluster.h @@ -35,6 +35,16 @@ class KMeansCluster { public: float* k_centroids_{nullptr}; +private: + void + find_nearest_one_with_blas(const float* query, + const uint64_t query_count, + const uint64_t k, + const uint64_t query_count_bs, + float* y_sqr, + float* distances, + Vector& labels); + private: Allocator* const allocator_{nullptr}; diff --git a/src/quantization/product_quantization/product_quantizer_test.cpp b/src/quantization/product_quantization/product_quantizer_test.cpp index a05fb7d89..be5416c8b 100644 --- a/src/quantization/product_quantization/product_quantizer_test.cpp +++ b/src/quantization/product_quantization/product_quantizer_test.cpp @@ -105,9 +105,7 @@ TEST_CASE("ProductQuantizer Serialize and Deserialize", "[ut][ProductQuantizer]" float error = 8.0F / 255.0F; int64_t pq_dim; for (auto dim : dims) { - if (dim % 4 == 0) { - pq_dim = dim / 4; - } else if (dim % 2 == 0) { + if (dim % 2 == 0) { pq_dim = dim / 2; } else { pq_dim = dim; diff --git a/src/quantization/quantizer_test.h b/src/quantization/quantizer_test.h index 240226203..8b676a3a6 100644 --- a/src/quantization/quantizer_test.h +++ b/src/quantization/quantizer_test.h @@ -233,7 +233,7 @@ TestComputer(Quantizer& quant, bool retrain = true, float unbounded_numeric_error_rate = 1.0f, float unbounded_related_error_rate = 1.0f) { - auto query_count = 100; + auto query_count = 10; bool need_normalize = true; if constexpr (metric == vsag::MetricType::METRIC_TYPE_COSINE) { need_normalize = false; From 75ebbaf265e1faeec72f79950420f9c14c797f34 Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Tue, 20 May 2025 11:20:40 +0800 Subject: [PATCH 03/42] fix double assignment for max_degree in HNSW Merge (#724) Signed-off-by: jinjiabao.jjb Signed-off-by: suguan.dx --- src/index/hnsw.cpp | 2 +- tests/test_hnsw_new.cpp | 23 +++++++++++------------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 4ce8a0744..a75d23471 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -1205,7 +1205,7 @@ HNSW::merge(const std::vector& merge_units) { { SlowTaskTimer t1("odescent build"); auto odescent_param = std::make_shared(); - odescent_param->max_degree = static_cast(2 * graph_param_ptr->max_degree_); + odescent_param->max_degree = static_cast(graph_param_ptr->max_degree_); ODescent graph(odescent_param, flatten_interface, index_common_param_.allocator_.get(), diff --git a/tests/test_hnsw_new.cpp b/tests/test_hnsw_new.cpp index 874b4c25f..f071cacad 100644 --- a/tests/test_hnsw_new.cpp +++ b/tests/test_hnsw_new.cpp @@ -62,8 +62,8 @@ HNSWTestIndex::GenerateHNSWBuildParametersString(const std::string& metric_type, "metric_type": "{}", "dim": {}, "hnsw": {{ - "max_degree": 64, - "ef_construction": 500, + "max_degree": 16, + "ef_construction": 200, "use_static": {} }} }} @@ -88,8 +88,8 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "dtype": "float32", "metric_type": "l2", "hnsw": {{ - "max_degree": 64, - "ef_construction": 500 + "max_degree": 32, + "ef_construction": 200 }} }})"; REQUIRE_THROWS(TestFactory(name, param, false)); @@ -103,8 +103,8 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "metric_type": "{}", "dim": 23, "hnsw": {{ - "max_degree": 64, - "ef_construction": 500 + "max_degree": 32, + "ef_construction": 300 }} }})"; auto param = fmt::format(param_tmp, metric); @@ -119,8 +119,8 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "metric_type": "l2", "dim": 23, "hnsw": {{ - "max_degree": 64, - "ef_construction": 500 + "max_degree": 32, + "ef_construction": 300 }} }})"; auto param = fmt::format(param_tmp, datatype); @@ -152,7 +152,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "metric_type": "l2", "dim": 35, "hnsw": {{ - "ef_construction": 500 + "ef_construction": 300 }} }})", R"({{ @@ -160,7 +160,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "metric_type": "l2", "dim": 35, "hnsw": {{ - "max_degree": 64, + "max_degree": 32, }} }})"); REQUIRE_THROWS(TestFactory(name, param, false)); @@ -176,7 +176,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "dim": 35, "hnsw": {{ "max_degree": {}, - "ef_construction": 500 + "ef_construction": 300 }} }})"; auto param = fmt::format(param_temp, max_degree); @@ -239,7 +239,6 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, vsag::Options::Instance().set_block_size_limit(size); auto dims_ = fixtures::get_common_used_dims(20); for (auto& dim : dims_) { - std::cout << dim << std::endl; auto param = GenerateHNSWBuildParametersString(metric_type, dim); auto index = TestFactory(name, param, true); auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); From bb27dee7cdc0d81d1ada7cf5880a0dd9a162a264 Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Tue, 20 May 2025 11:50:59 +0800 Subject: [PATCH 04/42] resolve not properly handling invalid IDs in CalDistanceById (#720) Signed-off-by: jinjiabao.jjb Signed-off-by: suguan.dx --- src/algorithm/hgraph.cpp | 6 +++++- tests/test_index.cpp | 13 +++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 56f4b3f37..10a126ad7 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -736,18 +736,22 @@ HGraph::CalDistanceById(const float* query, const int64_t* ids, int64_t count) c result->Distances(distances); auto computer = flat->FactoryComputer(query); Vector inner_ids(count, 0, allocator_); + Vector invalid_id_loc(allocator_); { std::shared_lock lock(this->label_lookup_mutex_); for (int64_t i = 0; i < count; ++i) { auto iter = this->label_table_->label_remap_.find(ids[i]); if (iter == this->label_table_->label_remap_.end()) { logger::debug(fmt::format("failed to find id: {}", ids[i])); - distances[i] = -1; + invalid_id_loc.push_back(i); continue; } inner_ids[i] = iter->second; } flat->Query(distances, computer, inner_ids.data(), count); + for (unsigned int i : invalid_id_loc) { + distances[i] = -1; + } } return result; } diff --git a/tests/test_index.cpp b/tests/test_index.cpp index a1c637d93..c35dca826 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -723,6 +723,19 @@ TestIndex::TestBatchCalcDistanceById(const IndexPtr& index, result.value()->GetDistances()[j]) < error); } } + SECTION("test non-existing id") { + int64_t test_num = 10; + std::vector no_exist_ids(test_num); + for (int i = 0; i < test_num; ++i) { + no_exist_ids[i] = -i - 1; + } + auto result = + index->CalDistanceById(queries->GetFloat32Vectors(), no_exist_ids.data(), test_num); + for (int i = 0; i < test_num; ++i) { + fixtures::dist_t dist = result.value()->GetDistances()[i]; + REQUIRE(dist == -1); + } + } } void From 04ef4abb0b722d180b0fd67e64d0b2c81d193fba Mon Sep 17 00:00:00 2001 From: LHT129 Date: Tue, 20 May 2025 17:18:58 +0800 Subject: [PATCH 05/42] remove mtime change on ci test (#736) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- .circleci/fresh_ci_cache.commit | 2 +- .github/workflows/asan_build_and_test.yml | 8 -------- .github/workflows/coverage.yml | 4 ---- .github/workflows/tsan_build_and_test.yml | 4 ---- 4 files changed, 1 insertion(+), 17 deletions(-) diff --git a/.circleci/fresh_ci_cache.commit b/.circleci/fresh_ci_cache.commit index 5c8c4dacd..b56443f02 100644 --- a/.circleci/fresh_ci_cache.commit +++ b/.circleci/fresh_ci_cache.commit @@ -1 +1 @@ -cce585e1f18168e5a98e5f3f4d08785cd4e120b2 +c13456cac7cd8c69a554d7f425bf658e3d0ea338 diff --git a/.github/workflows/asan_build_and_test.yml b/.github/workflows/asan_build_and_test.yml index d50644f76..217b5ad41 100644 --- a/.github/workflows/asan_build_and_test.yml +++ b/.github/workflows/asan_build_and_test.yml @@ -23,10 +23,6 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: '0' - - name: Change Time - run: | - git config --global --add safe.directory /__w/vsag/vsag - sh -x ./scripts/change_mtime.sh - name: Load Cache uses: actions/cache@v4 with: @@ -67,10 +63,6 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: '0' - - name: Change Time - run: | - git config --global --add safe.directory /__w/vsag/vsag - sh -x ./scripts/change_mtime.sh - name: Prepare Env run: sudo bash ./scripts/deps/install_deps_ubuntu.sh - name: Load Cache diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index e686bfaf7..bded3db04 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -19,10 +19,6 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: '0' - - name: Change Time - run: | - git config --global --add safe.directory /__w/vsag/vsag - sh -x ./scripts/change_mtime.sh - name: Install lcov run: | apt update diff --git a/.github/workflows/tsan_build_and_test.yml b/.github/workflows/tsan_build_and_test.yml index 8ea51d1a6..fe605cbf5 100644 --- a/.github/workflows/tsan_build_and_test.yml +++ b/.github/workflows/tsan_build_and_test.yml @@ -19,10 +19,6 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: '0' - - name: Change Time - run: | - git config --global --add safe.directory /__w/vsag/vsag - sh -x ./scripts/change_mtime.sh - name: Load Cache uses: actions/cache@v4 with: From a87e05c43b7ecd33bd8edaf523f211d4541fb2a8 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Wed, 21 May 2025 07:37:31 +0800 Subject: [PATCH 06/42] add simd operator to cal residual (#733) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- src/simd/avx.cpp | 21 +++++++++++++++++++++ src/simd/avx2.cpp | 21 +++++++++++++++++++++ src/simd/avx512.cpp | 21 +++++++++++++++++++++ src/simd/fp32_simd.cpp | 23 +++++++++++++++++++++++ src/simd/fp32_simd.h | 12 ++++++++++++ src/simd/fp32_simd_test.cpp | 35 +++++++++++++++++++++++++++++++++++ src/simd/generic.cpp | 7 +++++++ src/simd/sse.cpp | 21 +++++++++++++++++++++ 8 files changed, 161 insertions(+) diff --git a/src/simd/avx.cpp b/src/simd/avx.cpp index e9e73671e..e1b34e988 100644 --- a/src/simd/avx.cpp +++ b/src/simd/avx.cpp @@ -271,6 +271,27 @@ FP32ComputeL2SqrBatch4(const float* query, #endif } +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX) + if (dim < 8) { + return sse::FP32Sub(x, y, z, dim); + } + int i = 0; + for (; i + 7 < dim; i += 8) { + __m256 a = _mm256_loadu_ps(x + i); + __m256 b = _mm256_loadu_ps(y + i); + __m256 c = _mm256_sub_ps(a, b); + _mm256_storeu_ps(z + i, c); + } + if (i < dim) { + sse::FP32Sub(x + i, y + i, z + i, dim - i); + } +#else + sse::FP32Sub(x, y, z, dim); +#endif +} + #if defined(ENABLE_AVX) __inline __m256i __attribute__((__always_inline__)) load_8_short(const uint16_t* data) { return _mm256_set_epi16(data[7], diff --git a/src/simd/avx2.cpp b/src/simd/avx2.cpp index e703b921d..d9d23f434 100644 --- a/src/simd/avx2.cpp +++ b/src/simd/avx2.cpp @@ -265,6 +265,27 @@ FP32ComputeL2SqrBatch4(const float* query, #endif } +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX2) + if (dim < 8) { + return sse::FP32Sub(x, y, z, dim); + } + int i = 0; + for (; i + 7 < dim; i += 8) { + __m256 a = _mm256_loadu_ps(x + i); + __m256 b = _mm256_loadu_ps(y + i); + __m256 c = _mm256_sub_ps(a, b); + _mm256_storeu_ps(z + i, c); + } + if (i < dim) { + sse::FP32Sub(x + i, y + i, z + i, dim - i); + } +#else + return sse::FP32Sub(x, y, z, dim); +#endif +} + #if defined(ENABLE_AVX2) __inline __m256i __attribute__((__always_inline__)) load_8_short(const uint16_t* data) { __m128i bf16 = _mm_loadu_si128(reinterpret_cast(data)); diff --git a/src/simd/avx512.cpp b/src/simd/avx512.cpp index 25d8f7f70..eb3d9c172 100644 --- a/src/simd/avx512.cpp +++ b/src/simd/avx512.cpp @@ -308,6 +308,27 @@ FP32ComputeL2SqrBatch4(const float* query, #endif } +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX512) + if (dim < 16) { + return avx2::FP32Sub(x, y, z, dim); + } + uint64_t i = 0; + for (; i + 15 < dim; i += 16) { + __m512 x_vec = _mm512_loadu_ps(x + i); + __m512 y_vec = _mm512_loadu_ps(y + i); + __m512 diff_vec = _mm512_sub_ps(x_vec, y_vec); + _mm512_storeu_ps(z + i, diff_vec); + } + if (dim > i) { + avx2::FP32Sub(x + i, y + i, z + i, dim - i); + } +#else + return avx2::FP32Sub(x, y, z, dim); +#endif +} + #if defined(ENABLE_AVX512) __inline __m512i __attribute__((__always_inline__)) load_16_short(const uint16_t* data) { __m256i bf16 = _mm256_loadu_si256(reinterpret_cast(data)); diff --git a/src/simd/fp32_simd.cpp b/src/simd/fp32_simd.cpp index 22fcc586d..10f2119a7 100644 --- a/src/simd/fp32_simd.cpp +++ b/src/simd/fp32_simd.cpp @@ -110,4 +110,27 @@ GetFP32ComputeL2SqrBatch4() { return generic::FP32ComputeL2SqrBatch4; } FP32ComputeBatch4Type FP32ComputeL2SqrBatch4 = GetFP32ComputeL2SqrBatch4(); + +static FP32SubType +GetFP32Sub() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::FP32Sub; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::FP32Sub; +#endif + } else if (SimdStatus::SupportAVX()) { +#if defined(ENABLE_AVX) + return avx::FP32Sub; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::FP32Sub; +#endif + } + return generic::FP32Sub; +} +FP32SubType FP32Sub = GetFP32Sub(); } // namespace vsag diff --git a/src/simd/fp32_simd.h b/src/simd/fp32_simd.h index 675acdf6b..836dd1457 100644 --- a/src/simd/fp32_simd.h +++ b/src/simd/fp32_simd.h @@ -47,6 +47,8 @@ FP32ComputeL2SqrBatch4(const float* query, float& result2, float& result3, float& result4); +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim); } // namespace generic namespace sse { @@ -76,6 +78,8 @@ FP32ComputeL2SqrBatch4(const float* query, float& result2, float& result3, float& result4); +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim); } // namespace sse namespace avx { @@ -105,6 +109,8 @@ FP32ComputeL2SqrBatch4(const float* query, float& result2, float& result3, float& result4); +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim); } // namespace avx namespace avx2 { @@ -134,6 +140,8 @@ FP32ComputeL2SqrBatch4(const float* query, float& result2, float& result3, float& result4); +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim); } // namespace avx2 namespace avx512 { @@ -163,6 +171,8 @@ FP32ComputeL2SqrBatch4(const float* query, float& result2, float& result3, float& result4); +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim); } // namespace avx512 using FP32ComputeType = float (*)(const float* query, const float* codes, uint64_t dim); @@ -182,4 +192,6 @@ using FP32ComputeBatch4Type = void (*)(const float* query, extern FP32ComputeBatch4Type FP32ComputeIPBatch4; extern FP32ComputeBatch4Type FP32ComputeL2SqrBatch4; +using FP32SubType = void (*)(const float* x, const float* y, float* z, uint64_t dim); +extern FP32SubType FP32Sub; } // namespace vsag diff --git a/src/simd/fp32_simd_test.cpp b/src/simd/fp32_simd_test.cpp index b14a241af..9e1de729e 100644 --- a/src/simd/fp32_simd_test.cpp +++ b/src/simd/fp32_simd_test.cpp @@ -45,6 +45,40 @@ using namespace vsag; } \ }; +#define TEST_FP32_SUB_ACCURACY(Func) \ + { \ + std::vector gt(dim, 0.0F); \ + generic::Func(vec1.data() + i * dim, vec2.data() + i * dim, gt.data(), dim); \ + std::vector sse_gt(dim, 0.0F); \ + if (SimdStatus::SupportSSE()) { \ + sse::Func(vec1.data() + i * dim, vec2.data() + i * dim, sse_gt.data(), dim); \ + for (uint64_t j = 0; j < dim; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(sse_gt[j])); \ + } \ + } \ + std::vector avx_gt(dim, 0.0F); \ + if (SimdStatus::SupportAVX()) { \ + avx::Func(vec1.data() + i * dim, vec2.data() + i * dim, avx_gt.data(), dim); \ + for (uint64_t j = 0; j < dim; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(avx_gt[j])); \ + } \ + } \ + std::vector avx2_gt(dim, 0.0F); \ + if (SimdStatus::SupportAVX2()) { \ + avx2::Func(vec1.data() + i * dim, vec2.data() + i * dim, avx2_gt.data(), dim); \ + for (uint64_t j = 0; j < dim; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(avx2_gt[j])); \ + } \ + } \ + std::vector avx512_gt(dim, 0.0F); \ + if (SimdStatus::SupportAVX512()) { \ + avx512::Func(vec1.data() + i * dim, vec2.data() + i * dim, avx512_gt.data(), dim); \ + for (uint64_t j = 0; j < dim; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(avx512_gt[j])); \ + } \ + } \ + }; + #define TEST_FP32_COMPUTE_ACCURACY_BATCH4(Func, FuncBatch4) \ { \ std::vector gts(4); \ @@ -142,6 +176,7 @@ TEST_CASE("FP32 SIMD Compute", "[ut][simd]") { for (uint64_t i = 0; i < count; ++i) { TEST_FP32_COMPUTE_ACCURACY(FP32ComputeIP); TEST_FP32_COMPUTE_ACCURACY(FP32ComputeL2Sqr); + TEST_FP32_SUB_ACCURACY(FP32Sub); } for (uint64_t i = 0; i < count; i += 4) { TEST_FP32_COMPUTE_ACCURACY_BATCH4(FP32ComputeIP, FP32ComputeIPBatch4); diff --git a/src/simd/generic.cpp b/src/simd/generic.cpp index 6e7dce844..9e2cf4c7a 100644 --- a/src/simd/generic.cpp +++ b/src/simd/generic.cpp @@ -133,6 +133,13 @@ FP32ComputeL2SqrBatch4(const float* query, } } +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { + for (uint64_t i = 0; i < dim; ++i) { + z[i] = x[i] - y[i]; + } +} + union FP32Struct { uint32_t int_value; float float_value; diff --git a/src/simd/sse.cpp b/src/simd/sse.cpp index 031785996..869344e55 100644 --- a/src/simd/sse.cpp +++ b/src/simd/sse.cpp @@ -259,6 +259,27 @@ FP32ComputeL2SqrBatch4(const float* query, #endif } +void +FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_SSE) + if (dim < 4) { + return generic::FP32Sub(x, y, z, dim); + } + int64_t i = 0; + for (; i + 3 < dim; i += 4) { + __m128 a = _mm_loadu_ps(x + i); + __m128 b = _mm_loadu_ps(y + i); + __m128 c = _mm_sub_ps(a, b); + _mm_storeu_ps(z + i, c); + } + if (i < dim) { + generic::FP32Sub(x + i, y + i, z + i, dim - i); + } +#else + return generic::FP32Sub(x, y, z, dim); +#endif +} + float BF16ComputeIP(const uint8_t* query, const uint8_t* codes, uint64_t dim) { #if defined(ENABLE_SSE) From 779ac49c0bee0160c70f4cf6125075529af277d6 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Wed, 21 May 2025 07:37:58 +0800 Subject: [PATCH 07/42] add computable operators for bitset (#721) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- include/vsag/bitset.h | 27 ++++++++ src/bitset_impl.cpp | 27 ++++++++ src/bitset_impl.h | 11 +++- src/bitset_impl_test.cpp | 130 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 1 deletion(-) diff --git a/include/vsag/bitset.h b/include/vsag/bitset.h index b246b6fb2..6ac5a0496 100644 --- a/include/vsag/bitset.h +++ b/include/vsag/bitset.h @@ -86,6 +86,33 @@ class Bitset { virtual uint64_t Count() = 0; + /** + * @brief Performs a bitwise OR operation on the current bitset with another bitset. + * + * @param another The bitset to perform the OR operation with. + * @return void + */ + virtual void + Or(const Bitset& another) = 0; + + /** + * @brief Performs a bitwise AND operation on the current bitset with another bitset. + * + * @param another The bitset to perform the AND operation with. + * @return void + */ + virtual void + And(const Bitset& another) = 0; + + /** + * @brief Performs a bitwise XOR operation on the current bitset with another bitset. + * + * @param another The bitset to perform the XOR operation with. + * @return void + */ + virtual void + Xor(const Bitset& another) = 0; + public: /** * For debugging diff --git a/src/bitset_impl.cpp b/src/bitset_impl.cpp index d0330e17d..bbea2e38b 100644 --- a/src/bitset_impl.cpp +++ b/src/bitset_impl.cpp @@ -66,4 +66,31 @@ BitsetImpl::Dump() { return r_.toString(); } +void +BitsetImpl::Or(const Bitset& another) { + const auto* another_ptr = reinterpret_cast(&another); + std::lock(mutex_, another_ptr->mutex_); + std::lock_guard lock(mutex_, std::adopt_lock); + std::lock_guard lock_other(another_ptr->mutex_, std::adopt_lock); + r_ |= another_ptr->r_; +} + +void +BitsetImpl::And(const Bitset& another) { + const auto* another_ptr = reinterpret_cast(&another); + std::lock(mutex_, another_ptr->mutex_); + std::lock_guard lock(mutex_, std::adopt_lock); + std::lock_guard lock_other(another_ptr->mutex_, std::adopt_lock); + r_ &= another_ptr->r_; +} + +void +BitsetImpl::Xor(const Bitset& another) { + const auto* another_ptr = reinterpret_cast(&another); + std::lock(mutex_, another_ptr->mutex_); + std::lock_guard lock(mutex_, std::adopt_lock); + std::lock_guard lock_other(another_ptr->mutex_, std::adopt_lock); + r_ ^= another_ptr->r_; +} + } // namespace vsag diff --git a/src/bitset_impl.h b/src/bitset_impl.h index 3a922216b..8437e189b 100644 --- a/src/bitset_impl.h +++ b/src/bitset_impl.h @@ -48,8 +48,17 @@ class BitsetImpl : public Bitset { std::string Dump() override; + void + Or(const Bitset& another) override; + + void + And(const Bitset& another) override; + + void + Xor(const Bitset& another) override; + private: - std::mutex mutex_; + mutable std::mutex mutex_; roaring::Roaring r_; }; diff --git a/src/bitset_impl_test.cpp b/src/bitset_impl_test.cpp index 0776dfb86..7ab1b3e01 100644 --- a/src/bitset_impl_test.cpp +++ b/src/bitset_impl_test.cpp @@ -49,6 +49,136 @@ TEST_CASE("BitsetImpl Test", "[ut][bitset]") { REQUIRE(dumped == "{100}"); } +TEST_CASE("BitsetImpl Or Test", "[ut][bitset]") { + SECTION("both empty") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset1.Or(bitset2); + REQUIRE(bitset1.Count() == 0); + REQUIRE(bitset1.Dump() == "{}"); + } + + SECTION("empty and non-empty") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset2.Set(100, true); + bitset1.Or(bitset2); + REQUIRE(bitset1.Test(100)); + REQUIRE(bitset1.Count() == 1); + REQUIRE(bitset1.Dump() == "{100}"); + } + + SECTION("disjoint sets") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset1.Set(100, true); + bitset2.Set(200, true); + bitset1.Or(bitset2); + REQUIRE(bitset1.Test(100)); + REQUIRE(bitset1.Test(200)); + REQUIRE(bitset1.Count() == 2); + REQUIRE(bitset1.Dump() == "{100,200}"); + } + + SECTION("overlapping sets") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset1.Set(100, true); + bitset2.Set(100, true); + bitset1.Or(bitset2); + REQUIRE(bitset1.Count() == 1); + REQUIRE(bitset1.Dump() == "{100}"); + } +} + +TEST_CASE("BitsetImpl And Test", "[ut][bitset]") { + SECTION("both empty") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset1.And(bitset2); + REQUIRE(bitset1.Count() == 0); + REQUIRE(bitset1.Dump() == "{}"); + } + + SECTION("empty and non-empty") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset2.Set(100, true); + bitset1.And(bitset2); + REQUIRE(bitset1.Count() == 0); + } + + SECTION("common elements") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset1.Set(100, true); + bitset1.Set(200, true); + bitset2.Set(200, true); + bitset2.Set(300, true); + bitset1.And(bitset2); + REQUIRE(bitset1.Count() == 1); + REQUIRE(bitset1.Test(200)); + REQUIRE_FALSE(bitset1.Test(100)); + REQUIRE_FALSE(bitset1.Test(300)); + REQUIRE(bitset1.Dump() == "{200}"); + } + + SECTION("no common elements") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset1.Set(100, true); + bitset2.Set(200, true); + bitset1.And(bitset2); + REQUIRE(bitset1.Count() == 0); + REQUIRE(bitset1.Dump() == "{}"); + } +} + +TEST_CASE("BitsetImpl Xor Test", "[ut][bitset]") { + SECTION("both empty") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset1.Xor(bitset2); + REQUIRE(bitset1.Count() == 0); + REQUIRE(bitset1.Dump() == "{}"); + } + + SECTION("empty and non-empty") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset2.Set(100, true); + bitset1.Xor(bitset2); + REQUIRE(bitset1.Test(100)); + REQUIRE(bitset1.Count() == 1); + REQUIRE(bitset1.Dump() == "{100}"); + } + + SECTION("partial overlap") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset1.Set(100, true); + bitset1.Set(200, true); + bitset2.Set(200, true); + bitset2.Set(300, true); + bitset1.Xor(bitset2); + REQUIRE(bitset1.Count() == 2); + REQUIRE(bitset1.Test(100)); + REQUIRE_FALSE(bitset1.Test(200)); + REQUIRE(bitset1.Test(300)); + REQUIRE(bitset1.Dump() == "{100,300}"); + } + + SECTION("identical sets") { + vsag::BitsetImpl bitset1; + vsag::BitsetImpl bitset2; + bitset1.Set(100, true); + bitset2.Set(100, true); + bitset1.Xor(bitset2); + REQUIRE(bitset1.Count() == 0); + REQUIRE(bitset1.Dump() == "{}"); + } +} + TEST_CASE("Roaring Bitmap Test", "[ut][bitset]") { Roaring r1; for (uint32_t i = 100; i < 1000; i++) { From 05ce90d77fa2478c8993a328c33d2f836588efa5 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Wed, 21 May 2025 07:38:31 +0800 Subject: [PATCH 08/42] add new interface for attr search (#729) - add request search object - add attribute for datasets - add new interface on index search Signed-off-by: LHT129 Signed-off-by: suguan.dx --- include/vsag/attribute.h | 53 +++++++++++++++++++++++++++++++++++ include/vsag/constants.h | 1 + include/vsag/dataset.h | 19 +++++++++++++ include/vsag/index.h | 14 +++++++++ include/vsag/search_request.h | 38 +++++++++++++++++++++++++ src/constants.cpp | 1 + src/dataset_impl.h | 17 ++++++++++- 7 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 include/vsag/attribute.h create mode 100644 include/vsag/search_request.h diff --git a/include/vsag/attribute.h b/include/vsag/attribute.h new file mode 100644 index 000000000..45c05f4cb --- /dev/null +++ b/include/vsag/attribute.h @@ -0,0 +1,53 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace vsag { + +enum AttrValueType { + INT32 = 1, + UINT32 = 2, + INT64 = 3, + UINT64 = 4, + INT8 = 5, + UINT8 = 6, + STRING = 7, + FIXSIZE_STRING = 8, +}; + +class Attribute { +public: + std::string name_{}; + virtual AttrValueType + GetValueType() = 0; +}; + +template +class AttributeValue : public Attribute { +public: + std::vector value_{}; +}; + +struct AttributeSet { +public: + std::vector attrs_; +}; + +} // namespace vsag diff --git a/include/vsag/constants.h b/include/vsag/constants.h index 7c63af48e..10bf7f193 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -31,6 +31,7 @@ extern const char* const DISTS; extern const char* const FLOAT32_VECTORS; extern const char* const SPARSE_VECTORS; extern const char* const INT8_VECTORS; +extern const char* const ATTRIBUTE_SETS; extern const char* const DATASET_PATHS; extern const char* const EXTRA_INFOS; extern const char* const EXTRA_INFO_SIZE; diff --git a/include/vsag/dataset.h b/include/vsag/dataset.h index 6dc5b3e75..7d2d19c0c 100644 --- a/include/vsag/dataset.h +++ b/include/vsag/dataset.h @@ -19,8 +19,10 @@ #include #include #include +#include #include "allocator.h" +#include "attribute.h" #include "constants.h" namespace vsag { @@ -201,6 +203,23 @@ class Dataset : public std::enable_shared_from_this { virtual const SparseVector* GetSparseVectors() const = 0; + /** + * @brief Sets the attribute sets for the dataset. + * @param attr_sets Pointer to the attribute sets. + * + * @return DatasetPtr A shared pointer to the dataset with updated attribute sets. + */ + virtual DatasetPtr + AttributeSets(const AttributeSet* attr_sets) = 0; + + /** + * @brief Retrieves the attribute sets of the dataset. + * + * @return const AttributeSet* Pointer to the attribute sets. + */ + virtual const AttributeSet* + GetAttributeSets() const = 0; + /** * @brief Sets the paths array for the dataset. * diff --git a/include/vsag/index.h b/include/vsag/index.h index 7f41f5af6..35af316d9 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -32,6 +32,7 @@ #include "vsag/index_features.h" #include "vsag/iterator_context.h" #include "vsag/readerset.h" +#include "vsag/search_request.h" namespace vsag { @@ -182,6 +183,19 @@ class Index { throw std::runtime_error("Index doesn't support new filter"); } + /** + * @brief Performing search with request on index + * + * @param request @see SearchRequest + * @return result contains + * - num_elements: 1 + * - ids, distances: length is (num_elements * k) + */ + virtual tl::expected + SearchWithRequest(const SearchRequest& request) const { + throw std::runtime_error("Index doesn't support Search With Request"); + } + /** * @brief Performing single KNN search on index * diff --git a/include/vsag/search_request.h b/include/vsag/search_request.h new file mode 100644 index 000000000..50fe0badc --- /dev/null +++ b/include/vsag/search_request.h @@ -0,0 +1,38 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "filter.h" + +namespace vsag { +enum class SearchMode { + KNN_SEARCH = 1, + RANGE_SEARCH = 2, +}; + +class SearchRequest { +public: + SearchMode mode_{SearchMode::KNN_SEARCH}; + int64_t topk_; + + bool enable_attribute_filter_{false}; + std::string attribute_filter_str_; + FilterPtr filter_{nullptr}; +}; + +} // namespace vsag diff --git a/src/constants.cpp b/src/constants.cpp index 6b8a5d5fd..1ddf1d5b6 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -34,6 +34,7 @@ const char* const DISTS = "dists"; const char* const FLOAT32_VECTORS = "f32_vectors"; const char* const SPARSE_VECTORS = "sparse_vectors"; const char* const INT8_VECTORS = "i8_vectors"; +const char* const ATTRIBUTE_SETS = "attribute_sets"; const char* const DATASET_PATHS = "paths"; const char* const EXTRA_INFOS = "extra_infos"; const char* const EXTRA_INFO_SIZE = "extra_info_size"; diff --git a/src/dataset_impl.h b/src/dataset_impl.h index 39bbf5ba1..ac3898030 100644 --- a/src/dataset_impl.h +++ b/src/dataset_impl.h @@ -35,7 +35,8 @@ class DatasetImpl : public Dataset { const int8_t*, const int64_t*, const std::string*, - const SparseVector*>; + const SparseVector*, + const AttributeSet*>; public: DatasetImpl() = default; @@ -203,6 +204,20 @@ class DatasetImpl : public Dataset { return nullptr; } + DatasetPtr + AttributeSets(const AttributeSet* attr_sets) override { + this->data_[ATTRIBUTE_SETS] = attr_sets; + return shared_from_this(); + } + + const AttributeSet* + GetAttributeSets() const override { + if (auto iter = this->data_.find(ATTRIBUTE_SETS); iter != this->data_.end()) { + return std::get(iter->second); + } + return nullptr; + } + DatasetPtr Paths(const std::string* paths) override { this->data_[DATASET_PATHS] = paths; From 7fee097844364b3139dfa7c5372d549deab17c73 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Wed, 21 May 2025 11:16:07 +0800 Subject: [PATCH 09/42] add fast index create for inner index (#737) - only support hgraph, bruteforce Signed-off-by: LHT129 Signed-off-by: suguan.dx --- src/algorithm/brute_force.cpp | 16 ++++- src/algorithm/inner_index_interface.cpp | 54 +++++++++++++++ src/algorithm/inner_index_interface.h | 7 ++ src/algorithm/inner_index_interface_test.cpp | 69 ++++++++++++++++++++ src/utils/util_functions.cpp | 11 ++++ src/utils/util_functions.h | 3 + 6 files changed, 158 insertions(+), 2 deletions(-) create mode 100644 src/algorithm/inner_index_interface_test.cpp diff --git a/src/algorithm/brute_force.cpp b/src/algorithm/brute_force.cpp index 9d9cfd8e3..a5f2c25f8 100644 --- a/src/algorithm/brute_force.cpp +++ b/src/algorithm/brute_force.cpp @@ -266,8 +266,20 @@ ParamPtr BruteForce::CheckAndMappingExternalParam(const JsonType& external_param, const IndexCommonParam& common_param) { const std::unordered_map> external_mapping = { - {BRUTE_FORCE_QUANTIZATION_TYPE, {QUANTIZATION_PARAMS_KEY, QUANTIZATION_TYPE_KEY}}, - {BRUTE_FORCE_IO_TYPE, {IO_PARAMS_KEY, IO_TYPE_KEY}}}; + { + BRUTE_FORCE_QUANTIZATION_TYPE, + { + QUANTIZATION_PARAMS_KEY, + QUANTIZATION_TYPE_KEY, + }, + }, + { + BRUTE_FORCE_IO_TYPE, + { + IO_PARAMS_KEY, + IO_TYPE_KEY, + }, + }}; if (common_param.data_type_ == DataTypes::DATA_TYPE_INT8) { throw VsagException(ErrorType::INVALID_ARGUMENT, diff --git a/src/algorithm/inner_index_interface.cpp b/src/algorithm/inner_index_interface.cpp index 367283075..3fb184e61 100644 --- a/src/algorithm/inner_index_interface.cpp +++ b/src/algorithm/inner_index_interface.cpp @@ -16,8 +16,11 @@ #include "inner_index_interface.h" #include "base_filter_functor.h" +#include "brute_force.h" #include "empty_index_binary_set.h" +#include "hgraph.h" #include "utils/slow_task_timer.h" +#include "utils/util_functions.h" namespace vsag { @@ -216,4 +219,55 @@ InnerIndexInterface::Clone(const IndexCommonParam& param) { return index; } +InnerIndexPtr +InnerIndexInterface::FastCreateIndex(const std::string& index_fast_str, + const IndexCommonParam& common_param) { + auto strs = split_string(index_fast_str, fast_string_delimiter); + if (strs.size() < 2) { + throw VsagException(ErrorType::INVALID_ARGUMENT, "fast str is too short"); + } + if (strs[0] == INDEX_TYPE_HGRAPH) { + if (strs.size() < 3) { + throw VsagException(ErrorType::INVALID_ARGUMENT, "fast str(hgraph) is too short"); + } + constexpr const char* build_string_temp = R"( + {{ + "max_degree": {}, + "base_quantization_type": "{}", + "use_reorder": {}, + "precise_quantization_type": "{}" + }} + )"; + auto max_degree = std::stoi(strs[1]); + auto base_quantization_type = strs[2]; + bool use_reorder = false; + std::string precise_quantization_type = "fp32"; + if (strs.size() == 4) { + use_reorder = true; + precise_quantization_type = strs[3]; + } + JsonType json = JsonType::parse(fmt::format(build_string_temp, + max_degree, + base_quantization_type, + use_reorder, + precise_quantization_type)); + auto param_ptr = HGraph::CheckAndMappingExternalParam(json, common_param); + return std::make_shared(param_ptr, common_param); + } + if (strs[0] == INDEX_BRUTE_FORCE) { + constexpr const char* build_string_temp = R"( + {{ + "quantization_type": "{}" + }} + )"; + JsonType json = JsonType::parse(fmt::format(build_string_temp, strs[1])); + auto param_ptr = BruteForce::CheckAndMappingExternalParam(json, common_param); + return std::make_shared(param_ptr, common_param); + } + throw VsagException(ErrorType::INVALID_ARGUMENT, + fmt::format("not support fast string create type: {}," + " only support bruteforce and hgraph", + strs[0])); +} + } // namespace vsag diff --git a/src/algorithm/inner_index_interface.h b/src/algorithm/inner_index_interface.h index 349197fba..bcdef1a67 100644 --- a/src/algorithm/inner_index_interface.h +++ b/src/algorithm/inner_index_interface.h @@ -35,10 +35,17 @@ using InnerIndexPtr = std::shared_ptr; class InnerIndexInterface { public: + InnerIndexInterface() = default; + explicit InnerIndexInterface(ParamPtr index_param, const IndexCommonParam& common_param); virtual ~InnerIndexInterface() = default; + constexpr static char fast_string_delimiter = '|'; + + static InnerIndexPtr + FastCreateIndex(const std::string& index_fast_str, const IndexCommonParam& common_param); + [[nodiscard]] virtual std::string GetName() const = 0; diff --git a/src/algorithm/inner_index_interface_test.cpp b/src/algorithm/inner_index_interface_test.cpp new file mode 100644 index 000000000..e5fe6bfef --- /dev/null +++ b/src/algorithm/inner_index_interface_test.cpp @@ -0,0 +1,69 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "inner_index_interface.h" + +#include +#include + +#include "brute_force.h" +#include "hgraph.h" +#include "safe_allocator.h" + +using namespace vsag; + +TEST_CASE("Fast Create Index", "[ut][InnerIndexInterface]") { + IndexCommonParam common_param; + common_param.dim_ = 128; + common_param.thread_pool_ = SafeThreadPool::FactoryDefaultThreadPool(); + common_param.allocator_ = SafeAllocator::FactoryDefaultAllocator(); + common_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + + SECTION("HGraph created with minimal parameters") { + std::string index_fast_str = "hgraph|100|fp16"; + auto index = InnerIndexInterface::FastCreateIndex(index_fast_str, common_param); + REQUIRE(index != nullptr); + REQUIRE(dynamic_cast(index.get()) != nullptr); + } + + SECTION("HGraph created with optional parameters") { + std::string index_fast_str = "hgraph|100|sq8|fp32"; + auto index = InnerIndexInterface::FastCreateIndex(index_fast_str, common_param); + REQUIRE(index != nullptr); + REQUIRE(dynamic_cast(index.get()) != nullptr); + } + + SECTION("BruteForce created") { + std::string index_fast_str = "brute_force|fp32"; + auto index = InnerIndexInterface::FastCreateIndex(index_fast_str, common_param); + REQUIRE(index != nullptr); + REQUIRE(dynamic_cast(index.get()) != nullptr); + } + + SECTION("Unsupported index type returns null") { + std::string index_fast_str = "UNKNOWN|other"; + REQUIRE_THROWS(InnerIndexInterface::FastCreateIndex(index_fast_str, common_param)); + } + + SECTION("Invalid parameter count for HGraph (too few)") { + std::string index_fast_str = "hgraph|100"; + REQUIRE_THROWS(InnerIndexInterface::FastCreateIndex(index_fast_str, common_param)); + } + + SECTION("Invalid parameter count for BruteForce (too few)") { + std::string index_fast_str = "bruteforce"; + REQUIRE_THROWS(InnerIndexInterface::FastCreateIndex(index_fast_str, common_param)); + } +} diff --git a/src/utils/util_functions.cpp b/src/utils/util_functions.cpp index e38ce2ebe..f3ffaa76e 100644 --- a/src/utils/util_functions.cpp +++ b/src/utils/util_functions.cpp @@ -136,4 +136,15 @@ check_equal_on_string_stream(std::stringstream& s1, std::stringstream& s2) { return true; } +std::vector +split_string(const std::string& str, const char delimiter) { + std::vector tokens; + std::stringstream ss(str); + std::string token; + while (std::getline(ss, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + } // namespace vsag diff --git a/src/utils/util_functions.h b/src/utils/util_functions.h index a7b0b1413..b2f8499f6 100644 --- a/src/utils/util_functions.h +++ b/src/utils/util_functions.h @@ -76,4 +76,7 @@ next_multiple_of_power_of_two(uint64_t x, uint64_t n); bool check_equal_on_string_stream(std::stringstream& s1, std::stringstream& s2); +std::vector +split_string(const std::string& str, const char delimiter); + } // namespace vsag From 93b7ff8f339e0116ae7fd5f439d38af0015e6aad Mon Sep 17 00:00:00 2001 From: LHT129 Date: Wed, 21 May 2025 14:15:09 +0800 Subject: [PATCH 10/42] add option for hgraph to build by base quantization (#735) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- include/vsag/constants.h | 1 + src/algorithm/hgraph.cpp | 13 +++++++++++-- src/algorithm/hgraph.h | 1 + src/algorithm/hgraph_parameter.cpp | 4 ++++ src/algorithm/hgraph_parameter.h | 1 + src/constants.cpp | 1 + src/inner_string_params.h | 2 ++ 7 files changed, 21 insertions(+), 2 deletions(-) diff --git a/include/vsag/constants.h b/include/vsag/constants.h index 10bf7f193..c4ddf622c 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -133,6 +133,7 @@ extern const char* const RABITQ_BITS_PER_DIM_QUERY; extern const char* const HGRAPH_USE_REORDER; extern const char* const HGRAPH_USE_ELP_OPTIMIZER; extern const char* const HGRAPH_IGNORE_REORDER; +extern const char* const HGRAPH_BUILD_BY_BASE_QUANTIZATION; extern const char* const HGRAPH_BASE_QUANTIZATION_TYPE; extern const char* const HGRAPH_GRAPH_MAX_DEGREE; extern const char* const HGRAPH_BUILD_EF_CONSTRUCTION; diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 10a126ad7..cf2f56510 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -38,6 +38,7 @@ HGraph::HGraph(const HGraphParameterPtr& hgraph_param, const vsag::IndexCommonPa use_reorder_(hgraph_param->use_reorder), use_elp_optimizer_(hgraph_param->use_elp_optimizer), ignore_reorder_(hgraph_param->ignore_reorder), + build_by_base_(hgraph_param->build_by_base), ef_construct_(hgraph_param->ef_construction), build_thread_count_(hgraph_param->build_thread_count), odescent_param_(hgraph_param->odescent_param), @@ -157,7 +158,8 @@ HGraph::build_by_odescent(const DatasetPtr& data) { } } this->resize(total_count_); - auto build_data = use_reorder_ ? this->high_precise_codes_ : this->basic_flatten_codes_; + auto build_data = (use_reorder_ and not build_by_base_) ? this->high_precise_codes_ + : this->basic_flatten_codes_; { odescent_param_->max_degree = bottom_graph_->MaximumDegree(); ODescent odescent_builder(odescent_param_, build_data, allocator_, this->build_pool_.get()); @@ -818,7 +820,7 @@ HGraph::graph_add_one(const void* data, int level, InnerIdType inner_id) { LockGuard cur_lock(neighbors_mutex_, inner_id); auto flatten_codes = basic_flatten_codes_; - if (use_reorder_) { + if (use_reorder_ and not build_by_base_) { flatten_codes = high_precise_codes_; } for (auto j = this->route_graphs_.size() - 1; j > level; --j) { @@ -1000,6 +1002,7 @@ static const std::string HGRAPH_PARAMS_TEMPLATE = "{HGRAPH_USE_REORDER_KEY}": false, "{HGRAPH_USE_ENV_OPTIMIZER}": false, "{HGRAPH_IGNORE_REORDER_KEY}": false, + "{HGRAPH_BUILD_BY_BASE_QUANTIZATION_KEY}": false, "{HGRAPH_GRAPH_KEY}": { "{IO_PARAMS_KEY}": { "{IO_TYPE_KEY}": "{IO_TYPE_VALUE_BLOCK_MEMORY_IO}", @@ -1077,6 +1080,12 @@ HGraph::CheckAndMappingExternalParam(const JsonType& external_param, HGRAPH_IGNORE_REORDER_KEY, }, }, + { + HGRAPH_BUILD_BY_BASE_QUANTIZATION, + { + HGRAPH_BUILD_BY_BASE_QUANTIZATION_KEY, + }, + }, { HGRAPH_BASE_QUANTIZATION_TYPE, { diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index e72c5719e..9574ababb 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -222,6 +222,7 @@ class HGraph : public InnerIndexInterface { mutable bool use_reorder_{false}; bool use_elp_optimizer_{false}; bool ignore_reorder_{false}; + bool build_by_base_{false}; BasicSearcherPtr searcher_; diff --git a/src/algorithm/hgraph_parameter.cpp b/src/algorithm/hgraph_parameter.cpp index 7f08b6b01..8cf3fc1b5 100644 --- a/src/algorithm/hgraph_parameter.cpp +++ b/src/algorithm/hgraph_parameter.cpp @@ -45,6 +45,10 @@ HGraphParameter::FromJson(const JsonType& json) { this->ignore_reorder = json[HGRAPH_IGNORE_REORDER_KEY]; } + if (json.contains(HGRAPH_BUILD_BY_BASE_QUANTIZATION_KEY)) { + this->build_by_base = json[HGRAPH_BUILD_BY_BASE_QUANTIZATION_KEY]; + } + CHECK_ARGUMENT(json.contains(HGRAPH_BASE_CODES_KEY), fmt::format("hgraph parameters must contains {}", HGRAPH_BASE_CODES_KEY)); const auto& base_codes_json = json[HGRAPH_BASE_CODES_KEY]; diff --git a/src/algorithm/hgraph_parameter.h b/src/algorithm/hgraph_parameter.h index 9c177b4af..c2b21b256 100644 --- a/src/algorithm/hgraph_parameter.h +++ b/src/algorithm/hgraph_parameter.h @@ -49,6 +49,7 @@ class HGraphParameter : public Parameter { bool use_reorder{false}; bool use_elp_optimizer{false}; bool ignore_reorder{false}; + bool build_by_base{false}; uint64_t ef_construction{400}; uint64_t build_thread_count{100}; diff --git a/src/constants.cpp b/src/constants.cpp index 1ddf1d5b6..c35c055b2 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -136,6 +136,7 @@ const char* const RABITQ_BITS_PER_DIM_QUERY = "rabitq_bits_per_dim_query"; const char* const HGRAPH_USE_REORDER = HGRAPH_USE_REORDER_KEY; const char* const HGRAPH_USE_ELP_OPTIMIZER = HGRAPH_USE_ELP_OPTIMIZER_KEY; const char* const HGRAPH_IGNORE_REORDER = "ignore_reorder"; +const char* const HGRAPH_BUILD_BY_BASE_QUANTIZATION = "build_by_base"; const char* const HGRAPH_BASE_QUANTIZATION_TYPE = "base_quantization_type"; const char* const HGRAPH_GRAPH_MAX_DEGREE = "max_degree"; const char* const HGRAPH_BUILD_EF_CONSTRUCTION = "ef_construction"; diff --git a/src/inner_string_params.h b/src/inner_string_params.h index ada665634..9d747b066 100644 --- a/src/inner_string_params.h +++ b/src/inner_string_params.h @@ -29,6 +29,7 @@ const char* const INDEX_TYPE_IVF = "ivf"; const char* const HGRAPH_USE_REORDER_KEY = "use_reorder"; const char* const HGRAPH_USE_ELP_OPTIMIZER_KEY = "use_elp_optimizer"; const char* const HGRAPH_IGNORE_REORDER_KEY = "ignore_reorder"; +const char* const HGRAPH_BUILD_BY_BASE_QUANTIZATION_KEY = "build_by_base"; const char* const HGRAPH_GRAPH_KEY = "graph"; const char* const HGRAPH_BASE_CODES_KEY = "base_codes"; const char* const HGRAPH_PRECISE_CODES_KEY = "precise_codes"; @@ -101,6 +102,7 @@ const std::unordered_map DEFAULT_MAP = { {"HGRAPH_USE_REORDER_KEY", HGRAPH_USE_REORDER_KEY}, {"HGRAPH_USE_ELP_OPTIMIZER_KEY", HGRAPH_USE_ELP_OPTIMIZER_KEY}, {"HGRAPH_IGNORE_REORDER_KEY", HGRAPH_IGNORE_REORDER_KEY}, + {"HGRAPH_BUILD_BY_BASE_QUANTIZATION_KEY", HGRAPH_BUILD_BY_BASE_QUANTIZATION_KEY}, {"HGRAPH_GRAPH_KEY", HGRAPH_GRAPH_KEY}, {"HGRAPH_BASE_CODES_KEY", HGRAPH_BASE_CODES_KEY}, {"HGRAPH_PRECISE_CODES_KEY", HGRAPH_PRECISE_CODES_KEY}, From 58400cff039f47db1ebabed714bf9038e98dd3d9 Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Wed, 21 May 2025 16:20:07 +0800 Subject: [PATCH 11/42] resolving core dumps caused by incorrect resizing in sparse datacell (#719) Signed-off-by: jinjiabao.jjb Signed-off-by: suguan.dx --- src/data_cell/sparse_vector_datacell.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data_cell/sparse_vector_datacell.h b/src/data_cell/sparse_vector_datacell.h index 9cbdfb08d..ebc4a4cef 100644 --- a/src/data_cell/sparse_vector_datacell.h +++ b/src/data_cell/sparse_vector_datacell.h @@ -67,7 +67,7 @@ class SparseVectorDataCell : public FlattenInterface { if (new_capacity <= this->max_capacity_) { return; } - size_t io_size = (new_capacity - this->max_capacity_) * max_code_size_ + current_offset_; + size_t io_size = (new_capacity - total_count_) * max_code_size_ + current_offset_; this->max_capacity_ = new_capacity; uint8_t end_flag = 127; // the value is meaingless, only to occupy the position for io allocate From dc6dd79a9a9861477077ce02ac97f261321ce7e3 Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Wed, 21 May 2025 16:30:43 +0800 Subject: [PATCH 12/42] make the member access of Resource private (#728) Signed-off-by: jinjiabao.jjb Signed-off-by: suguan.dx --- include/vsag/resource.h | 2 +- src/index/index_common_param.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/vsag/resource.h b/include/vsag/resource.h index b66f04a5f..11b54f726 100644 --- a/include/vsag/resource.h +++ b/include/vsag/resource.h @@ -94,7 +94,7 @@ class Resource { return this->thread_pool; } -public: +private: ///< Shared pointer to the allocator associated with this resource. std::shared_ptr allocator; diff --git a/src/index/index_common_param.cpp b/src/index/index_common_param.cpp index d83007f23..9dcd65be6 100644 --- a/src/index/index_common_param.cpp +++ b/src/index/index_common_param.cpp @@ -87,7 +87,7 @@ IndexCommonParam IndexCommonParam::CheckAndCreate(JsonType& params, const std::shared_ptr& resource) { IndexCommonParam result; result.allocator_ = resource->GetAllocator(); - result.thread_pool_ = std::dynamic_pointer_cast(resource->thread_pool); + result.thread_pool_ = std::dynamic_pointer_cast(resource->GetThreadPool()); // Check and Fill DataType CHECK_ARGUMENT(params.contains(PARAMETER_DTYPE), From cbf999f944f462383217b0350ae4e602867266d9 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Thu, 22 May 2025 10:45:49 +0800 Subject: [PATCH 13/42] fix non-memory io's incorrect result on large dataset (#742) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- src/data_cell/flatten_datacell.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data_cell/flatten_datacell.h b/src/data_cell/flatten_datacell.h index f820df566..700201242 100644 --- a/src/data_cell/flatten_datacell.h +++ b/src/data_cell/flatten_datacell.h @@ -254,11 +254,11 @@ FlattenDataCell::query(float* result_dists, this->prefetch_depth_code_ * 64); } if (not this->io_->InMemory() and id_count > 1) { - ByteBuffer codes(id_count * this->code_size_, allocator_); + ByteBuffer codes(static_cast(id_count) * this->code_size_, allocator_); Vector sizes(id_count, this->code_size_, allocator_); Vector offsets(id_count, this->code_size_, allocator_); for (int64_t i = 0; i < id_count; ++i) { - offsets[i] = idx[i] * code_size_; + offsets[i] = static_cast(idx[i]) * this->code_size_; } this->io_->MultiRead(codes.data, sizes.data(), offsets.data(), id_count); computer->ScanBatchDists(id_count, codes.data, result_dists); From 59afb0343a7c711f860c9bee7c7fed313aee9517 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Fri, 23 May 2025 11:03:34 +0800 Subject: [PATCH 14/42] add merge support for ivfpqfs (#745) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- src/algorithm/ivf.cpp | 11 +++--- src/data_cell/bucket_datacell.h | 34 +++++++++++++++++++ src/data_cell/bucket_interface.h | 3 ++ .../pq_fastscan_quantizer.h | 33 ++++++++++++++++++ .../pq_fastscan_quantizer_test.cpp | 33 ++++++++++++++++++ src/quantization/quantizer.h | 3 ++ 6 files changed, 112 insertions(+), 5 deletions(-) diff --git a/src/algorithm/ivf.cpp b/src/algorithm/ivf.cpp index 05d14964b..d83baa2f6 100644 --- a/src/algorithm/ivf.cpp +++ b/src/algorithm/ivf.cpp @@ -170,8 +170,6 @@ IVF::InitFeatures() { if (this->bucket_->GetQuantizerName() == QUANTIZATION_TYPE_VALUE_PQFS) { this->index_feature_list_->SetFeature(IndexFeature::SUPPORT_ADD_AFTER_BUILD, false); - // TODO(LHT): merge on ivfpqfs - this->index_feature_list_->SetFeature(IndexFeature::SUPPORT_MERGE_INDEX, false); } } @@ -180,9 +178,6 @@ IVF::Build(const DatasetPtr& base) { this->Train(base); // TODO(LHT): duplicate auto result = this->Add(base); - if (this->bucket_->GetQuantizerName() == QUANTIZATION_TYPE_VALUE_PQFS) { - this->bucket_->Package(); - } return result; } @@ -205,6 +200,7 @@ IVF::Add(const DatasetPtr& base) { if (not partition_strategy_->is_trained_) { throw VsagException(ErrorType::INTERNAL_ERROR, "ivf index add without train error"); } + this->bucket_->Unpack(); auto num_element = base->GetNumElements(); const auto* ids = base->GetIds(); const auto* vectors = base->GetFloat32Vectors(); @@ -213,6 +209,7 @@ IVF::Add(const DatasetPtr& base) { bucket_->InsertVector(vectors + i * dim_, buckets[i], i + total_elements_); this->label_table_->Insert(i + total_elements_, ids[i]); } + this->bucket_->Package(); if (use_reorder_) { this->reorder_codes_->BatchInsertVector(base->GetFloat32Vectors(), base->GetNumElements()); } @@ -281,9 +278,11 @@ IVF::GetNumElements() const { void IVF::Merge(const std::vector& merge_units) { + this->bucket_->Unpack(); for (const auto& unit : merge_units) { this->merge_one_unit(unit); } + this->bucket_->Package(); } void @@ -413,7 +412,9 @@ IVF::merge_one_unit(const MergeUnit& unit) { std::dynamic_pointer_cast>(unit.index)->GetInnerIndex()); auto bias = this->total_elements_; this->label_table_->MergeOther(other_index->label_table_, bias); + other_index->bucket_->Unpack(); this->bucket_->MergeOther(other_index->bucket_, bias); + other_index->bucket_->Package(); if (this->use_reorder_) { this->reorder_codes_->MergeOther(other_index->reorder_codes_, bias); diff --git a/src/data_cell/bucket_datacell.h b/src/data_cell/bucket_datacell.h index 31d116423..9b27397d7 100644 --- a/src/data_cell/bucket_datacell.h +++ b/src/data_cell/bucket_datacell.h @@ -75,6 +75,13 @@ class BucketDataCell : public BucketInterface { } } + void + Unpack() override { + if (GetQuantizerName() == QUANTIZATION_TYPE_VALUE_PQFS) { + this->unpack_fastscan(); + } + } + void ExportModel(const BucketInterfacePtr& other) const override; @@ -128,6 +135,9 @@ class BucketDataCell : public BucketInterface { inline void package_fastscan(); + inline void + unpack_fastscan(); + private: std::shared_ptr quantizer_{nullptr}; @@ -300,6 +310,30 @@ BucketDataCell::package_fastscan() { } } +template +void +BucketDataCell::unpack_fastscan() { + ByteBuffer buffer(code_size_ * 32, this->allocator_); + for (int64_t i = 0; i < this->bucket_count_; ++i) { + auto bucket_size = (this->bucket_sizes_[i] + 31) / 32 * 32; + if (bucket_size == 0) { + continue; + } + bool need_release = false; + const auto* codes = this->datas_[i]->Read(code_size_ * bucket_size, 0, need_release); + InnerIdType begin = 0; + while (begin < bucket_size) { + const uint8_t* src_block = codes + begin * code_size_; + quantizer_->Unpack32(src_block, buffer.data); + this->datas_[i]->Write(buffer.data, code_size_ * 32, begin * code_size_); + begin += 32; + } + if (need_release) { + this->datas_[i]->Release(codes); + } + } +} + template void BucketDataCell::ExportModel(const BucketInterfacePtr& other) const { diff --git a/src/data_cell/bucket_interface.h b/src/data_cell/bucket_interface.h index 0081808e8..8b6eb7063 100644 --- a/src/data_cell/bucket_interface.h +++ b/src/data_cell/bucket_interface.h @@ -102,6 +102,9 @@ class BucketInterface { virtual void Package(){}; + virtual void + Unpack(){}; + public: BucketIdType bucket_count_{0}; uint32_t code_size_{0}; diff --git a/src/quantization/product_quantization/pq_fastscan_quantizer.h b/src/quantization/product_quantization/pq_fastscan_quantizer.h index 50f41ab97..a6b6d5f66 100644 --- a/src/quantization/product_quantization/pq_fastscan_quantizer.h +++ b/src/quantization/product_quantization/pq_fastscan_quantizer.h @@ -97,6 +97,9 @@ class PQFastScanQuantizer : public Quantizer> { void Package32(const uint8_t* codes, uint8_t* packaged_codes) const override; + void + Unpack32(const uint8_t* packaged_codes, uint8_t* codes) const override; + private: [[nodiscard]] inline const float* get_codebook_data(int64_t subspace_idx, int64_t centroid_num) const { @@ -410,4 +413,34 @@ PQFastScanQuantizer::Package32(const uint8_t* codes, uint8_t* packaged_c } } +template +void +PQFastScanQuantizer::Unpack32(const uint8_t* packaged_codes, uint8_t* codes) const { + constexpr int32_t mapper[32] = {0, 16, 8, 24, 1, 17, 9, 25, 2, 18, 10, 26, 3, 19, 11, 27, + 4, 20, 12, 28, 5, 21, 13, 29, 6, 22, 14, 30, 7, 23, 15, 31}; + + for (int i = 0; i < this->pq_dim_; ++i) { + for (int j = 0; j < BLOCK_SIZE_PACKAGE; ++j) { + int block_base = i * (BLOCK_SIZE_PACKAGE / 2) + (j / 2); + uint8_t byte = packaged_codes[block_base]; + + uint8_t code; + if (j % 2 == 0) { + code = byte & 0x0F; + } else { + code = byte >> 4; + } + int64_t vector_index = mapper[j]; + + int64_t code_offset = vector_index * this->code_size_ + (i / 2); + + if (i % 2 == 0) { + codes[code_offset] = (codes[code_offset] & 0xF0) | code; + } else { + codes[code_offset] = (codes[code_offset] & 0x0F) | (code << 4); + } + } + } +} + } // namespace vsag diff --git a/src/quantization/product_quantization/pq_fastscan_quantizer_test.cpp b/src/quantization/product_quantization/pq_fastscan_quantizer_test.cpp index 5b328b69f..b47c94373 100644 --- a/src/quantization/product_quantization/pq_fastscan_quantizer_test.cpp +++ b/src/quantization/product_quantization/pq_fastscan_quantizer_test.cpp @@ -47,6 +47,39 @@ TEST_CASE("PQFSQuantizer Encode and Decode", "[ut][PQFSQuantizer]") { } } +template +void +TestPackageUnpackMetricPQFS(uint64_t dim, int64_t pq_dim) { + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + PQFastScanQuantizer quantizer(dim, pq_dim, allocator.get()); + constexpr int count = PQFastScanQuantizer::BLOCK_SIZE_PACKAGE; + size_t code_size = quantizer.GetCodeSize(); + std::vector original_codes(count * code_size); + for (size_t i = 0; i < original_codes.size(); ++i) { + original_codes[i] = static_cast(rand() % 256); + } + + std::vector packaged(code_size * count); + quantizer.Package32(original_codes.data(), packaged.data()); + std::vector unpacked_codes(count * code_size); + quantizer.Unpack32(packaged.data(), unpacked_codes.data()); + for (size_t i = 0; i < original_codes.size(); ++i) { + REQUIRE(original_codes[i] == unpacked_codes[i]); + } +} + +TEST_CASE("PQFSQuantizer Package32 and Unpack32", "[ut][PQFSQuantizer]") { + constexpr MetricType metrics[2] = {MetricType::METRIC_TYPE_L2SQR, MetricType::METRIC_TYPE_IP}; + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + for (auto dim : dims) { + int64_t pq_dim = dim; + TestPackageUnpackMetricPQFS(dim, pq_dim); + TestPackageUnpackMetricPQFS(dim, pq_dim / 2); + TestPackageUnpackMetricPQFS(dim, pq_dim); + TestPackageUnpackMetricPQFS(dim, pq_dim / 2); + } +} + template void TestComputerBatchPQFS(PQFastScanQuantizer& quant, diff --git a/src/quantization/quantizer.h b/src/quantization/quantizer.h index 9a9371b9c..faa997956 100644 --- a/src/quantization/quantizer.h +++ b/src/quantization/quantizer.h @@ -215,6 +215,9 @@ class Quantizer { virtual void Package32(const uint8_t* codes, uint8_t* packaged_codes) const {}; + virtual void + Unpack32(const uint8_t* codes, uint8_t* packaged_codes) const {}; + /** * @brief Get the size of the encoded code in bytes. * From 75da847223cbf26ee920969d75faad5040baed24 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Fri, 23 May 2025 16:59:25 +0800 Subject: [PATCH 15/42] speedup kmeans by use hgraph for large k (#750) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- src/impl/kmeans_cluster.cpp | 56 +++++++++++++++++++++++++++++++++++-- src/impl/kmeans_cluster.h | 9 ++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/impl/kmeans_cluster.cpp b/src/impl/kmeans_cluster.cpp index e380c8ff7..f049ea1e0 100644 --- a/src/impl/kmeans_cluster.cpp +++ b/src/impl/kmeans_cluster.cpp @@ -20,16 +20,16 @@ #include +#include "algorithm/inner_index_interface.h" #include "byte_buffer.h" +#include "safe_allocator.h" #include "simd/fp32_simd.h" namespace vsag { - KMeansCluster::KMeansCluster(int32_t dim, Allocator* allocator, SafeThreadPoolPtr thread_pool) : dim_(dim), allocator_(allocator), thread_pool_(std::move(thread_pool)) { if (thread_pool_ == nullptr) { this->thread_pool_ = SafeThreadPool::FactoryDefaultThreadPool(); - // this->thread_pool_->SetPoolSize(10); } } @@ -69,7 +69,12 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { auto* y_sqr = reinterpret_cast(y_sqr_buffer.data); auto* distances = reinterpret_cast(distances_buffer.data); for (int it = 0; it < iter; ++it) { - this->find_nearest_one_with_blas(datas, count, k, query_count_bs, y_sqr, distances, labels); + if (k < THRESHOLD_FOR_HGRAPH) { + this->find_nearest_one_with_blas( + datas, count, k, query_count_bs, y_sqr, distances, labels); + } else { + this->find_nearest_one_with_hgraph(datas, count, k, query_count_bs, labels); + } constexpr uint64_t bs = 1024; Vector counts(k, 0, allocator_); @@ -192,4 +197,49 @@ KMeansCluster::find_nearest_one_with_blas(const float* query, } } +void +KMeansCluster::find_nearest_one_with_hgraph(const float* query, + const uint64_t query_count, + const uint64_t k, + const uint64_t query_count_bs, + Vector& labels) { + IndexCommonParam param; + param.dim_ = dim_; + param.allocator_ = std::make_shared(this->allocator_); + param.thread_pool_ = this->thread_pool_; + param.metric_ = MetricType::METRIC_TYPE_L2SQR; + + auto hgraph = InnerIndexInterface::FastCreateIndex("hgraph|32|fp32", param); + auto base = Dataset::Make(); + Vector ids(k, allocator_); + std::iota(ids.begin(), ids.end(), 0); + base->Dim(dim_) + ->NumElements(static_cast(k)) + ->Float32Vectors(this->k_centroids_) + ->Ids(ids.data()) + ->Owner(false); + hgraph->Build(base); + FilterPtr filter = nullptr; + constexpr const char* search_param = R"({"hgraph":{"ef_search":10}})"; + auto func = [&](const uint64_t begin, const uint64_t end) -> void { + for (uint64_t j = begin; j < end; ++j) { + auto q = Dataset::Make(); + q->Owner(false) + ->Float32Vectors(query + j * this->dim_) + ->NumElements(1) + ->Dim(this->dim_); + auto ret = hgraph->KnnSearch(q, 1, search_param, filter); + labels[j] = static_cast(ret->GetIds()[0]); + } + }; + std::vector> futures; + for (uint64_t i = 0; i < query_count; i += query_count_bs) { + futures.emplace_back( + thread_pool_->GeneralEnqueue(func, i, std::min(i + query_count_bs, query_count))); + } + for (auto& future : futures) { + future.wait(); + } +} + } // namespace vsag diff --git a/src/impl/kmeans_cluster.h b/src/impl/kmeans_cluster.h index d376c3fee..c14e8b2a0 100644 --- a/src/impl/kmeans_cluster.h +++ b/src/impl/kmeans_cluster.h @@ -45,12 +45,21 @@ class KMeansCluster { float* distances, Vector& labels); + void + find_nearest_one_with_hgraph(const float* query, + const uint64_t query_count, + const uint64_t k, + const uint64_t query_count_bs, + Vector& labels); + private: Allocator* const allocator_{nullptr}; SafeThreadPoolPtr thread_pool_{nullptr}; const int32_t dim_{0}; + + static constexpr uint64_t THRESHOLD_FOR_HGRAPH = 10000ULL; }; } // namespace vsag From 39c19c2f43baca2f61bc2f999bc90a59b5c525bd Mon Sep 17 00:00:00 2001 From: ShawnShawnYou Date: Sun, 25 May 2025 00:34:15 +0800 Subject: [PATCH 16/42] fix several bugs for performance (#730) * fix several bugs for performance in sq and optimizer Signed-off-by: zhongxiaoyao.zxy Signed-off-by: suguan.dx --- src/algorithm/hgraph.cpp | 6 +----- .../scalar_quantization_trainer.cpp | 13 ++++++++++--- .../scalar_quantization_trainer.h | 2 +- .../scalar_quantization/sq8_uniform_quantizer.h | 3 ++- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index cf2f56510..2acc8d0c2 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -956,11 +956,7 @@ HGraph::elp_optimize() { param.topk = 10; param.is_inner_id_allowed = nullptr; searcher_->SetMockParameters(bottom_graph_, basic_flatten_codes_, pool_, param, dim_); - optimizer_->RegisterParameter( - RuntimeParameter(PREFETCH_DEPTH_CODE, - 1, - (static_cast(basic_flatten_codes_->code_size_) + 63.0F) / 64.0F + 2, - 1)); + // TODO(ZXY): optimize PREFETCH_DEPTH_CODE and add default value for the others optimizer_->RegisterParameter(RuntimeParameter(PREFETCH_STRIDE_CODE, 1, 10, 1)); optimizer_->RegisterParameter(RuntimeParameter(PREFETCH_STRIDE_VISIT, 1, 10, 1)); optimizer_->Optimize(searcher_); diff --git a/src/quantization/scalar_quantization/scalar_quantization_trainer.cpp b/src/quantization/scalar_quantization/scalar_quantization_trainer.cpp index b695d6a40..e48520414 100644 --- a/src/quantization/scalar_quantization/scalar_quantization_trainer.cpp +++ b/src/quantization/scalar_quantization/scalar_quantization_trainer.cpp @@ -55,11 +55,17 @@ ScalarQuantizationTrainer::TrainUniform(const float* data, std::vector lower(dim_); if (mode == CLASSIC) { this->classic_train(sample_datas.data(), sample_count, upper.data(), lower.data()); + upper_bound = *std::max_element(upper.begin(), upper.end()); + lower_bound = *std::min_element(lower.begin(), lower.end()); } else if (mode == TRUNC_BOUND) { this->trunc_bound_train(sample_datas.data(), sample_count, upper.data(), lower.data()); + upper_bound = *std::min_element(upper.begin(), upper.end()); + lower_bound = *std::max_element(lower.begin(), lower.end()); + if (lower_bound > upper_bound) { + // case for count == 1 or trunc_rate > 0.5 + std::swap(lower_bound, upper_bound); + } } - upper_bound = *std::max_element(upper.begin(), upper.end()); - lower_bound = *std::min_element(lower.begin(), lower.end()); } void @@ -84,10 +90,11 @@ ScalarQuantizationTrainer::trunc_bound_train(const float* data, float* upper_bound, float* lower_bound) const { double ignore_rate = 0.001; - if (this->dim_ == 4) { + if (this->bits_ == 4) { ignore_rate = this->trunc_rate_; } auto ignore_count = static_cast(static_cast(count - 1) * ignore_rate); + ignore_count = ignore_count < 1 ? 1 : ignore_count; for (uint64_t i = 0; i < dim_; ++i) { std::priority_queue, std::greater<>> heap_max; std::priority_queue, std::less<>> heap_min; diff --git a/src/quantization/scalar_quantization/scalar_quantization_trainer.h b/src/quantization/scalar_quantization/scalar_quantization_trainer.h index 520b37833..77d2efd5c 100644 --- a/src/quantization/scalar_quantization/scalar_quantization_trainer.h +++ b/src/quantization/scalar_quantization/scalar_quantization_trainer.h @@ -82,7 +82,7 @@ class ScalarQuantizationTrainer { uint64_t max_sample_count_{MAX_DEFAULT_SAMPLE}; - constexpr static uint64_t MAX_DEFAULT_SAMPLE{65536}; + constexpr static uint64_t MAX_DEFAULT_SAMPLE{100000}; }; } // namespace vsag diff --git a/src/quantization/scalar_quantization/sq8_uniform_quantizer.h b/src/quantization/scalar_quantization/sq8_uniform_quantizer.h index 88d023e0a..7270473f1 100644 --- a/src/quantization/scalar_quantization/sq8_uniform_quantizer.h +++ b/src/quantization/scalar_quantization/sq8_uniform_quantizer.h @@ -159,7 +159,8 @@ SQ8UniformQuantizer::TrainImpl(const DataType* data, uint64_t count) { } ScalarQuantizationTrainer trainer(this->dim_, 8); - trainer.TrainUniform(data, count, this->diff_, this->lower_bound_, need_normalize); + trainer.TrainUniform( + data, count, this->diff_, this->lower_bound_, need_normalize, SQTrainMode::CLASSIC); this->diff_ -= this->lower_bound_; From caefbeab544496ec854f26a7b89c9aab713e64c4 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Sun, 25 May 2025 22:47:34 +0800 Subject: [PATCH 17/42] introduce fast bitset (#753) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- .circleci/fresh_ci_cache.commit | 2 +- .github/workflows/asan_build_and_test.yml | 30 ++-- include/vsag/bitset.h | 27 --- mockimpl/CMakeLists.txt | 9 +- src/CMakeLists.txt | 8 +- src/algorithm/hnswlib/hnswalg.cpp | 40 ++--- src/algorithm/hnswlib/hnswalg.h | 1 + src/base_filter_functor.h | 2 +- src/impl/bitset/bitset.cpp | 42 +++++ src/impl/bitset/computable_bitset.cpp | 33 ++++ src/impl/bitset/computable_bitset.h | 103 +++++++++++ src/impl/bitset/fast_bitset.cpp | 156 +++++++++++++++++ src/impl/bitset/fast_bitset.h | 67 ++++++++ src/impl/bitset/fast_bitset_test.cpp | 160 ++++++++++++++++++ .../bitset/sparse_bitset.cpp} | 68 ++++---- .../bitset/sparse_bitset.h} | 25 ++- .../bitset/sparse_bitset_test.cpp} | 79 +++++---- src/index/hnsw.cpp | 2 +- 18 files changed, 713 insertions(+), 141 deletions(-) create mode 100644 src/impl/bitset/bitset.cpp create mode 100644 src/impl/bitset/computable_bitset.cpp create mode 100644 src/impl/bitset/computable_bitset.h create mode 100644 src/impl/bitset/fast_bitset.cpp create mode 100644 src/impl/bitset/fast_bitset.h create mode 100644 src/impl/bitset/fast_bitset_test.cpp rename src/{bitset_impl.cpp => impl/bitset/sparse_bitset.cpp} (57%) rename src/{bitset_impl.h => impl/bitset/sparse_bitset.h} (71%) rename src/{bitset_impl_test.cpp => impl/bitset/sparse_bitset_test.cpp} (82%) diff --git a/.circleci/fresh_ci_cache.commit b/.circleci/fresh_ci_cache.commit index b56443f02..78b4ac80c 100644 --- a/.circleci/fresh_ci_cache.commit +++ b/.circleci/fresh_ci_cache.commit @@ -1 +1 @@ -c13456cac7cd8c69a554d7f425bf658e3d0ea338 +d2491912b3b18b2d8745cd7468001a91eab89692 diff --git a/.github/workflows/asan_build_and_test.yml b/.github/workflows/asan_build_and_test.yml index 217b5ad41..a744eecff 100644 --- a/.github/workflows/asan_build_and_test.yml +++ b/.github/workflows/asan_build_and_test.yml @@ -21,8 +21,6 @@ jobs: - name: Free Disk Space run: rm -rf /useless/hostedtoolcache - uses: actions/checkout@v4 - with: - fetch-depth: '0' - name: Load Cache uses: actions/cache@v4 with: @@ -38,6 +36,12 @@ jobs: fi - name: Make Asan run: export CMAKE_GENERATOR="Ninja"; make asan + - name: Compress + run: | + sudo apt install pigz -y; + tar -cf - ./build | pigz -p 4 > build.tar.gz + - name: Clean + run: find ./build -type f -name "*.o" -exec rm -f {} + - name: Save Test uses: actions/upload-artifact@v4 with: @@ -46,10 +50,6 @@ jobs: compression-level: 1 retention-days: 1 overwrite: 'true' - - name: Compress - run: | - sudo apt install pigz -y; - tar -cf - ./build | pigz -p 4 > build.tar.gz build_asan_aarch64: name: Asan Build Aarch64 @@ -61,8 +61,6 @@ jobs: - name: Free Disk Space run: rm -rf /opt/hostedtoolcache - uses: actions/checkout@v4 - with: - fetch-depth: '0' - name: Prepare Env run: sudo bash ./scripts/deps/install_deps_ubuntu.sh - name: Load Cache @@ -80,18 +78,20 @@ jobs: fi - name: Make Asan run: export CMAKE_GENERATOR="Ninja"; make asan + - name: Compress + run: | + sudo apt install pigz -y; + tar -cf - ./build | pigz -p 4 > build.tar.gz + - name: Clean + run: find ./build -type f -name "*.o" -exec rm -f {} + - name: Save Test uses: actions/upload-artifact@v4 with: - path: ./build + path: ./build/ name: test_aarch64-${{ github.run_id }} compression-level: 1 retention-days: 1 overwrite: 'true' - - name: Compress - run: | - sudo apt install pigz -y; - tar -cf - ./build | pigz -p 4 > build.tar.gz test_asan_x86: name: Test X86 @@ -117,7 +117,7 @@ jobs: uses: actions/download-artifact@v4 with: name: test_x86-${{ github.run_id }} - path: ./build/ + path: ./build - name: Do Asan Test In ${{ matrix.test_type }} run: | echo leak:libomp.so > omp.supp @@ -155,7 +155,7 @@ jobs: uses: actions/download-artifact@v4 with: name: test_aarch64-${{ github.run_id }} - path: ./build/ + path: ./build - name: Do Asan Test In ${{ matrix.test_type }} run: | echo leak:libomp.so > omp.supp diff --git a/include/vsag/bitset.h b/include/vsag/bitset.h index 6ac5a0496..b246b6fb2 100644 --- a/include/vsag/bitset.h +++ b/include/vsag/bitset.h @@ -86,33 +86,6 @@ class Bitset { virtual uint64_t Count() = 0; - /** - * @brief Performs a bitwise OR operation on the current bitset with another bitset. - * - * @param another The bitset to perform the OR operation with. - * @return void - */ - virtual void - Or(const Bitset& another) = 0; - - /** - * @brief Performs a bitwise AND operation on the current bitset with another bitset. - * - * @param another The bitset to perform the AND operation with. - * @return void - */ - virtual void - And(const Bitset& another) = 0; - - /** - * @brief Performs a bitwise XOR operation on the current bitset with another bitset. - * - * @param another The bitset to perform the XOR operation with. - * @return void - */ - virtual void - Xor(const Bitset& another) = 0; - public: /** * For debugging diff --git a/mockimpl/CMakeLists.txt b/mockimpl/CMakeLists.txt index 3b6e2f87a..27998bd63 100644 --- a/mockimpl/CMakeLists.txt +++ b/mockimpl/CMakeLists.txt @@ -2,7 +2,10 @@ set (MOCK_SRCS vsag/simpleflat.cpp vsag/factory.cpp - ../src/bitset_impl.cpp + ../src/impl/bitset/bitset.cpp + ../src/impl/bitset/computable_bitset.cpp + ../src/impl/bitset/sparse_bitset.cpp + ../src/impl/bitset/fast_bitset.cpp ../src/constants.cpp ../src/dataset_impl.cpp ) @@ -10,8 +13,8 @@ set (MOCK_SRCS add_library (vsag_mockimpl SHARED ${MOCK_SRCS}) add_library (vsag_mockimpl_static STATIC ${MOCK_SRCS}) -target_link_libraries (vsag_mockimpl roaring) -target_link_libraries (vsag_mockimpl_static roaring) +target_link_libraries (vsag_mockimpl roaring fmt::fmt-header-only) +target_link_libraries (vsag_mockimpl_static roaring fmt::fmt-header-only) add_dependencies (vsag_mockimpl version_mockimpl roaring) add_dependencies (vsag_mockimpl_static version_mockimpl roaring) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d1d06213c..65a9ef458 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -6,8 +6,10 @@ add_subdirectory (quantization) file (GLOB CPP_SRCS "*.cpp") list (FILTER CPP_SRCS EXCLUDE REGEX "_test.cpp") -file (GLOB CPP_CONJUGATE_GRAPH_SRCS "impl/*.cpp") -list (FILTER CPP_CONJUGATE_GRAPH_SRCS EXCLUDE REGEX "_test.cpp") +file (GLOB CPP_IMPL_SRCS "impl/*.cpp") +file (GLOB CPP_IMPL_EXTRA_SRCS "impl/**/*.cpp") +list (APPEND CPP_IMPL_SRCS ${CPP_IMPL_EXTRA_SRCS}) +list (FILTER CPP_IMPL_SRCS EXCLUDE REGEX "_test.cpp") file (GLOB CPP_INDEX_SRCS "index/*.cpp") list (FILTER CPP_INDEX_SRCS EXCLUDE REGEX "_test.cpp") @@ -23,7 +25,7 @@ list (FILTER CPP_ALGORITHM_SRCS EXCLUDE REGEX "_test.cpp") file (GLOB CPP_UTILS_SRCS "utils/*.cpp") list (FILTER CPP_UTILS_SRCS EXCLUDE REGEX "_test.cpp") -set (VSAG_SRCS ${CPP_SRCS} ${CPP_INDEX_SRCS} ${CPP_CONJUGATE_GRAPH_SRCS} +set (VSAG_SRCS ${CPP_SRCS} ${CPP_INDEX_SRCS} ${CPP_IMPL_SRCS} ${CPP_DATA_CELL_SRCS} ${CPP_ALGORITHM_SRCS} ${CPP_UTILS_SRCS}) add_library (vsag SHARED ${VSAG_SRCS}) diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index e5779a197..daafb1b58 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -391,18 +391,18 @@ HierarchicalNSW::searchBaseLayer(InnerIdType ep_id, const void* data_point, int size_t size = getListCount((linklistsizeint*)data); auto* datal = (InnerIdType*)(data + 1); #ifdef USE_SSE - _mm_prefetch((char*)(visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char*)(visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); + vsag::PrefetchLines((char*)(visited_array + *(data + 1)), 64); + vsag::PrefetchLines((char*)(visited_array + *(data + 1) + 64), 64); + vsag::PrefetchLines(getDataByInternalId(*datal), 64); + vsag::PrefetchLines(getDataByInternalId(*(datal + 1)), 64); #endif for (size_t j = 0; j < size; j++) { InnerIdType candidate_id = *(datal + j); #ifdef USE_SSE size_t pre_l = std::min(j, size - 2); - _mm_prefetch((char*)(visited_array + *(datal + pre_l + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + pre_l + 1)), _MM_HINT_T0); + vsag::PrefetchLines((char*)(visited_array + *(datal + pre_l + 1)), 64); + vsag::PrefetchLines(getDataByInternalId(*(datal + pre_l + 1)), 64); #endif if (visited_array[candidate_id] == visited_array_tag) continue; @@ -413,7 +413,7 @@ HierarchicalNSW::searchBaseLayer(InnerIdType ep_id, const void* data_point, int if (top_candidates.size() < ef_construction_ || lower_bound > dist1) { candidateSet.emplace(-dist1, candidate_id); #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); + vsag::PrefetchLines(getDataByInternalId(candidateSet.top().second), 64); #endif if (not isMarkedDeleted(candidate_id)) @@ -502,10 +502,10 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id, auto vector_data_ptr = data_level0_memory_->GetElementPtr((*(data + 1)), offset_data_); #ifdef ENABLE_SSE - _mm_prefetch((char*)(visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char*)(visited_array + *(data + 1) + 64), _MM_HINT_T0); + vsag::PrefetchLines((char*)(visited_array + *(data + 1)), 64); + vsag::PrefetchLines((char*)(visited_array + *(data + 1) + 64), 64); vsag::PrefetchLines(vector_data_ptr, data_size_); - _mm_prefetch((char*)(data + 2), _MM_HINT_T0); + vsag::PrefetchLines((char*)(data + 2), 64); #endif for (size_t j = 1; j <= size; j++) { @@ -515,8 +515,8 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id, vector_data_ptr = data_level0_memory_->GetElementPtr( (*(data + pre_l + prefetch_jump_code_size_)), offset_data_); #ifdef ENABLE_SSE - _mm_prefetch((char*)(visited_array + *(data + pre_l + prefetch_jump_code_size_)), - _MM_HINT_T0); + vsag::PrefetchLines( + (char*)(visited_array + *(data + pre_l + prefetch_jump_code_size_)), 64); vsag::PrefetchLines(vector_data_ptr, data_size_); #endif } @@ -535,7 +535,7 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id, vector_data_ptr = data_level0_memory_->GetElementPtr(candidate_set.top().second, offsetLevel0_); #ifdef ENABLE_SSE - _mm_prefetch(vector_data_ptr, _MM_HINT_T0); + vsag::PrefetchLines(vector_data_ptr, 64); #endif if ((!has_deletions || !isMarkedDeleted(candidate_id)) && @@ -614,10 +614,10 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id, auto vector_data_ptr = data_level0_memory_->GetElementPtr((*(data + 1)), offset_data_); #ifdef USE_SSE - _mm_prefetch((char*)(visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char*)(visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(vector_data_ptr, _MM_HINT_T0); - _mm_prefetch((char*)(data + 2), _MM_HINT_T0); + vsag::PrefetchLines((char*)(visited_array + *(data + 1)), 64); + vsag::PrefetchLines((char*)(visited_array + *(data + 1) + 64), 64); + vsag::PrefetchLines(vector_data_ptr, 64); + vsag::PrefetchLines((char*)(data + 2), 64); #endif for (size_t j = 1; j <= size; j++) { @@ -626,8 +626,8 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id, vector_data_ptr = data_level0_memory_->GetElementPtr((*(data + pre_l + 1)), offset_data_); #ifdef USE_SSE - _mm_prefetch((char*)(visited_array + *(data + pre_l + 1)), _MM_HINT_T0); - _mm_prefetch(vector_data_ptr, _MM_HINT_T0); //////////// + vsag::PrefetchLines((char*)(visited_array + *(data + pre_l + 1)), 64); + vsag::PrefetchLines(vector_data_ptr, 64); //////////// #endif if (visited_array[candidate_id] != visited_array_tag) { visited_array[candidate_id] = visited_array_tag; @@ -642,7 +642,7 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id, vector_data_ptr = data_level0_memory_->GetElementPtr(candidate_set.top().second, offsetLevel0_); #ifdef USE_SSE - _mm_prefetch(vector_data_ptr, _MM_HINT_T0); //////////////////////// + vsag::PrefetchLines(vector_data_ptr, 64); //////////////////////// #endif if ((!has_deletions || !isMarkedDeleted(candidate_id)) && diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index 5be512957..90ded5823 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -43,6 +43,7 @@ #include "visited_list_pool.h" #include "vsag/dataset.h" #include "vsag/iterator_context.h" + namespace hnswlib { using InnerIdType = vsag::InnerIdType; using linklistsizeint = unsigned int; diff --git a/src/base_filter_functor.h b/src/base_filter_functor.h index 10d9553e1..95c082027 100644 --- a/src/base_filter_functor.h +++ b/src/base_filter_functor.h @@ -17,11 +17,11 @@ #include -#include "bitset_impl.h" #include "common.h" #include "data_cell/extra_info_interface.h" #include "label_table.h" #include "typing.h" +#include "vsag/bitset.h" #include "vsag/filter.h" namespace vsag { diff --git a/src/impl/bitset/bitset.cpp b/src/impl/bitset/bitset.cpp new file mode 100644 index 000000000..336ca25b6 --- /dev/null +++ b/src/impl/bitset/bitset.cpp @@ -0,0 +1,42 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vsag/bitset.h" + +#include +#include +#include + +#include "computable_bitset.h" + +namespace vsag { + +BitsetPtr +Bitset::Random(int64_t length) { + auto bitset = ComputableBitset::MakeInstance(ComputableBitsetType::SparseBitset); + static auto gen = + std::bind(std::uniform_int_distribution<>(0, 1), // NOLINT(modernize-avoid-bind) + std::default_random_engine()); + for (int64_t i = 0; i < length; ++i) { + bitset->Set(i, gen() != 0); + } + return bitset; +} + +BitsetPtr +Bitset::Make() { + return ComputableBitset::MakeInstance(ComputableBitsetType::SparseBitset); +} +} // namespace vsag diff --git a/src/impl/bitset/computable_bitset.cpp b/src/impl/bitset/computable_bitset.cpp new file mode 100644 index 000000000..ee50f1a52 --- /dev/null +++ b/src/impl/bitset/computable_bitset.cpp @@ -0,0 +1,33 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "computable_bitset.h" + +#include "fast_bitset.h" +#include "sparse_bitset.h" +#include "vsag_exception.h" + +namespace vsag { +ComputableBitsetPtr +ComputableBitset::MakeInstance(ComputableBitsetType type, Allocator* allocator) { + if (type == ComputableBitsetType::SparseBitset) { + return std::make_shared(); + } + if (type == ComputableBitsetType::FastBitset) { + return std::make_shared(allocator); + } + throw VsagException(ErrorType::INTERNAL_ERROR, "Unknown bitset type"); +} +} // namespace vsag diff --git a/src/impl/bitset/computable_bitset.h b/src/impl/bitset/computable_bitset.h new file mode 100644 index 000000000..304666b20 --- /dev/null +++ b/src/impl/bitset/computable_bitset.h @@ -0,0 +1,103 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "stream_reader.h" +#include "stream_writer.h" +#include "vsag/bitset.h" + +namespace vsag { +class ComputableBitset; +using ComputableBitsetPtr = std::shared_ptr; + +enum class ComputableBitsetType { + SparseBitset, + FastBitset, +}; + +/** + * @brief ComputableBitset is a base class for bitsets that can be computed. + * + * @note ComputableBitset is a base class for bitsets that can be computed. + * It provides a set of methods that can be used to perform bitwise operations on the bitset. + */ +class ComputableBitset : public Bitset { +public: + static ComputableBitsetPtr + MakeInstance(ComputableBitsetType type, Allocator* allocator = nullptr); + +public: + ComputableBitset() = default; + + ~ComputableBitset() override = default; + + /** + * @brief Performs a bitwise OR operation on the current bitset with another bitset. + * + * @param another The bitset to perform the OR operation with. + * @return void + */ + virtual void + Or(const Bitset& another) = 0; + + /** + * @brief Performs a bitwise AND operation on the current bitset with another bitset. + * + * @param another The bitset to perform the AND operation with. + * @return void + */ + virtual void + And(const Bitset& another) = 0; + + /** + * @brief Performs a bitwise XOR operation on the current bitset with another bitset. + * + * @param another The bitset to perform the XOR operation with. + * @return void + */ + virtual void + Xor(const Bitset& another) = 0; + + /** + * @brief Performs a bitwise NOT operation on the current bitset. + * + * @return void + */ + virtual void + Not() = 0; + + /** + * @brief Serializes the bitset to a stream. + * + * @param writer The stream writer to write the serialized bitset to. + * @return void + * @note The serialized bitset is written to the stream in a format that can be deserialized by the Deserialize method. + */ + virtual void + Serialize(StreamWriter& writer) const = 0; + + /** + * @brief Deserializes the bitset from a stream. + * + * @param reader The stream reader to read the serialized bitset from. + * @return void + * @note The serialized bitset is read from the stream and deserialized into the current bitset. + */ + virtual void + Deserialize(StreamReader& reader) = 0; +}; + +} // namespace vsag diff --git a/src/impl/bitset/fast_bitset.cpp b/src/impl/bitset/fast_bitset.cpp new file mode 100644 index 000000000..80bed2593 --- /dev/null +++ b/src/impl/bitset/fast_bitset.cpp @@ -0,0 +1,156 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_bitset.h" + +#include "vsag_exception.h" + +namespace vsag { + +void +FastBitset::Set(int64_t pos, bool value) { + std::lock_guard lock(mutex_); + auto capacity = data_.size() * 64; + if (pos >= capacity) { + data_.resize((pos / 64) + 1, 0); + } + auto word_index = pos / 64; + auto bit_index = pos % 64; + if (value) { + data_[word_index] |= (1ULL << bit_index); + } else { + data_[word_index] &= ~(1ULL << bit_index); + } +} + +bool +FastBitset::Test(int64_t pos) { + std::shared_lock lock(mutex_); + auto capacity = data_.size() * 64; + if (pos >= capacity) { + return false; + } + auto word_index = pos / 64; + auto bit_index = pos % 64; + return (data_[word_index] & (1ULL << bit_index)) != 0; +} + +uint64_t +FastBitset::Count() { + std::shared_lock lock(mutex_); + uint64_t count = 0; + for (auto word : data_) { + count += __builtin_popcountll(word); + } + return count; +} +void +FastBitset::Or(const Bitset& another) { + const auto* fast_another = dynamic_cast(&another); + if (fast_another == nullptr) { + throw VsagException(ErrorType::INTERNAL_ERROR, "bitset not match"); + } + std::lock(mutex_, fast_another->mutex_); + std::lock_guard lock1(mutex_, std::adopt_lock); + std::lock_guard lock2(fast_another->mutex_, std::adopt_lock); + auto max_size = std::max(data_.size(), fast_another->data_.size()); + data_.resize(max_size, 0); + // TODO(LHT): use SIMD + for (uint64_t i = 0; i < max_size; ++i) { + data_[i] |= fast_another->data_[i]; + } +} + +void +FastBitset::And(const Bitset& another) { + const auto* fast_another = dynamic_cast(&another); + if (fast_another == nullptr) { + throw VsagException(ErrorType::INTERNAL_ERROR, "bitset not match"); + } + std::lock(mutex_, fast_another->mutex_); + std::lock_guard lock1(mutex_, std::adopt_lock); + std::lock_guard lock2(fast_another->mutex_, std::adopt_lock); + auto max_size = std::max(data_.size(), fast_another->data_.size()); + data_.resize(max_size, 0); + // TODO(LHT): use SIMD + for (uint64_t i = 0; i < max_size; ++i) { + data_[i] &= fast_another->data_[i]; + } +} + +void +FastBitset::Xor(const Bitset& another) { + const auto* fast_another = dynamic_cast(&another); + if (fast_another == nullptr) { + throw VsagException(ErrorType::INTERNAL_ERROR, "bitset not match"); + } + std::lock(mutex_, fast_another->mutex_); + std::lock_guard lock1(mutex_, std::adopt_lock); + std::lock_guard lock2(fast_another->mutex_, std::adopt_lock); + auto max_size = std::max(data_.size(), fast_another->data_.size()); + data_.resize(max_size, 0); + // TODO(LHT): use SIMD + for (uint64_t i = 0; i < max_size; ++i) { + data_[i] ^= fast_another->data_[i]; + } +} + +std::string +FastBitset::Dump() { + std::shared_lock lock(mutex_); + std::string result = "{"; + auto capacity = data_.size() * 64; + int count = 0; + for (int64_t i = 0; i < capacity; ++i) { + if (Test(i)) { + if (count == 0) { + result += std::to_string(i); + } else { + result += "," + std::to_string(i); + } + ++count; + } + } + result += "}"; + return result; +} + +void +FastBitset::Not() { + std::lock_guard lock(mutex_); + for (auto& word : data_) { + word = ~word; + } +} + +void +FastBitset::Serialize(StreamWriter& writer) const { + std::shared_lock lock(mutex_); + uint64_t size = data_.size(); + StreamWriter::WriteObj(writer, size); + if (size > 0) { + writer.Write(reinterpret_cast(data_.data()), size * sizeof(uint64_t)); + } +} + +void +FastBitset::Deserialize(StreamReader& reader) { + std::lock_guard lock(mutex_); + uint64_t size; + StreamReader::ReadObj(reader, size); + data_.resize(size); + reader.Read(reinterpret_cast(data_.data()), size * sizeof(uint64_t)); +} +} // namespace vsag diff --git a/src/impl/bitset/fast_bitset.h b/src/impl/bitset/fast_bitset.h new file mode 100644 index 000000000..0253c8a07 --- /dev/null +++ b/src/impl/bitset/fast_bitset.h @@ -0,0 +1,67 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "computable_bitset.h" +#include "safe_allocator.h" +#include "typing.h" + +namespace vsag { +class FastBitset : public ComputableBitset { +public: + explicit FastBitset(Allocator* allocator) + : ComputableBitset(), allocator_(allocator), data_(allocator){}; + + ~FastBitset() override = default; + + void + Set(int64_t pos, bool value) override; + + bool + Test(int64_t pos) override; + + uint64_t + Count() override; + + void + Or(const Bitset& another) override; + + void + And(const Bitset& another) override; + + void + Xor(const Bitset& another) override; + + void + Not() override; + + void + Serialize(StreamWriter& writer) const override; + + void + Deserialize(StreamReader& reader) override; + + std::string + Dump() override; + +private: + Vector data_; + mutable std::shared_mutex mutex_; + Allocator* const allocator_{nullptr}; +}; +} // namespace vsag diff --git a/src/impl/bitset/fast_bitset_test.cpp b/src/impl/bitset/fast_bitset_test.cpp new file mode 100644 index 000000000..48e252811 --- /dev/null +++ b/src/impl/bitset/fast_bitset_test.cpp @@ -0,0 +1,160 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_bitset.h" + +#include +#include + +#include "fixtures.h" + +using namespace vsag; + +TEST_CASE("FastBitset basic operations", "[ut][FastBitset]") { + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + FastBitset bs(allocator.get()); + + SECTION("Initial state") { + REQUIRE_FALSE(bs.Test(0)); + REQUIRE(bs.Count() == 0); + REQUIRE(bs.Dump() == "{}"); + } + + SECTION("Single bit set") { + bs.Set(10, true); + REQUIRE(bs.Test(10)); + REQUIRE(bs.Count() == 1); + REQUIRE(bs.Dump() == "{10}"); + + bs.Set(10, false); + REQUIRE_FALSE(bs.Test(0)); + REQUIRE(bs.Count() == 0); + REQUIRE(bs.Dump() == "{}"); + } + + SECTION("Multiple bits in same word") { + bs.Set(1, true); + bs.Set(3, true); + bs.Set(5, true); + REQUIRE(bs.Count() == 3); + REQUIRE(bs.Dump() == "{1,3,5}"); + } + + SECTION("Multiple bit set in different word") { + bs.Set(0, true); + bs.Set(103, true); + REQUIRE(bs.Test(0)); + REQUIRE(bs.Test(103)); + REQUIRE(bs.Count() == 2); + REQUIRE(bs.Dump() == "{0,103}"); + + bs.Set(0, false); + REQUIRE_FALSE(bs.Test(0)); + REQUIRE(bs.Test(103)); + REQUIRE(bs.Count() == 1); + REQUIRE(bs.Dump() == "{103}"); + } + + SECTION("Resize on large position") { + bs.Set(12800, true); + REQUIRE(bs.Test(12800)); + REQUIRE(bs.Count() == 1); + REQUIRE(bs.Dump() == "{12800}"); + } + + SECTION("Serialize and deserialize") { + fixtures::TempDir dir("fast_bitset"); + auto path = dir.GenerateRandomFile(); + std::ofstream ofs(path, std::ios::binary); + IOStreamWriter writer(ofs); + + bs.Set(101, true); + bs.Serialize(writer); + ofs.close(); + + std::ifstream ifs(path, std::ios::binary); + IOStreamReader reader(ifs); + FastBitset bitset2(allocator.get()); + bitset2.Deserialize(reader); + ifs.close(); + auto dump = bitset2.Dump(); + REQUIRE(dump == "{101}"); + } +} + +TEST_CASE("FastBitset bitwise operations", "[ut][FastBitset]") { + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + + FastBitset a(allocator.get()); + FastBitset b(allocator.get()); + + SECTION("OR operation") { + a.Set(10, true); + b.Set(111, true); + a.Or(b); + + REQUIRE(a.Test(10)); + REQUIRE(a.Test(111)); + REQUIRE(a.Count() == 2); + REQUIRE(a.Dump() == "{10,111}"); + + FastBitset c(allocator.get()); + c.Set(64, true); + c.Set(111, true); + a.Or(c); + REQUIRE(a.Test(64)); + REQUIRE(a.Count() == 3); + } + + SECTION("AND operation") { + a.Set(2, true); + a.Set(215, true); + b.Set(215, true); + b.Set(1928, true); + a.And(b); + + REQUIRE_FALSE(a.Test(2)); + REQUIRE(a.Test(215)); + REQUIRE_FALSE(a.Test(1928)); + REQUIRE(a.Count() == 1); + REQUIRE(a.Dump() == "{215}"); + } + + SECTION("XOR operation") { + a.Set(100, true); + a.Set(1001, true); + b.Set(1001, true); + b.Set(2025, true); + a.Xor(b); + + REQUIRE(a.Test(100)); + REQUIRE_FALSE(a.Test(1001)); + REQUIRE(a.Test(2025)); + REQUIRE(a.Count() == 2); + REQUIRE(a.Dump() == "{100,2025}"); + } + + SECTION("NOT operation") { + a.Set(100, true); + a.Set(1001, true); + a.Not(); + REQUIRE_FALSE(a.Test(100)); + REQUIRE_FALSE(a.Test(1001)); + a.Not(); + REQUIRE(a.Test(100)); + REQUIRE(a.Test(1001)); + REQUIRE(a.Count() == 2); + } +} diff --git a/src/bitset_impl.cpp b/src/impl/bitset/sparse_bitset.cpp similarity index 57% rename from src/bitset_impl.cpp rename to src/impl/bitset/sparse_bitset.cpp index bbea2e38b..aa87611da 100644 --- a/src/bitset_impl.cpp +++ b/src/impl/bitset/sparse_bitset.cpp @@ -13,35 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "bitset_impl.h" +#include "sparse_bitset.h" #include -#include #include -#include -#include namespace vsag { -BitsetPtr -Bitset::Random(int64_t length) { - auto bitset = std::make_shared(); - static auto gen = - std::bind(std::uniform_int_distribution<>(0, 1), // NOLINT(modernize-avoid-bind) - std::default_random_engine()); - for (int64_t i = 0; i < length; ++i) { - bitset->Set(i, gen() != 0); - } - return bitset; -} - -BitsetPtr -Bitset::Make() { - return std::make_shared(); -} - void -BitsetImpl::Set(int64_t pos, bool value) { +SparseBitset::Set(int64_t pos, bool value) { std::lock_guard lock(mutex_); if (value) { r_.add(pos); @@ -51,24 +31,24 @@ BitsetImpl::Set(int64_t pos, bool value) { } bool -BitsetImpl::Test(int64_t pos) { +SparseBitset::Test(int64_t pos) { std::lock_guard lock(mutex_); return r_.contains(pos); } uint64_t -BitsetImpl::Count() { +SparseBitset::Count() { return r_.cardinality(); } std::string -BitsetImpl::Dump() { +SparseBitset::Dump() { return r_.toString(); } void -BitsetImpl::Or(const Bitset& another) { - const auto* another_ptr = reinterpret_cast(&another); +SparseBitset::Or(const Bitset& another) { + const auto* another_ptr = reinterpret_cast(&another); std::lock(mutex_, another_ptr->mutex_); std::lock_guard lock(mutex_, std::adopt_lock); std::lock_guard lock_other(another_ptr->mutex_, std::adopt_lock); @@ -76,8 +56,8 @@ BitsetImpl::Or(const Bitset& another) { } void -BitsetImpl::And(const Bitset& another) { - const auto* another_ptr = reinterpret_cast(&another); +SparseBitset::And(const Bitset& another) { + const auto* another_ptr = reinterpret_cast(&another); std::lock(mutex_, another_ptr->mutex_); std::lock_guard lock(mutex_, std::adopt_lock); std::lock_guard lock_other(another_ptr->mutex_, std::adopt_lock); @@ -85,12 +65,38 @@ BitsetImpl::And(const Bitset& another) { } void -BitsetImpl::Xor(const Bitset& another) { - const auto* another_ptr = reinterpret_cast(&another); +SparseBitset::Xor(const Bitset& another) { + const auto* another_ptr = reinterpret_cast(&another); std::lock(mutex_, another_ptr->mutex_); std::lock_guard lock(mutex_, std::adopt_lock); std::lock_guard lock_other(another_ptr->mutex_, std::adopt_lock); r_ ^= another_ptr->r_; } +void +SparseBitset::Not() { + std::lock_guard lock(mutex_); + r_.flipClosed(r_.minimum(), r_.maximum()); +} + +void +SparseBitset::Serialize(StreamWriter& writer) const { + std::lock_guard lock(mutex_); + uint64_t size = r_.getSizeInBytes(); + StreamWriter::WriteObj(writer, size); + std::vector buffer(size); + r_.write(buffer.data()); + writer.Write(buffer.data(), size); +} + +void +SparseBitset::Deserialize(StreamReader& reader) { + std::lock_guard lock(mutex_); + uint64_t size; + StreamReader::ReadObj(reader, size); + std::vector buffer(size); + reader.Read(buffer.data(), size); + r_ = roaring::Roaring::readSafe(buffer.data(), size); +} + } // namespace vsag diff --git a/src/bitset_impl.h b/src/impl/bitset/sparse_bitset.h similarity index 71% rename from src/bitset_impl.h rename to src/impl/bitset/sparse_bitset.h index 8437e189b..72cf43b85 100644 --- a/src/bitset_impl.h +++ b/src/impl/bitset/sparse_bitset.h @@ -21,19 +21,19 @@ #include #include -#include "vsag/bitset.h" +#include "computable_bitset.h" namespace vsag { -class BitsetImpl : public Bitset { +class SparseBitset : public ComputableBitset { public: - BitsetImpl() = default; - ~BitsetImpl() override = default; + SparseBitset() = default; + ~SparseBitset() override = default; - BitsetImpl(const BitsetImpl&) = delete; - BitsetImpl& - operator=(const BitsetImpl&) = delete; - BitsetImpl(BitsetImpl&&) = delete; + SparseBitset(const SparseBitset&) = delete; + SparseBitset& + operator=(const SparseBitset&) = delete; + SparseBitset(SparseBitset&&) = delete; public: void @@ -57,6 +57,15 @@ class BitsetImpl : public Bitset { void Xor(const Bitset& another) override; + void + Not() override; + + void + Serialize(StreamWriter& writer) const override; + + void + Deserialize(StreamReader& reader) override; + private: mutable std::mutex mutex_; roaring::Roaring r_; diff --git a/src/bitset_impl_test.cpp b/src/impl/bitset/sparse_bitset_test.cpp similarity index 82% rename from src/bitset_impl_test.cpp rename to src/impl/bitset/sparse_bitset_test.cpp index 7ab1b3e01..974470ccd 100644 --- a/src/bitset_impl_test.cpp +++ b/src/impl/bitset/sparse_bitset_test.cpp @@ -13,18 +13,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "bitset_impl.h" +#include "sparse_bitset.h" #include #include #include #include -using namespace roaring; +#include "fixtures.h" -TEST_CASE("BitsetImpl Test", "[ut][bitset]") { - vsag::BitsetImpl bitset; +using namespace roaring; +using namespace vsag; +TEST_CASE("SparseBitset Test", "[ut][bitset]") { + SparseBitset bitset; // empty REQUIRE(bitset.Count() == 0); @@ -47,20 +49,35 @@ TEST_CASE("BitsetImpl Test", "[ut][bitset]") { bitset.Set(100, true); auto dumped = bitset.Dump(); REQUIRE(dumped == "{100}"); + + fixtures::TempDir dir("sparse_bitset"); + auto path = dir.GenerateRandomFile(); + std::ofstream ofs(path, std::ios::binary); + IOStreamWriter writer(ofs); + bitset.Serialize(writer); + ofs.close(); + + std::ifstream ifs(path, std::ios::binary); + IOStreamReader reader(ifs); + SparseBitset bitset2; + bitset2.Deserialize(reader); + ifs.close(); + dumped = bitset.Dump(); + REQUIRE(dumped == "{100}"); } -TEST_CASE("BitsetImpl Or Test", "[ut][bitset]") { +TEST_CASE("SparseBitset Or Test", "[ut][bitset]") { SECTION("both empty") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset1.Or(bitset2); REQUIRE(bitset1.Count() == 0); REQUIRE(bitset1.Dump() == "{}"); } SECTION("empty and non-empty") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset2.Set(100, true); bitset1.Or(bitset2); REQUIRE(bitset1.Test(100)); @@ -69,8 +86,8 @@ TEST_CASE("BitsetImpl Or Test", "[ut][bitset]") { } SECTION("disjoint sets") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset1.Set(100, true); bitset2.Set(200, true); bitset1.Or(bitset2); @@ -81,8 +98,8 @@ TEST_CASE("BitsetImpl Or Test", "[ut][bitset]") { } SECTION("overlapping sets") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset1.Set(100, true); bitset2.Set(100, true); bitset1.Or(bitset2); @@ -91,26 +108,26 @@ TEST_CASE("BitsetImpl Or Test", "[ut][bitset]") { } } -TEST_CASE("BitsetImpl And Test", "[ut][bitset]") { +TEST_CASE("SparseBitset And Test", "[ut][bitset]") { SECTION("both empty") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset1.And(bitset2); REQUIRE(bitset1.Count() == 0); REQUIRE(bitset1.Dump() == "{}"); } SECTION("empty and non-empty") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset2.Set(100, true); bitset1.And(bitset2); REQUIRE(bitset1.Count() == 0); } SECTION("common elements") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset1.Set(100, true); bitset1.Set(200, true); bitset2.Set(200, true); @@ -124,8 +141,8 @@ TEST_CASE("BitsetImpl And Test", "[ut][bitset]") { } SECTION("no common elements") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset1.Set(100, true); bitset2.Set(200, true); bitset1.And(bitset2); @@ -134,18 +151,18 @@ TEST_CASE("BitsetImpl And Test", "[ut][bitset]") { } } -TEST_CASE("BitsetImpl Xor Test", "[ut][bitset]") { +TEST_CASE("SparseBitset Xor Test", "[ut][bitset]") { SECTION("both empty") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset1.Xor(bitset2); REQUIRE(bitset1.Count() == 0); REQUIRE(bitset1.Dump() == "{}"); } SECTION("empty and non-empty") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset2.Set(100, true); bitset1.Xor(bitset2); REQUIRE(bitset1.Test(100)); @@ -154,8 +171,8 @@ TEST_CASE("BitsetImpl Xor Test", "[ut][bitset]") { } SECTION("partial overlap") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset1.Set(100, true); bitset1.Set(200, true); bitset2.Set(200, true); @@ -169,8 +186,8 @@ TEST_CASE("BitsetImpl Xor Test", "[ut][bitset]") { } SECTION("identical sets") { - vsag::BitsetImpl bitset1; - vsag::BitsetImpl bitset2; + SparseBitset bitset1; + SparseBitset bitset2; bitset1.Set(100, true); bitset2.Set(100, true); bitset1.Xor(bitset2); diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index a75d23471..d8490ee52 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -1120,7 +1120,7 @@ HNSW::ExtractDataAndGraph(FlattenInterfacePtr& data, auto cur_element_count = hnsw->getCurrentElementCount(); int64_t origin_data_num = data->total_count_; int64_t valid_id_count = 0; - BitsetPtr bitset = std::make_shared(); + auto bitset = Bitset::Make(); for (auto i = 0; i < cur_element_count; ++i) { int64_t id = hnsw->getExternalLabel(i); auto [is_exist, new_id] = func(id); From 9dc7ef23dcb08f51a3101e00e1a45f2fe1020bb7 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Mon, 26 May 2025 14:52:43 +0800 Subject: [PATCH 18/42] add logger for kmeans train (#757) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- src/impl/kmeans_cluster.cpp | 54 +++++++++++++++++++++++++----------- src/impl/kmeans_cluster.h | 8 +++--- src/utils/util_functions.cpp | 10 +++++++ src/utils/util_functions.h | 3 ++ 4 files changed, 55 insertions(+), 20 deletions(-) diff --git a/src/impl/kmeans_cluster.cpp b/src/impl/kmeans_cluster.cpp index f049ea1e0..3e89c7da7 100644 --- a/src/impl/kmeans_cluster.cpp +++ b/src/impl/kmeans_cluster.cpp @@ -22,9 +22,10 @@ #include "algorithm/inner_index_interface.h" #include "byte_buffer.h" +#include "logger.h" #include "safe_allocator.h" #include "simd/fp32_simd.h" - +#include "utils/util_functions.h" namespace vsag { KMeansCluster::KMeansCluster(int32_t dim, Allocator* allocator, SafeThreadPoolPtr thread_pool) : dim_(dim), allocator_(allocator), thread_pool_(std::move(thread_pool)) { @@ -62,18 +63,29 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { Vector labels(count, -1, this->allocator_); std::vector mutexes(k); std::vector> futures; - constexpr uint64_t query_count_bs = 65536; ByteBuffer y_sqr_buffer(static_cast(k) * sizeof(float), allocator_); - ByteBuffer distances_buffer(static_cast(k) * query_count_bs * sizeof(float), - allocator_); + ByteBuffer distances_buffer(static_cast(k) * QUERY_BS * sizeof(float), allocator_); auto* y_sqr = reinterpret_cast(y_sqr_buffer.data); auto* distances = reinterpret_cast(distances_buffer.data); + double error = std::numeric_limits::max(); + + logger::debug("KMeansCluster::Run k: {}, count: {}, iter: {}", k, count, iter); + if (k < THRESHOLD_FOR_HGRAPH) { + logger::debug("KMeansCluster::Run use blas"); + } else { + logger::debug("KMeansCluster::Run use hgraph"); + } + for (int it = 0; it < iter; ++it) { + logger::debug("[{}] KMeansCluster::Run iter: {}/{}, cur loss is {}", + get_current_time(), + it, + iter, + error); if (k < THRESHOLD_FOR_HGRAPH) { - this->find_nearest_one_with_blas( - datas, count, k, query_count_bs, y_sqr, distances, labels); + error = this->find_nearest_one_with_blas(datas, count, k, y_sqr, distances, labels); } else { - this->find_nearest_one_with_hgraph(datas, count, k, query_count_bs, labels); + error = this->find_nearest_one_with_hgraph(datas, count, k, labels); } constexpr uint64_t bs = 1024; @@ -125,17 +137,19 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { return labels; } -void +double KMeansCluster::find_nearest_one_with_blas(const float* query, const uint64_t query_count, const uint64_t k, - const uint64_t query_count_bs, float* y_sqr, float* distances, Vector& labels) { + double error = 0.0; + if (k_centroids_ == nullptr) { - return; + throw VsagException(ErrorType::INTERNAL_ERROR, "k_centroids_ is nullptr"); } + auto& thread_pool = this->thread_pool_; auto bs = 1024; std::vector> futures; @@ -158,8 +172,8 @@ KMeansCluster::find_nearest_one_with_blas(const float* query, } wait_futures_and_clear(); - for (uint64_t i = 0; i < query_count; i += query_count_bs) { - auto end = std::min(i + query_count_bs, query_count); + for (uint64_t i = 0; i < query_count; i += QUERY_BS) { + auto end = std::min(i + QUERY_BS, query_count); auto cur_query_count = end - i; auto* cur_label = labels.data() + i; @@ -184,6 +198,7 @@ KMeansCluster::find_nearest_one_with_blas(const float* query, cblas_saxpy(static_cast(k), 1.0, y_sqr, 1, distances + i * k, 1); auto* min_elem = std::min_element(distances + i * k, distances + i * k + k); auto min_index = std::distance(distances + i * k, min_elem); + error += static_cast(*min_elem); if (min_index != cur_label[i]) { cur_label[i] = static_cast(min_index); } @@ -195,14 +210,19 @@ KMeansCluster::find_nearest_one_with_blas(const float* query, } wait_futures_and_clear(); } + return error / static_cast(query_count); } -void +double KMeansCluster::find_nearest_one_with_hgraph(const float* query, const uint64_t query_count, const uint64_t k, - const uint64_t query_count_bs, Vector& labels) { + if (k_centroids_ == nullptr) { + throw VsagException(ErrorType::INTERNAL_ERROR, "k_centroids_ is nullptr"); + } + double error = 0.0; + IndexCommonParam param; param.dim_ = dim_; param.allocator_ = std::make_shared(this->allocator_); @@ -230,16 +250,18 @@ KMeansCluster::find_nearest_one_with_hgraph(const float* query, ->Dim(this->dim_); auto ret = hgraph->KnnSearch(q, 1, search_param, filter); labels[j] = static_cast(ret->GetIds()[0]); + error += static_cast(ret->GetDistances()[0]); } }; std::vector> futures; - for (uint64_t i = 0; i < query_count; i += query_count_bs) { + for (uint64_t i = 0; i < query_count; i += QUERY_BS) { futures.emplace_back( - thread_pool_->GeneralEnqueue(func, i, std::min(i + query_count_bs, query_count))); + thread_pool_->GeneralEnqueue(func, i, std::min(i + QUERY_BS, query_count))); } for (auto& future : futures) { future.wait(); } + return error / static_cast(query_count); } } // namespace vsag diff --git a/src/impl/kmeans_cluster.h b/src/impl/kmeans_cluster.h index c14e8b2a0..9330f0d52 100644 --- a/src/impl/kmeans_cluster.h +++ b/src/impl/kmeans_cluster.h @@ -36,20 +36,18 @@ class KMeansCluster { float* k_centroids_{nullptr}; private: - void + double find_nearest_one_with_blas(const float* query, const uint64_t query_count, const uint64_t k, - const uint64_t query_count_bs, float* y_sqr, float* distances, Vector& labels); - void + double find_nearest_one_with_hgraph(const float* query, const uint64_t query_count, const uint64_t k, - const uint64_t query_count_bs, Vector& labels); private: @@ -60,6 +58,8 @@ class KMeansCluster { const int32_t dim_{0}; static constexpr uint64_t THRESHOLD_FOR_HGRAPH = 10000ULL; + + static constexpr uint64_t QUERY_BS = 65536ULL; }; } // namespace vsag diff --git a/src/utils/util_functions.cpp b/src/utils/util_functions.cpp index f3ffaa76e..3292fecc1 100644 --- a/src/utils/util_functions.cpp +++ b/src/utils/util_functions.cpp @@ -147,4 +147,14 @@ split_string(const std::string& str, const char delimiter) { return tokens; } +std::string +get_current_time() { + auto now = std::chrono::system_clock::now(); + auto now_c = std::chrono::system_clock::to_time_t(now); + std::tm* now_tm = std::localtime(&now_c); + std::ostringstream oss; + oss << std::put_time(now_tm, "%Y-%m-%d %H:%M:%S"); + return oss.str(); +} + } // namespace vsag diff --git a/src/utils/util_functions.h b/src/utils/util_functions.h index b2f8499f6..d4cd774c4 100644 --- a/src/utils/util_functions.h +++ b/src/utils/util_functions.h @@ -79,4 +79,7 @@ check_equal_on_string_stream(std::stringstream& s1, std::stringstream& s2); std::vector split_string(const std::string& str, const char delimiter); +std::string +get_current_time(); + } // namespace vsag From 11207ba5d277844175bae7a828cb7e5f60b7e428 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 31 Mar 2025 15:58:24 +0800 Subject: [PATCH 19/42] support gno-imi partition Signed-off-by: suguan.dx --- examples/cpp/106_index_ivf.cpp | 4 +- examples/cpp/108_index_gno_imi.cpp | 108 +++++ examples/cpp/CMakeLists.txt | 3 + include/vsag/constants.h | 2 + src/algorithm/ivf.cpp | 100 ++++- src/algorithm/ivf.h | 1 + src/algorithm/ivf_parameter.cpp | 22 +- src/algorithm/ivf_parameter.h | 14 +- src/algorithm/ivf_parameter_test.cpp | 64 ++- .../ivf_partition/gno_imi_parameter.cpp | 50 +++ .../ivf_partition/gno_imi_parameter.h | 41 ++ .../ivf_partition/gno_imi_parameter_test.cpp | 32 ++ .../ivf_partition/gno_imi_partition.cpp | 390 ++++++++++++++++++ .../ivf_partition/gno_imi_partition.h | 73 ++++ .../ivf_partition/gno_imi_partition_test.cpp | 182 ++++++++ .../ivf_partition/ivf_nearest_partition.cpp | 11 +- .../ivf_partition/ivf_nearest_partition.h | 11 +- .../ivf_nearest_partition_test.cpp | 13 +- .../ivf_partition/ivf_partition_strategy.h | 21 + .../ivf_partition_strategy_parameter.cpp | 73 ++++ .../ivf_partition_strategy_parameter.h | 54 +++ .../ivf_partition_strategy_parameter_test.cpp | 38 ++ src/constants.cpp | 5 + src/impl/basic_searcher.h | 2 + src/impl/kmeans_cluster.cpp | 45 +- src/impl/kmeans_cluster.h | 5 +- src/inner_string_params.h | 22 +- 27 files changed, 1335 insertions(+), 51 deletions(-) create mode 100644 examples/cpp/108_index_gno_imi.cpp create mode 100644 src/algorithm/ivf_partition/gno_imi_parameter.cpp create mode 100644 src/algorithm/ivf_partition/gno_imi_parameter.h create mode 100644 src/algorithm/ivf_partition/gno_imi_parameter_test.cpp create mode 100644 src/algorithm/ivf_partition/gno_imi_partition.cpp create mode 100644 src/algorithm/ivf_partition/gno_imi_partition.h create mode 100644 src/algorithm/ivf_partition/gno_imi_partition_test.cpp create mode 100644 src/algorithm/ivf_partition/ivf_partition_strategy_parameter.cpp create mode 100644 src/algorithm/ivf_partition/ivf_partition_strategy_parameter.h create mode 100644 src/algorithm/ivf_partition/ivf_partition_strategy_parameter_test.cpp diff --git a/examples/cpp/106_index_ivf.cpp b/examples/cpp/106_index_ivf.cpp index 5d71e6632..5b5657b75 100644 --- a/examples/cpp/106_index_ivf.cpp +++ b/examples/cpp/106_index_ivf.cpp @@ -49,7 +49,9 @@ main(int argc, char** argv) { "dim": 128, "index_param": { "buckets_count": 50, - "base_quantization_type": "fp32" + "base_quantization_type": "fp32", + "partition_strategy_type": "ivf", + "ivf_train_type": "kmeans" } } )"; diff --git a/examples/cpp/108_index_gno_imi.cpp b/examples/cpp/108_index_gno_imi.cpp new file mode 100644 index 000000000..cb055b543 --- /dev/null +++ b/examples/cpp/108_index_gno_imi.cpp @@ -0,0 +1,108 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +int +main(int argc, char** argv) { + vsag::init(); + + /******************* Prepare Base Dataset *****************/ + int64_t num_vectors = 10000; + int64_t dim = 128; + std::vector ids(num_vectors); + std::vector datas(num_vectors * dim); + std::mt19937 rng(47); + std::uniform_real_distribution distrib_real; + for (int64_t i = 0; i < num_vectors; ++i) { + ids[i] = i; + } + for (int64_t i = 0; i < dim * num_vectors; ++i) { + datas[i] = distrib_real(rng); + } + + auto base = vsag::Dataset::Make(); + base->NumElements(num_vectors) + ->Dim(dim) + ->Ids(ids.data()) + ->Float32Vectors(datas.data()) + ->Owner(false); + + /******************* Create IVF Index *****************/ + std::string ivf_build_params = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 128, + "index_param": { + "base_quantization_type": "fp32", + "partition_strategy_type": "gno_imi", + "ivf_train_type": "kmeans", + "first_order_buckets_count": 10, + "second_order_buckets_count": 10 + } + })"; + auto index = vsag::Factory::CreateIndex("ivf", ivf_build_params).value(); + + /******************* Build IVF Index *****************/ + + auto path = "example_gno_imi.index"; + if (auto build_result = index->Build(base); build_result.has_value()) { + std::ofstream outfile(path, std::ios::out | std::ios::binary); + auto result = index->Serialize(outfile); + outfile.close(); + std::cout << "After Build(), Index IVF contains: " << index->GetNumElements() << std::endl; + } else if (build_result.error().type == vsag::ErrorType::INTERNAL_ERROR) { + std::cerr << "Failed to build index: internalError" << std::endl; + exit(-1); + } + + std::ifstream infile(path, std::ios::binary); + auto index2 = vsag::Factory::CreateIndex("ivf", ivf_build_params).value(); + index2->Deserialize(infile); + infile.close(); + + /******************* Prepare Query Dataset *****************/ + std::vector query_vector(dim); + for (int64_t i = 0; i < dim; ++i) { + query_vector[i] = distrib_real(rng); + } + + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector.data())->Owner(false); + + /******************* KnnSearch For IVF Index *****************/ + auto ivf_search_parameters = R"( + { + "ivf": { + "scan_buckets_count": 20, + "first_order_scan_ratio": 0.8 + } + })"; + int64_t topk = 10; + auto result = index->KnnSearch(query, topk, ivf_search_parameters).value(); + auto result2 = index2->KnnSearch(query, topk, ivf_search_parameters).value(); + /******************* Print Search Result *****************/ + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + std::cout << result2->GetIds()[i] << ": " << result2->GetDistances()[i] << std::endl; + } + + return 0; +} diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 5ad1b32ee..bbeaaa8bc 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -19,6 +19,9 @@ target_link_libraries (106_index_ivf vsag) add_executable (107_index_pyramid 107_index_pyramid.cpp) target_link_libraries (107_index_pyramid vsag) +add_executable (108_index_gno_imi 108_index_gno_imi.cpp) +target_link_libraries (108_index_gno_imi vsag) + add_executable (201_custom_allocator 201_custom_allocator.cpp) target_link_libraries (201_custom_allocator vsag) diff --git a/include/vsag/constants.h b/include/vsag/constants.h index c4ddf622c..5075b5dcb 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -165,4 +165,6 @@ extern const char* const IVF_PRECISE_QUANTIZATION_TYPE; extern const char* const IVF_PRECISE_IO_TYPE; extern const char* const IVF_PRECISE_FILE_PATH; +extern const char* const GNO_IMI_FIRST_ORDER_BUCKETS_COUNT; +extern const char* const GNO_IMI_SECOND_ORDER_BUCKETS_COUNT; } // namespace vsag diff --git a/src/algorithm/ivf.cpp b/src/algorithm/ivf.cpp index d83baa2f6..9becf5e19 100644 --- a/src/algorithm/ivf.cpp +++ b/src/algorithm/ivf.cpp @@ -15,9 +15,12 @@ #include "ivf.h" +#include + #include "impl/basic_searcher.h" #include "index/index_impl.h" #include "inner_string_params.h" +#include "ivf_partition/gno_imi_partition.h" #include "ivf_partition/ivf_nearest_partition.h" #include "utils/standard_heap.h" #include "utils/util_functions.h" @@ -28,7 +31,6 @@ static constexpr const char* IVF_PARAMS_TEMPLATE = R"( { "type": "{INDEX_TYPE_IVF}", - "{IVF_TRAIN_TYPE_KEY}": "{IVF_TRAIN_TYPE_KMEANS}", "{BUCKET_PARAMS_KEY}": { "{IO_PARAMS_KEY}": { "{IO_TYPE_KEY}": "{IO_TYPE_VALUE_BLOCK_MEMORY_IO}" @@ -42,6 +44,15 @@ static constexpr const char* IVF_PARAMS_TEMPLATE = }, "{BUCKETS_COUNT_KEY}": 10 }, + "{IVF_PARTITION_STRATEGY_PARAMS_KEY}": { + "{IVF_PARTITION_STRATEGY_TYPE_KEY}": "{IVF_PARTITION_STRATEGY_TYPE_NEAREST}", + "{IVF_TRAIN_TYPE_KEY}": "{IVF_TRAIN_TYPE_KMEANS}", + "{IVF_PARTITION_STRATEGY_TYPE_GNO_IMI}": { + "{GNO_IMI_FIRST_ORDER_BUCKETS_COUNT_KEY}": 10, + "{GNO_IMI_SECOND_ORDER_BUCKETS_COUNT_KEY}": 10 + } + }, + "{BUCKET_PER_DATA_KEY}": 1, "{IVF_USE_REORDER_KEY}": false, "{IVF_PRECISE_CODES_KEY}": { "{IO_PARAMS_KEY}": { @@ -82,7 +93,27 @@ IVF::CheckAndMappingExternalParam(const JsonType& external_param, }, { IVF_TRAIN_TYPE, - {IVF_TRAIN_TYPE_KEY}, + {IVF_PARTITION_STRATEGY_PARAMS_KEY, IVF_TRAIN_TYPE_KEY}, + }, + { + IVF_PARTITION_STRATEGY_TYPE_KEY, + {IVF_PARTITION_STRATEGY_PARAMS_KEY, IVF_PARTITION_STRATEGY_TYPE_KEY}, + }, + { + GNO_IMI_FIRST_ORDER_BUCKETS_COUNT, + {IVF_PARTITION_STRATEGY_PARAMS_KEY, + IVF_PARTITION_STRATEGY_TYPE_GNO_IMI, + GNO_IMI_FIRST_ORDER_BUCKETS_COUNT_KEY}, + }, + { + GNO_IMI_SECOND_ORDER_BUCKETS_COUNT, + {IVF_PARTITION_STRATEGY_PARAMS_KEY, + IVF_PARTITION_STRATEGY_TYPE_GNO_IMI, + GNO_IMI_SECOND_ORDER_BUCKETS_COUNT_KEY}, + }, + { + BUCKET_PER_DATA_KEY, + {BUCKET_PER_DATA_KEY}, }, { IVF_USE_REORDER, @@ -114,13 +145,20 @@ IVF::CheckAndMappingExternalParam(const JsonType& external_param, } IVF::IVF(const IVFParameterPtr& param, const IndexCommonParam& common_param) - : InnerIndexInterface(param, common_param) { + : InnerIndexInterface(param, common_param), buckets_per_data_(param->buckets_per_data) { this->bucket_ = BucketInterface::MakeInstance(param->bucket_param, common_param); if (this->bucket_ == nullptr) { throw VsagException(ErrorType::INTERNAL_ERROR, "bucket init error"); } - this->partition_strategy_ = std::make_shared( - bucket_->bucket_count_, common_param, param->partition_train_type); + if (param->ivf_partition_strategy_parameter->partition_strategy_type == + IVFPartitionStrategyType::IVF) { + this->partition_strategy_ = std::make_shared( + bucket_->bucket_count_, common_param, param->ivf_partition_strategy_parameter); + } else if (param->ivf_partition_strategy_parameter->partition_strategy_type == + IVFPartitionStrategyType::GNO_IMI) { + this->partition_strategy_ = std::make_shared( + common_param, param->ivf_partition_strategy_parameter); + } this->use_reorder_ = param->use_reorder; if (this->use_reorder_) { this->reorder_codes_ = FlattenInterface::MakeInstance(param->flatten_param, common_param); @@ -204,10 +242,14 @@ IVF::Add(const DatasetPtr& base) { auto num_element = base->GetNumElements(); const auto* ids = base->GetIds(); const auto* vectors = base->GetFloat32Vectors(); - auto buckets = partition_strategy_->ClassifyDatas(vectors, num_element, 1); + auto buckets = partition_strategy_->ClassifyDatas(vectors, num_element, buckets_per_data_); for (int64_t i = 0; i < num_element; ++i) { - bucket_->InsertVector(vectors + i * dim_, buckets[i], i + total_elements_); - this->label_table_->Insert(i + total_elements_, ids[i]); + for (int64_t j = 0; j < buckets_per_data_; ++j) { + auto idx = i * buckets_per_data_ + j; + bucket_->InsertVector( + vectors + i * dim_, buckets[idx], idx + total_elements_ * buckets_per_data_); + this->label_table_->Insert(idx + total_elements_ * buckets_per_data_, ids[i]); + } } this->bucket_->Package(); if (use_reorder_) { @@ -312,6 +354,7 @@ IVF::Deserialize(StreamReader& reader) { this->reorder_codes_->Deserialize(reader); } } + InnerSearchParam IVF::create_search_param(const std::string& parameters, const FilterPtr& filter) const { InnerSearchParam param; @@ -324,6 +367,7 @@ IVF::create_search_param(const std::string& parameters, const FilterPtr& filter) param.scan_bucket_size = std::min(static_cast(search_param.scan_buckets_count), bucket_->bucket_count_); param.factor = search_param.topk_factor; + param.first_order_scan_ratio = search_param.first_order_scan_ratio; return param; } @@ -363,7 +407,7 @@ DistHeapPtr IVF::search(const DatasetPtr& query, const InnerSearchParam& param) const { auto search_result = std::make_shared>(allocator_, -1); auto candidate_buckets = - partition_strategy_->ClassifyDatas(query->GetFloat32Vectors(), 1, param.scan_bucket_size); + partition_strategy_->ClassifyDatasForSearch(query->GetFloat32Vectors(), 1, param); auto computer = bucket_->FactoryComputer(query->GetFloat32Vectors()); Vector dist(allocator_); auto cur_heap_top = std::numeric_limits::max(); @@ -374,6 +418,17 @@ IVF::search(const DatasetPtr& query, const InnerSearchParam& param) const { topk = std::numeric_limits::max(); } } + + // Scale topk to ensure sufficient candidates after deduplication when buckets_per_data_ > 1 + int64_t origin_topk = topk; + if (buckets_per_data_ > 1) { + if (topk <= std::numeric_limits::max() / buckets_per_data_) { + topk *= buckets_per_data_; + } else { + topk = std::numeric_limits::max(); + } + } + const auto& ft = param.is_inner_id_allowed; for (auto& bucket_id : candidate_buckets) { auto bucket_size = bucket_->GetBucketSize(bucket_id); @@ -402,6 +457,33 @@ IVF::search(const DatasetPtr& query, const InnerSearchParam& param) const { } } } + + // Deduplicate ids when buckets_per_data_ > 1 + if (buckets_per_data_ > 1) { + std::unordered_map id_to_min_dist; + while (!search_result->Empty()) { + auto [dist_val, id] = search_result->Top(); + search_result->Pop(); + // Keep the smallest distance for each id + if (id_to_min_dist.find(id) == id_to_min_dist.end() || dist_val < id_to_min_dist[id]) { + id_to_min_dist[id] = dist_val; + } + } + + auto cur_heap_top = std::numeric_limits::max(); + for (const auto& [id, dist_val] : id_to_min_dist) { + if (dist_val < cur_heap_top) { + search_result->Push(dist_val, id); + } + if (search_result->Size() > origin_topk) { + search_result->Pop(); + } + if (not search_result->Empty() and search_result->Size() == origin_topk) { + cur_heap_top = search_result->Top().first; + } + } + } + return search_result; } diff --git a/src/algorithm/ivf.h b/src/algorithm/ivf.h index 9ff0f7472..92cd6dd32 100644 --- a/src/algorithm/ivf.h +++ b/src/algorithm/ivf.h @@ -114,6 +114,7 @@ class IVF : public InnerIndexInterface { BucketInterfacePtr bucket_{nullptr}; IVFPartitionStrategyPtr partition_strategy_{nullptr}; + BucketIdType buckets_per_data_; int64_t total_elements_{0}; diff --git a/src/algorithm/ivf_parameter.cpp b/src/algorithm/ivf_parameter.cpp index 8f314f9e3..eb3896ebc 100644 --- a/src/algorithm/ivf_parameter.cpp +++ b/src/algorithm/ivf_parameter.cpp @@ -25,16 +25,26 @@ IVFParameter::IVFParameter() = default; void IVFParameter::FromJson(const JsonType& json) { - if (json[IVF_TRAIN_TYPE_KEY] == IVF_TRAIN_TYPE_KMEANS) { - this->partition_train_type = IVFNearestPartitionTrainerType::KMeansTrainer; - } else if (json[IVF_TRAIN_TYPE_KEY] == IVF_TRAIN_TYPE_RANDOM) { - this->partition_train_type = IVFNearestPartitionTrainerType::RandomTrainer; - } this->bucket_param = std::make_shared(); CHECK_ARGUMENT(json.contains(BUCKET_PARAMS_KEY), fmt::format("ivf parameters must contains {}", BUCKET_PARAMS_KEY)); this->bucket_param->FromJson(json[BUCKET_PARAMS_KEY]); + this->ivf_partition_strategy_parameter = std::make_shared(); + if (json.contains(IVF_PARTITION_STRATEGY_PARAMS_KEY)) { + this->ivf_partition_strategy_parameter->FromJson(json[IVF_PARTITION_STRATEGY_PARAMS_KEY]); + } + + if (this->ivf_partition_strategy_parameter->partition_strategy_type == + IVFPartitionStrategyType::GNO_IMI) { + this->bucket_param->buckets_count = static_cast( + this->ivf_partition_strategy_parameter->gnoimi_param->first_order_buckets_count * + this->ivf_partition_strategy_parameter->gnoimi_param->second_order_buckets_count); + } + + if (json.contains(BUCKET_PER_DATA_KEY)) { + this->buckets_per_data = json[BUCKET_PER_DATA_KEY]; + } if (json.contains(IVF_USE_REORDER_KEY)) { this->use_reorder = json[IVF_USE_REORDER_KEY]; } @@ -53,6 +63,8 @@ IVFParameter::ToJson() { JsonType json; json["type"] = INDEX_IVF; json[BUCKET_PARAMS_KEY] = this->bucket_param->ToJson(); + json[IVF_PARTITION_STRATEGY_PARAMS_KEY] = this->ivf_partition_strategy_parameter->ToJson(); + json[BUCKET_PER_DATA_KEY] = this->buckets_per_data; json[IVF_USE_REORDER_KEY] = this->use_reorder; if (use_reorder) { json[IVF_PRECISE_CODES_KEY] = this->flatten_param->ToJson(); diff --git a/src/algorithm/ivf_parameter.h b/src/algorithm/ivf_parameter.h index 5f52fb886..ee659d7b2 100644 --- a/src/algorithm/ivf_parameter.h +++ b/src/algorithm/ivf_parameter.h @@ -15,6 +15,7 @@ #pragma once #include "algorithm/ivf_partition/ivf_nearest_partition.h" +#include "algorithm/ivf_partition/ivf_partition_strategy_parameter.h" #include "data_cell/bucket_datacell_parameter.h" #include "data_cell/flatten_datacell_parameter.h" #include "fmt/format-inl.h" @@ -35,15 +36,13 @@ class IVFParameter : public Parameter { public: BucketDataCellParamPtr bucket_param{nullptr}; - + IVFPartitionStrategyParametersPtr ivf_partition_strategy_parameter{nullptr}; + BucketIdType buckets_per_data{1}; bool use_residual{false}; bool use_reorder{false}; FlattenDataCellParamPtr flatten_param{nullptr}; - - IVFNearestPartitionTrainerType partition_train_type{ - IVFNearestPartitionTrainerType::KMeansTrainer}; }; using IVFParameterPtr = std::shared_ptr; @@ -69,13 +68,18 @@ class IVFSearchParameters { if (params[INDEX_TYPE_IVF].contains(IVF_SEARCH_PARAM_FACTOR)) { obj.topk_factor = params[INDEX_TYPE_IVF][IVF_SEARCH_PARAM_FACTOR]; } + + if (params[INDEX_TYPE_IVF].contains(GNO_IMI_SEARCH_PARAM_FIRST_ORDER_SCAN_RATIO)) { + obj.first_order_scan_ratio = + params[INDEX_TYPE_IVF][GNO_IMI_SEARCH_PARAM_FIRST_ORDER_SCAN_RATIO]; + } return obj; } public: int64_t scan_buckets_count{30}; - float topk_factor{2.0F}; + float first_order_scan_ratio{1.0F}; private: IVFSearchParameters() = default; diff --git a/src/algorithm/ivf_parameter_test.cpp b/src/algorithm/ivf_parameter_test.cpp index b1aeb0618..11947b078 100644 --- a/src/algorithm/ivf_parameter_test.cpp +++ b/src/algorithm/ivf_parameter_test.cpp @@ -22,7 +22,6 @@ TEST_CASE("IVF Parameters Test", "[ut][IVFParameter]") { auto param_str = R"({ "type": "ivf", - "ivf_train_type": "random", "buckets_params": { "io_params": { "type": "block_memory_io" @@ -42,11 +41,72 @@ TEST_CASE("IVF Parameters Test", "[ut][IVFParameter]") { } } })"; + vsag::JsonType param_json = vsag::JsonType::parse(param_str); auto param = std::make_shared(); param->FromJson(param_json); REQUIRE(param->bucket_param->buckets_count == 3); - REQUIRE(param->partition_train_type == vsag::IVFNearestPartitionTrainerType::RandomTrainer); + REQUIRE(param->ivf_partition_strategy_parameter->partition_strategy_type == + vsag::IVFPartitionStrategyType::IVF); + REQUIRE(param->ivf_partition_strategy_parameter->partition_train_type == + vsag::IVFNearestPartitionTrainerType::KMeansTrainer); + REQUIRE(param->buckets_per_data == 1); REQUIRE(param->use_reorder == true); REQUIRE(param->flatten_param->quantizer_parameter->GetTypeName() == "fp32"); + + param_str = R"({ + "type": "ivf", + "buckets_params": { + "io_params": { + "type": "block_memory_io" + }, + "quantization_params": { + "type": "fp32" + }, + "buckets_count": 3 + }, + "partition_strategy": { + "partition_strategy_type": "gno_imi", + "ivf_train_type": "random", + "gno_imi": { + "first_order_buckets_count": 200, + "second_order_buckets_count": 50 + } + }, + "buckets_per_data": 2 + })"; + param_json = vsag::JsonType::parse(param_str); + param = std::make_shared(); + param->FromJson(param_json); + REQUIRE(param->bucket_param->buckets_count == 200 * 50); + REQUIRE(param->ivf_partition_strategy_parameter->partition_strategy_type == + vsag::IVFPartitionStrategyType::GNO_IMI); + REQUIRE(param->ivf_partition_strategy_parameter->partition_train_type == + vsag::IVFNearestPartitionTrainerType::RandomTrainer); + REQUIRE(param->ivf_partition_strategy_parameter->gnoimi_param->first_order_buckets_count == + 200); + REQUIRE(param->ivf_partition_strategy_parameter->gnoimi_param->second_order_buckets_count == + 50); + REQUIRE(param->buckets_per_data == 2); + + param_str = R"( + { + "ivf": { + "scan_buckets_count": 10 + } + })"; + auto search_param = vsag::IVFSearchParameters::FromJson(param_str); + REQUIRE(search_param.scan_buckets_count == 10); + REQUIRE(search_param.first_order_scan_ratio == 1.0f); + + param_str = R"( + { + "ivf": { + "scan_buckets_count": 20, + "first_order_scan_ratio": 0.1 + } + })"; + search_param = vsag::IVFSearchParameters::FromJson(param_str); + REQUIRE(search_param.scan_buckets_count == 20); + REQUIRE(search_param.first_order_scan_ratio == 0.1f); } diff --git a/src/algorithm/ivf_partition/gno_imi_parameter.cpp b/src/algorithm/ivf_partition/gno_imi_parameter.cpp new file mode 100644 index 000000000..8e891ccd2 --- /dev/null +++ b/src/algorithm/ivf_partition/gno_imi_parameter.cpp @@ -0,0 +1,50 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gno_imi_parameter.h" + +#include + +#include + +#include "inner_string_params.h" +#include "vsag/constants.h" + +namespace vsag { + +GNOIMIParameter::GNOIMIParameter() = default; + +void +GNOIMIParameter::FromJson(const JsonType& json) { + CHECK_ARGUMENT( + json.contains(GNO_IMI_FIRST_ORDER_BUCKETS_COUNT_KEY), + fmt::format("ivf parameters must contains {}", GNO_IMI_FIRST_ORDER_BUCKETS_COUNT_KEY)); + this->first_order_buckets_count = json[GNO_IMI_FIRST_ORDER_BUCKETS_COUNT_KEY]; + + if (json.contains(GNO_IMI_SECOND_ORDER_BUCKETS_COUNT_KEY)) { + this->second_order_buckets_count = json[GNO_IMI_SECOND_ORDER_BUCKETS_COUNT_KEY]; + } else { + this->second_order_buckets_count = this->first_order_buckets_count; + } +} + +JsonType +GNOIMIParameter::ToJson() { + JsonType json; + json[GNO_IMI_FIRST_ORDER_BUCKETS_COUNT_KEY] = this->first_order_buckets_count; + json[GNO_IMI_SECOND_ORDER_BUCKETS_COUNT_KEY] = this->second_order_buckets_count; + return json; +} +} // namespace vsag diff --git a/src/algorithm/ivf_partition/gno_imi_parameter.h b/src/algorithm/ivf_partition/gno_imi_parameter.h new file mode 100644 index 000000000..c33850cdc --- /dev/null +++ b/src/algorithm/ivf_partition/gno_imi_parameter.h @@ -0,0 +1,41 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "data_cell/bucket_datacell_parameter.h" +#include "fmt/format-inl.h" +#include "inner_string_params.h" +#include "parameter.h" +#include "typing.h" + +namespace vsag { +class GNOIMIParameter : public Parameter { +public: + explicit GNOIMIParameter(); + + void + FromJson(const JsonType& json) override; + + JsonType + ToJson() override; + +public: + BucketIdType first_order_buckets_count{100}; + BucketIdType second_order_buckets_count{100}; +}; + +using GNOIMIParameterPtr = std::shared_ptr; + +} // namespace vsag diff --git a/src/algorithm/ivf_partition/gno_imi_parameter_test.cpp b/src/algorithm/ivf_partition/gno_imi_parameter_test.cpp new file mode 100644 index 000000000..7d55789b8 --- /dev/null +++ b/src/algorithm/ivf_partition/gno_imi_parameter_test.cpp @@ -0,0 +1,32 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gno_imi_parameter.h" + +#include + +#include "parameter_test.h" + +TEST_CASE("GNO-IMI Parameters Test", "[ut][GNOIMIParameter]") { + auto param_str = R"({ + "first_order_buckets_count": 200, + "second_order_buckets_count": 50 + })"; + vsag::JsonType param_json = vsag::JsonType::parse(param_str); + auto param = std::make_shared(); + param->FromJson(param_json); + REQUIRE(param->first_order_buckets_count == 200); + REQUIRE(param->second_order_buckets_count == 50); +} diff --git a/src/algorithm/ivf_partition/gno_imi_partition.cpp b/src/algorithm/ivf_partition/gno_imi_partition.cpp new file mode 100644 index 000000000..40b36aa71 --- /dev/null +++ b/src/algorithm/ivf_partition/gno_imi_partition.cpp @@ -0,0 +1,390 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gno_imi_partition.h" + +#include + +#include +#include + +#include "impl/kmeans_cluster.h" +#include "inner_string_params.h" +#include "safe_allocator.h" +#include "utils/util_functions.h" + +namespace vsag { + +static constexpr const char* SEARCH_PARAM_TEMPLATE_STR = R"( +{{ + "hnsw": {{ + "ef_search": {} + }} +}} +)"; + +// C = A * B^T +void +matmul(const float* A, const float* B, float* C, int64_t M, int64_t N, int64_t K) { + cblas_sgemm(CblasColMajor, + CblasTrans, + CblasNoTrans, + static_cast(N), + static_cast(M), + static_cast(K), + 1.0F, + B, + static_cast(K), + A, + static_cast(K), + 0.0F, + C, + static_cast(N)); +} + +GNOIMIPartition::GNOIMIPartition(const IndexCommonParam& common_param, + const IVFPartitionStrategyParametersPtr& param) + : IVFPartitionStrategy(common_param, + param->gnoimi_param->first_order_buckets_count * + param->gnoimi_param->second_order_buckets_count), + bucket_count_s_(param->gnoimi_param->first_order_buckets_count), + bucket_count_t_(param->gnoimi_param->second_order_buckets_count), + data_centroids_s_(allocator_), + data_centroids_t_(allocator_), + norms_s_(allocator_), + norms_t_(allocator_), + precomputed_terms_st_(allocator_), + common_param_(common_param) { + data_centroids_s_.resize(bucket_count_s_ * dim_); + data_centroids_t_.resize(bucket_count_t_ * dim_); + norms_s_.resize(bucket_count_s_); + norms_t_.resize(bucket_count_t_); + precomputed_terms_st_.resize(static_cast(bucket_count_s_) * bucket_count_t_); + + param_ptr_ = std::make_shared(); + param_ptr_->flatten_param = std::make_shared(); + JsonType memory_json = { + {"type", IO_TYPE_VALUE_BLOCK_MEMORY_IO}, + }; + param_ptr_->flatten_param->io_parameter = IOParameter::GetIOParameterByJson(memory_json); + JsonType quantizer_json = { + {"type", QUANTIZATION_TYPE_VALUE_FP32}, + }; + param_ptr_->flatten_param->quantizer_parameter = + QuantizerParameter::GetQuantizerParameterByJson(quantizer_json); +} + +void +GNOIMIPartition::Train(const DatasetPtr dataset) { + auto dim = this->dim_; + auto centroids_s = Dataset::Make(); + auto centroids_t = Dataset::Make(); + const auto* vectors = dataset->GetFloat32Vectors(); + auto num_element = dataset->GetNumElements(); + Vector ids_centroids_s(this->bucket_count_s_, allocator_); + Vector ids_centroids_t(this->bucket_count_t_, allocator_); + Vector data_centroids_s_tmp(this->bucket_count_s_ * dim_, allocator_); + Vector data_centroids_t_tmp(this->bucket_count_t_ * dim_, allocator_); + + std::iota(ids_centroids_s.begin(), ids_centroids_s.end(), 0); + std::iota(ids_centroids_t.begin(), ids_centroids_t.end(), 0); + centroids_s->Ids(ids_centroids_s.data()) + ->Dim(dim) + ->Float32Vectors(data_centroids_s_tmp.data()) + ->NumElements(this->bucket_count_s_) + ->Owner(false); + centroids_t->Ids(ids_centroids_t.data()) + ->Dim(dim) + ->Float32Vectors(data_centroids_t_tmp.data()) + ->NumElements(this->bucket_count_t_) + ->Owner(false); + + KMeansCluster cls(static_cast(dim), this->allocator_); + Vector residuals(vectors, vectors + num_element * dim, allocator_); + + auto train_and_get_residual = [&, this](const DatasetPtr& centroids, + float* data_centroids, + double* err) { + cls.Run(centroids->GetNumElements(), residuals.data(), num_element, 30, err); + memcpy(data_centroids, cls.k_centroids_, dim * centroids->GetNumElements() * sizeof(float)); + BruteForce route_index(param_ptr_, common_param_); + auto build_result = route_index.Build(centroids); + auto assign = this->inner_classify_datas(route_index, residuals.data(), num_element); + this->GetResidual(num_element, vectors, residuals.data(), data_centroids, assign.data()); + }; + + // train loop + double min_err = std::numeric_limits::max(); + for (size_t i = 0; i < 2; ++i) { + double err_to_s = 0.0; + double err_to_t = 0.0; + train_and_get_residual(centroids_s, data_centroids_s_tmp.data(), &err_to_s); + logger::info("gnoimi train iter: {}, err of centroids_s: {}", i, err_to_s); + + train_and_get_residual(centroids_t, data_centroids_t_tmp.data(), &err_to_t); + logger::info("gnoimi train iter: {}, err of centroids_t: {}", i, err_to_t); + + if (err_to_t < min_err) { + min_err = err_to_t; + std::copy(data_centroids_s_tmp.begin(), + data_centroids_s_tmp.end(), + data_centroids_s_.begin()); + std::copy(data_centroids_t_tmp.begin(), + data_centroids_t_tmp.end(), + data_centroids_t_.begin()); + } + } + + for (BucketIdType i = 0; i < bucket_count_s_; ++i) { + auto norm_sqr = FP32ComputeIP( + data_centroids_s_.data() + i * dim_, data_centroids_s_.data() + i * dim_, dim_); + norms_s_[i] = norm_sqr / 2; + } + + Vector> norms_t(bucket_count_t_, this->allocator_); + for (BucketIdType i = 0; i < bucket_count_t_; ++i) { + auto norm_sqr = FP32ComputeIP( + data_centroids_t_.data() + i * dim_, data_centroids_t_.data() + i * dim_, dim_); + norms_t[i].first = norm_sqr / 2; + norms_t[i].second = i; + } + + // Rearrange data_centroids_t_ based on ascending order of their norms. + std::sort(norms_t.begin(), norms_t.end(), std::greater<>()); + std::vector temp_data(bucket_count_t_ * dim_, 0.0); + for (BucketIdType i = 0; i < bucket_count_t_; ++i) { + BucketIdType src_idx = norms_t[i].second; + size_t src_offset = src_idx * dim_; + size_t dst_offset = i * dim_; + std::copy(data_centroids_t_.data() + src_offset, + data_centroids_t_.data() + src_offset + dim_, + temp_data.data() + dst_offset); + norms_t_[i] = norms_t[i].first; + } + std::copy(temp_data.begin(), temp_data.end(), data_centroids_t_.data()); + + Vector ip_st(static_cast(bucket_count_s_) * bucket_count_t_, allocator_); + matmul(data_centroids_s_.data(), + data_centroids_t_.data(), + ip_st.data(), + bucket_count_s_, + bucket_count_t_, + dim_); + for (BucketIdType i = 0; i < bucket_count_s_ * bucket_count_t_; ++i) { + BucketIdType cur_bucket_id_t = i % bucket_count_t_; + precomputed_terms_st_[i] = norms_t_[cur_bucket_id_t] + ip_st[i]; + } + + this->is_trained_ = true; +} + +Vector +GNOIMIPartition::ClassifyDatas(const void* datas, int64_t count, BucketIdType buckets_per_data) { + Vector result(buckets_per_data * count, this->allocator_); + inner_joint_classify_datas( + reinterpret_cast(datas), count, buckets_per_data, result.data()); + return result; +} + +Vector +GNOIMIPartition::ClassifyDatasForSearch(const void* datas, + int64_t count, + const InnerSearchParam& param) { + auto buckets_per_data = param.scan_bucket_size; + Vector result(buckets_per_data * count, this->allocator_); + auto candidate_count_s = bucket_count_s_; + Vector candidate_s_id(candidate_count_s, this->allocator_); + Vector candidate_s_dist(candidate_count_s, this->allocator_); + Vector dist_to_s(bucket_count_s_ * count, this->allocator_); + Vector dist_to_t(bucket_count_t_ * count, this->allocator_); + auto* dist_to_s_data = dist_to_s.data(); + auto* dist_to_t_data = dist_to_t.data(); + auto* candidate_s_id_data = candidate_s_id.data(); + auto* candidate_s_dist_data = candidate_s_dist.data(); + + matmul(reinterpret_cast(datas), + data_centroids_s_.data(), + dist_to_s_data, + count, + bucket_count_s_, + dim_); + matmul(reinterpret_cast(datas), + data_centroids_t_.data(), + dist_to_t_data, + count, + bucket_count_t_, + dim_); + + for (size_t i = 0; i < count; i++) { + auto qnorm = FP32ComputeIP(reinterpret_cast(datas) + i * dim_, + reinterpret_cast(datas) + i * dim_, + dim_) / + 2; + MaxHeap heap(this->allocator_); + for (size_t j = 0; j < bucket_count_s_; ++j) { + auto dist_term_s = norms_s_[j] - dist_to_s_data[i * bucket_count_s_ + j]; + if (heap.size() < candidate_count_s || dist_term_s < heap.top().first) { + heap.emplace(dist_term_s, j); + } + if (heap.size() > candidate_count_s) { + heap.pop(); + } + } + for (auto j = static_cast(candidate_count_s - 1); j >= 0; --j) { + candidate_s_id_data[j] = static_cast(heap.top().second); + candidate_s_dist_data[j] = heap.top().first; + heap.pop(); + } + CHECK_ARGUMENT(heap.empty(), fmt::format("Unexpected non-empty heap after pop candidates")); + + auto scan_bucket_count_s = static_cast( + std::floor(static_cast(bucket_count_s_) * param.first_order_scan_ratio)); + scan_bucket_count_s = std::max(scan_bucket_count_s, 1); + for (size_t j = 0; j < scan_bucket_count_s; ++j) { + for (size_t k = 0; k < bucket_count_t_; ++k) { + auto cur_bucket_id_s = candidate_s_id_data[j]; + auto cur_bucket_id_t = k; + float dist_term_st = candidate_s_dist_data[j] + + precomputed_terms_st_[static_cast( + cur_bucket_id_s * bucket_count_t_) + + cur_bucket_id_t] - + dist_to_t_data[i * bucket_count_t_ + cur_bucket_id_t]; + + auto cur_bucket_id_global = + static_cast(cur_bucket_id_s) * bucket_count_t_ + cur_bucket_id_t; + if (heap.size() < buckets_per_data || dist_term_st < heap.top().first) { + heap.emplace(dist_term_st, cur_bucket_id_global); + } + if (heap.size() > buckets_per_data) { + heap.pop(); + } + } + } + BucketIdType size = std::min((BucketIdType)heap.size(), buckets_per_data); + for (auto j = static_cast(size - 1); j >= 0 && !heap.empty(); --j) { + result[i * buckets_per_data + j] = static_cast(heap.top().second); + heap.pop(); + } + } + return result; +} + +void +GNOIMIPartition::Serialize(StreamWriter& writer) { + IVFPartitionStrategy::Serialize(writer); + StreamWriter::WriteObj(writer, this->bucket_count_s_); + StreamWriter::WriteObj(writer, this->bucket_count_t_); + StreamWriter::WriteVector(writer, this->data_centroids_s_); + StreamWriter::WriteVector(writer, this->data_centroids_t_); + StreamWriter::WriteVector(writer, this->norms_s_); + StreamWriter::WriteVector(writer, this->norms_t_); + StreamWriter::WriteVector(writer, this->precomputed_terms_st_); +} +void +GNOIMIPartition::Deserialize(StreamReader& reader) { + IVFPartitionStrategy::Deserialize(reader); + StreamReader::ReadObj(reader, this->bucket_count_s_); + StreamReader::ReadObj(reader, this->bucket_count_t_); + StreamReader::ReadVector(reader, this->data_centroids_s_); + StreamReader::ReadVector(reader, this->data_centroids_t_); + StreamReader::ReadVector(reader, this->norms_s_); + StreamReader::ReadVector(reader, this->norms_t_); + StreamReader::ReadVector(reader, this->precomputed_terms_st_); +} + +Vector +GNOIMIPartition::inner_classify_datas(BruteForce& route_index, const float* datas, int64_t count) { + BucketIdType buckets_per_data = 1; + Vector result(buckets_per_data * count, this->allocator_); + for (int64_t i = 0; i < count; ++i) { + auto query = Dataset::Make(); + query->Dim(this->dim_) + ->Float32Vectors(datas + i * this->dim_) + ->NumElements(1) + ->Owner(false); + auto search_param = fmt::format( + SEARCH_PARAM_TEMPLATE_STR, std::max(10L, static_cast(buckets_per_data * 1.2))); + FilterPtr filter = nullptr; + auto search_result = route_index.KnnSearch(query, buckets_per_data, search_param, filter); + const auto* result_ids = search_result->GetIds(); + + for (int64_t j = 0; j < buckets_per_data; ++j) { + result[i * buckets_per_data + j] = static_cast(result_ids[j]); + } + } + return result; +} + +void +GNOIMIPartition::inner_joint_classify_datas(const float* datas, + int64_t count, + BucketIdType buckets_per_data, + BucketIdType* result) { + Vector dist_to_s(bucket_count_s_ * count, this->allocator_); + Vector dist_to_t(bucket_count_t_ * count, this->allocator_); + Vector> precomputed_terms_s(bucket_count_s_, this->allocator_); + + matmul(datas, data_centroids_s_.data(), dist_to_s.data(), count, bucket_count_s_, dim_); + matmul(datas, data_centroids_t_.data(), dist_to_t.data(), count, bucket_count_t_, dim_); + // |x - s - t|^2 = |x|^2 + |s|^2 + |t|^2 - 2xs - 2xt + 2st + // precomputed_terms_s: |x - s|^2 = |s|^2 - 2xs + |x|^2 + // precomputed_terms_st: |t|^2 + 2st + float total_err = 0.0; + for (size_t i = 0; i < count; ++i) { + auto data_norm = FP32ComputeIP(datas + i * dim_, datas + i * dim_, dim_); + for (BucketIdType j = 0; j < bucket_count_s_; ++j) { + precomputed_terms_s[j].first = + norms_s_[j] - dist_to_s[i * bucket_count_s_ + j] + data_norm / 2; + precomputed_terms_s[j].second = j; + } + std::sort(precomputed_terms_s.begin(), precomputed_terms_s.end()); + + MaxHeap heap(this->allocator_); + for (size_t j = 0; j < bucket_count_s_; ++j) { + float cur_precomputed_term_s = precomputed_terms_s[j].first; + BucketIdType cur_bucket_id_s = precomputed_terms_s[j].second; + + for (BucketIdType k = 0; k < bucket_count_t_; ++k) { + BucketIdType cur_bucket_id_t = k; + if (heap.size() >= buckets_per_data && + std::sqrt(cur_precomputed_term_s) - std::sqrt(norms_t_[cur_bucket_id_t]) > + std::sqrt(heap.top().first)) { + break; + } + + int cur_bucket_id_global = cur_bucket_id_s * bucket_count_t_ + cur_bucket_id_t; + float dist = cur_precomputed_term_s - dist_to_t[i * bucket_count_t_ + k] + + precomputed_terms_st_[cur_bucket_id_global]; + + if (heap.size() < buckets_per_data || dist < heap.top().first) { + heap.emplace(dist, cur_bucket_id_global); + } + if (heap.size() > buckets_per_data) { + heap.pop(); + } + } + } + + for (auto j = static_cast(buckets_per_data - 1); j >= 0; --j) { + result[i * buckets_per_data + j] = static_cast(heap.top().second); + if (j == 0) { + total_err += heap.top().first; + } + heap.pop(); + } + } +} + +} // namespace vsag diff --git a/src/algorithm/ivf_partition/gno_imi_partition.h b/src/algorithm/ivf_partition/gno_imi_partition.h new file mode 100644 index 000000000..f6ad97c3b --- /dev/null +++ b/src/algorithm/ivf_partition/gno_imi_partition.h @@ -0,0 +1,73 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "algorithm/brute_force.h" +#include "algorithm/brute_force_parameter.h" +#include "algorithm/inner_index_interface.h" +#include "index/index_common_param.h" +#include "ivf_nearest_partition.h" +#include "ivf_partition_strategy.h" +#include "ivf_partition_strategy_parameter.h" +#include "vsag/index.h" +namespace vsag { + +class GNOIMIPartition : public IVFPartitionStrategy { +public: + explicit GNOIMIPartition(const IndexCommonParam& common_param, + const IVFPartitionStrategyParametersPtr& param); + + void + Train(const DatasetPtr dataset) override; + + Vector + ClassifyDatas(const void* datas, int64_t count, BucketIdType buckets_per_data) override; + + Vector + ClassifyDatasForSearch(const void* datas, + int64_t count, + const InnerSearchParam& param) override; + + void + Serialize(StreamWriter& writer) override; + + void + Deserialize(StreamReader& reader) override; + +public: + IVFNearestPartitionTrainerType trainer_type_{IVFNearestPartitionTrainerType::KMeansTrainer}; + IndexCommonParam common_param_; + std::shared_ptr param_ptr_{nullptr}; + BucketIdType bucket_count_s_{0}; + BucketIdType bucket_count_t_{0}; + Vector data_centroids_s_; + Vector data_centroids_t_; + // precomputed terms for S and T to speed up the distance computation + Vector norms_s_; + Vector norms_t_; + Vector precomputed_terms_st_; + +private: + Vector + inner_classify_datas(BruteForce& route_index, const float* datas, int64_t count); + + void + inner_joint_classify_datas(const float* data, + int64_t count, + BucketIdType buckets_per_data, + BucketIdType* result); +}; + +} // namespace vsag diff --git a/src/algorithm/ivf_partition/gno_imi_partition_test.cpp b/src/algorithm/ivf_partition/gno_imi_partition_test.cpp new file mode 100644 index 000000000..60ba92eeb --- /dev/null +++ b/src/algorithm/ivf_partition/gno_imi_partition_test.cpp @@ -0,0 +1,182 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gno_imi_partition.h" + +#include + +#include "algorithm/ivf_parameter.h" +#include "fixtures.h" +#include "impl/basic_searcher.h" +#include "safe_allocator.h" + +using namespace vsag; + +TEST_CASE("GNO-IMI Partition Basic Test", "[ut][GNOIMIPartition]") { + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + int64_t dim = 128; + IndexCommonParam param; + param.dim_ = 128; + param.metric_ = MetricType::METRIC_TYPE_L2SQR; + param.allocator_ = allocator; + auto param_str = R"({ + "partition_strategy_type": "gno_imi", + "ivf_train_type": "kmeans", + "gno_imi": { + "first_order_buckets_count": 10, + "second_order_buckets_count": 10 + } + })"; + vsag::JsonType param_json = vsag::JsonType::parse(param_str); + auto strategy_param = std::make_shared(); + strategy_param->FromJson(param_json); + auto partition = std::make_unique(param, strategy_param); + + auto dataset = Dataset::Make(); + int64_t data_count = 10000L; + auto vec = fixtures::generate_vectors(data_count, dim, true, 95); + dataset->Float32Vectors(vec.data())->Dim(dim)->NumElements(data_count)->Owner(false); + + partition->Train(dataset); + auto class_result = partition->ClassifyDatas(vec.data(), data_count, 1); + REQUIRE(class_result.size() == data_count); + + param_str = R"( + { + "ivf": { + "scan_buckets_count": 1, + "first_order_scan_ratio": 0.1 + } + })"; + auto search_param = IVFSearchParameters::FromJson(param_str); + InnerSearchParam inner_search_param; + inner_search_param.scan_bucket_size = search_param.scan_buckets_count; + inner_search_param.first_order_scan_ratio = search_param.first_order_scan_ratio; + REQUIRE(inner_search_param.scan_bucket_size == 1); + REQUIRE(inner_search_param.first_order_scan_ratio == 0.1f); + size_t match_count = 0; + for (int64_t i = 0; i < data_count; ++i) { + auto query = Dataset::Make(); + query->Dim(dim)->Float32Vectors(vec.data() + i * dim)->NumElements(1)->Owner(false); + auto result = + partition->ClassifyDatasForSearch(vec.data() + i * dim, 1, inner_search_param); + auto id = result[0]; + if (id == class_result[i]) { + match_count++; + } + } + std::cout << "match count(first_order_scan_ratio=0.1): " << match_count << std::endl; + + inner_search_param.first_order_scan_ratio = 0.2f; + match_count = 0; + for (int64_t i = 0; i < data_count; ++i) { + auto query = Dataset::Make(); + query->Dim(dim)->Float32Vectors(vec.data() + i * dim)->NumElements(1)->Owner(false); + auto result = + partition->ClassifyDatasForSearch(vec.data() + i * dim, 1, inner_search_param); + auto id = result[0]; + if (id == class_result[i]) { + match_count++; + } + } + std::cout << "match count(first_order_scan_ratio=0.2): " << match_count << std::endl; + + inner_search_param.first_order_scan_ratio = 1.0f; + match_count = 0; + for (int64_t i = 0; i < data_count; ++i) { + auto query = Dataset::Make(); + query->Dim(dim)->Float32Vectors(vec.data() + i * dim)->NumElements(1)->Owner(false); + auto result = + partition->ClassifyDatasForSearch(vec.data() + i * dim, 1, inner_search_param); + auto id = result[0]; + // REQUIRE(id == class_result[i]); + if (id == class_result[i]) { + match_count++; + } + } + REQUIRE(match_count > 9990); + std::cout << "match count(first_order_scan_ratio=1.0): " << match_count << std::endl; +} + +TEST_CASE("GNO-IMI Partition Serialize Test", "[ut][GNOIMIPartition]") { + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + int64_t dim = 128; + IndexCommonParam param; + param.dim_ = 128; + param.metric_ = MetricType::METRIC_TYPE_L2SQR; + param.allocator_ = allocator; + auto param_str = R"({ + "partition_strategy_type": "gno_imi", + "ivf_train_type": "kmeans", + "gno_imi": { + "first_order_buckets_count": 10, + "second_order_buckets_count": 10 + } + })"; + vsag::JsonType param_json = vsag::JsonType::parse(param_str); + auto strategy_param = std::make_shared(); + strategy_param->FromJson(param_json); + auto partition = std::make_unique(param, strategy_param); + + auto dataset = Dataset::Make(); + int64_t data_count = 10000L; + auto vec = fixtures::generate_vectors(data_count, dim, true, 95); + dataset->Float32Vectors(vec.data())->Dim(dim)->NumElements(data_count)->Owner(false); + + partition->Train(dataset); + auto class_result = partition->ClassifyDatas(vec.data(), data_count, 1); + REQUIRE(class_result.size() == data_count); + + auto dir = fixtures::TempDir("serialize"); + auto path = dir.GenerateRandomFile(); + std::ofstream outfile(path, std::ios::out | std::ios::binary); + IOStreamWriter writer(outfile); + partition->Serialize(writer); + outfile.close(); + partition = std::make_unique(param, strategy_param); + + std::ifstream infile(path, std::ios::in | std::ios::binary); + IOStreamReader reader(infile); + partition->Deserialize(reader); + infile.close(); + + param_str = R"( + { + "ivf": { + "scan_buckets_count": 1, + "first_order_scan_ratio": 1.0 + } + })"; + + auto search_param = IVFSearchParameters::FromJson(param_str); + InnerSearchParam inner_search_param; + inner_search_param.scan_bucket_size = search_param.scan_buckets_count; + inner_search_param.first_order_scan_ratio = search_param.first_order_scan_ratio; + + size_t match_count = 0; + FilterPtr filter = nullptr; + for (int64_t i = 0; i < data_count; ++i) { + auto query = Dataset::Make(); + query->Dim(dim)->Float32Vectors(vec.data() + i * dim)->NumElements(1)->Owner(false); + auto result = + partition->ClassifyDatasForSearch(vec.data() + i * dim, 1, inner_search_param); + auto id = result[0]; + if (id == class_result[i]) { + match_count++; + } + } + REQUIRE(match_count > 9990); + std::cout << "match count(first_order_scan_ratio=1.0): " << match_count << std::endl; +} diff --git a/src/algorithm/ivf_partition/ivf_nearest_partition.cpp b/src/algorithm/ivf_partition/ivf_nearest_partition.cpp index b4b40b2ad..f6749c573 100644 --- a/src/algorithm/ivf_partition/ivf_nearest_partition.cpp +++ b/src/algorithm/ivf_partition/ivf_nearest_partition.cpp @@ -35,8 +35,9 @@ static constexpr const char* SEARCH_PARAM_TEMPLATE_STR = R"( IVFNearestPartition::IVFNearestPartition(BucketIdType bucket_count, const IndexCommonParam& common_param, - IVFNearestPartitionTrainerType trainer_type) - : IVFPartitionStrategy(common_param, bucket_count), trainer_type_(trainer_type) { + IVFPartitionStrategyParametersPtr param) + : IVFPartitionStrategy(common_param, bucket_count), + ivf_partition_strategy_param_(std::move(param)) { this->factory_router_index(common_param); } @@ -53,7 +54,8 @@ IVFNearestPartition::Train(const DatasetPtr dataset) { ->NumElements(this->bucket_count_) ->Owner(false); - if (trainer_type_ == IVFNearestPartitionTrainerType::KMeansTrainer) { + if (ivf_partition_strategy_param_->partition_train_type == + IVFNearestPartitionTrainerType::KMeansTrainer) { constexpr int32_t kmeans_iter_count = 25; KMeansCluster cls(static_cast(dim), this->allocator_); cls.Run(this->bucket_count_, @@ -61,7 +63,8 @@ IVFNearestPartition::Train(const DatasetPtr dataset) { dataset->GetNumElements(), kmeans_iter_count); memcpy(data.data(), cls.k_centroids_, dim * this->bucket_count_ * sizeof(float)); - } else if (trainer_type_ == IVFNearestPartitionTrainerType::RandomTrainer) { + } else if (ivf_partition_strategy_param_->partition_train_type == + IVFNearestPartitionTrainerType::RandomTrainer) { auto selected = select_k_numbers(dataset->GetNumElements(), this->bucket_count_); for (int i = 0; i < bucket_count_; ++i) { memcpy(data.data() + i * dim, diff --git a/src/algorithm/ivf_partition/ivf_nearest_partition.h b/src/algorithm/ivf_partition/ivf_nearest_partition.h index e612cf8dc..f6d40ed17 100644 --- a/src/algorithm/ivf_partition/ivf_nearest_partition.h +++ b/src/algorithm/ivf_partition/ivf_nearest_partition.h @@ -18,20 +18,15 @@ #include "algorithm/inner_index_interface.h" #include "index/index_common_param.h" #include "ivf_partition_strategy.h" +#include "ivf_partition_strategy_parameter.h" #include "vsag/index.h" namespace vsag { -enum class IVFNearestPartitionTrainerType { - RandomTrainer = 0, - KMeansTrainer = 1, -}; - class IVFNearestPartition : public IVFPartitionStrategy { public: explicit IVFNearestPartition(BucketIdType bucket_count, const IndexCommonParam& common_param, - IVFNearestPartitionTrainerType trainer_type = - IVFNearestPartitionTrainerType::KMeansTrainer); + IVFPartitionStrategyParametersPtr param); void Train(const DatasetPtr dataset) override; @@ -46,7 +41,7 @@ class IVFNearestPartition : public IVFPartitionStrategy { Deserialize(StreamReader& reader) override; public: - IVFNearestPartitionTrainerType trainer_type_{IVFNearestPartitionTrainerType::KMeansTrainer}; + IVFPartitionStrategyParametersPtr ivf_partition_strategy_param_{nullptr}; InnerIndexPtr route_index_ptr_{nullptr}; diff --git a/src/algorithm/ivf_partition/ivf_nearest_partition_test.cpp b/src/algorithm/ivf_partition/ivf_nearest_partition_test.cpp index daf1c9fda..49b25c690 100644 --- a/src/algorithm/ivf_partition/ivf_nearest_partition_test.cpp +++ b/src/algorithm/ivf_partition/ivf_nearest_partition_test.cpp @@ -30,8 +30,9 @@ TEST_CASE("IVF Nearest Partition Basic Test", "[ut][IVFNearestPartition]") { param.dim_ = 128; param.metric_ = MetricType::METRIC_TYPE_L2SQR; param.allocator_ = allocator; - auto partition = std::make_unique( - bucket_count, param, IVFNearestPartitionTrainerType::KMeansTrainer); + IVFPartitionStrategyParametersPtr strategy_param = + std::make_shared(); + auto partition = std::make_unique(bucket_count, param, strategy_param); auto dataset = Dataset::Make(); int64_t data_count = 1000L; @@ -68,8 +69,9 @@ TEST_CASE("IVF Nearest Partition Serialize Test", "[ut][IVFNearestPartition]") { param.dim_ = 128; param.metric_ = MetricType::METRIC_TYPE_L2SQR; param.allocator_ = allocator; - auto partition = std::make_unique( - bucket_count, param, IVFNearestPartitionTrainerType::KMeansTrainer); + IVFPartitionStrategyParametersPtr strategy_param = + std::make_shared(); + auto partition = std::make_unique(bucket_count, param, strategy_param); auto dataset = Dataset::Make(); int64_t data_count = 1000L; @@ -86,8 +88,7 @@ TEST_CASE("IVF Nearest Partition Serialize Test", "[ut][IVFNearestPartition]") { IOStreamWriter writer(outfile); partition->Serialize(writer); outfile.close(); - partition = std::make_unique( - bucket_count, param, IVFNearestPartitionTrainerType::KMeansTrainer); + partition = std::make_unique(bucket_count, param, strategy_param); std::ifstream infile(path, std::ios::in | std::ios::binary); IOStreamReader reader(infile); diff --git a/src/algorithm/ivf_partition/ivf_partition_strategy.h b/src/algorithm/ivf_partition/ivf_partition_strategy.h index 6c5f2bf40..be9f2fbec 100644 --- a/src/algorithm/ivf_partition/ivf_partition_strategy.h +++ b/src/algorithm/ivf_partition/ivf_partition_strategy.h @@ -15,8 +15,13 @@ #pragma once +#include + +#include #include +#include "impl/basic_searcher.h" +#include "ivf_partition_strategy_parameter.h" #include "stream_reader.h" #include "stream_writer.h" #include "vsag/dataset.h" @@ -50,6 +55,11 @@ class IVFPartitionStrategy { virtual Vector ClassifyDatas(const void* datas, int64_t count, BucketIdType buckets_per_data) = 0; + virtual Vector + ClassifyDatasForSearch(const void* datas, int64_t count, const InnerSearchParam& param) { + return std::move(ClassifyDatas(datas, count, param.scan_bucket_size)); + } + virtual void Serialize(StreamWriter& writer) { StreamWriter::WriteObj(writer, this->is_trained_); @@ -64,6 +74,17 @@ class IVFPartitionStrategy { StreamReader::ReadObj(reader, this->dim_); } + virtual void + GetResidual( + size_t n, const float* x, float* residuals, float* centroids, BucketIdType* assign) { + // TODO: Directly implement c = a - b. + memcpy(residuals, x, sizeof(float) * n * dim_); + for (size_t i = 0; i < n; ++i) { + BucketIdType bucket_id = assign[i]; + cblas_saxpy(dim_, -1.0, centroids + bucket_id * dim_, 1, residuals + i * dim_, 1); + } + } + public: bool is_trained_{false}; diff --git a/src/algorithm/ivf_partition/ivf_partition_strategy_parameter.cpp b/src/algorithm/ivf_partition/ivf_partition_strategy_parameter.cpp new file mode 100644 index 000000000..025fff389 --- /dev/null +++ b/src/algorithm/ivf_partition/ivf_partition_strategy_parameter.cpp @@ -0,0 +1,73 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ivf_partition_strategy_parameter.h" + +#include + +#include + +#include "inner_string_params.h" +#include "vsag/constants.h" + +namespace vsag { + +IVFPartitionStrategyParameters::IVFPartitionStrategyParameters() = default; + +void +IVFPartitionStrategyParameters::FromJson(const JsonType& json) { + if (json[IVF_TRAIN_TYPE_KEY] == IVF_TRAIN_TYPE_KMEANS) { + this->partition_train_type = IVFNearestPartitionTrainerType::KMeansTrainer; + } else if (json[IVF_TRAIN_TYPE_KEY] == IVF_TRAIN_TYPE_RANDOM) { + this->partition_train_type = IVFNearestPartitionTrainerType::RandomTrainer; + } + + if (json[IVF_PARTITION_STRATEGY_TYPE_KEY] == IVF_PARTITION_STRATEGY_TYPE_NEAREST) { + this->partition_strategy_type = IVFPartitionStrategyType::IVF; + } else if (json[IVF_PARTITION_STRATEGY_TYPE_KEY] == IVF_PARTITION_STRATEGY_TYPE_GNO_IMI) { + this->partition_strategy_type = IVFPartitionStrategyType::GNO_IMI; + } + + this->gnoimi_param = std::make_shared(); + if (this->partition_strategy_type == IVFPartitionStrategyType::GNO_IMI) { + CHECK_ARGUMENT( + json.contains(IVF_PARTITION_STRATEGY_TYPE_GNO_IMI), + fmt::format("partition strategy parameters must contains {} when strategy type is {}", + IVF_PARTITION_STRATEGY_TYPE_GNO_IMI, + IVF_PARTITION_STRATEGY_TYPE_GNO_IMI)); + this->gnoimi_param->FromJson(json[IVF_PARTITION_STRATEGY_TYPE_GNO_IMI]); + } +} + +JsonType +IVFPartitionStrategyParameters::ToJson() { + JsonType json; + if (this->partition_train_type == IVFNearestPartitionTrainerType::KMeansTrainer) { + json[IVF_TRAIN_TYPE_KEY] = IVF_TRAIN_TYPE_KMEANS; + } else if (this->partition_train_type == IVFNearestPartitionTrainerType::RandomTrainer) { + json[IVF_TRAIN_TYPE_KEY] = IVF_TRAIN_TYPE_RANDOM; + } + + if (this->partition_strategy_type == IVFPartitionStrategyType::IVF) { + json[IVF_PARTITION_STRATEGY_TYPE_KEY] = IVF_PARTITION_STRATEGY_TYPE_NEAREST; + } else if (this->partition_strategy_type == IVFPartitionStrategyType::GNO_IMI) { + json[IVF_PARTITION_STRATEGY_TYPE_KEY] = IVF_PARTITION_STRATEGY_TYPE_GNO_IMI; + } + if (this->partition_strategy_type == IVFPartitionStrategyType::GNO_IMI) { + json[IVF_PARTITION_STRATEGY_TYPE_GNO_IMI] = this->gnoimi_param->ToJson(); + } + return json; +} +} // namespace vsag diff --git a/src/algorithm/ivf_partition/ivf_partition_strategy_parameter.h b/src/algorithm/ivf_partition/ivf_partition_strategy_parameter.h new file mode 100644 index 000000000..ce2700a20 --- /dev/null +++ b/src/algorithm/ivf_partition/ivf_partition_strategy_parameter.h @@ -0,0 +1,54 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "data_cell/bucket_datacell_parameter.h" +#include "fmt/format-inl.h" +#include "gno_imi_parameter.h" +#include "inner_string_params.h" +#include "parameter.h" +#include "typing.h" + +namespace vsag { + +enum class IVFNearestPartitionTrainerType { + RandomTrainer = 0, + KMeansTrainer = 1, +}; + +enum class IVFPartitionStrategyType { + IVF = 0, + GNO_IMI = 1, +}; + +class IVFPartitionStrategyParameters : public Parameter { +public: + explicit IVFPartitionStrategyParameters(); + + void + FromJson(const JsonType& json) override; + + JsonType + ToJson() override; + +public: + IVFNearestPartitionTrainerType partition_train_type{ + IVFNearestPartitionTrainerType::KMeansTrainer}; + IVFPartitionStrategyType partition_strategy_type{IVFPartitionStrategyType::IVF}; + GNOIMIParameterPtr gnoimi_param{nullptr}; +}; + +using IVFPartitionStrategyParametersPtr = std::shared_ptr; +} // namespace vsag diff --git a/src/algorithm/ivf_partition/ivf_partition_strategy_parameter_test.cpp b/src/algorithm/ivf_partition/ivf_partition_strategy_parameter_test.cpp new file mode 100644 index 000000000..f3119a4b7 --- /dev/null +++ b/src/algorithm/ivf_partition/ivf_partition_strategy_parameter_test.cpp @@ -0,0 +1,38 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ivf_partition_strategy_parameter.h" + +#include + +#include "parameter_test.h" + +TEST_CASE("IVF Partition Strategy Parameters Test", "[ut][IVFPartitionStrategyParameters]") { + auto param_str = R"({ + "partition_strategy_type": "gno_imi", + "ivf_train_type": "random", + "gno_imi": { + "first_order_buckets_count": 200, + "second_order_buckets_count": 50 + } + })"; + vsag::JsonType param_json = vsag::JsonType::parse(param_str); + auto param = std::make_shared(); + param->FromJson(param_json); + REQUIRE(param->partition_strategy_type == vsag::IVFPartitionStrategyType::GNO_IMI); + REQUIRE(param->partition_train_type == vsag::IVFNearestPartitionTrainerType::RandomTrainer); + REQUIRE(param->gnoimi_param->first_order_buckets_count == 200); + REQUIRE(param->gnoimi_param->second_order_buckets_count == 50); +} diff --git a/src/constants.cpp b/src/constants.cpp index c35c055b2..5bf55212f 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -26,6 +26,7 @@ const char* const INDEX_PYRAMID = "pyramid"; const char* const INDEX_SPARSE = "sparse_index"; const char* const INDEX_BRUTE_FORCE = "brute_force"; const char* const INDEX_IVF = "ivf"; +const char* const INDEX_GNO_IMI = "gno_imi"; const char* const DIM = "dim"; const char* const NUM_ELEMENTS = "num_elements"; @@ -164,6 +165,10 @@ const char* const IVF_BASE_QUANTIZATION_TYPE = "base_quantization_type"; const char* const IVF_BASE_IO_TYPE = "base_io_type"; const char* const IVF_BASE_PQ_DIM = "base_pq_dim"; const char* const IVF_BASE_FILE_PATH = "base_file_path"; + +const char* const GNO_IMI_FIRST_ORDER_BUCKETS_COUNT = "first_order_buckets_count"; +const char* const GNO_IMI_SECOND_ORDER_BUCKETS_COUNT = "second_order_buckets_count"; + const char* const IVF_PRECISE_QUANTIZATION_TYPE = "precise_quantization_type"; const char* const IVF_PRECISE_IO_TYPE = "precise_io_type"; const char* const IVF_PRECISE_FILE_PATH = "precise_file_path"; diff --git a/src/impl/basic_searcher.h b/src/impl/basic_searcher.h index 4400ddf68..bd761d0b5 100644 --- a/src/impl/basic_searcher.h +++ b/src/impl/basic_searcher.h @@ -50,6 +50,7 @@ class InnerSearchParam { // for ivf int scan_bucket_size{1}; float factor{2.0F}; + float first_order_scan_ratio{1.0F}; InnerSearchParam& operator=(const InnerSearchParam& other) { @@ -64,6 +65,7 @@ class InnerSearchParam { is_inner_id_allowed = other.is_inner_id_allowed; scan_bucket_size = other.scan_bucket_size; factor = other.factor; + first_order_scan_ratio = other.first_order_scan_ratio; } return *this; } diff --git a/src/impl/kmeans_cluster.cpp b/src/impl/kmeans_cluster.cpp index 3e89c7da7..706529981 100644 --- a/src/impl/kmeans_cluster.cpp +++ b/src/impl/kmeans_cluster.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include "algorithm/inner_index_interface.h" @@ -42,7 +43,13 @@ KMeansCluster::~KMeansCluster() { } Vector -KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { +KMeansCluster::Run(uint32_t k, + const float* datas, + uint64_t count, + int iter, + double* err, + bool use_mse_for_convergence, + float threshold) { if (k_centroids_ != nullptr) { allocator_->Deallocate(k_centroids_); k_centroids_ = nullptr; @@ -60,6 +67,8 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { } } + double total_err = std::numeric_limits::max(); + double last_err = std::numeric_limits::max(); Vector labels(count, -1, this->allocator_); std::vector mutexes(k); std::vector> futures; @@ -67,7 +76,6 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { ByteBuffer distances_buffer(static_cast(k) * QUERY_BS * sizeof(float), allocator_); auto* y_sqr = reinterpret_cast(y_sqr_buffer.data); auto* distances = reinterpret_cast(distances_buffer.data); - double error = std::numeric_limits::max(); logger::debug("KMeansCluster::Run k: {}, count: {}, iter: {}", k, count, iter); if (k < THRESHOLD_FOR_HGRAPH) { @@ -81,11 +89,11 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { get_current_time(), it, iter, - error); + total_err); if (k < THRESHOLD_FOR_HGRAPH) { - error = this->find_nearest_one_with_blas(datas, count, k, y_sqr, distances, labels); + total_err = this->find_nearest_one_with_blas(datas, count, k, y_sqr, distances, labels); } else { - error = this->find_nearest_one_with_hgraph(datas, count, k, labels); + total_err = this->find_nearest_one_with_hgraph(datas, count, k, labels); } constexpr uint64_t bs = 1024; @@ -117,6 +125,11 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { } futures.clear(); + if (it > 0 && use_mse_for_convergence && + std::fabs(last_err - total_err) / static_cast(count) < threshold) { + break; + } + for (int j = 0; j < k; ++j) { if (counts[j] > 0) { cblas_sscal(dim_, @@ -133,6 +146,10 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) { } } } + last_err = total_err; + } + if (err != nullptr) { + *err = total_err; } return labels; } @@ -145,7 +162,7 @@ KMeansCluster::find_nearest_one_with_blas(const float* query, float* distances, Vector& labels) { double error = 0.0; - + std::mutex error_mutex; if (k_centroids_ == nullptr) { throw VsagException(ErrorType::INTERNAL_ERROR, "k_centroids_ is nullptr"); } @@ -194,15 +211,21 @@ KMeansCluster::find_nearest_one_with_blas(const float* query, auto assign_labels_func = [&](uint64_t start, uint64_t end) -> void { omp_set_num_threads(1); + double thread_local_error = 0.0; for (uint64_t i = start; i < end; ++i) { cblas_saxpy(static_cast(k), 1.0, y_sqr, 1, distances + i * k, 1); auto* min_elem = std::min_element(distances + i * k, distances + i * k + k); + auto x_sqr = FP32ComputeIP(query + i * dim_, query + i * dim_, dim_); auto min_index = std::distance(distances + i * k, min_elem); - error += static_cast(*min_elem); + thread_local_error += static_cast(*min_elem + x_sqr); if (min_index != cur_label[i]) { cur_label[i] = static_cast(min_index); } } + { + std::lock_guard lock(error_mutex); + error += thread_local_error; + } }; for (uint64_t j = 0; j < cur_query_count; j += bs) { futures.emplace_back(thread_pool->GeneralEnqueue( @@ -222,6 +245,7 @@ KMeansCluster::find_nearest_one_with_hgraph(const float* query, throw VsagException(ErrorType::INTERNAL_ERROR, "k_centroids_ is nullptr"); } double error = 0.0; + std::mutex error_mutex; IndexCommonParam param; param.dim_ = dim_; @@ -242,6 +266,7 @@ KMeansCluster::find_nearest_one_with_hgraph(const float* query, FilterPtr filter = nullptr; constexpr const char* search_param = R"({"hgraph":{"ef_search":10}})"; auto func = [&](const uint64_t begin, const uint64_t end) -> void { + double thread_local_error = 0.0; for (uint64_t j = begin; j < end; ++j) { auto q = Dataset::Make(); q->Owner(false) @@ -250,7 +275,11 @@ KMeansCluster::find_nearest_one_with_hgraph(const float* query, ->Dim(this->dim_); auto ret = hgraph->KnnSearch(q, 1, search_param, filter); labels[j] = static_cast(ret->GetIds()[0]); - error += static_cast(ret->GetDistances()[0]); + thread_local_error += static_cast(ret->GetDistances()[0]); + } + { + std::lock_guard lock(error_mutex); + error += thread_local_error; } }; std::vector> futures; diff --git a/src/impl/kmeans_cluster.h b/src/impl/kmeans_cluster.h index 9330f0d52..6adbdda57 100644 --- a/src/impl/kmeans_cluster.h +++ b/src/impl/kmeans_cluster.h @@ -30,7 +30,7 @@ class KMeansCluster { ~KMeansCluster(); Vector - Run(uint32_t k, const float* datas, uint64_t count, int iter = 25); + Run(uint32_t k, const float* datas, uint64_t count, int iter = 25, double* err = nullptr); public: float* k_centroids_{nullptr}; @@ -42,7 +42,8 @@ class KMeansCluster { const uint64_t k, float* y_sqr, float* distances, - Vector& labels); + Vector& labels, + float* errs); double find_nearest_one_with_hgraph(const float* query, diff --git a/src/inner_string_params.h b/src/inner_string_params.h index 9d747b066..e588f4c78 100644 --- a/src/inner_string_params.h +++ b/src/inner_string_params.h @@ -24,6 +24,7 @@ namespace vsag { // Index Type const char* const INDEX_TYPE_HGRAPH = "hgraph"; const char* const INDEX_TYPE_IVF = "ivf"; +const char* const INDEX_TYPE_GNO_IMI = "gno_imi"; // Parameter key for hgraph const char* const HGRAPH_USE_REORDER_KEY = "use_reorder"; @@ -82,6 +83,7 @@ const char* const SPARSE_NEED_SORT = "need_sort"; const char* const GRAPH_TYPE_KEY = "graph_type"; const char* const BUCKET_PARAMS_KEY = "buckets_params"; +const char* const BUCKET_PER_DATA_KEY = "buckets_per_data"; const char* const NO_BUILD_LEVELS = "no_build_levels"; const char* const BUCKETS_COUNT_KEY = "buckets_count"; @@ -93,12 +95,22 @@ const char* const IVF_TRAIN_TYPE_KEY = "ivf_train_type"; const char* const IVF_TRAIN_TYPE_RANDOM = "random"; const char* const IVF_TRAIN_TYPE_KMEANS = "kmeans"; +const char* const IVF_PARTITION_STRATEGY_PARAMS_KEY = "partition_strategy"; +const char* const IVF_PARTITION_STRATEGY_TYPE_KEY = "partition_strategy_type"; +const char* const IVF_PARTITION_STRATEGY_TYPE_NEAREST = "ivf"; +const char* const IVF_PARTITION_STRATEGY_TYPE_GNO_IMI = "gno_imi"; + +const char* const GNO_IMI_FIRST_ORDER_BUCKETS_COUNT_KEY = "first_order_buckets_count"; +const char* const GNO_IMI_SECOND_ORDER_BUCKETS_COUNT_KEY = "second_order_buckets_count"; + +const char* const GNO_IMI_SEARCH_PARAM_FIRST_ORDER_SCAN_RATIO = "first_order_scan_ratio"; const char* const FLATTEN_DATA_CELL = "flatten_data_cell"; const char* const SPARSE_VECTOR_DATA_CELL = "sparse_vector_data_cell"; const std::unordered_map DEFAULT_MAP = { {"INDEX_TYPE_HGRAPH", INDEX_TYPE_HGRAPH}, {"INDEX_TYPE_IVF", INDEX_TYPE_IVF}, + {"INDEX_TYPE_GNO_IMI", INDEX_TYPE_GNO_IMI}, {"HGRAPH_USE_REORDER_KEY", HGRAPH_USE_REORDER_KEY}, {"HGRAPH_USE_ELP_OPTIMIZER_KEY", HGRAPH_USE_ELP_OPTIMIZER_KEY}, {"HGRAPH_IGNORE_REORDER_KEY", HGRAPH_IGNORE_REORDER_KEY}, @@ -138,9 +150,17 @@ const std::unordered_map DEFAULT_MAP = { {"SQ4_UNIFORM_QUANTIZATION_TRUNC_RATE", SQ4_UNIFORM_QUANTIZATION_TRUNC_RATE}, {"PCA_DIM", PCA_DIM}, {"IVF_SEARCH_PARAM_SCAN_BUCKETS_COUNT", IVF_SEARCH_PARAM_SCAN_BUCKETS_COUNT}, + {"GNO_IMI_FIRST_ORDER_BUCKETS_COUNT_KEY", GNO_IMI_FIRST_ORDER_BUCKETS_COUNT_KEY}, + {"GNO_IMI_SECOND_ORDER_BUCKETS_COUNT_KEY", GNO_IMI_SECOND_ORDER_BUCKETS_COUNT_KEY}, + {"BUCKETS_COUNT_KEY", BUCKETS_COUNT_KEY}, {"IVF_TRAIN_TYPE_KEY", IVF_TRAIN_TYPE_KEY}, {"HGRAPH_EXTRA_INFO_KEY", HGRAPH_EXTRA_INFO_KEY}, {"IVF_SEARCH_PARAM_FACTOR", IVF_SEARCH_PARAM_FACTOR}, -}; + {"BUCKET_PER_DATA_KEY", BUCKET_PER_DATA_KEY}, + {"IVF_PARTITION_STRATEGY_PARAMS_KEY", IVF_PARTITION_STRATEGY_PARAMS_KEY}, + {"IVF_PARTITION_STRATEGY_TYPE_KEY", IVF_PARTITION_STRATEGY_TYPE_KEY}, + {"IVF_PARTITION_STRATEGY_TYPE_NEAREST", IVF_PARTITION_STRATEGY_TYPE_NEAREST}, + {"IVF_TRAIN_TYPE_KMEANS", IVF_TRAIN_TYPE_KMEANS}, + {"IVF_PARTITION_STRATEGY_TYPE_GNO_IMI", IVF_PARTITION_STRATEGY_TYPE_GNO_IMI}}; } // namespace vsag From 48958418ac03e0537c5b6a5586b213166ee5a307 Mon Sep 17 00:00:00 2001 From: "suguan.dx" Date: Tue, 20 May 2025 21:13:01 +0800 Subject: [PATCH 20/42] rebase main Signed-off-by: suguan.dx --- src/impl/kmeans_cluster.h | 8 +++++++- src/impl/kmeans_cluster_test.cpp | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/impl/kmeans_cluster.h b/src/impl/kmeans_cluster.h index 6adbdda57..ca429a54b 100644 --- a/src/impl/kmeans_cluster.h +++ b/src/impl/kmeans_cluster.h @@ -30,7 +30,13 @@ class KMeansCluster { ~KMeansCluster(); Vector - Run(uint32_t k, const float* datas, uint64_t count, int iter = 25, double* err = nullptr); + Run(uint32_t k, + const float* datas, + uint64_t count, + int iter = 25, + double* err = nullptr, + bool use_mse_for_convergence = true, + float threshold = 1e-6F); public: float* k_centroids_{nullptr}; diff --git a/src/impl/kmeans_cluster_test.cpp b/src/impl/kmeans_cluster_test.cpp index d7928b387..137180242 100644 --- a/src/impl/kmeans_cluster_test.cpp +++ b/src/impl/kmeans_cluster_test.cpp @@ -49,7 +49,7 @@ TEST_CASE("Kmeans Basic Test", "[ut][KMeansCluster]") { auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator(); vsag::KMeansCluster cluster(dim, allocator.get()); - auto pos = cluster.Run(k, datas.data(), count); + auto pos = cluster.Run(k, datas.data(), count, 25, nullptr, false); std::vector new_labels(k, 0); for (int i = 0; i < count; ++i) { new_labels[pos[i]]++; From e31dd99100d3f49fe51ae40b740b86afe1c8aca9 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 31 Mar 2025 15:58:24 +0800 Subject: [PATCH 21/42] support gno-imi partition Signed-off-by: suguan.dx --- examples/cpp/108_index_gno_imi.cpp | 1 - src/impl/kmeans_cluster.h | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/cpp/108_index_gno_imi.cpp b/examples/cpp/108_index_gno_imi.cpp index cb055b543..b9dd3e76c 100644 --- a/examples/cpp/108_index_gno_imi.cpp +++ b/examples/cpp/108_index_gno_imi.cpp @@ -103,6 +103,5 @@ main(int argc, char** argv) { std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; std::cout << result2->GetIds()[i] << ": " << result2->GetDistances()[i] << std::endl; } - return 0; } diff --git a/src/impl/kmeans_cluster.h b/src/impl/kmeans_cluster.h index ca429a54b..6a7383a24 100644 --- a/src/impl/kmeans_cluster.h +++ b/src/impl/kmeans_cluster.h @@ -48,15 +48,15 @@ class KMeansCluster { const uint64_t k, float* y_sqr, float* distances, - Vector& labels, - float* errs); + Vector& labels); double find_nearest_one_with_hgraph(const float* query, const uint64_t query_count, - const uint64_t k, + const uint64_t k, Vector& labels); + private: Allocator* const allocator_{nullptr}; From 7c02b7349c39f676c7b54fc9815650e5e1d0a7e4 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 31 Mar 2025 15:58:24 +0800 Subject: [PATCH 22/42] support gno-imi partition Signed-off-by: suguan.dx --- src/algorithm/ivf.cpp | 100 +++++++++++++++--- src/algorithm/ivf_parameter.h | 1 - .../ivf_partition/gno_imi_partition.cpp | 13 +++ .../ivf_partition/gno_imi_partition.h | 3 + .../ivf_partition/ivf_nearest_partition.cpp | 17 ++- .../ivf_partition/ivf_nearest_partition.h | 5 + .../ivf_partition/ivf_partition_strategy.h | 3 + .../ivf_partition_strategy_parameter_test.cpp | 2 + src/impl/basic_searcher.h | 1 + src/impl/kmeans_cluster.cpp | 2 + src/impl/kmeans_cluster.h | 5 +- src/inner_string_params.h | 11 ++ 12 files changed, 140 insertions(+), 23 deletions(-) diff --git a/src/algorithm/ivf.cpp b/src/algorithm/ivf.cpp index 9becf5e19..09d259bf0 100644 --- a/src/algorithm/ivf.cpp +++ b/src/algorithm/ivf.cpp @@ -26,7 +26,7 @@ #include "utils/util_functions.h" namespace vsag { - +static constexpr const int64_t MAX_TRAIN_SIZE = 65536L; static constexpr const char* IVF_PARAMS_TEMPLATE = R"( { @@ -42,7 +42,8 @@ static constexpr const char* IVF_PARAMS_TEMPLATE = "{RABITQ_QUANTIZATION_BITS_PER_DIM_QUERY}": 32, "{PRODUCT_QUANTIZATION_DIM}": 0 }, - "{BUCKETS_COUNT_KEY}": 10 + "{BUCKETS_COUNT_KEY}": 10, + "{BUCKET_USE_RESIDUAL}": false }, "{IVF_PARTITION_STRATEGY_PARAMS_KEY}": { "{IVF_PARTITION_STRATEGY_TYPE_KEY}": "{IVF_PARTITION_STRATEGY_TYPE_NEAREST}", @@ -119,6 +120,7 @@ IVF::CheckAndMappingExternalParam(const JsonType& external_param, IVF_USE_REORDER, {IVF_USE_REORDER_KEY}, }, + {IVF_USE_RESIDUAL, {BUCKET_PARAMS_KEY, BUCKET_USE_RESIDUAL}}, { IVF_BASE_PQ_DIM, { @@ -163,6 +165,7 @@ IVF::IVF(const IVFParameterPtr& param, const IndexCommonParam& common_param) if (this->use_reorder_) { this->reorder_codes_ = FlattenInterface::MakeInstance(param->flatten_param, common_param); } + this->use_residual_ = param->bucket_param->use_residual_; } void @@ -225,7 +228,28 @@ IVF::Train(const DatasetPtr& data) { return; } partition_strategy_->Train(data); - this->bucket_->Train(data->GetFloat32Vectors(), data->GetNumElements()); + const auto* data_ptr = data->GetFloat32Vectors(); + Vector train_data_buffer(allocator_); + auto num_element = std::min(data->GetNumElements(), MAX_TRAIN_SIZE); + if (use_residual_) { + train_data_buffer.resize(num_element * dim_); + if (metric_ == MetricType::METRIC_TYPE_COSINE) { + for (int i = 0; i < num_element; ++i) { + Normalize(data_ptr + i * dim_, train_data_buffer.data() + i * dim_, dim_); + } + data_ptr = train_data_buffer.data(); + } + Vector centroid(dim_, allocator_); + auto buckets = partition_strategy_->ClassifyDatas(data_ptr, num_element, 1); + for (int i = 0; i < num_element; ++i) { + partition_strategy_->GetCentroid(buckets[i], centroid); + for (int j = 0; j < dim_; ++j) { + train_data_buffer[i * dim_ + j] = data_ptr[i * dim_ + j] - centroid[j]; + } + } + data_ptr = train_data_buffer.data(); + } + this->bucket_->Train(data_ptr, num_element); if (use_reorder_) { this->reorder_codes_->Train(data->GetFloat32Vectors(), data->GetNumElements()); } @@ -243,14 +267,33 @@ IVF::Add(const DatasetPtr& base) { const auto* ids = base->GetIds(); const auto* vectors = base->GetFloat32Vectors(); auto buckets = partition_strategy_->ClassifyDatas(vectors, num_element, buckets_per_data_); + Vector normalize_data(dim_, allocator_); + Vector residual_data(dim_, allocator_); + Vector centroid(dim_, allocator_); for (int64_t i = 0; i < num_element; ++i) { + const auto* data_ptr = vectors + i * dim_; for (int64_t j = 0; j < buckets_per_data_; ++j) { auto idx = i * buckets_per_data_ + j; - bucket_->InsertVector( - vectors + i * dim_, buckets[idx], idx + total_elements_ * buckets_per_data_); - this->label_table_->Insert(idx + total_elements_ * buckets_per_data_, ids[i]); + + if (use_residual_) { + partition_strategy_->GetCentroid(buckets[idx], centroid); + if (metric_ == MetricType::METRIC_TYPE_COSINE) { + Normalize(data_ptr, normalize_data.data(), dim_); + data_ptr = normalize_data.data(); + } + FP32Sub(data_ptr, centroid.data(), residual_data.data(), dim_); + bucket_->InsertVector(residual_data.data(), + buckets[idx], + idx + total_elements_ * buckets_per_data_, + centroid.data()); + } else { + bucket_->InsertVector( + data_ptr, buckets[idx], idx + total_elements_ * buckets_per_data_); + } } + this->label_table_->Insert(i + total_elements_, ids[i]); } + this->bucket_->Package(); if (use_reorder_) { this->reorder_codes_->BatchInsertVector(base->GetFloat32Vectors(), base->GetNumElements()); @@ -406,9 +449,15 @@ template DistHeapPtr IVF::search(const DatasetPtr& query, const InnerSearchParam& param) const { auto search_result = std::make_shared>(allocator_, -1); - auto candidate_buckets = - partition_strategy_->ClassifyDatasForSearch(query->GetFloat32Vectors(), 1, param); - auto computer = bucket_->FactoryComputer(query->GetFloat32Vectors()); + const auto* query_data = query->GetFloat32Vectors(); + Vector normalize_data(dim_, allocator_); + if (use_residual_ && metric_ == MetricType::METRIC_TYPE_COSINE) { + Normalize(query_data, normalize_data.data(), dim_); + query_data = normalize_data.data(); + } + auto candidate_buckets = partition_strategy_->ClassifyDatasForSearch(query_data, 1, param); + auto computer = bucket_->FactoryComputer(query_data); + Vector dist(allocator_); auto cur_heap_top = std::numeric_limits::max(); int64_t topk = param.topk; @@ -418,7 +467,6 @@ IVF::search(const DatasetPtr& query, const InnerSearchParam& param) const { topk = std::numeric_limits::max(); } } - // Scale topk to ensure sufficient candidates after deduplication when buckets_per_data_ > 1 int64_t origin_topk = topk; if (buckets_per_data_ > 1) { @@ -430,15 +478,31 @@ IVF::search(const DatasetPtr& query, const InnerSearchParam& param) const { } const auto& ft = param.is_inner_id_allowed; + Vector centroid(dim_, allocator_); + for (auto& bucket_id : candidate_buckets) { + if (bucket_id == -1) { + break; + } auto bucket_size = bucket_->GetBucketSize(bucket_id); const auto* ids = bucket_->GetInnerIds(bucket_id); if (bucket_size > dist.size()) { dist.resize(bucket_size); } + auto ip_distance = 0.0F; + if (use_residual_) { + partition_strategy_->GetCentroid(bucket_id, centroid); + ip_distance = FP32ComputeIP(query_data, centroid.data(), dim_); + if (metric_ == MetricType::METRIC_TYPE_L2SQR) { + ip_distance *= 2; + } + } + bucket_->ScanBucketById(dist.data(), computer, bucket_id); for (int j = 0; j < bucket_size; ++j) { - if (ft == nullptr or ft->CheckValid(ids[j])) { + auto origin_id = ids[j] / buckets_per_data_; + if (ft == nullptr or ft->CheckValid(origin_id)) { + dist[j] -= ip_distance; if constexpr (mode == KNN_SEARCH) { if (search_result->Size() < topk or dist[j] < cur_heap_top) { search_result->Push(dist[j], ids[j]); @@ -462,18 +526,20 @@ IVF::search(const DatasetPtr& query, const InnerSearchParam& param) const { if (buckets_per_data_ > 1) { std::unordered_map id_to_min_dist; while (!search_result->Empty()) { - auto [dist_val, id] = search_result->Top(); - search_result->Pop(); + const auto& [dist_val, id] = search_result->Top(); + auto origin_id = id / buckets_per_data_; // Keep the smallest distance for each id - if (id_to_min_dist.find(id) == id_to_min_dist.end() || dist_val < id_to_min_dist[id]) { - id_to_min_dist[id] = dist_val; + if (id_to_min_dist.find(origin_id) == id_to_min_dist.end() || + dist_val < id_to_min_dist[origin_id]) { + id_to_min_dist[origin_id] = dist_val; } + search_result->Pop(); } auto cur_heap_top = std::numeric_limits::max(); - for (const auto& [id, dist_val] : id_to_min_dist) { + for (const auto& [origin_id, dist_val] : id_to_min_dist) { if (dist_val < cur_heap_top) { - search_result->Push(dist_val, id); + search_result->Push(dist_val, origin_id); } if (search_result->Size() > origin_topk) { search_result->Pop(); diff --git a/src/algorithm/ivf_parameter.h b/src/algorithm/ivf_parameter.h index ee659d7b2..345a6d358 100644 --- a/src/algorithm/ivf_parameter.h +++ b/src/algorithm/ivf_parameter.h @@ -38,7 +38,6 @@ class IVFParameter : public Parameter { BucketDataCellParamPtr bucket_param{nullptr}; IVFPartitionStrategyParametersPtr ivf_partition_strategy_parameter{nullptr}; BucketIdType buckets_per_data{1}; - bool use_residual{false}; bool use_reorder{false}; diff --git a/src/algorithm/ivf_partition/gno_imi_partition.cpp b/src/algorithm/ivf_partition/gno_imi_partition.cpp index 40b36aa71..4d892f0f0 100644 --- a/src/algorithm/ivf_partition/gno_imi_partition.cpp +++ b/src/algorithm/ivf_partition/gno_imi_partition.cpp @@ -387,4 +387,17 @@ GNOIMIPartition::inner_joint_classify_datas(const float* datas, } } +void +GNOIMIPartition::GetCentroid(BucketIdType bucket_id, Vector& centroid) { + if (!is_trained_ || bucket_id >= bucket_count_) { + throw std::runtime_error("Invalid bucket_id or partition not trained"); + } + auto bucket_id_s = bucket_id / bucket_count_t_; + auto bucket_id_t = bucket_id % bucket_count_t_; + FP32Add(data_centroids_s_.data() + bucket_id_s * dim_, + data_centroids_t_.data() + bucket_id_t * dim_, + centroid.data(), + dim_); +} + } // namespace vsag diff --git a/src/algorithm/ivf_partition/gno_imi_partition.h b/src/algorithm/ivf_partition/gno_imi_partition.h index f6ad97c3b..c0be992d8 100644 --- a/src/algorithm/ivf_partition/gno_imi_partition.h +++ b/src/algorithm/ivf_partition/gno_imi_partition.h @@ -40,6 +40,9 @@ class GNOIMIPartition : public IVFPartitionStrategy { int64_t count, const InnerSearchParam& param) override; + void + GetCentroid(BucketIdType bucket_id, Vector& centroid) override; + void Serialize(StreamWriter& writer) override; diff --git a/src/algorithm/ivf_partition/ivf_nearest_partition.cpp b/src/algorithm/ivf_partition/ivf_nearest_partition.cpp index f6749c573..8423f3749 100644 --- a/src/algorithm/ivf_partition/ivf_nearest_partition.cpp +++ b/src/algorithm/ivf_partition/ivf_nearest_partition.cpp @@ -37,7 +37,8 @@ IVFNearestPartition::IVFNearestPartition(BucketIdType bucket_count, const IndexCommonParam& common_param, IVFPartitionStrategyParametersPtr param) : IVFPartitionStrategy(common_param, bucket_count), - ivf_partition_strategy_param_(std::move(param)) { + ivf_partition_strategy_param_(std::move(param)), + metric_type_(common_param.metric_) { this->factory_router_index(common_param); } @@ -72,6 +73,11 @@ IVFNearestPartition::Train(const DatasetPtr dataset) { dim * sizeof(float)); } } + if (metric_type_ == MetricType::METRIC_TYPE_COSINE) { + for (int i = 0; i < bucket_count_; ++i) { + Normalize(data.data() + i * dim_, data.data() + i * dim_, dim_); + } + } auto build_result = this->route_index_ptr_->Build(centroids); this->is_trained_ = true; @@ -81,7 +87,7 @@ Vector IVFNearestPartition::ClassifyDatas(const void* datas, int64_t count, BucketIdType buckets_per_data) { - Vector result(buckets_per_data * count, this->allocator_); + Vector result(buckets_per_data * count, -1, this->allocator_); for (int64_t i = 0; i < count; ++i) { auto query = Dataset::Make(); query->Dim(this->dim_) @@ -123,4 +129,11 @@ IVFNearestPartition::factory_router_index(const IndexCommonParam& common_param) param_ptr = HGraph::CheckAndMappingExternalParam(hgraph_json, common_param); this->route_index_ptr_ = std::make_shared(param_ptr, common_param); } +void +IVFNearestPartition::GetCentroid(BucketIdType bucket_id, Vector& centroid) { + if (!is_trained_ || bucket_id >= bucket_count_) { + throw std::runtime_error("Invalid bucket_id or partition not trained"); + } + this->route_index_ptr_->GetRawData(bucket_id, (uint8_t*)centroid.data()); +} } // namespace vsag diff --git a/src/algorithm/ivf_partition/ivf_nearest_partition.h b/src/algorithm/ivf_partition/ivf_nearest_partition.h index f6d40ed17..7db5dc2cd 100644 --- a/src/algorithm/ivf_partition/ivf_nearest_partition.h +++ b/src/algorithm/ivf_partition/ivf_nearest_partition.h @@ -34,6 +34,9 @@ class IVFNearestPartition : public IVFPartitionStrategy { Vector ClassifyDatas(const void* datas, int64_t count, BucketIdType buckets_per_data) override; + void + GetCentroid(BucketIdType bucket_id, Vector& centroid) override; + void Serialize(StreamWriter& writer) override; @@ -45,6 +48,8 @@ class IVFNearestPartition : public IVFPartitionStrategy { InnerIndexPtr route_index_ptr_{nullptr}; + MetricType metric_type_{MetricType::METRIC_TYPE_L2SQR}; + private: void factory_router_index(const IndexCommonParam& common_param); diff --git a/src/algorithm/ivf_partition/ivf_partition_strategy.h b/src/algorithm/ivf_partition/ivf_partition_strategy.h index be9f2fbec..52bcaa65c 100644 --- a/src/algorithm/ivf_partition/ivf_partition_strategy.h +++ b/src/algorithm/ivf_partition/ivf_partition_strategy.h @@ -60,6 +60,9 @@ class IVFPartitionStrategy { return std::move(ClassifyDatas(datas, count, param.scan_bucket_size)); } + virtual void + GetCentroid(BucketIdType bucket_id, Vector& centroid) = 0; + virtual void Serialize(StreamWriter& writer) { StreamWriter::WriteObj(writer, this->is_trained_); diff --git a/src/algorithm/ivf_partition/ivf_partition_strategy_parameter_test.cpp b/src/algorithm/ivf_partition/ivf_partition_strategy_parameter_test.cpp index f3119a4b7..500bcf3c9 100644 --- a/src/algorithm/ivf_partition/ivf_partition_strategy_parameter_test.cpp +++ b/src/algorithm/ivf_partition/ivf_partition_strategy_parameter_test.cpp @@ -35,4 +35,6 @@ TEST_CASE("IVF Partition Strategy Parameters Test", "[ut][IVFPartitionStrategyPa REQUIRE(param->partition_train_type == vsag::IVFNearestPartitionTrainerType::RandomTrainer); REQUIRE(param->gnoimi_param->first_order_buckets_count == 200); REQUIRE(param->gnoimi_param->second_order_buckets_count == 50); + + vsag::ParameterTest::TestToJson(param); } diff --git a/src/impl/basic_searcher.h b/src/impl/basic_searcher.h index bd761d0b5..79e040f8a 100644 --- a/src/impl/basic_searcher.h +++ b/src/impl/basic_searcher.h @@ -51,6 +51,7 @@ class InnerSearchParam { int scan_bucket_size{1}; float factor{2.0F}; float first_order_scan_ratio{1.0F}; + Allocator* search_alloc{nullptr}; InnerSearchParam& operator=(const InnerSearchParam& other) { diff --git a/src/impl/kmeans_cluster.cpp b/src/impl/kmeans_cluster.cpp index 706529981..b71b9587c 100644 --- a/src/impl/kmeans_cluster.cpp +++ b/src/impl/kmeans_cluster.cpp @@ -85,11 +85,13 @@ KMeansCluster::Run(uint32_t k, } for (int it = 0; it < iter; ++it) { + /* logger::debug("[{}] KMeansCluster::Run iter: {}/{}, cur loss is {}", get_current_time(), it, iter, total_err); + */ if (k < THRESHOLD_FOR_HGRAPH) { total_err = this->find_nearest_one_with_blas(datas, count, k, y_sqr, distances, labels); } else { diff --git a/src/impl/kmeans_cluster.h b/src/impl/kmeans_cluster.h index 6a7383a24..bdaa918ba 100644 --- a/src/impl/kmeans_cluster.h +++ b/src/impl/kmeans_cluster.h @@ -35,7 +35,7 @@ class KMeansCluster { uint64_t count, int iter = 25, double* err = nullptr, - bool use_mse_for_convergence = true, + bool use_mse_for_convergence = false, float threshold = 1e-6F); public: @@ -53,10 +53,9 @@ class KMeansCluster { double find_nearest_one_with_hgraph(const float* query, const uint64_t query_count, - const uint64_t k, + const uint64_t k, Vector& labels); - private: Allocator* const allocator_{nullptr}; diff --git a/src/inner_string_params.h b/src/inner_string_params.h index e588f4c78..01d74f9ad 100644 --- a/src/inner_string_params.h +++ b/src/inner_string_params.h @@ -82,11 +82,16 @@ const char* const BUILD_EF_CONSTRUCTION = "ef_construction"; const char* const SPARSE_NEED_SORT = "need_sort"; const char* const GRAPH_TYPE_KEY = "graph_type"; +const char* const GRAPH_STORAGE_TYPE_KEY = "graph_storage_type"; +const char* const GRAPH_STORAGE_TYPE_COMPRESSED = "compressed"; +const char* const GRAPH_STORAGE_TYPE_FLAT = "flat"; + const char* const BUCKET_PARAMS_KEY = "buckets_params"; const char* const BUCKET_PER_DATA_KEY = "buckets_per_data"; const char* const NO_BUILD_LEVELS = "no_build_levels"; const char* const BUCKETS_COUNT_KEY = "buckets_count"; +const char* const BUCKET_USE_RESIDUAL = "use_residual"; const char* const IVF_SEARCH_PARAM_SCAN_BUCKETS_COUNT = "scan_buckets_count"; const char* const IVF_SEARCH_PARAM_FACTOR = "factor"; const char* const IVF_USE_REORDER_KEY = "use_reorder"; @@ -107,6 +112,9 @@ const char* const GNO_IMI_SEARCH_PARAM_FIRST_ORDER_SCAN_RATIO = "first_order_sca const char* const FLATTEN_DATA_CELL = "flatten_data_cell"; const char* const SPARSE_VECTOR_DATA_CELL = "sparse_vector_data_cell"; +const char* const GRAPH_SUPPORT_REMOVE = "support_remove"; +const char* const REMOVE_FLAG_BIT = "remove_flag_bit"; + const std::unordered_map DEFAULT_MAP = { {"INDEX_TYPE_HGRAPH", INDEX_TYPE_HGRAPH}, {"INDEX_TYPE_IVF", INDEX_TYPE_IVF}, @@ -135,6 +143,9 @@ const std::unordered_map DEFAULT_MAP = { {"PRODUCT_QUANTIZATION_DIM", PRODUCT_QUANTIZATION_DIM}, {"PRODUCT_QUANTIZATION_BITS", PRODUCT_QUANTIZATION_BITS}, {"GRAPH_TYPE_NSW", GRAPH_TYPE_NSW}, + {"GRAPH_STORAGE_TYPE_KEY", GRAPH_STORAGE_TYPE_KEY}, + {"GRAPH_STORAGE_TYPE_FLAT", GRAPH_STORAGE_TYPE_FLAT}, + {"GRAPH_STORAGE_TYPE_COMPRESSED", GRAPH_STORAGE_TYPE_COMPRESSED}, {"QUANTIZATION_PARAMS_KEY", QUANTIZATION_PARAMS_KEY}, {"GRAPH_PARAM_MAX_DEGREE", GRAPH_PARAM_MAX_DEGREE}, {"GRAPH_PARAM_INIT_MAX_CAPACITY", GRAPH_PARAM_INIT_MAX_CAPACITY}, From 50949d60b1539c873c7fff5229d6638dbc8a83df Mon Sep 17 00:00:00 2001 From: LHT129 Date: Mon, 26 May 2025 16:26:31 +0800 Subject: [PATCH 23/42] add bit operator simd implement (#756) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- src/simd/CMakeLists.txt | 1 + src/simd/avx.cpp | 95 +++++++++++++++++++ src/simd/avx2.cpp | 94 +++++++++++++++++++ src/simd/avx512.cpp | 94 +++++++++++++++++++ src/simd/bit_simd.cpp | 114 ++++++++++++++++++++++ src/simd/bit_simd.h | 88 +++++++++++++++++ src/simd/bit_simd_test.cpp | 187 +++++++++++++++++++++++++++++++++++++ src/simd/generic.cpp | 28 ++++++ src/simd/simd.h | 1 + src/simd/sse.cpp | 94 +++++++++++++++++++ 10 files changed, 796 insertions(+) create mode 100644 src/simd/bit_simd.cpp create mode 100644 src/simd/bit_simd.h create mode 100644 src/simd/bit_simd_test.cpp diff --git a/src/simd/CMakeLists.txt b/src/simd/CMakeLists.txt index 1b4aa1c6d..4167169ab 100644 --- a/src/simd/CMakeLists.txt +++ b/src/simd/CMakeLists.txt @@ -7,6 +7,7 @@ set (SIMD_SRCS simd.cpp simd_status.cpp basic_func.cpp + bit_simd.cpp fp32_simd.cpp fp16_simd.cpp bf16_simd.cpp diff --git a/src/simd/avx.cpp b/src/simd/avx.cpp index e1b34e988..5d3d5ba71 100644 --- a/src/simd/avx.cpp +++ b/src/simd/avx.cpp @@ -768,4 +768,99 @@ PQFastScanLookUp32(const uint8_t* lookup_table, #endif } +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX) + if (num_byte == 0) { + return; + } + if (num_byte < 32) { + return sse::BitAnd(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 31 < num_byte; i += 32) { + __m256 x_vec = _mm256_loadu_ps(reinterpret_cast(x + i)); + __m256 y_vec = _mm256_loadu_ps(reinterpret_cast(y + i)); + __m256 z_vec = _mm256_and_ps(x_vec, y_vec); + _mm256_storeu_ps(reinterpret_cast(result + i), z_vec); + } + if (i < num_byte) { + sse::BitAnd(x + i, y + i, num_byte - i, result + i); + } +#else + return sse::BitAnd(x, y, num_byte, result); +#endif +} + +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX) + if (num_byte == 0) { + return; + } + if (num_byte < 32) { + return sse::BitOr(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 31 < num_byte; i += 32) { + __m256 x_vec = _mm256_loadu_ps(reinterpret_cast(x + i)); + __m256 y_vec = _mm256_loadu_ps(reinterpret_cast(y + i)); + __m256 z_vec = _mm256_or_ps(x_vec, y_vec); + _mm256_storeu_ps(reinterpret_cast(result + i), z_vec); + } + if (i < num_byte) { + sse::BitOr(x + i, y + i, num_byte - i, result + i); + } +#else + return sse::BitOr(x, y, num_byte, result); +#endif +} + +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX) + if (num_byte == 0) { + return; + } + if (num_byte < 32) { + return sse::BitXor(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 31 < num_byte; i += 32) { + __m256 x_vec = _mm256_loadu_ps(reinterpret_cast(x + i)); + __m256 y_vec = _mm256_loadu_ps(reinterpret_cast(y + i)); + __m256 z_vec = _mm256_xor_ps(x_vec, y_vec); + _mm256_storeu_ps(reinterpret_cast(result + i), z_vec); + } + if (i < num_byte) { + sse::BitXor(x + i, y + i, num_byte - i, result + i); + } +#else + return sse::BitXor(x, y, num_byte, result); +#endif +} + +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX) + if (num_byte == 0) { + return; + } + if (num_byte < 32) { + return sse::BitNot(x, num_byte, result); + } + int64_t i = 0; + __m256 all_one = _mm256_castsi256_ps(_mm256_set1_epi32(-1)); + for (; i + 31 < num_byte; i += 32) { + __m256 x_vec = _mm256_loadu_ps(reinterpret_cast(x + i)); + __m256 z_vec = _mm256_xor_ps(x_vec, all_one); + _mm256_storeu_ps(reinterpret_cast(result + i), z_vec); + } + if (i < num_byte) { + sse::BitNot(x + i, num_byte - i, result + i); + } +#else + return sse::BitNot(x, num_byte, result); +#endif +} } // namespace vsag::avx diff --git a/src/simd/avx2.cpp b/src/simd/avx2.cpp index d9d23f434..c3610be7b 100644 --- a/src/simd/avx2.cpp +++ b/src/simd/avx2.cpp @@ -824,4 +824,98 @@ PQFastScanLookUp32(const uint8_t* lookup_table, #endif } +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX2) + if (num_byte == 0) { + return; + } + if (num_byte < 32) { + return sse::BitAnd(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 31 < num_byte; i += 32) { + __m256i x_vec = _mm256_loadu_si256(reinterpret_cast(x + i)); + __m256i y_vec = _mm256_loadu_si256(reinterpret_cast(y + i)); + __m256i z_vec = _mm256_and_si256(x_vec, y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), z_vec); + } + if (i < num_byte) { + sse::BitAnd(x + i, y + i, num_byte - i, result + i); + } +#else + return sse::BitAnd(x, y, num_byte, result); +#endif +} + +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX2) + if (num_byte == 0) { + return; + } + if (num_byte < 32) { + return sse::BitOr(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 31 < num_byte; i += 32) { + __m256i x_vec = _mm256_loadu_si256(reinterpret_cast(x + i)); + __m256i y_vec = _mm256_loadu_si256(reinterpret_cast(y + i)); + __m256i z_vec = _mm256_or_si256(x_vec, y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), z_vec); + } + if (i < num_byte) { + sse::BitOr(x + i, y + i, num_byte - i, result + i); + } +#else + return sse::BitOr(x, y, num_byte, result); +#endif +} + +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX2) + if (num_byte == 0) { + return; + } + if (num_byte < 32) { + return sse::BitXor(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 31 < num_byte; i += 32) { + __m256i x_vec = _mm256_loadu_si256(reinterpret_cast(x + i)); + __m256i y_vec = _mm256_loadu_si256(reinterpret_cast(y + i)); + __m256i z_vec = _mm256_xor_si256(x_vec, y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), z_vec); + } + if (i < num_byte) { + sse::BitXor(x + i, y + i, num_byte - i, result + i); + } +#else + return sse::BitXor(x, y, num_byte, result); +#endif +} + +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX2) + if (num_byte == 0) { + return; + } + if (num_byte < 32) { + return sse::BitNot(x, num_byte, result); + } + int64_t i = 0; + for (; i + 31 < num_byte; i += 32) { + __m256i x_vec = _mm256_loadu_si256(reinterpret_cast(x + i)); + __m256i z_vec = _mm256_xor_si256(x_vec, _mm256_set1_epi8(0xFF)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(result + i), z_vec); + } + if (i < num_byte) { + sse::BitNot(x + i, num_byte - i, result + i); + } +#else + return sse::BitNot(x, num_byte, result); +#endif +} } // namespace vsag::avx2 diff --git a/src/simd/avx512.cpp b/src/simd/avx512.cpp index eb3d9c172..25032f724 100644 --- a/src/simd/avx512.cpp +++ b/src/simd/avx512.cpp @@ -894,4 +894,98 @@ PQFastScanLookUp32(const uint8_t* lookup_table, #endif } +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX512) + if (num_byte == 0) { + return; + } + if (num_byte < 64) { + return avx2::BitAnd(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 63 < num_byte; i += 64) { + __m512i x_vec = _mm512_loadu_si512(reinterpret_cast(x + i)); + __m512i y_vec = _mm512_loadu_si512(reinterpret_cast(y + i)); + __m512i z_vec = _mm512_and_si512(x_vec, y_vec); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(result + i), z_vec); + } + if (i < num_byte) { + avx2::BitAnd(x + i, y + i, num_byte - i, result + i); + } +#else + return avx2::BitAnd(x, y, num_byte, result); +#endif +} + +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX512) + if (num_byte == 0) { + return; + } + if (num_byte < 64) { + return avx2::BitOr(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 63 < num_byte; i += 64) { + __m512i x_vec = _mm512_loadu_si512(reinterpret_cast(x + i)); + __m512i y_vec = _mm512_loadu_si512(reinterpret_cast(y + i)); + __m512i z_vec = _mm512_or_si512(x_vec, y_vec); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(result + i), z_vec); + } + if (i < num_byte) { + avx2::BitOr(x + i, y + i, num_byte - i, result + i); + } +#else + return avx2::BitOr(x, y, num_byte, result); +#endif +} + +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX512) + if (num_byte == 0) { + return; + } + if (num_byte < 64) { + return avx2::BitXor(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 63 < num_byte; i += 64) { + __m512i x_vec = _mm512_loadu_si512(reinterpret_cast(x + i)); + __m512i y_vec = _mm512_loadu_si512(reinterpret_cast(y + i)); + __m512i z_vec = _mm512_xor_si512(x_vec, y_vec); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(result + i), z_vec); + } + if (i < num_byte) { + avx2::BitXor(x + i, y + i, num_byte - i, result + i); + } +#else + return avx2::BitXor(x, y, num_byte, result); +#endif +} + +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_AVX512) + if (num_byte == 0) { + return; + } + if (num_byte < 64) { + return avx2::BitNot(x, num_byte, result); + } + int64_t i = 0; + for (; i + 63 < num_byte; i += 64) { + __m512i x_vec = _mm512_loadu_si512(reinterpret_cast(x + i)); + __m512i z_vec = _mm512_xor_si512(x_vec, _mm512_set1_epi8(0xFF)); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(result + i), z_vec); + } + if (i < num_byte) { + avx2::BitNot(x + i, num_byte - i, result + i); + } +#else + return avx2::BitNot(x, num_byte, result); +#endif +} } // namespace vsag::avx512 diff --git a/src/simd/bit_simd.cpp b/src/simd/bit_simd.cpp new file mode 100644 index 000000000..a2a4e1533 --- /dev/null +++ b/src/simd/bit_simd.cpp @@ -0,0 +1,114 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bit_simd.h" + +#include "simd_status.h" + +namespace vsag { + +static BitOperatorType +GetBitAnd() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::BitAnd; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::BitAnd; +#endif + } else if (SimdStatus::SupportAVX()) { +#if defined(ENABLE_AVX) + return avx::BitAnd; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::BitAnd; +#endif + } + return generic::BitAnd; +} +BitOperatorType BitAnd = GetBitAnd(); + +static BitOperatorType +GetBitOr() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::BitOr; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::BitOr; +#endif + } else if (SimdStatus::SupportAVX()) { +#if defined(ENABLE_AVX) + return avx::BitOr; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::BitOr; +#endif + } + return generic::BitOr; +} +BitOperatorType BitOr = GetBitOr(); + +static BitOperatorType +GetBitXor() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::BitXor; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::BitXor; +#endif + } else if (SimdStatus::SupportAVX()) { +#if defined(ENABLE_AVX) + return avx::BitXor; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::BitXor; +#endif + } + return generic::BitXor; +} +BitOperatorType BitXor = GetBitXor(); + +static BitNotType +GetBitNot() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::BitNot; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::BitNot; +#endif + } else if (SimdStatus::SupportAVX()) { +#if defined(ENABLE_AVX) + return avx::BitNot; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::BitNot; +#endif + } + return generic::BitNot; +} +BitNotType BitNot = GetBitNot(); + +} // namespace vsag diff --git a/src/simd/bit_simd.h b/src/simd/bit_simd.h new file mode 100644 index 000000000..b8953b6df --- /dev/null +++ b/src/simd/bit_simd.h @@ -0,0 +1,88 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace vsag { + +namespace generic { +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result); +} // namespace generic + +namespace sse { +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result); +} // namespace sse + +namespace avx { +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result); +} // namespace avx + +namespace avx2 { +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result); +} // namespace avx2 + +namespace avx512 { +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result); +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result); +} // namespace avx512 + +using BitOperatorType = void (*)(const uint8_t* x, + const uint8_t* y, + const uint64_t num_byte, + uint8_t* result); +extern BitOperatorType BitAnd; +extern BitOperatorType BitOr; +extern BitOperatorType BitXor; + +using BitNotType = void (*)(const uint8_t* x, const uint64_t num_byte, uint8_t* result); +extern BitNotType BitNot; +} // namespace vsag diff --git a/src/simd/bit_simd_test.cpp b/src/simd/bit_simd_test.cpp new file mode 100644 index 000000000..dab935097 --- /dev/null +++ b/src/simd/bit_simd_test.cpp @@ -0,0 +1,187 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bit_simd.h" + +#include +#include + +#include "fixtures.h" +#include "simd_status.h" + +using namespace vsag; + +#define TEST_BIT_OPERATOR_ACCURACY(Func) \ + { \ + std::vector gt(num_bytes, 0); \ + generic::Func( \ + vec1.data() + i * num_bytes, vec2.data() + i * num_bytes, num_bytes, gt.data()); \ + std::vector sse_gt(num_bytes, 0.0F); \ + if (SimdStatus::SupportSSE()) { \ + sse::Func(vec1.data() + i * num_bytes, \ + vec2.data() + i * num_bytes, \ + num_bytes, \ + sse_gt.data()); \ + for (uint64_t j = 0; j < num_bytes; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(sse_gt[j])); \ + } \ + } \ + std::vector avx_gt(num_bytes, 0.0F); \ + if (SimdStatus::SupportAVX()) { \ + avx::Func(vec1.data() + i * num_bytes, \ + vec2.data() + i * num_bytes, \ + num_bytes, \ + avx_gt.data()); \ + for (uint64_t j = 0; j < num_bytes; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(avx_gt[j])); \ + } \ + } \ + std::vector avx2_gt(num_bytes, 0); \ + if (SimdStatus::SupportAVX2()) { \ + avx2::Func(vec1.data() + i * num_bytes, \ + vec2.data() + i * num_bytes, \ + num_bytes, \ + avx2_gt.data()); \ + for (uint64_t j = 0; j < num_bytes; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(avx2_gt[j])); \ + } \ + } \ + std::vector avx512_gt(num_bytes, 0); \ + if (SimdStatus::SupportAVX512()) { \ + avx512::Func(vec1.data() + i * num_bytes, \ + vec2.data() + i * num_bytes, \ + num_bytes, \ + avx512_gt.data()); \ + for (uint64_t j = 0; j < num_bytes; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(avx512_gt[j])); \ + } \ + } \ + }; + +#define TEST_BIT_NOT_ACCURACY(Func) \ + { \ + std::vector gt(num_bytes, 0); \ + generic::Func(vec1.data() + i * num_bytes, num_bytes, gt.data()); \ + std::vector sse_gt(num_bytes, 0); \ + if (SimdStatus::SupportSSE()) { \ + sse::Func(vec1.data() + i * num_bytes, num_bytes, sse_gt.data()); \ + for (uint64_t j = 0; j < num_bytes; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(sse_gt[j])); \ + } \ + } \ + std::vector avx_gt(num_bytes, 0); \ + if (SimdStatus::SupportAVX()) { \ + avx::Func(vec1.data() + i * num_bytes, num_bytes, avx_gt.data()); \ + for (uint64_t j = 0; j < num_bytes; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(avx_gt[j])); \ + } \ + } \ + std::vector avx2_gt(num_bytes, 0); \ + if (SimdStatus::SupportAVX2()) { \ + avx2::Func(vec1.data() + i * num_bytes, num_bytes, avx2_gt.data()); \ + for (uint64_t j = 0; j < num_bytes; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(avx2_gt[j])); \ + } \ + } \ + std::vector avx512_gt(num_bytes, 0); \ + if (SimdStatus::SupportAVX512()) { \ + avx512::Func(vec1.data() + i * num_bytes, num_bytes, avx512_gt.data()); \ + for (uint64_t j = 0; j < num_bytes; ++j) { \ + REQUIRE(fixtures::dist_t(gt[j]) == fixtures::dist_t(avx512_gt[j])); \ + } \ + } \ + }; + +TEST_CASE("Bit Operator (NOT)", "[ut][simd]") { + const auto dims = fixtures::get_common_used_dims(); + int64_t count = 100; + for (const auto& num_bytes : dims) { + auto vec1 = fixtures::GenerateVectors(count * 2, num_bytes); + std::vector vec2(vec1.begin() + count, vec1.end()); + for (uint64_t i = 0; i < count; ++i) { + TEST_BIT_NOT_ACCURACY(BitNot); + } + } +} + +TEST_CASE("Bit Operator (AND, OR, XOR)", "[ut][simd]") { + const auto dims = fixtures::get_common_used_dims(); + int64_t count = 100; + for (const auto& num_bytes : dims) { + auto vec1 = fixtures::GenerateVectors(count * 2, num_bytes); + std::vector vec2(vec1.begin() + count, vec1.end()); + for (uint64_t i = 0; i < count; ++i) { + TEST_BIT_OPERATOR_ACCURACY(BitAnd); + TEST_BIT_OPERATOR_ACCURACY(BitOr); + TEST_BIT_OPERATOR_ACCURACY(BitXor); + } + } +} + +#define BENCHMARK_BIT_OPERATOR_COMPUTE(Simd, Comp) \ + BENCHMARK_ADVANCED(#Simd #Comp) { \ + for (int i = 0; i < count; ++i) { \ + Simd::Comp(vec1.data() + i * dim, vec2.data() + i * dim, dim, vec3.data()); \ + } \ + return; \ + } + +#define BENCHMARK_BIT_NOT_COMPUTE(Simd, Comp) \ + BENCHMARK_ADVANCED(#Simd #Comp) { \ + for (int i = 0; i < count; ++i) { \ + Simd::Comp(vec1.data() + i * dim, dim, vec3.data()); \ + } \ + return; \ + } + +TEST_CASE("Bit Operator (AND, OR, XOR, NOT)", "[!benchmark][simd]") { + const auto dim = 4096; + int64_t count = 500; + auto vec1 = fixtures::GenerateVectors(count * 3, dim); + std::vector vec2(vec1.begin() + count, vec1.end()); + std::vector vec3(vec1.begin() + count * 2, vec1.end()); + + SECTION("Bit Operator And") { + BENCHMARK_BIT_OPERATOR_COMPUTE(generic, BitAnd); + BENCHMARK_BIT_OPERATOR_COMPUTE(sse, BitAnd); + BENCHMARK_BIT_OPERATOR_COMPUTE(avx, BitAnd); + BENCHMARK_BIT_OPERATOR_COMPUTE(avx2, BitAnd); + BENCHMARK_BIT_OPERATOR_COMPUTE(avx512, BitAnd); + } + + SECTION("Bit Operator Or") { + BENCHMARK_BIT_OPERATOR_COMPUTE(generic, BitOr); + BENCHMARK_BIT_OPERATOR_COMPUTE(sse, BitOr); + BENCHMARK_BIT_OPERATOR_COMPUTE(avx, BitOr); + BENCHMARK_BIT_OPERATOR_COMPUTE(avx2, BitOr); + BENCHMARK_BIT_OPERATOR_COMPUTE(avx512, BitOr); + } + + SECTION("Bit Operator Xor") { + BENCHMARK_BIT_OPERATOR_COMPUTE(generic, BitXor); + BENCHMARK_BIT_OPERATOR_COMPUTE(sse, BitXor); + BENCHMARK_BIT_OPERATOR_COMPUTE(avx, BitXor); + BENCHMARK_BIT_OPERATOR_COMPUTE(avx2, BitXor); + BENCHMARK_BIT_OPERATOR_COMPUTE(avx512, BitXor); + } + + SECTION("Bit Operator Not") { + BENCHMARK_BIT_NOT_COMPUTE(generic, BitNot); + BENCHMARK_BIT_NOT_COMPUTE(sse, BitNot); + BENCHMARK_BIT_NOT_COMPUTE(avx, BitNot); + BENCHMARK_BIT_NOT_COMPUTE(avx2, BitNot); + BENCHMARK_BIT_NOT_COMPUTE(avx512, BitNot); + } +} diff --git a/src/simd/generic.cpp b/src/simd/generic.cpp index 9e2cf4c7a..ffacfa4ee 100644 --- a/src/simd/generic.cpp +++ b/src/simd/generic.cpp @@ -555,4 +555,32 @@ PQFastScanLookUp32(const uint8_t* lookup_table, } } +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { + for (uint64_t i = 0; i < num_byte; i++) { + result[i] = x[i] & y[i]; + } +} + +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { + for (uint64_t i = 0; i < num_byte; i++) { + result[i] = x[i] | y[i]; + } +} + +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { + for (uint64_t i = 0; i < num_byte; i++) { + result[i] = x[i] ^ y[i]; + } +} + +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result) { + for (uint64_t i = 0; i < num_byte; i++) { + result[i] = ~x[i]; + } +} + } // namespace vsag::generic diff --git a/src/simd/simd.h b/src/simd/simd.h index ca5c0116a..458b597e2 100644 --- a/src/simd/simd.h +++ b/src/simd/simd.h @@ -20,6 +20,7 @@ #include "basic_func.h" #include "bf16_simd.h" +#include "bit_simd.h" #include "fp16_simd.h" #include "fp32_simd.h" #include "normalize.h" diff --git a/src/simd/sse.cpp b/src/simd/sse.cpp index 869344e55..22686b4eb 100644 --- a/src/simd/sse.cpp +++ b/src/simd/sse.cpp @@ -717,4 +717,98 @@ PQFastScanLookUp32(const uint8_t* lookup_table, #endif } +void +BitAnd(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_SSE) + if (num_byte == 0) { + return; + } + if (num_byte < 16) { + return generic::BitAnd(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 15 < num_byte; i += 16) { + __m128i x_vec = _mm_loadu_si128(reinterpret_cast(x + i)); + __m128i y_vec = _mm_loadu_si128(reinterpret_cast(y + i)); + __m128i result_vec = _mm_and_si128(x_vec, y_vec); + _mm_storeu_si128(reinterpret_cast<__m128i*>(result + i), result_vec); + } + if (i < num_byte) { + generic::BitAnd(x + i, y + i, num_byte - i, result + i); + } +#else + return generic::BitAnd(x, y, num_byte, result); +#endif +} + +void +BitOr(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_SSE) + if (num_byte == 0) { + return; + } + if (num_byte < 16) { + return generic::BitOr(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 15 < num_byte; i += 16) { + __m128i x_vec = _mm_loadu_si128(reinterpret_cast(x + i)); + __m128i y_vec = _mm_loadu_si128(reinterpret_cast(y + i)); + __m128i result_vec = _mm_or_si128(x_vec, y_vec); + _mm_storeu_si128(reinterpret_cast<__m128i*>(result + i), result_vec); + } + if (i < num_byte) { + generic::BitOr(x + i, y + i, num_byte - i, result + i); + } +#else + return generic::BitOr(x, y, num_byte, result); +#endif +} + +void +BitXor(const uint8_t* x, const uint8_t* y, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_SSE) + if (num_byte == 0) { + return; + } + if (num_byte < 16) { + return generic::BitXor(x, y, num_byte, result); + } + int64_t i = 0; + for (; i + 15 < num_byte; i += 16) { + __m128i x_vec = _mm_loadu_si128(reinterpret_cast(x + i)); + __m128i y_vec = _mm_loadu_si128(reinterpret_cast(y + i)); + __m128i result_vec = _mm_xor_si128(x_vec, y_vec); + _mm_storeu_si128(reinterpret_cast<__m128i*>(result + i), result_vec); + } + if (i < num_byte) { + generic::BitXor(x + i, y + i, num_byte - i, result + i); + } +#else + return generic::BitXor(x, y, num_byte, result); +#endif +} + +void +BitNot(const uint8_t* x, const uint64_t num_byte, uint8_t* result) { +#if defined(ENABLE_SSE) + if (num_byte == 0) { + return; + } + if (num_byte < 16) { + return generic::BitNot(x, num_byte, result); + } + int64_t i = 0; + for (; i + 15 < num_byte; i += 16) { + __m128i x_vec = _mm_loadu_si128(reinterpret_cast(x + i)); + __m128i result_vec = _mm_xor_si128(x_vec, _mm_set1_epi8(0xFF)); + _mm_storeu_si128(reinterpret_cast<__m128i*>(result + i), result_vec); + } + if (i < num_byte) { + generic::BitNot(x + i, num_byte - i, result + i); + } +#else + return generic::BitNot(x, num_byte, result); +#endif +} } // namespace vsag::sse From b74ba9c95b1648a8e82caaf66c59a461aa748272 Mon Sep 17 00:00:00 2001 From: ShawnShawnYou Date: Tue, 27 May 2025 13:40:11 +0800 Subject: [PATCH 24/42] adapt update with mark delete (#762) * adapt update with mark delete Signed-off-by: zhongxiaoyao.zxy Signed-off-by: suguan.dx --- src/algorithm/hnswlib/algorithm_interface.h | 4 ++ src/algorithm/hnswlib/hnswalg.cpp | 41 ++++++++++++----- src/algorithm/hnswlib/hnswalg.h | 13 ++++-- src/algorithm/hnswlib/hnswalg_static.h | 5 ++ src/index/hnsw_test.cpp | 51 +++++++++++++++++++++ 5 files changed, 98 insertions(+), 16 deletions(-) diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index fa2270e05..921039c9d 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -33,6 +33,7 @@ namespace hnswlib { using LabelType = vsag::LabelType; +using InnerIdType = vsag::InnerIdType; template class AlgorithmInterface { @@ -106,6 +107,9 @@ class AlgorithmInterface { virtual size_t getDeletedCount() = 0; + virtual vsag::UnorderedMap + getDeletedElements() = 0; + virtual bool isValidLabel(LabelType label) = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index daafb1b58..96acd79b3 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -1075,8 +1075,9 @@ HierarchicalNSW::DeserializeImpl(StreamReader& reader, SpaceInterface* s, size_t for (size_t i = 0; i < cur_element_count_; i++) { if (isMarkedDeleted(i)) { num_deleted_ += 1; - if (allow_replace_deleted_) - deleted_elements_.insert(i); + if (allow_replace_deleted_) { + deleted_elements_.insert({getExternalLabel(i), i}); + } } } } @@ -1139,7 +1140,7 @@ HierarchicalNSW::markDeletedInternal(InnerIdType internalId) { num_deleted_ += 1; if (allow_replace_deleted_) { std::unique_lock lock_deleted_elements(deleted_elements_lock_); - deleted_elements_.insert(internalId); + deleted_elements_.insert({getExternalLabel(internalId), internalId}); } } else { throw std::runtime_error("The requested to delete element is already deleted"); @@ -1284,21 +1285,37 @@ HierarchicalNSW::updateVector(LabelType label, const void* data_point) { void HierarchicalNSW::updateLabel(LabelType old_label, LabelType new_label) { std::unique_lock lock(label_lookup_lock_); - auto iter_old = label_lookup_.find(old_label); + + // 1. check whether new_label is occupied auto iter_new = label_lookup_.find(new_label); - if (iter_old == label_lookup_.end()) { - throw std::runtime_error(fmt::format("no old label {} in HNSW", old_label)); - } else if (iter_new != label_lookup_.end()) { + if (iter_new != label_lookup_.end()) { throw std::runtime_error(fmt::format("new label {} has been in HNSW", new_label)); - } else { - InnerIdType internal_id = iter_old->second; + } - // reset label + // 2. check whether old_label exists + InnerIdType internal_id = 0; + auto iter_old = label_lookup_.find(old_label); + if (iter_old == label_lookup_.end()) { + // 3. deal the situation of mark delete + auto iter_mark_delete = deleted_elements_.find(old_label); + if (iter_mark_delete == deleted_elements_.end()) { + throw std::runtime_error(fmt::format("no old label {} in HNSW", old_label)); + } + + // 4. update label to id + internal_id = iter_mark_delete->second; + deleted_elements_.erase(iter_mark_delete); + deleted_elements_.insert({new_label, internal_id}); + } else { + // 4. update label to id + internal_id = iter_old->second; label_lookup_.erase(iter_old); label_lookup_[new_label] = internal_id; - std::unique_lock resize_lock(resize_mutex_); - setExternalLabel(internal_id, new_label); } + + // 5. reset id to label + std::unique_lock resize_lock(resize_mutex_); + setExternalLabel(internal_id, new_label); } void diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index 90ded5823..eeaf5651e 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -45,7 +45,6 @@ #include "vsag/iterator_context.h" namespace hnswlib { -using InnerIdType = vsag::InnerIdType; using linklistsizeint = unsigned int; using reverselinklist = vsag::UnorderedSet; struct CompareByFirst { @@ -129,8 +128,9 @@ class HierarchicalNSW : public AlgorithmInterface { // flag to replace deleted elements (marked as deleted) during insertion bool allow_replace_deleted_{false}; - std::mutex deleted_elements_lock_{}; // lock for deleted_elements_ - vsag::UnorderedSet deleted_elements_; // contains internal ids of deleted elements + std::mutex deleted_elements_lock_{}; // lock for deleted_elements_ + vsag::UnorderedMap + deleted_elements_; // contains labels and internal ids of deleted elements public: HierarchicalNSW(SpaceInterface* s, @@ -142,7 +142,7 @@ class HierarchicalNSW : public AlgorithmInterface { bool normalize = false, size_t block_size_limit = 128 * 1024 * 1024, size_t random_seed = 100, - bool allow_replace_deleted = false); + bool allow_replace_deleted = true); ~HierarchicalNSW() override; @@ -246,6 +246,11 @@ class HierarchicalNSW : public AlgorithmInterface { return num_deleted_; } + vsag::UnorderedMap + getDeletedElements() override { + return deleted_elements_; + } + MaxHeap searchBaseLayer(InnerIdType ep_id, const void* data_point, int layer) const; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 23028f184..20b699214 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -247,6 +247,11 @@ class StaticHierarchicalNSW : public AlgorithmInterface { return num_deleted_; } + vsag::UnorderedMap + getDeletedElements() override { + throw std::runtime_error("Static HNSW doesn't support delete"); + }; + float getDistanceByLabel(LabelType label, const void* data_point) override { std::unique_lock lock_table(label_lookup_lock); diff --git a/src/index/hnsw_test.cpp b/src/index/hnsw_test.cpp index 8e5d3eec1..f3f7aca2c 100644 --- a/src/index/hnsw_test.cpp +++ b/src/index/hnsw_test.cpp @@ -1054,3 +1054,54 @@ TEST_CASE("extract/set data and graph", "[ut][hnsw]") { float recall = correct / (float)num_elements; REQUIRE(recall > 0.99); } + +TEST_CASE("update mark-deleted vector", "[ut][hnsw]") { + logger::set_level(logger::level::debug); + + // parameters + int dim = 128; + int base_size = 100; + int delete_size = 50; + int update_size = 50; + + // create hnsw + hnswlib::L2Space space(dim); + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + auto* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 100, allocator.get()); + alg_hnsw->init_memory_space(); + + // data and build index + auto [base_ids, base_vectors] = fixtures::generate_ids_and_vectors(base_size, dim); + for (auto i = 0; i < base_size; i++) { + alg_hnsw->addPoint(base_vectors.data() + i * dim, base_ids[i]); + } + + // start remove + for (auto i = 0; i < delete_size; i++) { + REQUIRE(alg_hnsw->getCurrentElementCount() == base_size); + REQUIRE(alg_hnsw->getDeletedCount() == i); + REQUIRE(alg_hnsw->getDeletedElements().size() == i); + + alg_hnsw->markDelete(base_ids[i]); + + REQUIRE(alg_hnsw->getDeletedElements().count(base_ids[i]) != 0); + } + + // update + for (auto i = 0; i < update_size; i++) { + auto old_label = base_ids[i] + delete_size / 2; + auto new_label = old_label + base_size; + bool is_deleted = alg_hnsw->isMarkedDeleted(old_label); + alg_hnsw->updateLabel(old_label, new_label); + if (is_deleted) { + REQUIRE(alg_hnsw->getDeletedElements().count(old_label) == 0); + REQUIRE(alg_hnsw->getDeletedElements().count(new_label) != 0); + REQUIRE(alg_hnsw->getDeletedElements()[new_label] == old_label); + } else { + REQUIRE(not alg_hnsw->isValidLabel(old_label)); + REQUIRE(alg_hnsw->isValidLabel(new_label)); + } + } + + delete alg_hnsw; +} From 2cd5dba7b0c7d3b90230c9d56da2751104ab6745 Mon Sep 17 00:00:00 2001 From: Deming Chu Date: Tue, 27 May 2025 16:53:47 +0800 Subject: [PATCH 25/42] add compressed graph support in HGraph (#747) * feat(compressed_graph_datacell): add graph_storage param and test, Signed-off-by: Deming Chu * feat(hgraph): add compressed graph support and test Signed-off-by: Deming Chu * style: minor Signed-off-by: Deming Chu * style: format Signed-off-by: Deming Chu * style: graph_storage -> graph_storage_type Signed-off-by: Deming Chu * feat(graph_storage_type): add invalid parameter check and test case Signed-off-by: Deming Chu * refactor(graph_storage_type): refine invalid parameter check Signed-off-by: Deming Chu * style(graph_storage_type): remove useless checks in parameters Signed-off-by: Deming Chu * style(graph_storage_type): "compress" -> "compressed" Signed-off-by: Deming Chu * refactor(graph_storage_type): max_degree_ moved to graph_interface_parameter.h Signed-off-by: Deming Chu * refactor(hgraph): simplify max_degree_ get Signed-off-by: Deming Chu --------- Signed-off-by: Deming Chu Signed-off-by: suguan.dx --- include/vsag/constants.h | 1 + src/algorithm/hgraph.cpp | 13 +++- src/algorithm/hgraph_parameter.cpp | 21 +++++-- src/constants.cpp | 1 + src/data_cell/compressed_graph_datacell.cpp | 7 ++- .../compressed_graph_datacell_parameter.h | 6 +- ...mpressed_graph_datacell_parameter_test.cpp | 38 ++++++++++++ .../compressed_graph_datacell_test.cpp | 18 +++++- src/data_cell/graph_datacell_parameter.cpp | 1 + src/data_cell/graph_datacell_parameter.h | 2 - src/data_cell/graph_interface_parameter.h | 2 + .../sparse_graph_datacell_parameter.h | 3 - tests/test_hgraph.cpp | 60 +++++++++++++++++-- 13 files changed, 152 insertions(+), 21 deletions(-) create mode 100644 src/data_cell/compressed_graph_datacell_parameter_test.cpp diff --git a/include/vsag/constants.h b/include/vsag/constants.h index 5075b5dcb..0daf697d9 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -139,6 +139,7 @@ extern const char* const HGRAPH_GRAPH_MAX_DEGREE; extern const char* const HGRAPH_BUILD_EF_CONSTRUCTION; extern const char* const HGRAPH_INIT_CAPACITY; extern const char* const HGRAPH_GRAPH_TYPE; +extern const char* const HGRAPH_GRAPH_STORAGE_TYPE; extern const char* const HGRAPH_BUILD_THREAD_COUNT; extern const char* const HGRAPH_PRECISE_QUANTIZATION_TYPE; extern const char* const HGRAPH_BASE_IO_TYPE; diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 2acc8d0c2..0e4d29fb2 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -15,6 +15,7 @@ #include "hgraph.h" +#include #include #include @@ -1005,6 +1006,7 @@ static const std::string HGRAPH_PARAMS_TEMPLATE = "{IO_FILE_PATH}": "{DEFAULT_FILE_PATH_VALUE}" }, "{GRAPH_TYPE_KEY}": "{GRAPH_TYPE_NSW}", + "{GRAPH_STORAGE_TYPE_KEY}": "{GRAPH_STORAGE_TYPE_FLAT}", "{ODESCENT_PARAMETER_BUILD_BLOCK_SIZE}": 10000, "{ODESCENT_PARAMETER_MIN_IN_DEGREE}": 1, "{ODESCENT_PARAMETER_ALPHA}": 1.2, @@ -1158,6 +1160,13 @@ HGraph::CheckAndMappingExternalParam(const JsonType& external_param, GRAPH_TYPE_KEY, }, }, + { + HGRAPH_GRAPH_STORAGE_TYPE, + { + HGRAPH_GRAPH_KEY, + GRAPH_STORAGE_TYPE_KEY, + }, + }, { ODESCENT_PARAMETER_ALPHA, { @@ -1245,10 +1254,8 @@ HGraph::CheckAndMappingExternalParam(const JsonType& external_param, auto hgraph_parameter = std::make_shared(); hgraph_parameter->data_type = common_param.data_type_; hgraph_parameter->FromJson(inner_json); + uint64_t max_degree = hgraph_parameter->bottom_graph_param->max_degree_; - auto max_degree = - std::dynamic_pointer_cast(hgraph_parameter->bottom_graph_param) - ->max_degree_; auto max_degree_threshold = std::max(common_param.dim_, 128L); CHECK_ARGUMENT( // NOLINT (4 <= max_degree) and (max_degree <= max_degree_threshold), diff --git a/src/algorithm/hgraph_parameter.cpp b/src/algorithm/hgraph_parameter.cpp index 8cf3fc1b5..4580f4414 100644 --- a/src/algorithm/hgraph_parameter.cpp +++ b/src/algorithm/hgraph_parameter.cpp @@ -15,8 +15,6 @@ #include "hgraph_parameter.h" -#include - #include "data_cell/graph_interface_parameter.h" #include "data_cell/sparse_vector_datacell_parameter.h" #include "inner_string_params.h" @@ -74,8 +72,23 @@ HGraphParameter::FromJson(const JsonType& json) { CHECK_ARGUMENT(json.contains(HGRAPH_GRAPH_KEY), fmt::format("hgraph parameters must contains {}", HGRAPH_GRAPH_KEY)); const auto& graph_json = json[HGRAPH_GRAPH_KEY]; - this->bottom_graph_param = GraphInterfaceParameter::GetGraphParameterByJson( - GraphStorageTypes::GRAPH_STORAGE_TYPE_FLAT, graph_json); + + GraphStorageTypes graph_storage_type = GraphStorageTypes::GRAPH_STORAGE_TYPE_FLAT; + if (graph_json.contains(GRAPH_STORAGE_TYPE_KEY)) { + const auto& graph_storage_type_str = graph_json[GRAPH_STORAGE_TYPE_KEY]; + if (graph_storage_type_str == GRAPH_STORAGE_TYPE_COMPRESSED) { + graph_storage_type = GraphStorageTypes::GRAPH_STORAGE_TYPE_COMPRESSED; + } + + if (graph_storage_type_str != GRAPH_STORAGE_TYPE_COMPRESSED && + graph_storage_type_str != GRAPH_STORAGE_TYPE_FLAT) { + throw VsagException( + ErrorType::INVALID_ARGUMENT, + fmt::format("invalid graph_storage_type: {}", graph_storage_type_str.dump())); + } + } + this->bottom_graph_param = + GraphInterfaceParameter::GetGraphParameterByJson(graph_storage_type, graph_json); if (json.contains(BUILD_PARAMS_KEY)) { const auto& build_params = json[BUILD_PARAMS_KEY]; diff --git a/src/constants.cpp b/src/constants.cpp index 5bf55212f..e8ee41582 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -143,6 +143,7 @@ const char* const HGRAPH_GRAPH_MAX_DEGREE = "max_degree"; const char* const HGRAPH_BUILD_EF_CONSTRUCTION = "ef_construction"; const char* const HGRAPH_INIT_CAPACITY = "hgraph_init_capacity"; const char* const HGRAPH_GRAPH_TYPE = "graph_type"; +const char* const HGRAPH_GRAPH_STORAGE_TYPE = "graph_storage_type"; const char* const HGRAPH_BUILD_THREAD_COUNT = "build_thread_count"; const char* const HGRAPH_PRECISE_QUANTIZATION_TYPE = "precise_quantization_type"; const char* const HGRAPH_BASE_IO_TYPE = "base_io_type"; diff --git a/src/data_cell/compressed_graph_datacell.cpp b/src/data_cell/compressed_graph_datacell.cpp index 6b7175c57..3766e0fd0 100644 --- a/src/data_cell/compressed_graph_datacell.cpp +++ b/src/data_cell/compressed_graph_datacell.cpp @@ -48,6 +48,12 @@ CompressedGraphDataCell::InsertNeighborsById(InnerIdType id, neighbor_sets_[id] = std::make_unique(allocator_); } neighbor_sets_[id]->Encode(tmp, max_capacity_); + } else { + neighbor_sets_[id].reset(); + } + + InnerIdType current = total_count_.load(); + while (current < id + 1 && !total_count_.compare_exchange_weak(current, id + 1)) { } } @@ -120,7 +126,6 @@ CompressedGraphDataCell::Resize(InnerIdType new_size) { } neighbor_sets_.resize(new_size); this->max_capacity_ = new_size; - this->total_count_ = new_size; } } // namespace vsag diff --git a/src/data_cell/compressed_graph_datacell_parameter.h b/src/data_cell/compressed_graph_datacell_parameter.h index c5872ac26..7b67cd64d 100644 --- a/src/data_cell/compressed_graph_datacell_parameter.h +++ b/src/data_cell/compressed_graph_datacell_parameter.h @@ -15,6 +15,8 @@ #pragma once +#include + #include "graph_interface_parameter.h" namespace vsag { @@ -35,11 +37,9 @@ class CompressedGraphDatacellParameter : public GraphInterfaceParameter { ToJson() override { JsonType json; json[GRAPH_PARAM_MAX_DEGREE] = this->max_degree_; + json[GRAPH_STORAGE_TYPE_KEY] = GRAPH_STORAGE_TYPE_COMPRESSED; return json; } - -public: - uint64_t max_degree_{64}; }; using CompressedGraphDatacellParamPtr = std::shared_ptr; diff --git a/src/data_cell/compressed_graph_datacell_parameter_test.cpp b/src/data_cell/compressed_graph_datacell_parameter_test.cpp new file mode 100644 index 000000000..5f53aef29 --- /dev/null +++ b/src/data_cell/compressed_graph_datacell_parameter_test.cpp @@ -0,0 +1,38 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compressed_graph_datacell_parameter.h" + +#include + +#include "parameter_test.h" + +namespace vsag { + +TEST_CASE("CompressedGraphDatacellParameter ToJson Test", + "[ut][CompressedGraphDatacellParameter]") { + std::string param_str = R"( + { + "max_degree": 100, + "graph_storage_type": "compressed" + } + )"; + auto param = std::make_shared(); + auto json = JsonType::parse(param_str); + param->FromJson(json); + ParameterTest::TestToJson(param); +} + +} // namespace vsag diff --git a/src/data_cell/compressed_graph_datacell_test.cpp b/src/data_cell/compressed_graph_datacell_test.cpp index 8c59aa3e0..6559112b4 100644 --- a/src/data_cell/compressed_graph_datacell_test.cpp +++ b/src/data_cell/compressed_graph_datacell_test.cpp @@ -41,13 +41,29 @@ TEST_CASE("CompressedGraphDataCell Basic Test", "[ut][CompressedGraphDataCell]") auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto dim = GENERATE(32, 64); auto max_degree = GENERATE(5, 12, 32, 64, 128); - auto max_capacity = 10000; + // raw construction IndexCommonParam common_param; common_param.dim_ = dim; common_param.allocator_ = allocator; + auto graph_param = std::make_shared(); graph_param->max_degree_ = max_degree; graph_param->graph_storage_type_ = GraphStorageTypes::GRAPH_STORAGE_TYPE_COMPRESSED; TestCompressedGraphDataCell(graph_param, common_param); + + // parameter construction + constexpr const char* graph_param_temp = + R"( + {{ + "max_degree": {}, + "graph_storage_type": "{}" + }} + )"; + + auto param_str = fmt::format(graph_param_temp, max_degree, GRAPH_STORAGE_TYPE_COMPRESSED); + auto param_json = JsonType::parse(param_str); + graph_param->FromJson(param_json); + REQUIRE(graph_param->graph_storage_type_ == GraphStorageTypes::GRAPH_STORAGE_TYPE_COMPRESSED); + TestCompressedGraphDataCell(graph_param, common_param); } diff --git a/src/data_cell/graph_datacell_parameter.cpp b/src/data_cell/graph_datacell_parameter.cpp index f0acd2080..1669d6764 100644 --- a/src/data_cell/graph_datacell_parameter.cpp +++ b/src/data_cell/graph_datacell_parameter.cpp @@ -32,6 +32,7 @@ GraphDataCellParameter::FromJson(const JsonType& json) { this->init_max_capacity_ = json[GRAPH_PARAM_INIT_MAX_CAPACITY]; } } + JsonType GraphDataCellParameter::ToJson() { JsonType json; diff --git a/src/data_cell/graph_datacell_parameter.h b/src/data_cell/graph_datacell_parameter.h index 9f36b1c3d..e72463b6e 100644 --- a/src/data_cell/graph_datacell_parameter.h +++ b/src/data_cell/graph_datacell_parameter.h @@ -33,8 +33,6 @@ class GraphDataCellParameter : public GraphInterfaceParameter { public: IOParamPtr io_parameter_{nullptr}; - uint64_t max_degree_{64}; - uint64_t init_max_capacity_{100}; }; diff --git a/src/data_cell/graph_interface_parameter.h b/src/data_cell/graph_interface_parameter.h index 19294792f..c09b3242e 100644 --- a/src/data_cell/graph_interface_parameter.h +++ b/src/data_cell/graph_interface_parameter.h @@ -36,6 +36,8 @@ class GraphInterfaceParameter : public Parameter { public: GraphStorageTypes graph_storage_type_{GraphStorageTypes::GRAPH_STORAGE_TYPE_FLAT}; + uint64_t max_degree_{64}; + protected: explicit GraphInterfaceParameter(GraphStorageTypes graph_type) : graph_storage_type_(graph_type){}; diff --git a/src/data_cell/sparse_graph_datacell_parameter.h b/src/data_cell/sparse_graph_datacell_parameter.h index fda72cf25..4d00d18bc 100644 --- a/src/data_cell/sparse_graph_datacell_parameter.h +++ b/src/data_cell/sparse_graph_datacell_parameter.h @@ -37,9 +37,6 @@ class SparseGraphDatacellParameter : public GraphInterfaceParameter { json[GRAPH_PARAM_MAX_DEGREE] = this->max_degree_; return json; } - -public: - uint64_t max_degree_{64}; }; using SparseGraphDatacellParamPtr = std::shared_ptr; diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index ee0907d1b..69a55613b 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -35,7 +35,8 @@ class HgraphTestIndex : public fixtures::TestIndex { int thread_count = 5, int extra_info_size = 0, const std::string& data_type = "float32", - std::string graph_type = "nsw"); + std::string graph_type = "nsw", + std::string graph_storage = "flat"); static bool IsRaBitQ(const std::string& quantization_str); @@ -86,7 +87,8 @@ HgraphTestIndex::GenerateHGraphBuildParametersString(const std::string& metric_t int thread_count, int extra_info_size, const std::string& data_type, - std::string graph_type) { + std::string graph_type, + std::string graph_storage) { std::string build_parameters_str; constexpr auto parameter_temp_reorder = R"( @@ -106,6 +108,7 @@ HgraphTestIndex::GenerateHGraphBuildParametersString(const std::string& metric_t "precise_io_type": "{}", "precise_file_path": "{}", "graph_type": "{}", + "graph_storage_type": "{}", "graph_iter_turn": 10, "neighbor_sample_rate": 0.3, "alpha": 1.2 @@ -126,6 +129,7 @@ HgraphTestIndex::GenerateHGraphBuildParametersString(const std::string& metric_t "ef_construction": 500, "build_thread_count": {}, "graph_type": "{}", + "graph_storage_type": "{}", "graph_iter_turn": 10, "neighbor_sample_rate": 0.3, "alpha": 1.2 @@ -158,7 +162,8 @@ HgraphTestIndex::GenerateHGraphBuildParametersString(const std::string& metric_t high_quantizer_str, precise_io_type, dir.GenerateRandomFile(), - graph_type); + graph_type, + graph_storage); } else { build_parameters_str = fmt::format(parameter_temp_origin, data_type, @@ -168,7 +173,8 @@ HgraphTestIndex::GenerateHGraphBuildParametersString(const std::string& metric_t base_quantizer_str, pq_dim, thread_count, - graph_type); + graph_type, + graph_storage); } return build_parameters_str; } @@ -318,6 +324,21 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, auto param = fmt::format(param_temp, param_keys); REQUIRE_THROWS(TestFactory(name, param, false)); } + + SECTION("Invalid hgraph param graph_storage_type") { + auto graph_storage_type = "fsa"; + constexpr const char* param_temp = + R"({{ + "dtype": "float32", + "metric_type": "l2", + "dim": 35, + "index_param": {{ + "graph_storage_type": "{}" + }} + }})"; + auto param = fmt::format(param_temp, graph_storage_type); + REQUIRE_THROWS(TestFactory(name, param, false)); + } } TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, @@ -476,6 +497,37 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph ODescent Build", } } +TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, + "HGraph Compressed Graph Build", + "[ft][hgraph]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2", "ip", "cosine"); + + const std::string name = "hgraph"; + auto search_param = fmt::format(search_param_tmp, 200, false); + for (auto dim : dims) { + for (auto& [base_quantization_str, recall] : test_cases) { + if (IsRaBitQ(base_quantization_str)) { + if (std::string(metric_type) != "l2") { + continue; + } + if (dim <= fixtures::RABITQ_MIN_RACALL_DIM) { + dim += fixtures::RABITQ_MIN_RACALL_DIM; + } + } + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateHGraphBuildParametersString( + metric_type, dim, base_quantization_str, 0, 0, "float32", "nsw", "compressed"); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestBuildIndex(index, dataset, true); + TestGeneral(index, dataset, search_param, recall); + vsag::Options::Instance().set_block_size_limit(origin_size); + } + } +} + TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Add", "[ft][hgraph]") { auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); From d07b324b9742ded9720c959d49a38a678a00da3f Mon Sep 17 00:00:00 2001 From: LHT129 Date: Wed, 28 May 2025 09:50:28 +0800 Subject: [PATCH 26/42] Support arthimetic operator for fp32 simd (#761) - add reverse codebook for pq8bit quantization - speedup some compute for pq8bit scan Signed-off-by: LHT129 Signed-off-by: suguan.dx --- .../product_quantization/product_quantizer.h | 107 +++++++++++++----- src/simd/avx.cpp | 68 +++++++++++ src/simd/avx2.cpp | 67 +++++++++++ src/simd/avx512.cpp | 68 +++++++++++ src/simd/fp32_simd.cpp | 97 +++++++++++++++- src/simd/fp32_simd.h | 55 ++++++++- src/simd/fp32_simd_test.cpp | 7 +- src/simd/generic.cpp | 30 +++++ src/simd/sse.cpp | 87 ++++++++++++++ 9 files changed, 551 insertions(+), 35 deletions(-) diff --git a/src/quantization/product_quantization/product_quantizer.h b/src/quantization/product_quantization/product_quantizer.h index addfab676..ff3441520 100644 --- a/src/quantization/product_quantization/product_quantizer.h +++ b/src/quantization/product_quantization/product_quantizer.h @@ -96,6 +96,9 @@ class ProductQuantizer : public Quantizer> { centroid_num * subspace_dim_; } + void + transpose_codebooks(); + public: constexpr static int64_t PQ_BITS = 8L; constexpr static int64_t CENTROIDS_PER_SUBSPACE = 256L; @@ -105,11 +108,16 @@ class ProductQuantizer : public Quantizer> { int64_t subspace_dim_{1}; // equal to dim/pq_dim_; Vector codebooks_; + + Vector reverse_codebooks_; }; template ProductQuantizer::ProductQuantizer(int dim, int64_t pq_dim, Allocator* allocator) - : Quantizer>(dim, allocator), pq_dim_(pq_dim), codebooks_(allocator) { + : Quantizer>(dim, allocator), + pq_dim_(pq_dim), + codebooks_(allocator), + reverse_codebooks_(allocator) { if (dim % pq_dim != 0) { throw VsagException( ErrorType::INVALID_ARGUMENT, @@ -118,6 +126,7 @@ ProductQuantizer::ProductQuantizer(int dim, int64_t pq_dim, Allocator* a this->code_size_ = this->pq_dim_; this->subspace_dim_ = this->dim_ / pq_dim; codebooks_.resize(this->dim_ * CENTROIDS_PER_SUBSPACE); + reverse_codebooks_.resize(this->dim_ * CENTROIDS_PER_SUBSPACE); } template @@ -164,6 +173,7 @@ ProductQuantizer::TrainImpl(const vsag::DataType* data, uint64_t count) cluster.k_centroids_, CENTROIDS_PER_SUBSPACE * subspace_dim_ * sizeof(float)); } + this->transpose_codebooks(); this->is_trained_ = true; return true; @@ -258,30 +268,53 @@ ProductQuantizer::ProcessQueryImpl(const DataType* query, } auto* lookup_table = reinterpret_cast( this->allocator_->Allocate(this->pq_dim_ * CENTROIDS_PER_SUBSPACE * sizeof(float))); - - for (int i = 0; i < pq_dim_; ++i) { - const auto* per_query = cur_query + i * subspace_dim_; - const auto* per_code_book = get_codebook_data(i, 0); - auto* per_result = lookup_table + i * CENTROIDS_PER_SUBSPACE; - if constexpr (metric == MetricType::METRIC_TYPE_IP or - metric == MetricType::METRIC_TYPE_COSINE) { - cblas_sgemv(CblasRowMajor, - CblasNoTrans, - CENTROIDS_PER_SUBSPACE, - subspace_dim_, - 1.0F, - per_code_book, - subspace_dim_, - per_query, - 1, - 0.0F, - per_result, - 1); - } else if constexpr (metric == MetricType::METRIC_TYPE_L2SQR) { - // TODO(LHT): use blas opt - for (int64_t j = 0; j < CENTROIDS_PER_SUBSPACE; ++j) { - per_result[j] = FP32ComputeL2Sqr( - per_query, per_code_book + j * subspace_dim_, subspace_dim_); + if (true) { + for (int i = 0; i < pq_dim_; ++i) { + const auto* per_query = cur_query + i * subspace_dim_; + const auto* per_code_book = get_codebook_data(i, 0); + auto* per_result = lookup_table + i * CENTROIDS_PER_SUBSPACE; + if constexpr (metric == MetricType::METRIC_TYPE_IP or + metric == MetricType::METRIC_TYPE_COSINE) { + cblas_sgemv(CblasRowMajor, + CblasNoTrans, + CENTROIDS_PER_SUBSPACE, + subspace_dim_, + 1.0F, + per_code_book, + subspace_dim_, + per_query, + 1, + 0.0F, + per_result, + 1); + } else if constexpr (metric == MetricType::METRIC_TYPE_L2SQR) { + // TODO(LHT): use blas opt + for (int64_t j = 0; j < CENTROIDS_PER_SUBSPACE; ++j) { + per_result[j] = FP32ComputeL2Sqr( + per_query, per_code_book + j * subspace_dim_, subspace_dim_); + } + } + } + } else { + Vector tmp(this->allocator_); + tmp.resize(this->dim_); + for (int i = 0; i < CENTROIDS_PER_SUBSPACE; ++i) { + if constexpr (metric == MetricType::METRIC_TYPE_IP or + metric == MetricType::METRIC_TYPE_COSINE) { + FP32Mul(reverse_codebooks_.data() + i * this->dim_, + cur_query, + tmp.data(), + this->dim_); + } else if constexpr (metric == MetricType::METRIC_TYPE_L2SQR) { + FP32Sub(reverse_codebooks_.data() + i * this->dim_, + cur_query, + tmp.data(), + this->dim_); + FP32Mul(tmp.data(), tmp.data(), tmp.data(), this->dim_); + } + for (int j = 0; j < pq_dim_; ++j) { + lookup_table[j * CENTROIDS_PER_SUBSPACE + i] = + FP32ReduceAdd(tmp.data() + j * subspace_dim_, subspace_dim_); } } } @@ -302,7 +335,7 @@ ProductQuantizer::ComputeDistImpl(Computer& computer, const uint8_t* codes, float* dists) const { auto* lut = reinterpret_cast(computer.buf_); - dists[0] = 0.0F; + float dist = 0.0F; int64_t i = 0; for (; i + 4 < pq_dim_; i += 4) { float dism = 0; @@ -314,15 +347,17 @@ ProductQuantizer::ComputeDistImpl(Computer& computer, lut += CENTROIDS_PER_SUBSPACE; dism += lut[*codes++]; lut += CENTROIDS_PER_SUBSPACE; - dists[0] += dism; + dist += dism; } for (; i < pq_dim_; ++i) { - dists[0] += lut[*codes++]; + dist += lut[*codes++]; lut += CENTROIDS_PER_SUBSPACE; } if constexpr (metric == MetricType::METRIC_TYPE_COSINE or metric == MetricType::METRIC_TYPE_IP) { - dists[0] = 1.0F - dists[0]; + dists[0] = 1.0F - dist; + } else if constexpr (metric == MetricType::METRIC_TYPE_L2SQR) { + dists[0] = dist; } } @@ -352,6 +387,7 @@ ProductQuantizer::DeserializeImpl(StreamReader& reader) { StreamReader::ReadObj(reader, this->pq_dim_); StreamReader::ReadObj(reader, this->subspace_dim_); StreamReader::ReadVector(reader, this->codebooks_); + this->transpose_codebooks(); } template @@ -360,4 +396,17 @@ ProductQuantizer::ReleaseComputerImpl(Computer> this->allocator_->Deallocate(computer.buf_); } +template +void +ProductQuantizer::transpose_codebooks() { + for (int64_t i = 0; i < this->pq_dim_; ++i) { + for (int64_t j = 0; j < CENTROIDS_PER_SUBSPACE; ++j) { + memcpy(this->reverse_codebooks_.data() + j * this->dim_ + i * subspace_dim_, + this->codebooks_.data() + i * CENTROIDS_PER_SUBSPACE * subspace_dim_ + + j * subspace_dim_, + subspace_dim_ * sizeof(float)); + } + } +} + } // namespace vsag diff --git a/src/simd/avx.cpp b/src/simd/avx.cpp index 5d3d5ba71..ed4f4d149 100644 --- a/src/simd/avx.cpp +++ b/src/simd/avx.cpp @@ -292,6 +292,74 @@ FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { #endif } +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX) + if (dim < 8) { + return sse::FP32Add(x, y, z, dim); + } + int i = 0; + for (; i + 7 < dim; i += 8) { + __m256 a = _mm256_loadu_ps(x + i); + __m256 b = _mm256_loadu_ps(y + i); + __m256 c = _mm256_add_ps(a, b); + _mm256_storeu_ps(z + i, c); + } + if (i < dim) { + sse::FP32Add(x + i, y + i, z + i, dim - i); + } +#else + sse::FP32Add(x, y, z, dim); +#endif +} + +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX) + if (dim < 8) { + return sse::FP32Mul(x, y, z, dim); + } + int i = 0; + for (; i + 7 < dim; i += 8) { + __m256 a = _mm256_loadu_ps(x + i); + __m256 b = _mm256_loadu_ps(y + i); + __m256 c = _mm256_mul_ps(a, b); + _mm256_storeu_ps(z + i, c); + } + if (i < dim) { + sse::FP32Mul(x + i, y + i, z + i, dim - i); + } +#else + sse::FP32Mul(x, y, z, dim); +#endif +} + +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX) + if (dim < 8) { + return sse::FP32Div(x, y, z, dim); + } + int i = 0; + for (; i + 7 < dim; i += 8) { + __m256 a = _mm256_loadu_ps(x + i); + __m256 b = _mm256_loadu_ps(y + i); + __m256 c = _mm256_div_ps(a, b); + _mm256_storeu_ps(z + i, c); + } + if (i < dim) { + sse::FP32Div(x + i, y + i, z + i, dim - i); + } +#else + sse::FP32Div(x, y, z, dim); +#endif +} + +float +FP32ReduceAdd(const float* x, uint64_t dim) { + return sse::FP32ReduceAdd(x, dim); +} + #if defined(ENABLE_AVX) __inline __m256i __attribute__((__always_inline__)) load_8_short(const uint16_t* data) { return _mm256_set_epi16(data[7], diff --git a/src/simd/avx2.cpp b/src/simd/avx2.cpp index c3610be7b..96f93cc73 100644 --- a/src/simd/avx2.cpp +++ b/src/simd/avx2.cpp @@ -286,6 +286,73 @@ FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { #endif } +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX2) + if (dim < 8) { + return sse::FP32Add(x, y, z, dim); + } + int i = 0; + for (; i + 7 < dim; i += 8) { + __m256 a = _mm256_loadu_ps(x + i); + __m256 b = _mm256_loadu_ps(y + i); + __m256 c = _mm256_add_ps(a, b); + _mm256_storeu_ps(z + i, c); + } + if (i < dim) { + sse::FP32Add(x + i, y + i, z + i, dim - i); + } +#else + return sse::FP32Add(x, y, z, dim); +#endif +} + +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX2) + if (dim < 8) { + return sse::FP32Mul(x, y, z, dim); + } + int i = 0; + for (; i + 7 < dim; i += 8) { + __m256 a = _mm256_loadu_ps(x + i); + __m256 b = _mm256_loadu_ps(y + i); + __m256 c = _mm256_mul_ps(a, b); + _mm256_storeu_ps(z + i, c); + } + if (i < dim) { + sse::FP32Mul(x + i, y + i, z + i, dim - i); + } +#else + return sse::FP32Mul(x, y, z, dim); +#endif +} + +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX2) + if (dim < 8) { + return sse::FP32Div(x, y, z, dim); + } + int i = 0; + for (; i + 7 < dim; i += 8) { + __m256 a = _mm256_loadu_ps(x + i); + __m256 b = _mm256_loadu_ps(y + i); + __m256 c = _mm256_div_ps(a, b); + _mm256_storeu_ps(z + i, c); + } + if (i < dim) { + sse::FP32Div(x + i, y + i, z + i, dim - i); + } +#else + return sse::FP32Div(x, y, z, dim); +#endif +} +float +FP32ReduceAdd(const float* x, uint64_t dim) { + return sse::FP32ReduceAdd(x, dim); +} + #if defined(ENABLE_AVX2) __inline __m256i __attribute__((__always_inline__)) load_8_short(const uint16_t* data) { __m128i bf16 = _mm_loadu_si128(reinterpret_cast(data)); diff --git a/src/simd/avx512.cpp b/src/simd/avx512.cpp index 25032f724..00d1212fb 100644 --- a/src/simd/avx512.cpp +++ b/src/simd/avx512.cpp @@ -329,6 +329,74 @@ FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { #endif } +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX512) + if (dim < 16) { + return avx2::FP32Add(x, y, z, dim); + } + uint64_t i = 0; + for (; i + 15 < dim; i += 16) { + __m512 x_vec = _mm512_loadu_ps(x + i); + __m512 y_vec = _mm512_loadu_ps(y + i); + __m512 sum_vec = _mm512_add_ps(x_vec, y_vec); + _mm512_storeu_ps(z + i, sum_vec); + } + if (dim > i) { + avx2::FP32Add(x + i, y + i, z + i, dim - i); + } +#else + return avx2::FP32Add(x, y, z, dim); +#endif +} + +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX512) + if (dim < 16) { + return avx2::FP32Mul(x, y, z, dim); + } + uint64_t i = 0; + for (; i + 15 < dim; i += 16) { + __m512 x_vec = _mm512_loadu_ps(x + i); + __m512 y_vec = _mm512_loadu_ps(y + i); + __m512 mul_vec = _mm512_mul_ps(x_vec, y_vec); + _mm512_storeu_ps(z + i, mul_vec); + } + if (dim > i) { + avx2::FP32Mul(x + i, y + i, z + i, dim - i); + } +#else + return avx2::FP32Mul(x, y, z, dim); +#endif +} + +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_AVX512) + if (dim < 16) { + return avx2::FP32Div(x, y, z, dim); + } + uint64_t i = 0; + for (; i + 15 < dim; i += 16) { + __m512 x_vec = _mm512_loadu_ps(x + i); + __m512 y_vec = _mm512_loadu_ps(y + i); + __m512 div_vec = _mm512_div_ps(x_vec, y_vec); + _mm512_storeu_ps(z + i, div_vec); + } + if (dim > i) { + avx2::FP32Div(x + i, y + i, z + i, dim - i); + } +#else + return avx2::FP32Div(x, y, z, dim); +#endif +} + +float +FP32ReduceAdd(const float* x, uint64_t dim) { + return sse::FP32ReduceAdd(x, dim); +} + #if defined(ENABLE_AVX512) __inline __m512i __attribute__((__always_inline__)) load_16_short(const uint16_t* data) { __m256i bf16 = _mm256_loadu_si256(reinterpret_cast(data)); diff --git a/src/simd/fp32_simd.cpp b/src/simd/fp32_simd.cpp index 10f2119a7..0677654f2 100644 --- a/src/simd/fp32_simd.cpp +++ b/src/simd/fp32_simd.cpp @@ -111,7 +111,7 @@ GetFP32ComputeL2SqrBatch4() { } FP32ComputeBatch4Type FP32ComputeL2SqrBatch4 = GetFP32ComputeL2SqrBatch4(); -static FP32SubType +static FP32ArithmeticType GetFP32Sub() { if (SimdStatus::SupportAVX512()) { #if defined(ENABLE_AVX512) @@ -132,5 +132,98 @@ GetFP32Sub() { } return generic::FP32Sub; } -FP32SubType FP32Sub = GetFP32Sub(); +FP32ArithmeticType FP32Sub = GetFP32Sub(); + +static FP32ArithmeticType +GetFP32Add() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::FP32Add; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::FP32Add; +#endif + } else if (SimdStatus::SupportAVX()) { +#if defined(ENABLE_AVX) + return avx::FP32Add; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::FP32Add; +#endif + } + return generic::FP32Add; +} +FP32ArithmeticType FP32Add = GetFP32Add(); + +static FP32ArithmeticType +GetFP32Mul() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::FP32Mul; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::FP32Mul; +#endif + } else if (SimdStatus::SupportAVX()) { +#if defined(ENABLE_AVX) + return avx::FP32Mul; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::FP32Mul; +#endif + } + return generic::FP32Mul; +} +FP32ArithmeticType FP32Mul = GetFP32Mul(); + +static FP32ArithmeticType +GetFP32Div() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::FP32Div; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::FP32Div; +#endif + } else if (SimdStatus::SupportAVX()) { +#if defined(ENABLE_AVX) + return avx::FP32Div; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::FP32Div; +#endif + } + return generic::FP32Div; +} +FP32ArithmeticType FP32Div = GetFP32Div(); + +static FP32ReduceType +GetFP32ReduceAdd() { + if (SimdStatus::SupportAVX512()) { +#if defined(ENABLE_AVX512) + return avx512::FP32ReduceAdd; +#endif + } else if (SimdStatus::SupportAVX2()) { +#if defined(ENABLE_AVX2) + return avx2::FP32ReduceAdd; +#endif + } else if (SimdStatus::SupportAVX()) { +#if defined(ENABLE_AVX) + return avx::FP32ReduceAdd; +#endif + } else if (SimdStatus::SupportSSE()) { +#if defined(ENABLE_SSE) + return sse::FP32ReduceAdd; +#endif + } + return generic::FP32ReduceAdd; +} +FP32ReduceType FP32ReduceAdd = GetFP32ReduceAdd(); + } // namespace vsag diff --git a/src/simd/fp32_simd.h b/src/simd/fp32_simd.h index 836dd1457..2779e35a5 100644 --- a/src/simd/fp32_simd.h +++ b/src/simd/fp32_simd.h @@ -49,6 +49,15 @@ FP32ComputeL2SqrBatch4(const float* query, float& result4); void FP32Sub(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim); + +float +FP32ReduceAdd(const float* x, uint64_t dim); } // namespace generic namespace sse { @@ -80,6 +89,15 @@ FP32ComputeL2SqrBatch4(const float* query, float& result4); void FP32Sub(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim); + +float +FP32ReduceAdd(const float* x, uint64_t dim); } // namespace sse namespace avx { @@ -111,6 +129,15 @@ FP32ComputeL2SqrBatch4(const float* query, float& result4); void FP32Sub(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim); + +float +FP32ReduceAdd(const float* x, uint64_t dim); } // namespace avx namespace avx2 { @@ -142,6 +169,15 @@ FP32ComputeL2SqrBatch4(const float* query, float& result4); void FP32Sub(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim); + +float +FP32ReduceAdd(const float* x, uint64_t dim); } // namespace avx2 namespace avx512 { @@ -173,6 +209,15 @@ FP32ComputeL2SqrBatch4(const float* query, float& result4); void FP32Sub(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim); +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim); + +float +FP32ReduceAdd(const float* x, uint64_t dim); } // namespace avx512 using FP32ComputeType = float (*)(const float* query, const float* codes, uint64_t dim); @@ -192,6 +237,12 @@ using FP32ComputeBatch4Type = void (*)(const float* query, extern FP32ComputeBatch4Type FP32ComputeIPBatch4; extern FP32ComputeBatch4Type FP32ComputeL2SqrBatch4; -using FP32SubType = void (*)(const float* x, const float* y, float* z, uint64_t dim); -extern FP32SubType FP32Sub; +using FP32ArithmeticType = void (*)(const float* x, const float* y, float* z, uint64_t dim); +extern FP32ArithmeticType FP32Sub; +extern FP32ArithmeticType FP32Add; +extern FP32ArithmeticType FP32Mul; +extern FP32ArithmeticType FP32Div; + +using FP32ReduceType = float (*)(const float* x, uint64_t dim); +extern FP32ReduceType FP32ReduceAdd; } // namespace vsag diff --git a/src/simd/fp32_simd_test.cpp b/src/simd/fp32_simd_test.cpp index 9e1de729e..237abed61 100644 --- a/src/simd/fp32_simd_test.cpp +++ b/src/simd/fp32_simd_test.cpp @@ -45,7 +45,7 @@ using namespace vsag; } \ }; -#define TEST_FP32_SUB_ACCURACY(Func) \ +#define TEST_FP32_ARTHIMETIC_ACCURACY(Func) \ { \ std::vector gt(dim, 0.0F); \ generic::Func(vec1.data() + i * dim, vec2.data() + i * dim, gt.data(), dim); \ @@ -176,7 +176,10 @@ TEST_CASE("FP32 SIMD Compute", "[ut][simd]") { for (uint64_t i = 0; i < count; ++i) { TEST_FP32_COMPUTE_ACCURACY(FP32ComputeIP); TEST_FP32_COMPUTE_ACCURACY(FP32ComputeL2Sqr); - TEST_FP32_SUB_ACCURACY(FP32Sub); + TEST_FP32_ARTHIMETIC_ACCURACY(FP32Sub); + TEST_FP32_ARTHIMETIC_ACCURACY(FP32Add); + TEST_FP32_ARTHIMETIC_ACCURACY(FP32Mul); + TEST_FP32_ARTHIMETIC_ACCURACY(FP32Div); } for (uint64_t i = 0; i < count; i += 4) { TEST_FP32_COMPUTE_ACCURACY_BATCH4(FP32ComputeIP, FP32ComputeIPBatch4); diff --git a/src/simd/generic.cpp b/src/simd/generic.cpp index ffacfa4ee..41854f900 100644 --- a/src/simd/generic.cpp +++ b/src/simd/generic.cpp @@ -140,6 +140,36 @@ FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { } } +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim) { + for (uint64_t i = 0; i < dim; ++i) { + z[i] = x[i] + y[i]; + } +} + +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim) { + for (uint64_t i = 0; i < dim; ++i) { + z[i] = x[i] * y[i]; + } +} + +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim) { + for (uint64_t i = 0; i < dim; ++i) { + z[i] = x[i] / y[i]; + } +} + +float +FP32ReduceAdd(const float* x, uint64_t dim) { + float result = 0.0F; + for (uint64_t i = 0; i < dim; ++i) { + result += x[i]; + } + return result; +} + union FP32Struct { uint32_t int_value; float float_value; diff --git a/src/simd/sse.cpp b/src/simd/sse.cpp index 22686b4eb..514fddfd4 100644 --- a/src/simd/sse.cpp +++ b/src/simd/sse.cpp @@ -280,6 +280,93 @@ FP32Sub(const float* x, const float* y, float* z, uint64_t dim) { #endif } +void +FP32Add(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_SSE) + if (dim < 4) { + return generic::FP32Add(x, y, z, dim); + } + int64_t i = 0; + for (; i + 3 < dim; i += 4) { + __m128 a = _mm_loadu_ps(x + i); + __m128 b = _mm_loadu_ps(y + i); + __m128 c = _mm_add_ps(a, b); + _mm_storeu_ps(z + i, c); + } + if (i < dim) { + generic::FP32Add(x + i, y + i, z + i, dim - i); + } +#else + return generic::FP32Add(x, y, z, dim); +#endif +} + +void +FP32Mul(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_SSE) + if (dim < 4) { + return generic::FP32Mul(x, y, z, dim); + } + int64_t i = 0; + for (; i + 3 < dim; i += 4) { + __m128 a = _mm_loadu_ps(x + i); + __m128 b = _mm_loadu_ps(y + i); + __m128 c = _mm_mul_ps(a, b); + _mm_storeu_ps(z + i, c); + } + if (i < dim) { + generic::FP32Mul(x + i, y + i, z + i, dim - i); + } +#else + return generic::FP32Mul(x, y, z, dim); +#endif +} + +void +FP32Div(const float* x, const float* y, float* z, uint64_t dim) { +#if defined(ENABLE_SSE) + if (dim < 4) { + return generic::FP32Div(x, y, z, dim); + } + int64_t i = 0; + for (; i + 3 < dim; i += 4) { + __m128 a = _mm_loadu_ps(x + i); + __m128 b = _mm_loadu_ps(y + i); + __m128 c = _mm_div_ps(a, b); + _mm_storeu_ps(z + i, c); + } + if (i < dim) { + generic::FP32Div(x + i, y + i, z + i, dim - i); + } +#else + return generic::FP32Div(x, y, z, dim); +#endif +} + +float +FP32ReduceAdd(const float* x, uint64_t dim) { +#if defined(ENABLE_SSE) + if (dim < 4) { + return generic::FP32ReduceAdd(x, dim); + } + __m128 sum = _mm_setzero_ps(); + int i = 0; + for (; i + 3 < dim; i += 4) { + __m128 a = _mm_loadu_ps(x + i); + sum = _mm_add_ps(sum, a); + } + sum = _mm_hadd_ps(sum, sum); + sum = _mm_hadd_ps(sum, sum); + float result = _mm_cvtss_f32(sum); + if (i < dim) { + result += generic::FP32ReduceAdd(x + i, dim - i); + } + return result; +#else + return generic::FP32ReduceAdd(x, dim); +#endif +} + float BF16ComputeIP(const uint8_t* query, const uint8_t* codes, uint64_t dim) { #if defined(ENABLE_SSE) From 8e190ddbf89c06829a9cfcb952564ce08088820b Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Wed, 28 May 2025 14:41:19 +0800 Subject: [PATCH 27/42] support build ivf index with residual (#754) Signed-off-by: jinjiabao.jjb Signed-off-by: suguan.dx --- src/algorithm/hgraph.cpp | 8 +++ src/algorithm/hgraph.h | 3 ++ src/algorithm/inner_index_interface.h | 6 +++ src/algorithm/ivf.h | 2 + src/data_cell/bucket_datacell.h | 50 ++++++++++++++++--- src/data_cell/bucket_datacell_parameter.cpp | 5 ++ src/data_cell/bucket_datacell_parameter.h | 2 + src/data_cell/bucket_interface.cpp | 9 +++- src/data_cell/bucket_interface.h | 5 +- src/quantization/fp32_quantizer.h | 1 + .../pq_fastscan_quantizer.h | 7 +-- .../product_quantization/product_quantizer.h | 7 +-- .../scalar_quantization/bf16_quantizer.h | 1 + .../scalar_quantization/fp16_quantizer.h | 1 + .../scalar_quantization/sq4_quantizer.h | 1 + .../scalar_quantization/sq8_quantizer.h | 7 +-- tests/test_ivf.cpp | 46 +++++++++++++++-- 17 files changed, 138 insertions(+), 23 deletions(-) diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 0e4d29fb2..d6a41410d 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -1279,4 +1279,12 @@ HGraph::ExportModel(const IndexCommonParam& param) const { } return index; } +void +HGraph::GetRawData(vsag::InnerIdType inner_id, uint8_t* data) const { + if (use_reorder_) { + high_precise_codes_->GetCodesById(inner_id, data); + } else { + basic_flatten_codes_->GetCodesById(inner_id, data); + } +} } // namespace vsag diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index 9574ababb..78b6f3166 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -138,6 +138,9 @@ class HGraph : public InnerIndexInterface { this->build_pool_->SetPoolSize(count); } + void + GetRawData(vsag::InnerIdType inner_id, uint8_t* data) const override; + private: const void* get_data(const DatasetPtr& dataset, uint32_t index = 0) const { diff --git a/src/algorithm/inner_index_interface.h b/src/algorithm/inner_index_interface.h index bcdef1a67..219956056 100644 --- a/src/algorithm/inner_index_interface.h +++ b/src/algorithm/inner_index_interface.h @@ -258,6 +258,12 @@ class InnerIndexInterface { return this->label_table_->CheckLabel(id); } + virtual void + GetRawData(InnerIdType inner_id, uint8_t* data) const { + throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, + "Index doesn't support GetRawData"); + } + public: LabelTablePtr label_table_{nullptr}; diff --git a/src/algorithm/ivf.h b/src/algorithm/ivf.h index 92cd6dd32..92ba184a1 100644 --- a/src/algorithm/ivf.h +++ b/src/algorithm/ivf.h @@ -122,6 +122,8 @@ class IVF : public InnerIndexInterface { bool is_trained_{false}; + bool use_residual_{false}; + FlattenInterfacePtr reorder_codes_{nullptr}; }; } // namespace vsag diff --git a/src/data_cell/bucket_datacell.h b/src/data_cell/bucket_datacell.h index 9b27397d7..77383f97f 100644 --- a/src/data_cell/bucket_datacell.h +++ b/src/data_cell/bucket_datacell.h @@ -29,7 +29,8 @@ class BucketDataCell : public BucketInterface { explicit BucketDataCell(const QuantizerParamPtr& quantization_param, const IOParamPtr& io_param, const IndexCommonParam& common_param, - BucketIdType bucket_count); + BucketIdType bucket_count, + bool use_residual = false); void ScanBucketById(float* result_dists, @@ -54,7 +55,10 @@ class BucketDataCell : public BucketInterface { Train(const void* data, uint64_t count) override; void - InsertVector(const void* vector, BucketIdType bucket_id, InnerIdType inner_id) override; + InsertVector(const void* vector, + BucketIdType bucket_id, + InnerIdType inner_id, + const float* centroid = nullptr) override; InnerIdType* GetInnerIds(BucketIdType bucket_id) override { @@ -130,7 +134,8 @@ class BucketDataCell : public BucketInterface { inline void insert_vector_with_locate(const float* vector, const BucketIdType& bucket_id, - const InnerIdType& offset_id); + const InnerIdType& offset_id, + const float* centroid); inline void package_fastscan(); @@ -150,13 +155,18 @@ class BucketDataCell : public BucketInterface { Vector> inner_ids_; Allocator* const allocator_{nullptr}; + + bool use_residual_{false}; + + Vector> residual_bias_; }; template BucketDataCell::BucketDataCell(const QuantizerParamPtr& quantization_param, const IOParamPtr& io_param, const IndexCommonParam& common_param, - BucketIdType bucket_count) + BucketIdType bucket_count, + bool use_residual) : BucketInterface(), datas_(common_param.allocator_.get()), bucket_sizes_(bucket_count, 0, common_param.allocator_.get()), @@ -164,7 +174,9 @@ BucketDataCell::BucketDataCell(const QuantizerParamPtr& quant Vector(common_param.allocator_.get()), common_param.allocator_.get()), bucket_mutexes_(bucket_count, common_param.allocator_.get()), - allocator_(common_param.allocator_.get()) { + allocator_(common_param.allocator_.get()), + residual_bias_(bucket_count, Vector(allocator_), allocator_), + use_residual_(use_residual) { this->bucket_count_ = bucket_count; this->quantizer_ = std::make_shared(quantization_param, common_param); this->code_size_ = quantizer_->GetCodeSize(); @@ -217,6 +229,9 @@ BucketDataCell::scan_bucket_by_id( data_count -= compute_count; offset += compute_count; } + if (use_residual_ && this->quantizer_->Metric() == MetricType::METRIC_TYPE_L2SQR) { + FP32Sub(result_dists, residual_bias_[bucket_id].data(), result_dists, offset); + } } template @@ -238,7 +253,8 @@ template void BucketDataCell::InsertVector(const void* vector, BucketIdType bucket_id, - InnerIdType inner_id) { + InnerIdType inner_id, + const float* centroid) { check_valid_bucket_id(bucket_id); InnerIdType locate; { @@ -246,21 +262,33 @@ BucketDataCell::InsertVector(const void* vector, locate = this->bucket_sizes_[bucket_id]; this->bucket_sizes_[bucket_id]++; inner_ids_[bucket_id].emplace_back(inner_id); + if (use_residual_ && this->quantizer_->Metric() == MetricType::METRIC_TYPE_L2SQR) { + residual_bias_[bucket_id].emplace_back(0.0F); + } } - this->insert_vector_with_locate(reinterpret_cast(vector), bucket_id, locate); + this->insert_vector_with_locate( + reinterpret_cast(vector), bucket_id, locate, centroid); } template void BucketDataCell::insert_vector_with_locate(const float* vector, const BucketIdType& bucket_id, - const InnerIdType& offset_id) { + const InnerIdType& offset_id, + const float* centroid) { ByteBuffer codes(static_cast(code_size_), this->allocator_); this->quantizer_->EncodeOne(vector, codes.data); this->datas_[bucket_id]->Write( codes.data, code_size_, static_cast(offset_id) * static_cast(code_size_)); + if (use_residual_ && this->quantizer_->Metric() == MetricType::METRIC_TYPE_L2SQR && centroid) { + Vector compress_vector(this->quantizer_->GetDim(), this->allocator_); + this->quantizer_->DecodeOne(codes.data, compress_vector.data()); + residual_bias_[bucket_id][offset_id] = + -2 * FP32ComputeIP(centroid, compress_vector.data(), this->quantizer_->GetDim()) - + FP32ComputeIP(centroid, centroid, this->quantizer_->GetDim()); + } } template @@ -271,6 +299,9 @@ BucketDataCell::Serialize(StreamWriter& writer) { for (BucketIdType i = 0; i < this->bucket_count_; ++i) { datas_[i]->Serialize(writer); StreamWriter::WriteVector(writer, inner_ids_[i]); + if (use_residual_) { + StreamWriter::WriteVector(writer, residual_bias_[i]); + } } StreamWriter::WriteVector(writer, this->bucket_sizes_); } @@ -283,6 +314,9 @@ BucketDataCell::Deserialize(StreamReader& reader) { for (BucketIdType i = 0; i < this->bucket_count_; ++i) { datas_[i]->Deserialize(reader); StreamReader::ReadVector(reader, inner_ids_[i]); + if (use_residual_) { + StreamReader::ReadVector(reader, residual_bias_[i]); + } } StreamReader::ReadVector(reader, this->bucket_sizes_); } diff --git a/src/data_cell/bucket_datacell_parameter.cpp b/src/data_cell/bucket_datacell_parameter.cpp index b5e210973..87582df15 100644 --- a/src/data_cell/bucket_datacell_parameter.cpp +++ b/src/data_cell/bucket_datacell_parameter.cpp @@ -37,12 +37,17 @@ BucketDataCellParameter::FromJson(const JsonType& json) { if (json.contains(BUCKETS_COUNT_KEY)) { this->buckets_count = json[BUCKETS_COUNT_KEY]; } + + if (json.contains(BUCKET_USE_RESIDUAL)) { + this->use_residual_ = json[BUCKET_USE_RESIDUAL]; + } } JsonType BucketDataCellParameter::ToJson() { JsonType json; json[IO_PARAMS_KEY] = this->io_parameter->ToJson(); + json[BUCKET_USE_RESIDUAL] = this->use_residual_; json[QUANTIZATION_PARAMS_KEY] = this->quantizer_parameter->ToJson(); json[BUCKETS_COUNT_KEY] = this->buckets_count; return json; diff --git a/src/data_cell/bucket_datacell_parameter.h b/src/data_cell/bucket_datacell_parameter.h index 4a2ba0e12..979f69649 100644 --- a/src/data_cell/bucket_datacell_parameter.h +++ b/src/data_cell/bucket_datacell_parameter.h @@ -36,6 +36,8 @@ class BucketDataCellParameter : public Parameter { IOParamPtr io_parameter{nullptr}; + bool use_residual_{false}; + int64_t buckets_count{1}; }; diff --git a/src/data_cell/bucket_interface.cpp b/src/data_cell/bucket_interface.cpp index e4b157698..c2af55c7b 100644 --- a/src/data_cell/bucket_interface.cpp +++ b/src/data_cell/bucket_interface.cpp @@ -30,7 +30,11 @@ make_instance(const BucketDataCellParamPtr& param, const IndexCommonParam& commo auto& quantizer_param = param->quantizer_parameter; return std::make_shared>( - quantizer_param, io_param, common_param, static_cast(param->buckets_count)); + quantizer_param, + io_param, + common_param, + static_cast(param->buckets_count), + param->use_residual_); } template @@ -78,6 +82,9 @@ make_instance(const BucketDataCellParamPtr& param, const IndexCommonParam& commo return make_instance(param, common_param); } if (metric == MetricType::METRIC_TYPE_COSINE) { + if (param->use_residual_) { + return make_instance(param, common_param); + } return make_instance(param, common_param); } return nullptr; diff --git a/src/data_cell/bucket_interface.h b/src/data_cell/bucket_interface.h index 8b6eb7063..e33ab31a8 100644 --- a/src/data_cell/bucket_interface.h +++ b/src/data_cell/bucket_interface.h @@ -53,7 +53,10 @@ class BucketInterface { Train(const void* data, uint64_t count) = 0; virtual void - InsertVector(const void* vector, BucketIdType bucket_id, InnerIdType inner_id) = 0; + InsertVector(const void* vector, + BucketIdType bucket_id, + InnerIdType inner_id, + const float* centroid = nullptr) = 0; virtual InnerIdType* GetInnerIds(BucketIdType bucket_id) = 0; diff --git a/src/quantization/fp32_quantizer.h b/src/quantization/fp32_quantizer.h index 9807c7729..a4373bfbc 100644 --- a/src/quantization/fp32_quantizer.h +++ b/src/quantization/fp32_quantizer.h @@ -100,6 +100,7 @@ template FP32Quantizer::FP32Quantizer(int dim, Allocator* allocator) : Quantizer>(dim, allocator) { this->code_size_ = dim * sizeof(float); + this->metric_ = metric; } template diff --git a/src/quantization/product_quantization/pq_fastscan_quantizer.h b/src/quantization/product_quantization/pq_fastscan_quantizer.h index a6b6d5f66..1aecf3aff 100644 --- a/src/quantization/product_quantization/pq_fastscan_quantizer.h +++ b/src/quantization/product_quantization/pq_fastscan_quantizer.h @@ -119,9 +119,9 @@ class PQFastScanQuantizer : public Quantizer> { Vector codebooks_; }; -template -PQFastScanQuantizer::PQFastScanQuantizer(int dim, int64_t pq_dim, Allocator* allocator) - : Quantizer>(dim, allocator), +template +PQFastScanQuantizer::PQFastScanQuantizer(int dim, int64_t pq_dim, Allocator* allocator) + : Quantizer>(dim, allocator), pq_dim_(pq_dim), codebooks_(allocator) { if (dim % pq_dim != 0) { @@ -131,6 +131,7 @@ PQFastScanQuantizer::PQFastScanQuantizer(int dim, int64_t pq_dim, Alloca } this->code_size_ = (this->pq_dim_ + 1) / 2; this->subspace_dim_ = this->dim_ / pq_dim; + this->metric_ = metric; codebooks_.resize(this->dim_ * CENTROIDS_PER_SUBSPACE); } diff --git a/src/quantization/product_quantization/product_quantizer.h b/src/quantization/product_quantization/product_quantizer.h index ff3441520..c20e339a7 100644 --- a/src/quantization/product_quantization/product_quantizer.h +++ b/src/quantization/product_quantization/product_quantizer.h @@ -112,9 +112,9 @@ class ProductQuantizer : public Quantizer> { Vector reverse_codebooks_; }; -template -ProductQuantizer::ProductQuantizer(int dim, int64_t pq_dim, Allocator* allocator) - : Quantizer>(dim, allocator), +template +ProductQuantizer::ProductQuantizer(int dim, int64_t pq_dim, Allocator* allocator) + : Quantizer>(dim, allocator), pq_dim_(pq_dim), codebooks_(allocator), reverse_codebooks_(allocator) { @@ -124,6 +124,7 @@ ProductQuantizer::ProductQuantizer(int dim, int64_t pq_dim, Allocator* a fmt::format("pq_dim({}) does not divide evenly into dim({})", pq_dim, dim)); } this->code_size_ = this->pq_dim_; + this->metric_ = metric; this->subspace_dim_ = this->dim_ / pq_dim; codebooks_.resize(this->dim_ * CENTROIDS_PER_SUBSPACE); reverse_codebooks_.resize(this->dim_ * CENTROIDS_PER_SUBSPACE); diff --git a/src/quantization/scalar_quantization/bf16_quantizer.h b/src/quantization/scalar_quantization/bf16_quantizer.h index 0bdc38a20..74eb3b087 100644 --- a/src/quantization/scalar_quantization/bf16_quantizer.h +++ b/src/quantization/scalar_quantization/bf16_quantizer.h @@ -89,6 +89,7 @@ template BF16Quantizer::BF16Quantizer(int dim, Allocator* allocator) : Quantizer>(dim, allocator) { this->code_size_ = dim * 2; + this->metric_ = metric; } template diff --git a/src/quantization/scalar_quantization/fp16_quantizer.h b/src/quantization/scalar_quantization/fp16_quantizer.h index a752e1610..3ce64a4bb 100644 --- a/src/quantization/scalar_quantization/fp16_quantizer.h +++ b/src/quantization/scalar_quantization/fp16_quantizer.h @@ -89,6 +89,7 @@ template FP16Quantizer::FP16Quantizer(int dim, Allocator* allocator) : Quantizer>(dim, allocator) { this->code_size_ = dim * 2; + this->metric_ = metric; } template diff --git a/src/quantization/scalar_quantization/sq4_quantizer.h b/src/quantization/scalar_quantization/sq4_quantizer.h index 1db5722b4..92ab9154c 100644 --- a/src/quantization/scalar_quantization/sq4_quantizer.h +++ b/src/quantization/scalar_quantization/sq4_quantizer.h @@ -92,6 +92,7 @@ template SQ4Quantizer::SQ4Quantizer(int dim, Allocator* allocator) : Quantizer>(dim, allocator) { this->code_size_ = (dim + (1 << 6) - 1) >> 6 << 6; + this->metric_ = metric; lower_bound_.resize(dim, std::numeric_limits::max()); diff_.resize(dim, std::numeric_limits::lowest()); } diff --git a/src/quantization/scalar_quantization/sq8_quantizer.h b/src/quantization/scalar_quantization/sq8_quantizer.h index ed109eb86..9d894d80a 100644 --- a/src/quantization/scalar_quantization/sq8_quantizer.h +++ b/src/quantization/scalar_quantization/sq8_quantizer.h @@ -91,11 +91,12 @@ class SQ8Quantizer : public Quantizer> { Vector lower_bound_; }; -template -SQ8Quantizer::SQ8Quantizer(int dim, Allocator* allocator) - : Quantizer>(dim, allocator), diff_(allocator), lower_bound_(allocator) { +template +SQ8Quantizer::SQ8Quantizer(int dim, Allocator* allocator) + : Quantizer>(dim, allocator), diff_(allocator), lower_bound_(allocator) { // align 64 bytes (512 bits) to avoid illegal memory access in SIMD this->code_size_ = this->dim_; + this->metric_ = metric; this->diff_.resize(dim, 0); this->lower_bound_.resize(dim, std::numeric_limits::max()); } diff --git a/tests/test_ivf.cpp b/tests/test_ivf.cpp index d4d14414e..276b9217b 100644 --- a/tests/test_ivf.cpp +++ b/tests/test_ivf.cpp @@ -25,7 +25,8 @@ class IVFTestIndex : public fixtures::TestIndex { int64_t dim, const std::string& quantization_str = "sq8", int buckets_count = 210, - const std::string& train_type = "kmeans"); + const std::string& train_type = "kmeans", + bool use_residual = false); static void TestGeneral(const IndexPtr& index, const TestDatasetPtr& dataset, @@ -70,7 +71,8 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, int64_t dim, const std::string& quantization_str, int buckets_count, - const std::string& train_type) { + const std::string& train_type, + bool use_residual) { std::string build_parameters_str; constexpr auto parameter_temp = R"( @@ -84,7 +86,8 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, "ivf_train_type": "{}", "use_reorder": {}, "base_pq_dim": {}, - "precise_quantization_type": "{}" + "precise_quantization_type": "{}", + "use_residual": {} }} }} )"; @@ -109,7 +112,8 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, train_type, use_reorder, pq_dim, - precise_quantizer_str); + precise_quantizer_str, + use_residual); INFO(build_parameters_str); return build_parameters_str; @@ -276,6 +280,40 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, "IVF Build & ContinueAdd Te } } +TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, "IVF Build with Residual", "[ft][ivf]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2", "ip", "cosine"); + std::string train_type = GENERATE("kmeans"); + + const std::string name = "ivf"; + auto search_param = fmt::format(search_param_tmp, 200); + std::vector> tmp_test_cases = { + {"fp32", 0.90}, + {"bf16", 0.88}, + {"fp16", 0.88}, + {"sq8", 0.84}, + // {"sq8_uniform", 0.83}, + // {"sq8_uniform,fp32", 0.89}, + {"pq,fp32", 0.82}, + {"pqfs,fp32", 0.82}, + }; + for (auto& dim : dims) { + for (auto& [base_quantization_str, recall] : tmp_test_cases) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateIVFBuildParametersString( + metric_type, dim, base_quantization_str, 300, train_type, true); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestContinueAdd(index, dataset, true); + if (index->CheckFeature(vsag::SUPPORT_ADD_AFTER_BUILD)) { + TestGeneral(index, dataset, search_param, recall); + } + vsag::Options::Instance().set_block_size_limit(origin_size); + } + } +} + TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, "IVF Build", "[ft][ivf]") { auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); From 9b7063fdb86eabb161f82abdcb9b194749e9495a Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Wed, 28 May 2025 15:46:29 +0800 Subject: [PATCH 28/42] support estimate memory in hnsw (#709) Signed-off-by: jinjiabao.jjb Signed-off-by: suguan.dx --- src/algorithm/hnswlib/algorithm_interface.h | 5 +++++ src/algorithm/hnswlib/hnswalg.cpp | 19 ++++++++++++++++++ src/algorithm/hnswlib/hnswalg.h | 3 +++ src/index/hnsw.cpp | 8 +++++++- src/index/hnsw.h | 8 ++++++++ tests/test_hnsw_new.cpp | 22 +++++++++++++++++++++ tests/test_index.cpp | 9 ++++++--- 7 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index 921039c9d..ee372c690 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -116,6 +116,11 @@ class AlgorithmInterface { virtual bool init_memory_space() = 0; + virtual uint64_t + estimateMemory(uint64_t num_elements) { + return 0; + } + virtual ~AlgorithmInterface() { } }; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index 96acd79b3..c179a3e01 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -133,6 +133,25 @@ HierarchicalNSW::init_memory_space() { return true; } +uint64_t +HierarchicalNSW::estimateMemory(uint64_t num_elements) { + size_t size = 0; + size += sizeof(unsigned short int) * num_elements; // visited_list_pool_ + size += sizeof(int) * num_elements; // element_levels_ + size += num_elements * size_data_per_element_; // data_level0_memory_ + if (use_reversed_edges_) { + size += sizeof(reverselinklist*) * num_elements; // reversed_level0_link_list_ + size += sizeof(vsag::UnorderedMap*) * + num_elements; // reversed_link_lists_ + } + if (normalize_) { + size += sizeof(float) * num_elements; // molds_ + } + size += sizeof(void*) * num_elements; // link_lists_ + size += sizeof(std::shared_mutex) * num_elements; // points_locks_ + return size; +} + HierarchicalNSW::~HierarchicalNSW() { if (link_lists_ != nullptr) { for (InnerIdType i = 0; i < max_elements_; i++) { diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index eeaf5651e..8da51b322 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -431,5 +431,8 @@ class HierarchicalNSW : public AlgorithmInterface { bool init_memory_space() override; + + uint64_t + estimateMemory(uint64_t num_elements) override; }; } // namespace hnswlib diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index d8490ee52..05dc52cf2 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -1104,7 +1104,8 @@ HNSW::init_feature_list() { // other feature_list_.SetFeatures({IndexFeature::SUPPORT_CAL_DISTANCE_BY_ID, IndexFeature::SUPPORT_CHECK_ID_EXIST, - IndexFeature::SUPPORT_MERGE_INDEX}); + IndexFeature::SUPPORT_MERGE_INDEX, + IndexFeature::SUPPORT_ESTIMATE_MEMORY}); } bool @@ -1286,6 +1287,11 @@ HNSW::get_memory_usage() const { return static_cast(alg_hnsw_->calcSerializeSize()); } +uint64_t +HNSW::estimate_memory(uint64_t num_elements) const { + return alg_hnsw_->estimateMemory(num_elements); +} + template tl::expected HNSW::knn_search_internal(const DatasetPtr& query, int64_t k, diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 2521b952d..0f24cb369 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -265,6 +265,11 @@ class HNSW : public Index { return this->get_memory_usage(); } + uint64_t + EstimateMemory(uint64_t num_elements) const override { + return this->estimate_memory(num_elements); + } + std::string GetStats() const override; @@ -391,6 +396,9 @@ class HNSW : public Index { void init_feature_list(); + uint64_t + estimate_memory(uint64_t num_elements) const; + private: std::shared_ptr> alg_hnsw_; std::shared_ptr space_; diff --git a/tests/test_hnsw_new.cpp b/tests/test_hnsw_new.cpp index f071cacad..5be7d6616 100644 --- a/tests/test_hnsw_new.cpp +++ b/tests/test_hnsw_new.cpp @@ -201,6 +201,28 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, } } +TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Estimate Memory", "[ft][hnsw]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2", "cosine"); + + const std::string name = "hnsw"; + auto search_param = fmt::format(search_param_tmp, 200, false); + uint64_t estimate_count = 1000; + for (auto dim : dims) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateHNSWBuildParametersString(metric_type, dim); + auto dataset = pool.GetDatasetAndCreate(dim, + estimate_count, + metric_type, + false /*with_path*/, + 0.8 /*valid_ratio*/, + 0 /*extro_info_size*/); + TestEstimateMemory(name, param, dataset); + vsag::Options::Instance().set_block_size_limit(origin_size); + } +} + TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Build & ContinueAdd Test", "[ft][hnsw]") { diff --git a/tests/test_index.cpp b/tests/test_index.cpp index c35dca826..6eb950f86 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -1236,26 +1236,29 @@ TestIndex::TestEstimateMemory(const std::string& index_name, REQUIRE(index2->GetNumElements() == 0); fixtures::TempDir dir("index"); auto path = dir.GenerateRandomFile(); - std::ofstream outf(path, std::ios::binary); if (index1->CheckFeature(vsag::SUPPORT_ESTIMATE_MEMORY)) { auto data_size = dataset->base_->GetNumElements(); auto estimate_memory = index1->EstimateMemory(data_size); auto build_index = index2->Build(dataset->base_); REQUIRE(build_index.has_value()); + std::ofstream outf(path, std::ios::binary); index2->Serialize(outf); + outf.close(); std::ifstream inf(path, std::ios::binary); index1->Deserialize(inf); auto real_memory = allocator->GetCurrentMemory(); if (estimate_memory <= static_cast(real_memory * 0.8) or estimate_memory >= static_cast(real_memory * 1.2)) { - WARN("estimate_memory failed"); + WARN(fmt::format("estimate_memory({}) is not in range [{}, {}]", + estimate_memory, + static_cast(real_memory * 0.8), + static_cast(real_memory * 1.2))); } REQUIRE(estimate_memory >= static_cast(real_memory * 0.2)); REQUIRE(estimate_memory <= static_cast(real_memory * 3.2)); inf.close(); } - outf.close(); } } From 046964010523592a6b8a940faaca2b0bc7c2e74f Mon Sep 17 00:00:00 2001 From: LHT129 Date: Thu, 29 May 2025 11:02:29 +0800 Subject: [PATCH 29/42] reduce test log level to avoid too much debug logs (#771) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- tests/test_main.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_main.cpp b/tests/test_main.cpp index bb6d664d8..c438ad036 100644 --- a/tests/test_main.cpp +++ b/tests/test_main.cpp @@ -23,6 +23,8 @@ main(int argc, char** argv) { // your setup ... vsag::Options::Instance().set_logger(&fixtures::logger::test_logger); + fixtures::logger::test_logger.SetLevel(vsag::Logger::Level::kWARN); + int result = Catch::Session().run(argc, argv); // your clean-up... From 13323123c3f835867ccd82ccbeec3c65fcca25f3 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Thu, 29 May 2025 14:44:14 +0800 Subject: [PATCH 30/42] speed up fast_bitset by use simd operators (#764) Signed-off-by: LHT129 Signed-off-by: suguan.dx --- mockimpl/CMakeLists.txt | 4 ++-- src/impl/bitset/fast_bitset.cpp | 31 ++++++++++++++++--------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/mockimpl/CMakeLists.txt b/mockimpl/CMakeLists.txt index 27998bd63..fcdf11581 100644 --- a/mockimpl/CMakeLists.txt +++ b/mockimpl/CMakeLists.txt @@ -13,8 +13,8 @@ set (MOCK_SRCS add_library (vsag_mockimpl SHARED ${MOCK_SRCS}) add_library (vsag_mockimpl_static STATIC ${MOCK_SRCS}) -target_link_libraries (vsag_mockimpl roaring fmt::fmt-header-only) -target_link_libraries (vsag_mockimpl_static roaring fmt::fmt-header-only) +target_link_libraries (vsag_mockimpl roaring fmt::fmt-header-only simd) +target_link_libraries (vsag_mockimpl_static roaring fmt::fmt-header-only simd) add_dependencies (vsag_mockimpl version_mockimpl roaring) add_dependencies (vsag_mockimpl_static version_mockimpl roaring) diff --git a/src/impl/bitset/fast_bitset.cpp b/src/impl/bitset/fast_bitset.cpp index 80bed2593..c50ef96e7 100644 --- a/src/impl/bitset/fast_bitset.cpp +++ b/src/impl/bitset/fast_bitset.cpp @@ -15,6 +15,7 @@ #include "fast_bitset.h" +#include "simd/bit_simd.h" #include "vsag_exception.h" namespace vsag { @@ -67,10 +68,10 @@ FastBitset::Or(const Bitset& another) { std::lock_guard lock2(fast_another->mutex_, std::adopt_lock); auto max_size = std::max(data_.size(), fast_another->data_.size()); data_.resize(max_size, 0); - // TODO(LHT): use SIMD - for (uint64_t i = 0; i < max_size; ++i) { - data_[i] |= fast_another->data_[i]; - } + BitOr(reinterpret_cast(data_.data()), + reinterpret_cast(fast_another->data_.data()), + max_size * sizeof(uint64_t), + reinterpret_cast(data_.data())); } void @@ -84,10 +85,10 @@ FastBitset::And(const Bitset& another) { std::lock_guard lock2(fast_another->mutex_, std::adopt_lock); auto max_size = std::max(data_.size(), fast_another->data_.size()); data_.resize(max_size, 0); - // TODO(LHT): use SIMD - for (uint64_t i = 0; i < max_size; ++i) { - data_[i] &= fast_another->data_[i]; - } + BitAnd(reinterpret_cast(data_.data()), + reinterpret_cast(fast_another->data_.data()), + max_size * sizeof(uint64_t), + reinterpret_cast(data_.data())); } void @@ -101,10 +102,10 @@ FastBitset::Xor(const Bitset& another) { std::lock_guard lock2(fast_another->mutex_, std::adopt_lock); auto max_size = std::max(data_.size(), fast_another->data_.size()); data_.resize(max_size, 0); - // TODO(LHT): use SIMD - for (uint64_t i = 0; i < max_size; ++i) { - data_[i] ^= fast_another->data_[i]; - } + BitXor(reinterpret_cast(data_.data()), + reinterpret_cast(fast_another->data_.data()), + max_size * sizeof(uint64_t), + reinterpret_cast(data_.data())); } std::string @@ -130,9 +131,9 @@ FastBitset::Dump() { void FastBitset::Not() { std::lock_guard lock(mutex_); - for (auto& word : data_) { - word = ~word; - } + BitNot(reinterpret_cast(data_.data()), + data_.size() * sizeof(uint64_t), + reinterpret_cast(data_.data())); } void From daefc1012c8d16503b767a6b34a55f1748949be4 Mon Sep 17 00:00:00 2001 From: "suguan.dx" Date: Tue, 3 Jun 2025 12:18:13 +0800 Subject: [PATCH 31/42] fix wrong id when buckets_per_data > 1 Signed-off-by: suguan.dx --- tests/test_ivf.cpp | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/tests/test_ivf.cpp b/tests/test_ivf.cpp index 276b9217b..760a4cc0d 100644 --- a/tests/test_ivf.cpp +++ b/tests/test_ivf.cpp @@ -26,7 +26,8 @@ class IVFTestIndex : public fixtures::TestIndex { const std::string& quantization_str = "sq8", int buckets_count = 210, const std::string& train_type = "kmeans", - bool use_residual = false); + bool use_residual = false, + int buckets_per_data = 1); static void TestGeneral(const IndexPtr& index, const TestDatasetPtr& dataset, @@ -72,7 +73,8 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, const std::string& quantization_str, int buckets_count, const std::string& train_type, - bool use_residual) { + bool use_residual, + int buckets_per_data) { std::string build_parameters_str; constexpr auto parameter_temp = R"( @@ -87,7 +89,8 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, "use_reorder": {}, "base_pq_dim": {}, "precise_quantization_type": "{}", - "use_residual": {} + "use_residual": {}, + "buckets_per_data": {} }} }} )"; @@ -113,7 +116,8 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, use_reorder, pq_dim, precise_quantizer_str, - use_residual); + use_residual, + buckets_per_data); INFO(build_parameters_str); return build_parameters_str; @@ -545,3 +549,29 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, "IVF Estimate Memory", "[ft } } } + +TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, + "IVF Build Multi Buckets Per Data", + "[ft][ivf]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2", "ip", "cosine"); + std::string train_type = GENERATE("random", "kmeans"); + + const std::string name = "ivf"; + auto search_param = fmt::format(search_param_tmp, 200); + for (auto& dim : dims) { + for (auto& [base_quantization_str, recall] : test_cases) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateIVFBuildParametersString( + metric_type, dim, base_quantization_str, 300, train_type, false, 2); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestBuildIndex(index, dataset, true); + if (index->CheckFeature(vsag::SUPPORT_BUILD)) { + TestGeneral(index, dataset, search_param, recall); + } + vsag::Options::Instance().set_block_size_limit(origin_size); + } + } +} \ No newline at end of file From c57317ffd61dc016158347d0d2ae485df9b03a67 Mon Sep 17 00:00:00 2001 From: "suguan.dx" Date: Tue, 3 Jun 2025 15:16:26 +0800 Subject: [PATCH 32/42] add gnoimi param tojson ut Signed-off-by: suguan.dx --- src/algorithm/ivf_partition/gno_imi_parameter_test.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/algorithm/ivf_partition/gno_imi_parameter_test.cpp b/src/algorithm/ivf_partition/gno_imi_parameter_test.cpp index 7d55789b8..ad508fdc3 100644 --- a/src/algorithm/ivf_partition/gno_imi_parameter_test.cpp +++ b/src/algorithm/ivf_partition/gno_imi_parameter_test.cpp @@ -29,4 +29,5 @@ TEST_CASE("GNO-IMI Parameters Test", "[ut][GNOIMIParameter]") { param->FromJson(param_json); REQUIRE(param->first_order_buckets_count == 200); REQUIRE(param->second_order_buckets_count == 50); + vsag::ParameterTest::TestToJson(param); } From 6df9cbd36973848a8072e3c844fde7e4512f33d3 Mon Sep 17 00:00:00 2001 From: "suguan.dx" Date: Wed, 4 Jun 2025 20:40:46 +0800 Subject: [PATCH 33/42] add gnoimi build ut Signed-off-by: suguan.dx --- tests/test_ivf.cpp | 129 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/tests/test_ivf.cpp b/tests/test_ivf.cpp index 760a4cc0d..89c34f40b 100644 --- a/tests/test_ivf.cpp +++ b/tests/test_ivf.cpp @@ -28,6 +28,17 @@ class IVFTestIndex : public fixtures::TestIndex { const std::string& train_type = "kmeans", bool use_residual = false, int buckets_per_data = 1); + + static std::string + GenerateGNOIMIBuildParametersString(const std::string& metric_type, + int64_t dim, + const std::string& quantization_str = "sq8", + int first_order_buckets_count = 15, + int second_order_buckets_count = 15, + const std::string& train_type = "kmeans", + bool use_residual = false, + int buckets_per_data = 1); + static void TestGeneral(const IndexPtr& index, const TestDatasetPtr& dataset, @@ -122,6 +133,67 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, INFO(build_parameters_str); return build_parameters_str; } + +std::string +IVFTestIndex::GenerateGNOIMIBuildParametersString(const std::string& metric_type, + int64_t dim, + const std::string& quantization_str, + int first_order_buckets_count, + int second_order_buckets_count, + const std::string& train_type, + bool use_residual, + int buckets_per_data) { + std::string build_parameters_str; + + constexpr auto parameter_temp = R"( + {{ + "dtype": "float32", + "metric_type": "{}", + "dim": {}, + "index_param": {{ + "first_order_buckets_count": {}, + "second_order_buckets_count": {}, + "base_quantization_type": "{}", + "ivf_train_type": "{}", + "use_reorder": {}, + "base_pq_dim": {}, + "precise_quantization_type": "{}", + "use_residual": {}, + "buckets_per_data": {}, + "partition_strategy_type": "gno_imi" + }} + }} + )"; + + auto strs = fixtures::SplitString(quantization_str, ','); + std::string basic_quantizer_str = strs[0]; + bool use_reorder = false; + std::string precise_quantizer_str = "fp32"; + auto pq_dim = dim; + if (dim % 2 == 0 && basic_quantizer_str == "pq") { + pq_dim = dim / 2; + } + if (strs.size() == 2) { + use_reorder = true; + precise_quantizer_str = strs[1]; + } + build_parameters_str = fmt::format(parameter_temp, + metric_type, + dim, + first_order_buckets_count, + second_order_buckets_count, + basic_quantizer_str, + train_type, + use_reorder, + pq_dim, + precise_quantizer_str, + use_residual, + buckets_per_data); + + INFO(build_parameters_str); + return build_parameters_str; +} + void IVFTestIndex::TestGeneral(const TestIndex::IndexPtr& index, const TestDatasetPtr& dataset, @@ -574,4 +646,61 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, vsag::Options::Instance().set_block_size_limit(origin_size); } } +} + +TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, "GNO-IMI Build", "[ft][ivf]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2"); + std::string train_type = GENERATE("kmeans"); + + const std::string name = "ivf"; + auto search_param = fmt::format(search_param_tmp, 250); + for (auto& dim : dims) { + for (auto& [base_quantization_str, recall] : test_cases) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateGNOIMIBuildParametersString( + metric_type, dim, base_quantization_str, 20, 20, train_type, false, 1); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestBuildIndex(index, dataset, true); + if (index->CheckFeature(vsag::SUPPORT_BUILD)) { + TestGeneral(index, dataset, search_param, recall); + } + vsag::Options::Instance().set_block_size_limit(origin_size); + } + } +} + +TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, "GNO-IMI Build with Residual", "[ft][ivf]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2"); + std::string train_type = GENERATE("kmeans"); + + std::vector> tmp_test_cases = { + {"fp32", 0.90}, + {"bf16", 0.88}, + {"fp16", 0.88}, + {"sq8", 0.84}, + {"pq,fp32", 0.82}, + {"pqfs,fp32", 0.82}, + }; + + const std::string name = "ivf"; + auto search_param = fmt::format(search_param_tmp, 300); + for (auto& dim : dims) { + for (auto& [base_quantization_str, recall] : tmp_test_cases) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateGNOIMIBuildParametersString( + metric_type, dim, base_quantization_str, 20, 20, train_type, true, 1); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestBuildIndex(index, dataset, true); + if (index->CheckFeature(vsag::SUPPORT_BUILD)) { + TestGeneral(index, dataset, search_param, recall); + } + vsag::Options::Instance().set_block_size_limit(origin_size); + } + } } \ No newline at end of file From c8e99c29c959e9aa63981d16bfa29109c90c1704 Mon Sep 17 00:00:00 2001 From: azl Date: Fri, 30 May 2025 14:59:28 +0800 Subject: [PATCH 34/42] compile openblas dynamic arch (#769) Signed-off-by: aozeliu.azl Co-authored-by: aozeliu.azl Signed-off-by: suguan.dx --- extern/openblas/openblas.cmake | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/extern/openblas/openblas.cmake b/extern/openblas/openblas.cmake index 2e9bc9275..3db31ed4c 100644 --- a/extern/openblas/openblas.cmake +++ b/extern/openblas/openblas.cmake @@ -3,12 +3,6 @@ set(name openblas) set(source_dir ${CMAKE_CURRENT_BINARY_DIR}/${name}/source) set(install_dir ${CMAKE_CURRENT_BINARY_DIR}/${name}/install) -if (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") - set (openblas_target "TARGET=GENERIC") -else () - set (openblas_target "") -endif () - ExternalProject_Add( ${name} URL https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.23/OpenBLAS-0.3.23.tar.gz @@ -28,9 +22,9 @@ ExternalProject_Add( OMP_NUM_THREADS=1 PATH=/usr/lib/ccache:$ENV{PATH} LD_LIBRARY_PATH=/opt/alibaba-cloud-compiler/lib64/:$ENV{LD_LIBRARY_PATH} - make ${openblas_target} USE_THREAD=0 USE_LOCKING=1 -j${NUM_BUILDING_JOBS} + make USE_THREAD=0 USE_LOCKING=1 DYNAMIC_ARCH=1 -j${NUM_BUILDING_JOBS} INSTALL_COMMAND - make ${openblas_target} PREFIX=${install_dir} install + make DYNAMIC_ARCH=1 PREFIX=${install_dir} install BUILD_IN_SOURCE 1 LOG_CONFIGURE TRUE LOG_BUILD TRUE From 07fc182df6a7a94064864576d6fbed38a95813d4 Mon Sep 17 00:00:00 2001 From: inabao <37021995+inabao@users.noreply.github.com> Date: Tue, 3 Jun 2025 14:45:11 +0800 Subject: [PATCH 35/42] support remove in graph datacell (#695) Signed-off-by: jinjiabao.jjb Signed-off-by: suguan.dx --- src/algorithm/hgraph.cpp | 13 +- src/algorithm/hgraph.h | 2 + src/data_cell/compressed_graph_datacell.cpp | 1 + .../compressed_graph_datacell_test.cpp | 2 +- src/data_cell/graph_datacell.h | 113 +++++++++++++++--- src/data_cell/graph_datacell_parameter.cpp | 9 ++ src/data_cell/graph_datacell_parameter.h | 3 + src/data_cell/graph_datacell_test.cpp | 48 ++++++-- src/data_cell/graph_interface.h | 6 + src/data_cell/graph_interface_test.cpp | 92 ++++++++++---- src/data_cell/graph_interface_test.h | 2 +- src/data_cell/sparse_graph_datacell.cpp | 98 +++++++++++++-- src/data_cell/sparse_graph_datacell.h | 15 ++- .../sparse_graph_datacell_parameter.h | 13 ++ src/data_cell/sparse_graph_datacell_test.cpp | 29 ++++- 15 files changed, 387 insertions(+), 59 deletions(-) diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index d6a41410d..ca182af20 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -56,6 +56,16 @@ HGraph::HGraph(const HGraphParameterPtr& hgraph_param, const vsag::IndexCommonPa this->bottom_graph_ = GraphInterface::MakeInstance(hgraph_param->bottom_graph_param, common_param); + auto graph_param = + std::dynamic_pointer_cast(hgraph_param->bottom_graph_param); + sparse_datacell_param_ = std::make_shared(); + sparse_datacell_param_->max_degree_ = hgraph_param->bottom_graph_param->max_degree_ / 2; + if (graph_param != nullptr) { + sparse_datacell_param_->remove_flag_bit_ = graph_param->remove_flag_bit_; + sparse_datacell_param_->support_delete_ = graph_param->support_remove_; + } else { + sparse_datacell_param_->support_delete_ = false; + } mult_ = 1 / log(1.0 * static_cast(this->bottom_graph_->MaximumDegree())); if (extra_info_size_ > 0) { @@ -500,8 +510,7 @@ HGraph::EstimateMemory(uint64_t num_elements) const { GraphInterfacePtr HGraph::generate_one_route_graph() { - return std::make_shared(this->allocator_, - bottom_graph_->MaximumDegree() / 2); + return std::make_shared(sparse_datacell_param_, this->allocator_); } template diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index 78b6f3166..6c1b62eac 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -25,6 +25,7 @@ #include "data_cell/extra_info_interface.h" #include "data_cell/flatten_interface.h" #include "data_cell/graph_interface.h" +#include "data_cell/sparse_graph_datacell_parameter.h" #include "default_thread_pool.h" #include "hgraph_parameter.h" #include "impl/basic_searcher.h" @@ -221,6 +222,7 @@ class HGraph : public InnerIndexInterface { FlattenInterfacePtr high_precise_codes_{nullptr}; Vector route_graphs_; GraphInterfacePtr bottom_graph_{nullptr}; + SparseGraphDatacellParamPtr sparse_datacell_param_{nullptr}; mutable bool use_reorder_{false}; bool use_elp_optimizer_{false}; diff --git a/src/data_cell/compressed_graph_datacell.cpp b/src/data_cell/compressed_graph_datacell.cpp index 3766e0fd0..4826fe1ad 100644 --- a/src/data_cell/compressed_graph_datacell.cpp +++ b/src/data_cell/compressed_graph_datacell.cpp @@ -43,6 +43,7 @@ CompressedGraphDataCell::InsertNeighborsById(InnerIdType id, Vector tmp(neighbor_ids.begin(), neighbor_ids.end(), allocator_); std::sort(tmp.begin(), tmp.end()); + // TODO(deming): reset the size of neighbor_sets_[id] while tmp is empty if (not tmp.empty()) { if (neighbor_sets_[id] == nullptr) { neighbor_sets_[id] = std::make_unique(allocator_); diff --git a/src/data_cell/compressed_graph_datacell_test.cpp b/src/data_cell/compressed_graph_datacell_test.cpp index 6559112b4..bb56b091f 100644 --- a/src/data_cell/compressed_graph_datacell_test.cpp +++ b/src/data_cell/compressed_graph_datacell_test.cpp @@ -34,7 +34,7 @@ TestCompressedGraphDataCell(const GraphInterfaceParamPtr& param, auto graph = GraphInterface::MakeInstance(param, common_param); GraphInterfaceTest test(graph, true); auto other = GraphInterface::MakeInstance(param, common_param); - test.BasicTest(max_id, count, other); + test.BasicTest(max_id, count, other, false); } TEST_CASE("CompressedGraphDataCell Basic Test", "[ut][CompressedGraphDataCell]") { diff --git a/src/data_cell/graph_datacell.h b/src/data_cell/graph_datacell.h index 35c82fbe1..d9f958687 100644 --- a/src/data_cell/graph_datacell.h +++ b/src/data_cell/graph_datacell.h @@ -52,6 +52,9 @@ class GraphDataCell : public GraphInterface { void InsertNeighborsById(InnerIdType id, const Vector& neighbor_ids) override; + void + DeleteNeighborsById(vsag::InnerIdType id) override; + [[nodiscard]] uint32_t GetNeighborSize(InnerIdType id) const override; @@ -91,16 +94,32 @@ class GraphDataCell : public GraphInterface { private: std::shared_ptr> io_{nullptr}; + Vector node_versions_; + + bool is_support_delete_{true}; + uint32_t remove_flag_bit_{8}; + uint32_t id_bit_{24}; + uint32_t remove_flag_mask_{0x00ffffff}; + uint32_t code_line_size_{0}; }; template GraphDataCell::GraphDataCell(const GraphDataCellParamPtr& param, - const IndexCommonParam& common_param) { + const IndexCommonParam& common_param) + : node_versions_(common_param.allocator_.get()) { this->io_ = std::make_shared(param->io_parameter_, common_param); this->maximum_degree_ = param->max_degree_; this->max_capacity_ = param->init_max_capacity_; + this->is_support_delete_ = param->support_remove_; + this->remove_flag_bit_ = param->remove_flag_bit_; + this->id_bit_ = sizeof(InnerIdType) * 8 - this->remove_flag_bit_; + this->remove_flag_mask_ = (1 << this->id_bit_) - 1; this->code_line_size_ = this->maximum_degree_ * sizeof(InnerIdType) + sizeof(uint32_t); + this->allocator_ = common_param.allocator_.get(); + if (this->is_support_delete_) { + node_versions_.resize(max_capacity_); + } } template @@ -122,21 +141,35 @@ GraphDataCell::InsertNeighborsById(InnerIdType id, while (current < id + 1 && !total_count_.compare_exchange_weak(current, id + 1)) { } auto start = static_cast(id) * static_cast(this->code_line_size_); - uint32_t neighbor_count = std::min((uint32_t)(neighbor_ids.size()), this->maximum_degree_); - this->io_->Write((uint8_t*)(&neighbor_count), sizeof(neighbor_count), start); - start += sizeof(neighbor_count); - this->io_->Write((uint8_t*)(neighbor_ids.data()), - static_cast(neighbor_count) * sizeof(InnerIdType), - start); + if (is_support_delete_) { + uint32_t neighbor_count = std::min((uint32_t)(neighbor_ids.size()), this->maximum_degree_); + this->io_->Write((uint8_t*)(&neighbor_count), sizeof(neighbor_count), start); + start += sizeof(neighbor_count); + Vector neighbor_ids_ptr(neighbor_ids.size(), 0, this->allocator_); + for (int i = 0; i < neighbor_ids.size(); ++i) { + auto neighbor_id = neighbor_ids[i]; + neighbor_ids_ptr[i] = neighbor_id | (node_versions_[neighbor_id] << id_bit_); + } + this->io_->Write((uint8_t*)(neighbor_ids_ptr.data()), + static_cast(neighbor_count) * sizeof(InnerIdType), + start); + } else { + uint32_t neighbor_count = std::min((uint32_t)(neighbor_ids.size()), this->maximum_degree_); + this->io_->Write((uint8_t*)(&neighbor_count), sizeof(neighbor_count), start); + start += sizeof(neighbor_count); + this->io_->Write((uint8_t*)(neighbor_ids.data()), + static_cast(neighbor_count) * sizeof(InnerIdType), + start); + } } template uint32_t GraphDataCell::GetNeighborSize(InnerIdType id) const { auto start = static_cast(id) * static_cast(this->code_line_size_); - uint32_t result = 0; - this->io_->Read(sizeof(result), start, (uint8_t*)(&result)); - return result; + uint32_t neighbor_count = 0; + this->io_->Read(sizeof(neighbor_count), start, (uint8_t*)(&neighbor_count)); + return neighbor_count; } template @@ -145,10 +178,27 @@ GraphDataCell::GetNeighbors(InnerIdType id, Vector& neighbo auto start = static_cast(id) * static_cast(this->code_line_size_); uint32_t neighbor_count = 0; this->io_->Read(sizeof(neighbor_count), start, (uint8_t*)(&neighbor_count)); - neighbor_ids.resize(neighbor_count); - start += sizeof(neighbor_count); - this->io_->Read( - neighbor_ids.size() * sizeof(InnerIdType), start, (uint8_t*)(neighbor_ids.data())); + if (is_support_delete_) { + neighbor_count &= remove_flag_mask_; + start += sizeof(neighbor_count); + Vector shared_neighbor_ids(neighbor_count, this->allocator_); + this->io_->Read( + neighbor_count * sizeof(InnerIdType), start, (uint8_t*)(shared_neighbor_ids.data())); + neighbor_ids.clear(); + neighbor_ids.reserve(neighbor_count); + for (int i = 0; i < neighbor_count; ++i) { + uint8_t neighbor_version = shared_neighbor_ids[i] >> id_bit_; + InnerIdType neighbor_id = shared_neighbor_ids[i] & remove_flag_mask_; + if (node_versions_[neighbor_id] == neighbor_version) { + neighbor_ids.push_back(neighbor_id); + } + } + } else { + start += sizeof(neighbor_count); + neighbor_ids.resize(neighbor_count); + this->io_->Read( + neighbor_ids.size() * sizeof(InnerIdType), start, (uint8_t*)(neighbor_ids.data())); + } } template @@ -157,6 +207,14 @@ GraphDataCell::Resize(InnerIdType new_size) { if (new_size < this->max_capacity_) { return; } + if (is_support_delete_) { + if (new_size > remove_flag_mask_) { + // remove_flag_mask_ exactly matches the maximum size of the graph in dynamic mode. + throw VsagException(ErrorType::INTERNAL_ERROR, + fmt::format("the size of graph is limit ({})", remove_flag_mask_)); + } + node_versions_.resize(new_size); + } this->max_capacity_ = new_size; uint64_t io_size = static_cast(new_size) * static_cast(code_line_size_); uint8_t end_flag = @@ -170,6 +228,9 @@ GraphDataCell::Serialize(StreamWriter& writer) { GraphInterface::Serialize(writer); this->io_->Serialize(writer); StreamWriter::WriteObj(writer, this->code_line_size_); + if (is_support_delete_) { + StreamWriter::WriteVector(writer, node_versions_); + } } template @@ -178,6 +239,30 @@ GraphDataCell::Deserialize(StreamReader& reader) { GraphInterface::Deserialize(reader); this->io_->Deserialize(reader); StreamReader::ReadObj(reader, this->code_line_size_); + if (is_support_delete_) { + StreamReader::ReadVector(reader, node_versions_); + } +} + +template +void +GraphDataCell::DeleteNeighborsById(vsag::InnerIdType id) { + if (is_support_delete_) { + if (id <= max_capacity_) { + if (node_versions_[id] + 1 == 0) { + throw VsagException( + ErrorType::INTERNAL_ERROR, + "remove point too many times in GraphDatacell, please rebuild index"); + } + } else { + throw VsagException(ErrorType::INTERNAL_ERROR, + fmt::format("remove point {} not exist in GraphDatacell", id)); + } + node_versions_[id]++; + } else { + throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, + "disable delete in graph datacell"); + } } } // namespace vsag diff --git a/src/data_cell/graph_datacell_parameter.cpp b/src/data_cell/graph_datacell_parameter.cpp index 1669d6764..ec14c230b 100644 --- a/src/data_cell/graph_datacell_parameter.cpp +++ b/src/data_cell/graph_datacell_parameter.cpp @@ -18,6 +18,7 @@ #include #include "inner_string_params.h" +#include "vsag/constants.h" namespace vsag { void @@ -31,6 +32,12 @@ GraphDataCellParameter::FromJson(const JsonType& json) { if (json.contains(GRAPH_PARAM_INIT_MAX_CAPACITY)) { this->init_max_capacity_ = json[GRAPH_PARAM_INIT_MAX_CAPACITY]; } + if (json.contains(GRAPH_SUPPORT_REMOVE)) { + this->support_remove_ = json[GRAPH_SUPPORT_REMOVE]; + } + if (json.contains(REMOVE_FLAG_BIT)) { + this->remove_flag_bit_ = json[REMOVE_FLAG_BIT]; + } } JsonType @@ -39,6 +46,8 @@ GraphDataCellParameter::ToJson() { json[IO_PARAMS_KEY] = this->io_parameter_->ToJson(); json[GRAPH_PARAM_MAX_DEGREE] = this->max_degree_; json[GRAPH_PARAM_INIT_MAX_CAPACITY] = this->init_max_capacity_; + json[GRAPH_SUPPORT_REMOVE] = this->support_remove_; + json[REMOVE_FLAG_BIT] = this->remove_flag_bit_; return json; } diff --git a/src/data_cell/graph_datacell_parameter.h b/src/data_cell/graph_datacell_parameter.h index e72463b6e..5c4cdc41c 100644 --- a/src/data_cell/graph_datacell_parameter.h +++ b/src/data_cell/graph_datacell_parameter.h @@ -34,6 +34,9 @@ class GraphDataCellParameter : public GraphInterfaceParameter { IOParamPtr io_parameter_{nullptr}; uint64_t init_max_capacity_{100}; + + bool support_remove_{false}; + uint32_t remove_flag_bit_{8}; }; using GraphDataCellParamPtr = std::shared_ptr; diff --git a/src/data_cell/graph_datacell_test.cpp b/src/data_cell/graph_datacell_test.cpp index c39e3eed1..80ac747a4 100644 --- a/src/data_cell/graph_datacell_test.cpp +++ b/src/data_cell/graph_datacell_test.cpp @@ -25,22 +25,25 @@ using namespace vsag; void -TestGraphDataCell(const GraphInterfaceParamPtr& param, const IndexCommonParam& common_param) { +TestGraphDataCell(const GraphInterfaceParamPtr& param, + const IndexCommonParam& common_param, + bool test_delete) { auto count = GENERATE(1000, 2000); auto max_id = 10000; auto graph = GraphInterface::MakeInstance(param, common_param); GraphInterfaceTest test(graph); auto other = GraphInterface::MakeInstance(param, common_param); - test.BasicTest(max_id, count, other); + test.BasicTest(max_id, count, other, test_delete); } TEST_CASE("GraphDataCell Basic Test", "[ut][GraphDataCell]") { auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto dim = GENERATE(32, 64); - auto max_degree = GENERATE(5, 32, 64, 128); - auto max_capacity = GENERATE(100, 10000); + auto max_degree = GENERATE(5, 32, 64); + auto max_capacity = GENERATE(100); auto io_type = GENERATE("memory_io", "block_memory_io"); + auto is_support_delete = GENERATE(true, false); constexpr const char* graph_param_temp = R"( {{ @@ -48,16 +51,47 @@ TEST_CASE("GraphDataCell Basic Test", "[ut][GraphDataCell]") { "type": "{}" }}, "max_degree": {}, - "init_capacity": {} + "init_capacity": {}, + "support_remove": {} }} )"; IndexCommonParam common_param; common_param.dim_ = dim; common_param.allocator_ = allocator; - auto param_str = fmt::format(graph_param_temp, io_type, max_degree, max_capacity); + auto param_str = + fmt::format(graph_param_temp, io_type, max_degree, max_capacity, is_support_delete); auto param_json = JsonType::parse(param_str); auto graph_param = GraphInterfaceParameter::GetGraphParameterByJson( GraphStorageTypes::GRAPH_STORAGE_TYPE_FLAT, param_json); - TestGraphDataCell(graph_param, common_param); + TestGraphDataCell(graph_param, common_param, is_support_delete); +} + +TEST_CASE("GraphDataCell Remove Test", "[ut][GraphDataCell]") { + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + auto dim = GENERATE(32, 64); + auto max_degree = GENERATE(5, 32); + auto io_type = GENERATE("block_memory_io"); + auto is_support_delete = GENERATE(true); + auto remove_flag_bit = GENERATE(4, 8); + constexpr const char* graph_param_temp = + R"( + {{ + "io_params": {{ + "type": "{}" + }}, + "max_degree": {}, + "support_remove": true, + "remove_flag_bit": {} + }} + )"; + + IndexCommonParam common_param; + common_param.dim_ = dim; + common_param.allocator_ = allocator; + auto param_str = fmt::format(graph_param_temp, io_type, max_degree, remove_flag_bit); + auto param_json = JsonType::parse(param_str); + auto graph_param = GraphInterfaceParameter::GetGraphParameterByJson( + GraphStorageTypes::GRAPH_STORAGE_TYPE_FLAT, param_json); + TestGraphDataCell(graph_param, common_param, is_support_delete); } diff --git a/src/data_cell/graph_interface.h b/src/data_cell/graph_interface.h index e854d1ba3..0029be2a6 100644 --- a/src/data_cell/graph_interface.h +++ b/src/data_cell/graph_interface.h @@ -46,6 +46,11 @@ class GraphInterface { virtual void InsertNeighborsById(InnerIdType id, const Vector& neighbor_ids) = 0; + virtual void + DeleteNeighborsById(InnerIdType id) { + throw VsagException(ErrorType::INTERNAL_ERROR, "DeleteNeighborsById is not implemented"); + } + virtual uint32_t GetNeighborSize(InnerIdType id) const = 0; @@ -114,6 +119,7 @@ class GraphInterface { protected: std::atomic total_count_{0}; + Allocator* allocator_{nullptr}; }; } // namespace vsag diff --git a/src/data_cell/graph_interface_test.cpp b/src/data_cell/graph_interface_test.cpp index 8e32d89fc..03d4987a4 100644 --- a/src/data_cell/graph_interface_test.cpp +++ b/src/data_cell/graph_interface_test.cpp @@ -25,31 +25,48 @@ using namespace vsag; void -GraphInterfaceTest::BasicTest(uint64_t max_id, uint64_t count, const GraphInterfacePtr& other) { +GraphInterfaceTest::BasicTest(uint64_t max_id, + uint64_t count, + const GraphInterfacePtr& other, + bool test_delete) { auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto max_degree = this->graph_->MaximumDegree(); this->graph_->Resize(max_id); + UnorderedMap>> maps(allocator.get()); + std::unordered_set unique_keys; + while (unique_keys.size() < count) { + InnerIdType new_key = random() % max_id; + unique_keys.insert(new_key); + } - auto generate_graph = [&]() { - UnorderedMap>> cur_map(allocator.get()); - for (auto i = 0; i < count; ++i) { - auto length = random() % max_degree + 1; - auto ids = std::make_shared>(length, allocator.get()); - for (auto& id : *ids) { - id = random() % max_id; - } - auto cur_id = random() % max_id; - cur_map[cur_id] = ids; + std::vector keys(unique_keys.begin(), unique_keys.end()); + for (auto key : keys) { + maps[key] = std::make_shared>(allocator.get()); + } + + std::random_device rd; + std::mt19937 rng(rd()); + + for (auto& pair : maps) { + auto& vec_ptr = pair.second; + int max_possible_length = keys.size(); + int length = random() % (max_degree - 1) + 2; + length = std::min(length, max_possible_length); + std::vector temp_keys = keys; + std::shuffle(temp_keys.begin(), temp_keys.end(), rng); + + vec_ptr->resize(length); + for (int i = 0; i < length; ++i) { + (*vec_ptr)[i] = temp_keys[i]; } - if (require_sorted_) { - for (auto& [key, value] : cur_map) { - std::sort(value->begin(), value->end()); - } + } + + if (require_sorted_) { + for (auto& [key, value] : maps) { + std::sort(value->begin(), value->end()); } - return cur_map; - }; + } - UnorderedMap>> maps = generate_graph(); for (auto& [key, value] : maps) { this->graph_->InsertNeighborsById(key, *value); } @@ -96,7 +113,6 @@ GraphInterfaceTest::BasicTest(uint64_t max_id, uint64_t count, const GraphInterf REQUIRE(this->graph_->TotalCount() == other->TotalCount()); REQUIRE(this->graph_->MaxCapacity() == other->MaxCapacity()); REQUIRE(this->graph_->MaximumDegree() == other->MaximumDegree()); - for (auto& [key, value] : maps) { Vector neighbors(allocator.get()); other->GetNeighbors(key, neighbors); @@ -107,8 +123,44 @@ GraphInterfaceTest::BasicTest(uint64_t max_id, uint64_t count, const GraphInterf infile.close(); } - maps = generate_graph(); + if (test_delete) { + SECTION("Delete") { + std::unordered_set keys_to_delete; + for (const auto& item : maps) { + if (keys_to_delete.size() > count / 2) { + Vector neighbors(allocator.get()); + this->graph_->GetNeighbors(item.first, neighbors); + for (const auto& neighbor_id : neighbors) { + REQUIRE(keys_to_delete.count(neighbor_id) == 0); + } + } else { + this->graph_->DeleteNeighborsById(item.first); + keys_to_delete.insert(item.first); + } + } + for (const auto& key : keys_to_delete) { + this->graph_->InsertNeighborsById(key, *maps[key]); + } + for (const auto& [key, value] : maps) { + if (keys_to_delete.find(key) == keys_to_delete.end()) { + Vector neighbors(allocator.get()); + this->graph_->GetNeighbors(key, neighbors); + for (const auto& neighbor_id : neighbors) { + REQUIRE(keys_to_delete.count(neighbor_id) == 0); + } + this->graph_->InsertNeighborsById(key, *value); + this->graph_->GetNeighbors(key, neighbors); + REQUIRE(neighbors.size() == value->size()); + REQUIRE(memcmp(neighbors.data(), + value->data(), + value->size() * sizeof(InnerIdType)) == 0); + } + } + } + } + for (auto& [key, value] : maps) { + value->resize(value->size() / 2); this->graph_->InsertNeighborsById(key, *value); } SECTION("Test Update Graph") { diff --git a/src/data_cell/graph_interface_test.h b/src/data_cell/graph_interface_test.h index aeb2069dc..bfbc26bab 100644 --- a/src/data_cell/graph_interface_test.h +++ b/src/data_cell/graph_interface_test.h @@ -28,7 +28,7 @@ class GraphInterfaceTest { } void - BasicTest(uint64_t max_id, uint64_t count, const GraphInterfacePtr& other); + BasicTest(uint64_t max_id, uint64_t count, const GraphInterfacePtr& other, bool test_delete); public: GraphInterfacePtr graph_{nullptr}; diff --git a/src/data_cell/sparse_graph_datacell.cpp b/src/data_cell/sparse_graph_datacell.cpp index aa59b753d..f10086efc 100644 --- a/src/data_cell/sparse_graph_datacell.cpp +++ b/src/data_cell/sparse_graph_datacell.cpp @@ -19,16 +19,27 @@ namespace vsag { -SparseGraphDataCell::SparseGraphDataCell(Allocator* allocator, uint32_t max_degree) - : allocator_(allocator), neighbors_(allocator_) { - this->maximum_degree_ = max_degree; +SparseGraphDataCell::SparseGraphDataCell(const SparseGraphDatacellParamPtr& graph_param, + Allocator* allocator) + : allocator_(allocator), + neighbors_(allocator), + node_version_(allocator), + is_support_delete_(graph_param->support_delete_) { + this->maximum_degree_ = graph_param->max_degree_; + this->remove_flag_bit_ = graph_param->remove_flag_bit_; + this->id_bit_ = sizeof(InnerIdType) * 8 - this->remove_flag_bit_; + this->remove_flag_mask_ = (1 << this->id_bit_) - 1; +} + +SparseGraphDataCell::SparseGraphDataCell(const SparseGraphDatacellParamPtr& graph_param, + const IndexCommonParam& common_param) + : SparseGraphDataCell(graph_param, common_param.allocator_.get()) { } SparseGraphDataCell::SparseGraphDataCell(const GraphInterfaceParamPtr& param, const IndexCommonParam& common_param) - : SparseGraphDataCell( - common_param.allocator_.get(), - std::dynamic_pointer_cast(param)->max_degree_) { + : SparseGraphDataCell(std::dynamic_pointer_cast(param), + common_param) { } void @@ -45,8 +56,26 @@ SparseGraphDataCell::InsertNeighborsById(InnerIdType id, const Vectorneighbors_.emplace(id, std::make_unique>(allocator_)).first; total_count_++; + if (is_support_delete_) { + node_version_[id] = 0; + } + } + if (is_support_delete_) { + iter->second->resize(size); + for (int i = 0; i < size; ++i) { +#if defined(_DEBUG) || defined(DEBUG) + if (neighbor_ids[i] >= node_version_.size()) { + throw VsagException(ErrorType::INTERNAL_ERROR, + "incorrect id {} >= node_version.size()", + neighbor_ids[i], + node_version_.size()); + } +#endif + iter->second->at(i) = (neighbor_ids[i] | (node_version_[neighbor_ids[i]] << id_bit_)); + } + } else { + iter->second->assign(neighbor_ids.begin(), neighbor_ids.begin() + size); } - iter->second->assign(neighbor_ids.begin(), neighbor_ids.begin() + size); } uint32_t @@ -63,7 +92,20 @@ SparseGraphDataCell::GetNeighbors(InnerIdType id, Vector& neighbor_ std::shared_lock rlock(this->neighbors_map_mutex_); auto iter = this->neighbors_.find(id); if (iter != this->neighbors_.end()) { - neighbor_ids.assign(iter->second->begin(), iter->second->end()); + const auto& ngbrs = iter->second; + if (is_support_delete_) { + neighbor_ids.clear(); + neighbor_ids.reserve(iter->second->size()); + for (unsigned int& neighbor_id : *ngbrs) { + uint8_t cur_version = neighbor_id >> id_bit_; + uint32_t real_id = neighbor_id & remove_flag_mask_; + if (node_version_.at(real_id) == cur_version) { + neighbor_ids.push_back(real_id); + } + } + } else { + neighbor_ids.assign(iter->second->begin(), iter->second->end()); + } } } void @@ -77,6 +119,12 @@ SparseGraphDataCell::Serialize(StreamWriter& writer) { StreamWriter::WriteObj(writer, key); StreamWriter::WriteVector(writer, *(pair.second)); } + if (is_support_delete_) { + for (const auto& item : this->node_version_) { + StreamWriter::WriteObj(writer, item.first); + StreamWriter::WriteObj(writer, item.second); + } + } } void @@ -92,7 +140,41 @@ SparseGraphDataCell::Deserialize(StreamReader& reader) { StreamReader::ReadVector(reader, *(this->neighbors_[key])); } this->total_count_ = size; + if (is_support_delete_) { + for (uint64_t i = 0; i < size; ++i) { + InnerIdType key; + StreamReader::ReadObj(reader, key); + uint8_t value; + StreamReader::ReadObj(reader, value); + this->node_version_[key] = value; + } + } } + void SparseGraphDataCell::Resize(InnerIdType new_size){}; + +void +SparseGraphDataCell::DeleteNeighborsById(vsag::InnerIdType id) { + if (is_support_delete_) { + std::unique_lock wlock(this->neighbors_map_mutex_); + auto iter = node_version_.find(id); + if (iter != node_version_.end()) { + if (iter->second + 1 == 0) { + throw VsagException( + ErrorType::INTERNAL_ERROR, + "remove point too many times in SparseGraphDatacell, please rebuild index"); + } + iter->second++; + } else { + throw VsagException( + ErrorType::INTERNAL_ERROR, + fmt::format("remove point {} not exist in SparseGraphDatacell", id)); + } + } else { + throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, + "disable delete in sparse graph datacell"); + } +} + } // namespace vsag diff --git a/src/data_cell/sparse_graph_datacell.h b/src/data_cell/sparse_graph_datacell.h index e5e1fea7b..7f7f648e4 100644 --- a/src/data_cell/sparse_graph_datacell.h +++ b/src/data_cell/sparse_graph_datacell.h @@ -19,6 +19,7 @@ #include "graph_interface.h" #include "io/memory_block_io.h" +#include "sparse_graph_datacell_parameter.h" namespace vsag { @@ -28,12 +29,16 @@ class SparseGraphDataCell : public GraphInterface { SparseGraphDataCell(const GraphInterfaceParamPtr& graph_param, const IndexCommonParam& common_param); - - explicit SparseGraphDataCell(Allocator* allocator = nullptr, uint32_t max_degree = 32); + SparseGraphDataCell(const SparseGraphDatacellParamPtr& graph_param, + const IndexCommonParam& common_param); + SparseGraphDataCell(const SparseGraphDatacellParamPtr& graph_param, Allocator* allocator); void InsertNeighborsById(InnerIdType id, const Vector& neighbor_ids) override; + void + DeleteNeighborsById(InnerIdType id) override; + uint32_t GetNeighborSize(InnerIdType id) const override; @@ -64,6 +69,12 @@ class SparseGraphDataCell : public GraphInterface { Allocator* const allocator_{nullptr}; UnorderedMap>> neighbors_; mutable std::shared_mutex neighbors_map_mutex_{}; + + bool is_support_delete_{true}; + uint32_t remove_flag_bit_{8}; + uint32_t id_bit_{24}; + uint32_t remove_flag_mask_{0x00ffffff}; + UnorderedMap node_version_; }; } // namespace vsag diff --git a/src/data_cell/sparse_graph_datacell_parameter.h b/src/data_cell/sparse_graph_datacell_parameter.h index 4d00d18bc..cdd5fe411 100644 --- a/src/data_cell/sparse_graph_datacell_parameter.h +++ b/src/data_cell/sparse_graph_datacell_parameter.h @@ -16,6 +16,7 @@ #pragma once #include "graph_interface_parameter.h" +#include "vsag/constants.h" namespace vsag { class SparseGraphDatacellParameter : public GraphInterfaceParameter { @@ -29,14 +30,26 @@ class SparseGraphDatacellParameter : public GraphInterfaceParameter { if (json.contains(GRAPH_PARAM_MAX_DEGREE)) { this->max_degree_ = json[GRAPH_PARAM_MAX_DEGREE]; } + if (json.contains(GRAPH_SUPPORT_REMOVE)) { + this->support_delete_ = json[GRAPH_SUPPORT_REMOVE]; + } + if (json.contains(REMOVE_FLAG_BIT)) { + this->remove_flag_bit_ = json[REMOVE_FLAG_BIT]; + } } JsonType ToJson() override { JsonType json; json[GRAPH_PARAM_MAX_DEGREE] = this->max_degree_; + json[GRAPH_SUPPORT_REMOVE] = support_delete_; + json[REMOVE_FLAG_BIT] = remove_flag_bit_; return json; } + +public: + bool support_delete_{false}; + uint32_t remove_flag_bit_{8}; }; using SparseGraphDatacellParamPtr = std::shared_ptr; diff --git a/src/data_cell/sparse_graph_datacell_test.cpp b/src/data_cell/sparse_graph_datacell_test.cpp index c60af2796..125e67d4f 100644 --- a/src/data_cell/sparse_graph_datacell_test.cpp +++ b/src/data_cell/sparse_graph_datacell_test.cpp @@ -26,25 +26,46 @@ using namespace vsag; void -TestSparseGraphDataCell(const GraphInterfaceParamPtr& param, const IndexCommonParam& common_param) { +TestSparseGraphDataCell(const GraphInterfaceParamPtr& param, + const IndexCommonParam& common_param, + bool test_delete) { auto count = GENERATE(1000, 2000); auto max_id = 10000; auto graph = GraphInterface::MakeInstance(param, common_param); GraphInterfaceTest test(graph); auto other = GraphInterface::MakeInstance(param, common_param); - test.BasicTest(max_id, count, other); + test.BasicTest(max_id, count, other, test_delete); } TEST_CASE("SparseGraphDataCell Basic Test", "[ut][SparseGraphDataCell]") { auto allocator = SafeAllocator::FactoryDefaultAllocator(); auto dim = GENERATE(32, 64); - auto max_degree = GENERATE(5, 12, 32, 64, 128); + auto max_degree = GENERATE(5, 32, 64); + auto is_support_delete = GENERATE(true, false); IndexCommonParam common_param; common_param.dim_ = dim; common_param.allocator_ = allocator; auto graph_param = std::make_shared(); graph_param->max_degree_ = max_degree; - TestSparseGraphDataCell(graph_param, common_param); + graph_param->support_delete_ = is_support_delete; + TestSparseGraphDataCell(graph_param, common_param, is_support_delete); } + +TEST_CASE("SparseGraphDataCell Remove Test", "[ut][SparseGraphDataCell]") { + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + auto dim = GENERATE(32, 64); + auto max_degree = GENERATE(5, 32); + auto is_support_delete = GENERATE(true); + auto remove_flag_bit = GENERATE(4, 8); + + IndexCommonParam common_param; + common_param.dim_ = dim; + common_param.allocator_ = allocator; + auto graph_param = std::make_shared(); + graph_param->max_degree_ = max_degree; + graph_param->support_delete_ = is_support_delete; + graph_param->remove_flag_bit_ = remove_flag_bit; + TestSparseGraphDataCell(graph_param, common_param, is_support_delete); +} \ No newline at end of file From 61b4a0fea5f7ba433b797f521cecbf0d186dbfd0 Mon Sep 17 00:00:00 2001 From: Carrot-77 <61344086+Carrot-77@users.noreply.github.com> Date: Wed, 4 Jun 2025 17:57:47 +0800 Subject: [PATCH 36/42] add search allocator (#716) Signed-off-by: zourunxin.zrx Co-authored-by: zourunxin.zrx Signed-off-by: suguan.dx --- examples/cpp/313_feature_search_allocator.cpp | 185 +++++++++++++++++ .../314_feature_hgraph_search_allocator.cpp | 194 ++++++++++++++++++ examples/cpp/CMakeLists.txt | 6 + include/vsag/index.h | 16 ++ include/vsag/search_param.h | 62 ++++++ src/algorithm/hgraph.cpp | 26 ++- src/algorithm/hgraph.h | 8 + src/algorithm/hnswlib/algorithm_interface.h | 1 + src/algorithm/hnswlib/hnswalg.cpp | 30 ++- src/algorithm/hnswlib/hnswalg.h | 2 + src/algorithm/hnswlib/hnswalg_static.h | 1 + src/algorithm/inner_index_interface.h | 15 ++ src/data_cell/flatten_datacell.h | 18 +- src/data_cell/flatten_interface.h | 3 +- src/data_cell/sparse_vector_datacell.h | 3 +- src/impl/basic_searcher.cpp | 40 ++-- src/index/hnsw.cpp | 7 +- src/index/hnsw.h | 21 +- src/index/index_impl.h | 21 +- src/index/iterator_filter.cpp | 2 +- src/index/iterator_filter.h | 2 +- tests/test_hgraph.cpp | 1 + tests/test_hnsw_new.cpp | 1 + tests/test_index.cpp | 89 ++++++++ tests/test_index.h | 7 + 25 files changed, 715 insertions(+), 46 deletions(-) create mode 100644 examples/cpp/313_feature_search_allocator.cpp create mode 100644 examples/cpp/314_feature_hgraph_search_allocator.cpp create mode 100644 include/vsag/search_param.h diff --git a/examples/cpp/313_feature_search_allocator.cpp b/examples/cpp/313_feature_search_allocator.cpp new file mode 100644 index 000000000..40c68b69c --- /dev/null +++ b/examples/cpp/313_feature_search_allocator.cpp @@ -0,0 +1,185 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "nlohmann/json.hpp" +#include "vsag/logger.h" +#include "vsag/search_param.h" +#include "vsag/vsag.h" + +class ExampleAllocator : public vsag::Allocator { +public: + std::string + Name() override { + return "example-allocator"; + } + + void* + Allocate(size_t size) override { + vsag::Options::Instance().logger()->Debug("allocate " + std::to_string(size) + " bytes."); + auto addr = (void*)malloc(size); + sizes_[addr] = size; + return addr; + } + + void + Deallocate(void* p) override { + if (sizes_.find(p) == sizes_.end()) + return; + vsag::Options::Instance().logger()->Debug("deallocate " + std::to_string(sizes_[p]) + + " bytes."); + sizes_.erase(p); + return free(p); + } + + void* + Reallocate(void* p, size_t size) override { + vsag::Options::Instance().logger()->Debug("reallocate " + std::to_string(size) + " bytes."); + auto addr = (void*)realloc(p, size); + sizes_.erase(p); + sizes_[addr] = size; + return addr; + } + +private: + std::unordered_map sizes_; +}; + +int +main() { + vsag::Options::Instance().logger()->SetLevel(vsag::Logger::kINFO); + + ExampleAllocator allocator; + vsag::Resource resource(&allocator, nullptr); + vsag::Engine engine(&resource); + + auto paramesters = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 4, + "hnsw": { + "max_degree": 5, + "ef_construction": 20 + } + } + )"; + std::cout << "create index" << std::endl; + auto index = engine.CreateIndex("hnsw", paramesters).value(); + + std::cout << "prepare data" << std::endl; + int64_t num_vectors = 100; + int64_t dim = 4; + + // prepare ids and vectors + std::vector ids(num_vectors); + std::vector vectors(num_vectors * dim); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + for (int64_t i = 0; i < num_vectors; ++i) { + ids[i] = i; + } + for (int64_t i = 0; i < dim * num_vectors; ++i) { + vectors[i] = distrib_real(rng); + } + auto base = vsag::Dataset::Make(); + base->NumElements(num_vectors) + ->Dim(dim) + ->Ids(ids.data()) + ->Float32Vectors(vectors.data()) + ->Owner(false); + index->Build(base); + + // search on the index + auto query_vector = new float[dim]; // memory will be released by query the dataset + for (int64_t i = 0; i < dim; ++i) { + query_vector[i] = distrib_real(rng); + } + + int64_t topk = 10; + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector)->Owner(true); + + /******************* HNSW Search *****************/ + { + nlohmann::json search_parameters = { + {"hnsw", {{"ef_search", 100}, {"skip_ratio", 0.7f}}}, + }; + int64_t topk = 10; + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector)->Owner(true); + + std::string param_str = search_parameters.dump(); + vsag::SearchParam search_param(false, param_str, nullptr, &allocator); + auto result = index->KnnSearch(query, topk, search_param).value(); + + // print the results + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + + allocator.Deallocate((void*)result->GetIds()); + allocator.Deallocate((void*)result->GetDistances()); + } + + /******************* HNSW Iterator Filter *****************/ + { + vsag::IteratorContext* iter_ctx = nullptr; + nlohmann::json search_parameters = { + {"hnsw", {{"ef_search", 100}, {"skip_ratio", 0.7f}}}, + }; + std::string param_str = search_parameters.dump(); + vsag::SearchParam search_param(true, param_str, nullptr, &allocator, iter_ctx, false); + + /* first search */ + { + auto result = index->KnnSearch(query, topk, search_param).value(); + + // print the results + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + + allocator.Deallocate((void*)result->GetIds()); + allocator.Deallocate((void*)result->GetDistances()); + } + + /* last search */ + { + search_param.is_last_search = true; + auto result = index->KnnSearch(query, topk, search_param).value(); + + // print the results + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + + allocator.Deallocate((void*)result->GetIds()); + allocator.Deallocate((void*)result->GetDistances()); + } + } + + std::cout << "delete index" << std::endl; + index = nullptr; + engine.Shutdown(); + + return 0; +} diff --git a/examples/cpp/314_feature_hgraph_search_allocator.cpp b/examples/cpp/314_feature_hgraph_search_allocator.cpp new file mode 100644 index 000000000..6bce54a46 --- /dev/null +++ b/examples/cpp/314_feature_hgraph_search_allocator.cpp @@ -0,0 +1,194 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "nlohmann/json.hpp" +#include "vsag/logger.h" +#include "vsag/search_param.h" +#include "vsag/vsag.h" + +class ExampleAllocator : public vsag::Allocator { +public: + std::string + Name() override { + return "example-allocator"; + } + + void* + Allocate(size_t size) override { + vsag::Options::Instance().logger()->Debug("allocate " + std::to_string(size) + " bytes."); + auto addr = (void*)malloc(size); + sizes_[addr] = size; + return addr; + } + + void + Deallocate(void* p) override { + if (sizes_.find(p) == sizes_.end()) + return; + vsag::Options::Instance().logger()->Debug("deallocate " + std::to_string(sizes_[p]) + + " bytes."); + sizes_.erase(p); + return free(p); + } + + void* + Reallocate(void* p, size_t size) override { + vsag::Options::Instance().logger()->Debug("reallocate " + std::to_string(size) + " bytes."); + auto addr = (void*)realloc(p, size); + sizes_.erase(p); + sizes_[addr] = size; + return addr; + } + +private: + std::unordered_map sizes_; +}; + +int +main() { + vsag::init(); + std::cout << "hgraph index example" << std::endl; + + /******************* Prepare Base Dataset *****************/ + int64_t num_vectors = 1000; + int64_t dim = 128; + std::vector ids(num_vectors); + std::vector datas(num_vectors * dim); + std::mt19937 rng(47); + std::uniform_real_distribution distrib_real; + for (int64_t i = 0; i < num_vectors; ++i) { + ids[i] = i; + } + for (int64_t i = 0; i < dim * num_vectors; ++i) { + datas[i] = distrib_real(rng); + } + auto base = vsag::Dataset::Make(); + base->NumElements(num_vectors) + ->Dim(dim) + ->Ids(ids.data()) + ->Float32Vectors(datas.data()) + ->Owner(false); + + /******************* Create HGraph Index *****************/ + std::string hgraph_build_parameters = R"( + { + "dtype": "float32", + "metric_type": "l2", + "dim": 128, + "index_param": { + "base_quantization_type": "sq8", + "max_degree": 26, + "ef_construction": 100 + } + } + )"; + vsag::Resource resource(vsag::Engine::CreateDefaultAllocator(), nullptr); + vsag::Engine engine(&resource); + std::cout << "create index" << std::endl; + auto index = engine.CreateIndex("hgraph", hgraph_build_parameters).value(); + + ExampleAllocator allocator; + + /******************* Build HGraph Index *****************/ + if (auto build_result = index->Build(base); build_result.has_value()) { + std::cout << "After Build(), Index HGraph contains: " << index->GetNumElements() + << std::endl; + } else if (build_result.error().type == vsag::ErrorType::INTERNAL_ERROR) { + std::cerr << "Failed to build index: internalError" << std::endl; + exit(-1); + } + + /******************* Prepare Query Dataset *****************/ + std::cout << "prepare index" << std::endl; + std::vector query_vector(dim); + for (int64_t i = 0; i < dim; ++i) { + query_vector[i] = distrib_real(rng); + } + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector.data())->Owner(false); + + /******************* KnnSearch For HGraph Index *****************/ + auto hgraph_search_parameters = R"( + { + "hgraph": { + "ef_search": 100 + } + } + )"; + int64_t topk = 10; + + /******************* Hgraph sq8 Search *****************/ + { + nlohmann::json search_parameters = { + {"hgraph", {{"ef_search", 100}, {"skip_ratio", 0.7f}}}, + }; + std::string param_str = search_parameters.dump(); + vsag::SearchParam search_param(false, param_str, nullptr, &allocator); + auto result = index->KnnSearch(query, topk, search_param).value(); + + // print the results + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + + allocator.Deallocate((void*)result->GetIds()); + allocator.Deallocate((void*)result->GetDistances()); + } + + /******************* Hgraph sq8 Iterator Filter *****************/ + { + vsag::IteratorContext* iter_ctx = nullptr; + nlohmann::json search_parameters = { + {"hgraph", {{"ef_search", 100}, {"skip_ratio", 0.7f}}}, + }; + std::string param_str = search_parameters.dump(); + vsag::SearchParam search_param(true, param_str, nullptr, &allocator, iter_ctx, false); + + /* first search */ + { + auto result = index->KnnSearch(query, topk, search_param).value(); + + // print the results + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + + allocator.Deallocate((void*)result->GetIds()); + allocator.Deallocate((void*)result->GetDistances()); + } + + /* last search */ + { + search_param.is_last_search = true; + auto result = index->KnnSearch(query, topk, search_param).value(); + + // print the results + std::cout << "results: " << std::endl; + for (int64_t i = 0; i < result->GetDim(); ++i) { + std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl; + } + + allocator.Deallocate((void*)result->GetIds()); + allocator.Deallocate((void*)result->GetDistances()); + } + } + + engine.Shutdown(); + return 0; +} diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index bbeaaa8bc..39463c07a 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -64,6 +64,12 @@ target_link_libraries(310_feature_export_model vsag) add_executable(311_feature_train 311_feature_train.cpp) target_link_libraries(311_feature_train vsag) +add_executable(313_feature_search_allocator 313_feature_search_allocator.cpp) +target_link_libraries(313_feature_search_allocator vsag) + +add_executable(314_feature_hgraph_search_allocator 314_feature_hgraph_search_allocator.cpp) +target_link_libraries(314_feature_hgraph_search_allocator vsag) + add_executable (401_persistent_kv 401_persistent_kv.cpp) target_link_libraries (401_persistent_kv vsag) diff --git a/include/vsag/index.h b/include/vsag/index.h index 35af316d9..b61183ac7 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -32,6 +32,7 @@ #include "vsag/index_features.h" #include "vsag/iterator_context.h" #include "vsag/readerset.h" +#include "vsag/search_param.h" #include "vsag/search_request.h" namespace vsag { @@ -218,6 +219,21 @@ class Index { throw std::runtime_error("Index doesn't support new filter"); } + /** + * @brief Performing single KNN search on index + * + * @param query should contains dim, num_elements and vectors + * @param k the result size of every query + * @param search_param search param contains filter, iter_ctx and allocator + * @return result contains + * - num_elements: 1 + * - ids, distances: length is (num_elements * k) + */ + virtual tl::expected + KnnSearch(const DatasetPtr& query, int64_t k, SearchParam& search_param) const { + throw std::runtime_error("Index doesn't support new filter"); + } + /** * @brief Performing single range search on index * diff --git a/include/vsag/search_param.h b/include/vsag/search_param.h new file mode 100644 index 000000000..9bd585866 --- /dev/null +++ b/include/vsag/search_param.h @@ -0,0 +1,62 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +namespace vsag { + +struct SearchParam { +public: + SearchParam(bool iter_filter_flag, + const std::string& parameter, + FilterPtr flt, + Allocator* alloc) + : is_iter_filter(iter_filter_flag), + is_last_search(false), + parameters(parameter), + filter(flt), + allocator(alloc) { + } + + SearchParam(bool iter_filter_flag, + const std::string& parameter, + FilterPtr flt, + Allocator* alloc, + IteratorContext* ctx, + bool last_search_flag) + : is_iter_filter(iter_filter_flag), + is_last_search(last_search_flag), + parameters(parameter), + filter(flt), + allocator(alloc), + iter_ctx(ctx) { + } + +public: + bool is_iter_filter{false}; + bool is_last_search{false}; + const std::string& parameters; + FilterPtr filter{nullptr}; + Allocator* allocator{nullptr}; + IteratorContext* iter_ctx{nullptr}; +}; + +}; // namespace vsag \ No newline at end of file diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index ca182af20..0e00ab4a3 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -266,7 +266,17 @@ HGraph::KnnSearch(const DatasetPtr& query, int64_t k, const std::string& parameters, const FilterPtr& filter) const { + return KnnSearch(query, k, parameters, filter, nullptr); +} + +DatasetPtr +HGraph::KnnSearch(const DatasetPtr& query, + int64_t k, + const std::string& parameters, + const FilterPtr& filter, + Allocator* allocator) const { int64_t query_dim = query->GetDim(); + Allocator* search_allocator = allocator == nullptr ? allocator_ : allocator; if (data_type_ != DataTypes::DATA_TYPE_SPARSE) { CHECK_ARGUMENT( query_dim == dim_, @@ -284,6 +294,7 @@ HGraph::KnnSearch(const DatasetPtr& query, search_param.topk = 1; search_param.ef = 1; search_param.is_inner_id_allowed = nullptr; + search_param.search_alloc = search_allocator; const auto* raw_query = get_data(query); for (auto i = static_cast(this->route_graphs_.size() - 1); i >= 0; --i) { auto result = this->search_one_graph( @@ -326,10 +337,10 @@ HGraph::KnnSearch(const DatasetPtr& query, return DatasetImpl::MakeEmptyDataset(); } auto count = static_cast(search_result->Size()); - auto [dataset_results, dists, ids] = CreateFastDataset(count, allocator_); + auto [dataset_results, dists, ids] = CreateFastDataset(count, search_allocator); char* extra_infos = nullptr; if (extra_info_size_ > 0) { - extra_infos = (char*)allocator_->Allocate(extra_info_size_ * search_result->Size()); + extra_infos = (char*)search_allocator->Allocate(extra_info_size_ * search_result->Size()); dataset_results->ExtraInfos(extra_infos); } for (int64_t j = count - 1; j >= 0; --j) { @@ -349,8 +360,10 @@ HGraph::KnnSearch(const DatasetPtr& query, int64_t k, const std::string& parameters, const FilterPtr& filter, + Allocator* allocator, IteratorContext*& iter_ctx, bool is_last_filter) const { + Allocator* search_allocator = allocator == nullptr ? allocator_ : allocator; if (GetNumElements() == 0) { return DatasetImpl::MakeEmptyDataset(); } @@ -381,7 +394,7 @@ HGraph::KnnSearch(const DatasetPtr& query, if (iter_ctx == nullptr) { auto cur_count = this->bottom_graph_->TotalCount(); auto* new_ctx = new IteratorFilterContext(); - if (auto ret = new_ctx->init(cur_count, params.ef_search, allocator_); + if (auto ret = new_ctx->init(cur_count, params.ef_search, search_allocator); not ret.has_value()) { throw vsag::VsagException(ErrorType::INTERNAL_ERROR, "failed to init IteratorFilterContext"); @@ -390,7 +403,7 @@ HGraph::KnnSearch(const DatasetPtr& query, } auto* iter_filter_ctx = static_cast(iter_ctx); - DistHeapPtr search_result = std::make_shared>(allocator_, -1); + DistHeapPtr search_result = std::make_shared>(search_allocator, -1); const auto* query_data = get_data(query); if (is_last_filter) { while (!iter_filter_ctx->Empty()) { @@ -405,6 +418,7 @@ HGraph::KnnSearch(const DatasetPtr& query, search_param.topk = 1; search_param.ef = 1; search_param.is_inner_id_allowed = nullptr; + search_param.search_alloc = search_allocator; if (iter_filter_ctx->IsFirstUsed()) { for (auto i = static_cast(this->route_graphs_.size() - 1); i >= 0; --i) { auto result = this->search_one_graph( @@ -438,10 +452,10 @@ HGraph::KnnSearch(const DatasetPtr& query, return DatasetImpl::MakeEmptyDataset(); } auto count = static_cast(search_result->Size()); - auto [dataset_results, dists, ids] = CreateFastDataset(count, allocator_); + auto [dataset_results, dists, ids] = CreateFastDataset(count, search_allocator); char* extra_infos = nullptr; if (extra_info_size_ > 0) { - extra_infos = (char*)allocator_->Allocate(extra_info_size_ * search_result->Size()); + extra_infos = (char*)search_allocator->Allocate(extra_info_size_ * search_result->Size()); dataset_results->ExtraInfos(extra_infos); } for (int64_t j = count - 1; j >= 0; --j) { diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index 6c1b62eac..8fc3f66a1 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -88,6 +88,14 @@ class HGraph : public InnerIndexInterface { int64_t k, const std::string& parameters, const FilterPtr& filter, + Allocator* allocator) const override; + + [[nodiscard]] DatasetPtr + KnnSearch(const DatasetPtr& query, + int64_t k, + const std::string& parameters, + const FilterPtr& filter, + Allocator* allocator, IteratorContext*& iter_ctx, bool is_last_filter) const override; diff --git a/src/algorithm/hnswlib/algorithm_interface.h b/src/algorithm/hnswlib/algorithm_interface.h index ee372c690..a63fbb68b 100644 --- a/src/algorithm/hnswlib/algorithm_interface.h +++ b/src/algorithm/hnswlib/algorithm_interface.h @@ -47,6 +47,7 @@ class AlgorithmInterface { size_t ef, const vsag::FilterPtr is_id_allowed = nullptr, float skip_ratio = 0.9f, + vsag::Allocator* allocator = nullptr, vsag::IteratorFilterContext* iter_ctx = nullptr, bool is_last_filter = false) const = 0; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index c179a3e01..959ee9f14 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -458,14 +458,16 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id, size_t ef, const vsag::FilterPtr is_id_allowed, const float skip_ratio, + vsag::Allocator* allocator, vsag::IteratorFilterContext* iter_ctx) const { vsag::LinearCongruentialGenerator generator; VisitedListPtr vl = visited_list_pool_->getFreeVisitedList(); vl_type* visited_array = vl->mass; vl_type visited_array_tag = vl->curV; + vsag::Allocator* search_allocator = allocator == nullptr ? allocator_ : allocator; - MaxHeap top_candidates(allocator_); - MaxHeap candidate_set(allocator_); + MaxHeap top_candidates(search_allocator); + MaxHeap candidate_set(search_allocator); float valid_ratio = is_id_allowed ? is_id_allowed->ValidRatio() : 1.0F; float skip_threshold = valid_ratio == 1.0F ? 0 : (1 - ((1 - valid_ratio) * skip_ratio)); @@ -1567,6 +1569,7 @@ HierarchicalNSW::searchKnn(const void* query_data, uint64_t ef, const vsag::FilterPtr is_id_allowed, const float skip_ratio, + vsag::Allocator* allocator, vsag::IteratorFilterContext* iter_ctx, bool is_last_filter) const { std::shared_lock resize_lock(resize_mutex_); @@ -1574,9 +1577,10 @@ HierarchicalNSW::searchKnn(const void* query_data, if (cur_element_count_ == 0) return result; + vsag::Allocator* search_allocator = allocator == nullptr ? allocator_ : allocator; std::shared_ptr normalize_query; normalizeVector(query_data, normalize_query); - MaxHeap top_candidates(allocator_); + MaxHeap top_candidates(search_allocator); if (iter_ctx != nullptr && !iter_ctx->IsFirstUsed()) { if (iter_ctx->Empty()) return result; @@ -1594,6 +1598,7 @@ HierarchicalNSW::searchKnn(const void* query_data, std::max(ef, k), is_id_allowed, skip_ratio, + allocator, iter_ctx); } else { int64_t currObj; @@ -1634,11 +1639,21 @@ HierarchicalNSW::searchKnn(const void* query_data, } if (num_deleted_ == 0) { - top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef, k), is_id_allowed, skip_ratio, iter_ctx); + top_candidates = searchBaseLayerST(currObj, + query_data, + std::max(ef, k), + is_id_allowed, + skip_ratio, + allocator, + iter_ctx); } else { - top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef, k), is_id_allowed, skip_ratio, iter_ctx); + top_candidates = searchBaseLayerST(currObj, + query_data, + std::max(ef, k), + is_id_allowed, + skip_ratio, + allocator, + iter_ctx); } } @@ -1759,6 +1774,7 @@ HierarchicalNSW::searchBaseLayerST( size_t ef, const vsag::FilterPtr is_id_allowed, const float skip_ratio, + vsag::Allocator* allocator, vsag::IteratorFilterContext* iter_ctx = nullptr) const; template MaxHeap diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index 8da51b322..aded29ff8 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -261,6 +261,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_t ef, const vsag::FilterPtr is_id_allowed = nullptr, const float skip_ratio = 0.9f, + vsag::Allocator* allocator = nullptr, vsag::IteratorFilterContext* iter_ctx = nullptr) const; template @@ -417,6 +418,7 @@ class HierarchicalNSW : public AlgorithmInterface { uint64_t ef, const vsag::FilterPtr is_id_allowed = nullptr, const float skip_ratio = 0.9f, + vsag::Allocator* allocator = nullptr, vsag::IteratorFilterContext* iter_ctx = nullptr, bool is_last_filter = false) const override; diff --git a/src/algorithm/hnswlib/hnswalg_static.h b/src/algorithm/hnswlib/hnswalg_static.h index 20b699214..3eeca710d 100644 --- a/src/algorithm/hnswlib/hnswalg_static.h +++ b/src/algorithm/hnswlib/hnswalg_static.h @@ -1472,6 +1472,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface { uint64_t ef, const vsag::FilterPtr is_id_allowed = nullptr, const float skip_ratio = 0.9f, + vsag::Allocator* allocator = nullptr, vsag::IteratorFilterContext* iter_ctx = nullptr, bool is_last_filter = false) const override { std::priority_queue> result; diff --git a/src/algorithm/inner_index_interface.h b/src/algorithm/inner_index_interface.h index 219956056..396bd299c 100644 --- a/src/algorithm/inner_index_interface.h +++ b/src/algorithm/inner_index_interface.h @@ -101,12 +101,27 @@ class InnerIndexInterface { int64_t k, const std::string& parameters, const FilterPtr& filter, + Allocator* allocator) const { + throw std::runtime_error("Index doesn't support new filter"); + }; + + [[nodiscard]] virtual DatasetPtr + KnnSearch(const DatasetPtr& query, + int64_t k, + const std::string& parameters, + const FilterPtr& filter, + Allocator* allocator, IteratorContext*& iter_ctx, bool is_last_filter) const { throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, "Index doesn't support new filter"); }; + [[nodiscard]] virtual DatasetPtr + KnnSearch(const DatasetPtr& query, int64_t k, SearchParam& search_param) const { + throw std::runtime_error("Index doesn't support new filter"); + } + [[nodiscard]] virtual DatasetPtr RangeSearch(const DatasetPtr& query, float radius, diff --git a/src/data_cell/flatten_datacell.h b/src/data_cell/flatten_datacell.h index 700201242..46a7178e5 100644 --- a/src/data_cell/flatten_datacell.h +++ b/src/data_cell/flatten_datacell.h @@ -43,9 +43,10 @@ class FlattenDataCell : public FlattenInterface { Query(float* result_dists, const ComputerInterfacePtr& computer, const InnerIdType* idx, - InnerIdType id_count) override { + InnerIdType id_count, + Allocator* allocator = nullptr) override { auto comp = std::static_pointer_cast>(computer); - this->query(result_dists, comp, idx, id_count); + this->query(result_dists, comp, idx, id_count, allocator); } ComputerInterfacePtr @@ -153,7 +154,8 @@ class FlattenDataCell : public FlattenInterface { query(float* result_dists, const std::shared_ptr>& computer, const InnerIdType* idx, - InnerIdType id_count); + InnerIdType id_count, + Allocator* allocator); ComputerInterfacePtr factory_computer(const float* query) { @@ -248,15 +250,17 @@ void FlattenDataCell::query(float* result_dists, const std::shared_ptr>& computer, const InnerIdType* idx, - InnerIdType id_count) { + InnerIdType id_count, + Allocator* allocator) { + Allocator* search_alloc = allocator == nullptr ? allocator_ : allocator; for (uint32_t i = 0; i < this->prefetch_stride_code_ and i < id_count; i++) { this->io_->Prefetch(static_cast(idx[i]) * static_cast(code_size_), this->prefetch_depth_code_ * 64); } if (not this->io_->InMemory() and id_count > 1) { - ByteBuffer codes(static_cast(id_count) * this->code_size_, allocator_); - Vector sizes(id_count, this->code_size_, allocator_); - Vector offsets(id_count, this->code_size_, allocator_); + ByteBuffer codes(id_count * this->code_size_, search_alloc); + Vector sizes(id_count, this->code_size_, search_alloc); + Vector offsets(id_count, this->code_size_, search_alloc); for (int64_t i = 0; i < id_count; ++i) { offsets[i] = static_cast(idx[i]) * this->code_size_; } diff --git a/src/data_cell/flatten_interface.h b/src/data_cell/flatten_interface.h index ef0c59751..d541da45d 100644 --- a/src/data_cell/flatten_interface.h +++ b/src/data_cell/flatten_interface.h @@ -44,7 +44,8 @@ class FlattenInterface { Query(float* result_dists, const ComputerInterfacePtr& computer, const InnerIdType* idx, - InnerIdType id_count) = 0; + InnerIdType id_count, + Allocator* allocator = nullptr) = 0; virtual ComputerInterfacePtr FactoryComputer(const void* query) = 0; diff --git a/src/data_cell/sparse_vector_datacell.h b/src/data_cell/sparse_vector_datacell.h index ebc4a4cef..51a540884 100644 --- a/src/data_cell/sparse_vector_datacell.h +++ b/src/data_cell/sparse_vector_datacell.h @@ -35,7 +35,8 @@ class SparseVectorDataCell : public FlattenInterface { Query(float* result_dists, const ComputerInterfacePtr& computer, const InnerIdType* idx, - InnerIdType id_count) override { + InnerIdType id_count, + Allocator* allocator = nullptr) override { auto comp = std::static_pointer_cast>(computer); this->query(result_dists, comp, idx, id_count); } diff --git a/src/impl/basic_searcher.cpp b/src/impl/basic_searcher.cpp index 2581ca38e..0a1dd98fa 100644 --- a/src/impl/basic_searcher.cpp +++ b/src/impl/basic_searcher.cpp @@ -96,8 +96,10 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph, const void* query, const InnerSearchParam& inner_search_param, IteratorFilterContext* iter_ctx) const { - auto top_candidates = std::make_shared>(allocator_, -1); - auto candidate_set = std::make_shared>(allocator_, -1); + Allocator* alloc = + inner_search_param.search_alloc == nullptr ? allocator_ : inner_search_param.search_alloc; + auto top_candidates = std::make_shared>(alloc, -1); + auto candidate_set = std::make_shared>(alloc, -1); if (not graph or not flatten) { return top_candidates; @@ -116,10 +118,10 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph, uint32_t hops = 0; uint32_t dist_cmp = 0; uint32_t count_no_visited = 0; - Vector to_be_visited_rid(graph->MaximumDegree(), allocator_); - Vector to_be_visited_id(graph->MaximumDegree(), allocator_); - Vector neighbors(graph->MaximumDegree(), allocator_); - Vector line_dists(graph->MaximumDegree(), allocator_); + Vector to_be_visited_rid(graph->MaximumDegree(), alloc); + Vector to_be_visited_id(graph->MaximumDegree(), alloc); + Vector neighbors(graph->MaximumDegree(), alloc); + Vector line_dists(graph->MaximumDegree(), alloc); if (!iter_ctx->IsFirstUsed()) { if (iter_ctx->Empty()) { @@ -131,7 +133,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph, if (!vl->Get(cur_inner_id) && iter_ctx->CheckPoint(cur_inner_id)) { vl->Set(cur_inner_id); lower_bound = std::max(lower_bound, cur_dist); - flatten->Query(&cur_dist, computer, &cur_inner_id, 1); + flatten->Query(&cur_dist, computer, &cur_inner_id, 1, alloc); top_candidates->Push(cur_dist, cur_inner_id); candidate_set->Push(cur_dist, cur_inner_id); if constexpr (mode == InnerSearchMode::RANGE_SEARCH) { @@ -143,7 +145,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph, iter_ctx->PopDiscard(); } } else { - flatten->Query(&dist, computer, &ep, 1); + flatten->Query(&dist, computer, &ep, 1, alloc); if (not is_id_allowed || is_id_allowed->CheckValid(ep)) { top_candidates->Push(dist, ep); lower_bound = top_candidates->Top().first; @@ -178,7 +180,8 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph, dist_cmp += count_no_visited; - flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited); + flatten->Query( + line_dists.data(), computer, to_be_visited_id.data(), count_no_visited, alloc); for (uint32_t i = 0; i < count_no_visited; i++) { dist = line_dists[i]; @@ -230,8 +233,10 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph, const VisitedListPtr& vl, const void* query, const InnerSearchParam& inner_search_param) const { - auto top_candidates = std::make_shared>(allocator_, -1); - auto candidate_set = std::make_shared>(allocator_, -1); + Allocator* alloc = + inner_search_param.search_alloc == nullptr ? allocator_ : inner_search_param.search_alloc; + auto top_candidates = std::make_shared>(alloc, -1); + auto candidate_set = std::make_shared>(alloc, -1); if (not graph or not flatten) { return top_candidates; @@ -249,12 +254,12 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph, uint32_t hops = 0; uint32_t dist_cmp = 0; uint32_t count_no_visited = 0; - Vector to_be_visited_rid(graph->MaximumDegree(), allocator_); - Vector to_be_visited_id(graph->MaximumDegree(), allocator_); - Vector neighbors(graph->MaximumDegree(), allocator_); - Vector line_dists(graph->MaximumDegree(), allocator_); + Vector to_be_visited_rid(graph->MaximumDegree(), alloc); + Vector to_be_visited_id(graph->MaximumDegree(), alloc); + Vector neighbors(graph->MaximumDegree(), alloc); + Vector line_dists(graph->MaximumDegree(), alloc); - flatten->Query(&dist, computer, &ep, 1); + flatten->Query(&dist, computer, &ep, 1, alloc); if (not is_id_allowed || is_id_allowed->CheckValid(ep)) { top_candidates->Push(dist, ep); lower_bound = top_candidates->Top().first; @@ -293,7 +298,8 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph, dist_cmp += count_no_visited; - flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited); + flatten->Query( + line_dists.data(), computer, to_be_visited_id.data(), count_no_visited, alloc); for (uint32_t i = 0; i < count_no_visited; i++) { dist = line_dists[i]; diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 05dc52cf2..a8b9a99d7 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -222,6 +222,7 @@ HNSW::knn_search(const DatasetPtr& query, int64_t k, const std::string& parameters, const FilterPtr& filter_ptr, + vsag::Allocator* allocator, vsag::IteratorContext** iter_ctx, bool is_last_filter) const { #ifndef ENABLE_TESTS @@ -234,6 +235,7 @@ HNSW::knn_search(const DatasetPtr& query, ret->Dim(0)->NumElements(1); return ret; } + vsag::Allocator* search_allocator = allocator == nullptr ? allocator_.get() : allocator; // check query vector CHECK_ARGUMENT(query->GetNumElements() == 1, "query dataset should contain 1 vector only"); @@ -260,7 +262,7 @@ HNSW::knn_search(const DatasetPtr& query, "ef_search({}) must in range[1, {}]", params.ef_search, ef_search_threshold)); if (iter_ctx != nullptr && *iter_ctx == nullptr) { auto* filter_context = new IteratorFilterContext(); - filter_context->init(alg_hnsw_->getMaxElements(), params.ef_search, allocator_.get()); + filter_context->init(alg_hnsw_->getMaxElements(), params.ef_search, search_allocator); *iter_ctx = filter_context; } IteratorFilterContext* iter_filter_ctx = nullptr; @@ -282,6 +284,7 @@ HNSW::knn_search(const DatasetPtr& query, std::max(params.ef_search, k), filter_ptr, params.skip_ratio, + allocator, iter_filter_ctx, is_last_filter); } catch (const std::runtime_error& e) { @@ -324,7 +327,7 @@ HNSW::knn_search(const DatasetPtr& query, results.pop(); } auto [dataset_results, dists, ids] = - CreateFastDataset(static_cast(results.size()), allocator_.get()); + CreateFastDataset(static_cast(results.size()), search_allocator); for (auto j = static_cast(results.size() - 1); j >= 0; --j) { dists[j] = results.top().first; diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 0f24cb369..ef0ccfa65 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -130,8 +130,24 @@ class HNSW : public Index { const FilterPtr& filter, vsag::IteratorContext*& filter_ctx, bool is_last_search) const override { - SAFE_CALL( - return this->knn_search(query, k, parameters, filter, &filter_ctx, is_last_search)); + SAFE_CALL(return this->knn_search( + query, k, parameters, filter, nullptr, &filter_ctx, is_last_search)); + } + + tl::expected + KnnSearch(const DatasetPtr& query, int64_t k, SearchParam& search_param) const override { + if (search_param.is_iter_filter) { + SAFE_CALL(return this->knn_search(query, + k, + search_param.parameters, + search_param.filter, + search_param.allocator, + &search_param.iter_ctx, + search_param.is_last_search)); + } else { + SAFE_CALL(return this->knn_search( + query, k, search_param.parameters, search_param.filter, search_param.allocator)); + } } tl::expected @@ -318,6 +334,7 @@ class HNSW : public Index { int64_t k, const std::string& parameters, const FilterPtr& filter_ptr, + vsag::Allocator* allocator = nullptr, vsag::IteratorContext** iter_ctx = nullptr, bool is_last_filter = false) const; diff --git a/src/index/index_impl.h b/src/index/index_impl.h index c8c69bfd5..ceb9013e3 100644 --- a/src/index/index_impl.h +++ b/src/index/index_impl.h @@ -117,6 +117,25 @@ class IndexImpl : public Index { SAFE_CALL(return this->inner_index_->KnnSearch(query, k, parameters, filter)); } + tl::expected + KnnSearch(const DatasetPtr& query, int64_t k, SearchParam& search_param) const override { + if (GetNumElements() == 0) { + return DatasetImpl::MakeEmptyDataset(); + } + if (search_param.is_iter_filter) { + SAFE_CALL(return this->inner_index_->KnnSearch(query, + k, + search_param.parameters, + search_param.filter, + search_param.allocator, + search_param.iter_ctx, + search_param.is_last_search)); + } else { + SAFE_CALL(return this->inner_index_->KnnSearch( + query, k, search_param.parameters, search_param.filter, search_param.allocator)); + } + } + tl::expected KnnSearch(const DatasetPtr& query, int64_t k, @@ -128,7 +147,7 @@ class IndexImpl : public Index { return DatasetImpl::MakeEmptyDataset(); } SAFE_CALL(return this->inner_index_->KnnSearch( - query, k, parameters, filter, iter_ctx, is_last_filter)); + query, k, parameters, filter, nullptr, iter_ctx, is_last_filter)); } [[nodiscard]] tl::expected diff --git a/src/index/iterator_filter.cpp b/src/index/iterator_filter.cpp index 37aad7a65..899a7dd09 100644 --- a/src/index/iterator_filter.cpp +++ b/src/index/iterator_filter.cpp @@ -33,7 +33,7 @@ IteratorFilterContext::init(InnerIdType max_size, int64_t ef_search, Allocator* ef_search_ = ef_search; allocator_ = allocator; max_size_ = max_size; - discard_ = std::make_unique>>(); + discard_ = std::make_unique(allocator); list_ = reinterpret_cast( allocator_->Allocate((uint64_t)max_size * sizeof(VisitedListType))); memset(list_, 0, max_size * sizeof(VisitedListType)); diff --git a/src/index/iterator_filter.h b/src/index/iterator_filter.h index d699564b0..da821af8d 100644 --- a/src/index/iterator_filter.h +++ b/src/index/iterator_filter.h @@ -73,7 +73,7 @@ class IteratorFilterContext : public IteratorContext { uint32_t max_size_{0}; Allocator* allocator_{nullptr}; VisitedListType* list_{nullptr}; - std::unique_ptr>> discard_; + std::unique_ptr discard_; }; }; // namespace vsag diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index 69a55613b..24fb29559 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -199,6 +199,7 @@ HgraphTestIndex::TestGeneral(const TestIndex::IndexPtr& index, TestCheckIdExist(index, dataset); TestCalcDistanceById(index, dataset); TestBatchCalcDistanceById(index, dataset); + TestSearchAllocator(index, dataset, search_param, recall, true); } } // namespace fixtures diff --git a/tests/test_hnsw_new.cpp b/tests/test_hnsw_new.cpp index 5be7d6616..b842d94cd 100644 --- a/tests/test_hnsw_new.cpp +++ b/tests/test_hnsw_new.cpp @@ -244,6 +244,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, TestRangeSearch(index, dataset, search_param, 0.99, 10, true); TestRangeSearch(index, dataset, search_param, 0.49, 5, true); TestFilterSearch(index, dataset, search_param, 0.99, true); + TestSearchAllocator(index, dataset, search_param, 0.99, true); } vsag::Options::Instance().set_block_size_limit(origin_size); } diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 6eb950f86..007462e6d 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -23,6 +23,7 @@ #include "simd/fp32_simd.h" #include "vsag/engine.h" #include "vsag/resource.h" +#include "vsag/search_param.h" namespace fixtures { static int64_t @@ -661,6 +662,94 @@ TestIndex::TestFilterSearch(const TestIndex::IndexPtr& index, REQUIRE(cur_recall > expected_recall * query_count * RECALL_THRESHOLD); } +void +TestIndex::TestSearchAllocator(const TestIndex::IndexPtr& index, + const TestDatasetPtr& dataset, + const std::string& search_param, + float expected_recall, + bool expected_success) { + if (not index->CheckFeature(vsag::SUPPORT_KNN_SEARCH)) { + return; + } + auto queries = dataset->query_; + auto query_count = queries->GetNumElements(); + auto dim = queries->GetDim(); + auto gts = dataset->ground_truth_; + auto gt_topK = dataset->top_k; + float cur_recall = 0.0f; + auto topk = gt_topK; + class ExampleAllocator : public vsag::Allocator { + public: + std::string + Name() override { + return "example-allocator"; + } + + void* + Allocate(size_t size) override { + auto addr = (void*)malloc(size); + sizes_[addr] = size; + return addr; + } + + void + Deallocate(void* p) override { + if (sizes_.find(p) == sizes_.end()) + return; + sizes_.erase(p); + return free(p); + } + + void* + Reallocate(void* p, size_t size) override { + auto addr = (void*)realloc(p, size); + sizes_.erase(p); + sizes_[addr] = size; + return addr; + } + + private: + std::unordered_map sizes_; + }; + + for (auto i = 0; i < query_count; ++i) { + auto query = vsag::Dataset::Make(); + query->NumElements(1) + ->Dim(dim) + ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) + ->SparseVectors(queries->GetSparseVectors() + i) + ->Paths(queries->GetPaths() + i) + ->Owner(false); + ExampleAllocator allocator; + vsag::SearchParam search_params(false, search_param, nullptr, &allocator); + auto res = index->KnnSearch(query, topk, search_params); + if (not expected_success) { + if (res.has_value()) { + REQUIRE(res.value()->GetDim() == 0); + } + } else { + REQUIRE(res.has_value() == expected_success); + } + if (!expected_success) { + return; + } + REQUIRE(res.value()->GetDim() == topk); + auto result = res.value()->GetIds(); + auto dis = res.value()->GetDistances(); + auto gt = gts->GetIds() + gt_topK * i; + auto val = Intersection(gt, gt_topK, result, topk); + cur_recall += static_cast(val) / static_cast(gt_topK); + allocator.Deallocate((void*)result); + allocator.Deallocate((void*)dis); + } + if (cur_recall <= expected_recall * query_count) { + WARN(fmt::format("cur_result({}) <= expected_recall * query_count({})", + cur_recall, + expected_recall * query_count)); + } + REQUIRE(cur_recall > expected_recall * query_count * RECALL_THRESHOLD); +} + void TestIndex::TestCalcDistanceById(const IndexPtr& index, const TestDatasetPtr& dataset, diff --git a/tests/test_index.h b/tests/test_index.h index 05237077c..e7d9dcafa 100644 --- a/tests/test_index.h +++ b/tests/test_index.h @@ -115,6 +115,13 @@ class TestIndex { bool expected_success = true, bool use_ex_filter = false); + static void + TestSearchAllocator(const IndexPtr& index, + const TestDatasetPtr& dataset, + const std::string& search_param, + float expected_recall = 0.99, + bool expected_success = true); + static void TestSearchWithDirtyVector(const IndexPtr& index, const TestDatasetPtr& dataset, From e2d8263931af81d28f0477cc9e505929a99a3085 Mon Sep 17 00:00:00 2001 From: Deming Chu Date: Thu, 5 Jun 2025 10:35:50 +0800 Subject: [PATCH 37/42] print memory usage detail in eval_performance for HGraph (#770) Signed-off-by: Deming Chu Signed-off-by: suguan.dx --- include/vsag/index.h | 18 ++++++++++++++---- src/algorithm/hgraph.cpp | 23 +++++++++++++++++++++++ src/algorithm/hgraph.h | 10 ++++++---- src/algorithm/inner_index_interface.h | 7 +++++++ src/data_cell/extra_info_interface.h | 8 ++++++++ src/data_cell/flatten_interface.h | 8 ++++++++ src/data_cell/graph_interface.h | 8 ++++++++ src/index/index_impl.h | 5 +++++ tests/test_hgraph.cpp | 12 ++++++++++++ tools/eval/case/build_eval_case.cpp | 6 ++++++ tools/eval/case/search_eval_case.cpp | 6 ++++++ 11 files changed, 103 insertions(+), 8 deletions(-) diff --git a/include/vsag/index.h b/include/vsag/index.h index b61183ac7..ab3ec5c0f 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -15,16 +15,13 @@ #pragma once -#include #include #include -#include #include #include "bitset.h" -#include "features.h" +#include "typing.h" #include "vsag/binaryset.h" -#include "vsag/bitset.h" #include "vsag/dataset.h" #include "vsag/errors.h" #include "vsag/expected.hpp" @@ -532,6 +529,19 @@ class Index { [[nodiscard]] virtual int64_t GetMemoryUsage() const = 0; + /** + * @brief Return the memory usage of every component in the index + * + * @return a json object that contains the memory usage of every component in the index + */ + // TODO(deming): implement func for every types of index + // [[nodiscard]] virtual JsonType + // GetMemoryUsageDetail() const = 0; + [[nodiscard]] virtual JsonType + GetMemoryUsageDetail() const { + throw std::runtime_error("Index not support GetMemoryUsageDetail"); + } + /** * @brief estimate the memory used by the index with given element counts * diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 0e00ab4a3..0e72ad8c0 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -700,6 +700,29 @@ HGraph::Deserialize(StreamReader& reader) { } } +JsonType +HGraph::GetMemoryUsageDetail() const { + JsonType memory_usage; + if (this->ignore_reorder_) { + this->use_reorder_ = false; + } + memory_usage["basic_flatten_codes"] = this->basic_flatten_codes_->CalcSerializeSize(); + memory_usage["bottom_graph"] = this->bottom_graph_->CalcSerializeSize(); + if (this->use_reorder_) { + memory_usage["high_precise_codes"] = this->high_precise_codes_->CalcSerializeSize(); + } + size_t route_graph_size = 0; + for (const auto& route_graph : this->route_graphs_) { + route_graph_size += route_graph->CalcSerializeSize(); + } + memory_usage["route_graph"] = route_graph_size; + if (this->extra_info_size_ > 0 && this->extra_infos_ != nullptr) { + memory_usage["extra_infos"] = this->extra_infos_->CalcSerializeSize(); + } + memory_usage["__total_size__"] = this->CalSerializeSize(); + return memory_usage; +} + void HGraph::deserialize_basic_info(StreamReader& reader) { StreamReader::ReadObj(reader, this->use_reorder_); diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index 8fc3f66a1..a6271f990 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -114,18 +114,20 @@ class HGraph : public InnerIndexInterface { int64_t GetNumElements() const override { - return this->total_count_; + return static_cast(this->total_count_); } uint64_t EstimateMemory(uint64_t num_elements) const override; - // TODO(LHT): implement - inline int64_t + int64_t GetMemoryUsage() const override { - return 0; + return static_cast(this->CalSerializeSize()); } + JsonType + GetMemoryUsageDetail() const override; + float CalcDistanceById(const float* query, int64_t id) const override; diff --git a/src/algorithm/inner_index_interface.h b/src/algorithm/inner_index_interface.h index 396bd299c..a390412c5 100644 --- a/src/algorithm/inner_index_interface.h +++ b/src/algorithm/inner_index_interface.h @@ -250,6 +250,13 @@ class InnerIndexInterface { "Index doesn't support GetMemoryUsage"); } + [[nodiscard]] virtual JsonType + GetMemoryUsageDetail() const { + // TODO(deming): implement func for every types of inner index + throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, + "Index doesn't support GetMemoryUsageDetail"); + } + [[nodiscard]] virtual uint64_t EstimateMemory(uint64_t num_elements) const { throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, diff --git a/src/data_cell/extra_info_interface.h b/src/data_cell/extra_info_interface.h index 5e27e211c..751adf993 100644 --- a/src/data_cell/extra_info_interface.h +++ b/src/data_cell/extra_info_interface.h @@ -99,6 +99,14 @@ class ExtraInfoInterface { StreamReader::ReadObj(reader, this->extra_info_size_); } + uint64_t + CalcSerializeSize() { + auto calSizeFunc = [](uint64_t cursor, uint64_t size, void* buf) { return; }; + WriteFuncStreamWriter writer(calSizeFunc, 0); + this->Serialize(writer); + return writer.cursor_; + } + [[nodiscard]] virtual bool InMemory() const { return true; diff --git a/src/data_cell/flatten_interface.h b/src/data_cell/flatten_interface.h index d541da45d..a46841c11 100644 --- a/src/data_cell/flatten_interface.h +++ b/src/data_cell/flatten_interface.h @@ -135,6 +135,14 @@ class FlattenInterface { StreamReader::ReadObj(reader, this->code_size_); } + uint64_t + CalcSerializeSize() { + auto calSizeFunc = [](uint64_t cursor, uint64_t size, void* buf) { return; }; + WriteFuncStreamWriter writer(calSizeFunc, 0); + this->Serialize(writer); + return writer.cursor_; + } + [[nodiscard]] virtual bool InMemory() const { return true; diff --git a/src/data_cell/graph_interface.h b/src/data_cell/graph_interface.h index 0029be2a6..d4d3637f7 100644 --- a/src/data_cell/graph_interface.h +++ b/src/data_cell/graph_interface.h @@ -78,6 +78,14 @@ class GraphInterface { StreamReader::ReadObj(reader, this->maximum_degree_); } + uint64_t + CalcSerializeSize() { + auto calSizeFunc = [](uint64_t cursor, uint64_t size, void* buf) { return; }; + WriteFuncStreamWriter writer(calSizeFunc, 0); + this->Serialize(writer); + return writer.cursor_; + } + [[nodiscard]] virtual InnerIdType TotalCount() const { return this->total_count_; diff --git a/src/index/index_impl.h b/src/index/index_impl.h index ceb9013e3..a143c8250 100644 --- a/src/index/index_impl.h +++ b/src/index/index_impl.h @@ -298,6 +298,11 @@ class IndexImpl : public Index { return this->inner_index_->GetMemoryUsage(); } + [[nodiscard]] JsonType + GetMemoryUsageDetail() const override { + return this->inner_index_->GetMemoryUsageDetail(); + } + [[nodiscard]] uint64_t EstimateMemory(uint64_t num_elements) const override { return this->inner_index_->EstimateMemory(num_elements); diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index 24fb29559..5aea82f78 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -47,6 +47,9 @@ class HgraphTestIndex : public fixtures::TestIndex { const std::string& search_param, float recall); + static void + TestMemoryUsageDetail(const IndexPtr& index); + static TestDatasetPool pool; static std::vector dims; @@ -200,6 +203,15 @@ HgraphTestIndex::TestGeneral(const TestIndex::IndexPtr& index, TestCalcDistanceById(index, dataset); TestBatchCalcDistanceById(index, dataset); TestSearchAllocator(index, dataset, search_param, recall, true); + TestMemoryUsageDetail(index); +} + +void +HgraphTestIndex::TestMemoryUsageDetail(const IndexPtr& index) { + auto memory_detail = index->GetMemoryUsageDetail(); + REQUIRE(memory_detail.contains("basic_flatten_codes")); + REQUIRE(memory_detail.contains("bottom_graph")); + REQUIRE(memory_detail.contains("route_graph")); } } // namespace fixtures diff --git a/tools/eval/case/build_eval_case.cpp b/tools/eval/case/build_eval_case.cpp index 0006eebab..538622574 100644 --- a/tools/eval/case/build_eval_case.cpp +++ b/tools/eval/case/build_eval_case.cpp @@ -101,6 +101,12 @@ BuildEvalCase::process_result() { result["index_info"] = JsonType::parse(config_.build_param); result["action"] = "build"; result["index"] = config_.index_name; + // TODO(deming): remove try-catch after implement GetMemoryUsageDetail + try { + result["memory_detail(B)"] = this->index_->GetMemoryUsageDetail(); + } catch (std::exception& e) { + logger_->Debug(e.what()); + } return result; } diff --git a/tools/eval/case/search_eval_case.cpp b/tools/eval/case/search_eval_case.cpp index 90d9ec043..1f2db4af4 100644 --- a/tools/eval/case/search_eval_case.cpp +++ b/tools/eval/case/search_eval_case.cpp @@ -246,6 +246,12 @@ SearchEvalCase::process_result() { result["index_info"] = JsonType::parse(config_.build_param); result["search_param"] = config_.search_param; result["index"] = config_.index_name; + // TODO(deming): remove try-catch after implement GetMemoryUsageDetail + try { + result["memory_detail(B)"] = this->index_->GetMemoryUsageDetail(); + } catch (std::exception& e) { + logger_->Debug(e.what()); + } EvalCase::MergeJsonType(this->basic_info_, result); return result; } From fc6bd3e8cb51f8f3fc5c7b6c87e1c0780a8bbacb Mon Sep 17 00:00:00 2001 From: "suguan.dx" Date: Thu, 5 Jun 2025 20:44:48 +0800 Subject: [PATCH 38/42] add ut IVF Build With Large K Signed-off-by: suguan.dx --- tests/test_ivf.cpp | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_ivf.cpp b/tests/test_ivf.cpp index 89c34f40b..0eed3aaf0 100644 --- a/tests/test_ivf.cpp +++ b/tests/test_ivf.cpp @@ -414,6 +414,35 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, "IVF Build", "[ft][ivf]") { } } +TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, "IVF Build With Large K", "[ft][ivf]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2"); + std::string train_type = GENERATE("kmeans"); + + std::vector> tmp_test_cases = { + {"fp32", 0.84}, + }; + + const std::string name = "ivf"; + auto search_param = fmt::format(search_param_tmp, 100); + std::vector dims_tmp = {32}; + for (auto& dim : dims_tmp) { + for (auto& [base_quantization_str, recall] : tmp_test_cases) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateIVFBuildParametersString( + metric_type, dim, base_quantization_str, 10000, train_type); + auto index = TestFactory(name, param, true); + auto dataset = pool.GetDatasetAndCreate(dim, 20000, metric_type); + TestBuildIndex(index, dataset, true); + if (index->CheckFeature(vsag::SUPPORT_BUILD)) { + TestGeneral(index, dataset, search_param, recall); + } + vsag::Options::Instance().set_block_size_limit(origin_size); + } + } +} + TEST_CASE_PERSISTENT_FIXTURE(fixtures::IVFTestIndex, "IVF Export Model", "[ft][ivf]") { auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); From 391264c15bb855687094310392163960054e341a Mon Sep 17 00:00:00 2001 From: Deming Chu Date: Fri, 6 Jun 2025 10:59:03 +0800 Subject: [PATCH 39/42] fix(hgraph): issues of segment fault and recall degradation (#784) Signed-off-by: Deming Chu Signed-off-by: suguan.dx --- src/data_cell/compressed_graph_datacell.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/data_cell/compressed_graph_datacell.cpp b/src/data_cell/compressed_graph_datacell.cpp index 4826fe1ad..b0c84140e 100644 --- a/src/data_cell/compressed_graph_datacell.cpp +++ b/src/data_cell/compressed_graph_datacell.cpp @@ -43,15 +43,12 @@ CompressedGraphDataCell::InsertNeighborsById(InnerIdType id, Vector tmp(neighbor_ids.begin(), neighbor_ids.end(), allocator_); std::sort(tmp.begin(), tmp.end()); - // TODO(deming): reset the size of neighbor_sets_[id] while tmp is empty + std::unique_ptr new_node; if (not tmp.empty()) { - if (neighbor_sets_[id] == nullptr) { - neighbor_sets_[id] = std::make_unique(allocator_); - } - neighbor_sets_[id]->Encode(tmp, max_capacity_); - } else { - neighbor_sets_[id].reset(); + new_node = std::make_unique(allocator_); + new_node->Encode(tmp, max_capacity_); } + neighbor_sets_[id] = std::move(new_node); InnerIdType current = total_count_.load(); while (current < id + 1 && !total_count_.compare_exchange_weak(current, id + 1)) { @@ -65,6 +62,7 @@ CompressedGraphDataCell::GetNeighborSize(InnerIdType id) const { void CompressedGraphDataCell::GetNeighbors(InnerIdType id, Vector& neighbor_ids) const { + neighbor_ids.clear(); if (GetNeighborSize(id) > 0) { neighbor_sets_[id]->DecompressAll(neighbor_ids); } From 7a291bcdd552d1c38006a431a4aa64430230cd0c Mon Sep 17 00:00:00 2001 From: Xiangyu Wang Date: Fri, 6 Jun 2025 11:26:22 +0800 Subject: [PATCH 40/42] add spdlog to the deps of mockimpl (#775) Signed-off-by: Xiangyu Wang Signed-off-by: suguan.dx --- mockimpl/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mockimpl/CMakeLists.txt b/mockimpl/CMakeLists.txt index fcdf11581..d33513bf6 100644 --- a/mockimpl/CMakeLists.txt +++ b/mockimpl/CMakeLists.txt @@ -15,8 +15,8 @@ add_library (vsag_mockimpl_static STATIC ${MOCK_SRCS}) target_link_libraries (vsag_mockimpl roaring fmt::fmt-header-only simd) target_link_libraries (vsag_mockimpl_static roaring fmt::fmt-header-only simd) -add_dependencies (vsag_mockimpl version_mockimpl roaring) -add_dependencies (vsag_mockimpl_static version_mockimpl roaring) +add_dependencies (vsag_mockimpl version_mockimpl roaring spdlog) +add_dependencies (vsag_mockimpl_static version_mockimpl roaring spdlog) set_target_properties(vsag_mockimpl_static PROPERTIES OUTPUT_NAME "vsag_mockimpl") install(TARGETS vsag_mockimpl vsag_mockimpl_static From 05407d060f5235fa5f12fda8a2f051ad40324eda Mon Sep 17 00:00:00 2001 From: Xiangyu Wang Date: Mon, 9 Jun 2025 14:10:23 +0800 Subject: [PATCH 41/42] optimize header files including (#787) Signed-off-by: Xiangyu Wang Signed-off-by: suguan.dx --- include/vsag/index.h | 5 ++--- mockimpl/vsag/simpleflat.cpp | 1 + src/algorithm/hgraph.cpp | 4 ++-- src/algorithm/hgraph.h | 3 ++- src/algorithm/inner_index_interface.h | 2 +- src/index/index_impl.h | 2 +- tests/test_hgraph.cpp | 3 ++- 7 files changed, 11 insertions(+), 9 deletions(-) diff --git a/include/vsag/index.h b/include/vsag/index.h index ab3ec5c0f..b614fd829 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -19,9 +19,8 @@ #include #include -#include "bitset.h" -#include "typing.h" #include "vsag/binaryset.h" +#include "vsag/bitset.h" #include "vsag/dataset.h" #include "vsag/errors.h" #include "vsag/expected.hpp" @@ -537,7 +536,7 @@ class Index { // TODO(deming): implement func for every types of index // [[nodiscard]] virtual JsonType // GetMemoryUsageDetail() const = 0; - [[nodiscard]] virtual JsonType + [[nodiscard]] virtual std::string GetMemoryUsageDetail() const { throw std::runtime_error("Index not support GetMemoryUsageDetail"); } diff --git a/mockimpl/vsag/simpleflat.cpp b/mockimpl/vsag/simpleflat.cpp index 2a36cafe0..ec0aa660b 100644 --- a/mockimpl/vsag/simpleflat.cpp +++ b/mockimpl/vsag/simpleflat.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include "vsag/errors.h" #include "vsag/expected.hpp" diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 0e72ad8c0..807ef118a 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -700,7 +700,7 @@ HGraph::Deserialize(StreamReader& reader) { } } -JsonType +std::string HGraph::GetMemoryUsageDetail() const { JsonType memory_usage; if (this->ignore_reorder_) { @@ -720,7 +720,7 @@ HGraph::GetMemoryUsageDetail() const { memory_usage["extra_infos"] = this->extra_infos_->CalcSerializeSize(); } memory_usage["__total_size__"] = this->CalSerializeSize(); - return memory_usage; + return memory_usage.dump(); } void diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index a6271f990..97af4ff73 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "algorithm/hnswlib/algorithm_interface.h" #include "algorithm/hnswlib/visited_list_pool.h" @@ -125,7 +126,7 @@ class HGraph : public InnerIndexInterface { return static_cast(this->CalSerializeSize()); } - JsonType + std::string GetMemoryUsageDetail() const override; float diff --git a/src/algorithm/inner_index_interface.h b/src/algorithm/inner_index_interface.h index a390412c5..a4b6ffd3f 100644 --- a/src/algorithm/inner_index_interface.h +++ b/src/algorithm/inner_index_interface.h @@ -250,7 +250,7 @@ class InnerIndexInterface { "Index doesn't support GetMemoryUsage"); } - [[nodiscard]] virtual JsonType + [[nodiscard]] virtual std::string GetMemoryUsageDetail() const { // TODO(deming): implement func for every types of inner index throw VsagException(ErrorType::UNSUPPORTED_INDEX_OPERATION, diff --git a/src/index/index_impl.h b/src/index/index_impl.h index a143c8250..1b79f997d 100644 --- a/src/index/index_impl.h +++ b/src/index/index_impl.h @@ -298,7 +298,7 @@ class IndexImpl : public Index { return this->inner_index_->GetMemoryUsage(); } - [[nodiscard]] JsonType + [[nodiscard]] std::string GetMemoryUsageDetail() const override { return this->inner_index_->GetMemoryUsageDetail(); } diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index 5aea82f78..6c8fd568b 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -23,6 +23,7 @@ #include "fixtures/test_logger.h" #include "inner_string_params.h" #include "test_index.h" +#include "typing.h" #include "vsag/options.h" namespace fixtures { @@ -208,7 +209,7 @@ HgraphTestIndex::TestGeneral(const TestIndex::IndexPtr& index, void HgraphTestIndex::TestMemoryUsageDetail(const IndexPtr& index) { - auto memory_detail = index->GetMemoryUsageDetail(); + auto memory_detail = vsag::JsonType::parse(index->GetMemoryUsageDetail()); REQUIRE(memory_detail.contains("basic_flatten_codes")); REQUIRE(memory_detail.contains("bottom_graph")); REQUIRE(memory_detail.contains("route_graph")); From 0b4659094c7150b97b0479b1e64de5f42fb3f2ab Mon Sep 17 00:00:00 2001 From: "suguan.dx" Date: Mon, 9 Jun 2025 18:38:56 +0800 Subject: [PATCH 42/42] fix ut Signed-off-by: suguan.dx --- tests/test_ivf.cpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tests/test_ivf.cpp b/tests/test_ivf.cpp index 21b931403..0eed3aaf0 100644 --- a/tests/test_ivf.cpp +++ b/tests/test_ivf.cpp @@ -26,7 +26,6 @@ class IVFTestIndex : public fixtures::TestIndex { const std::string& quantization_str = "sq8", int buckets_count = 210, const std::string& train_type = "kmeans", -<<<<<<< HEAD bool use_residual = false, int buckets_per_data = 1); @@ -40,9 +39,6 @@ class IVFTestIndex : public fixtures::TestIndex { bool use_residual = false, int buckets_per_data = 1); -======= - bool use_residual = false); ->>>>>>> upstream/main static void TestGeneral(const IndexPtr& index, const TestDatasetPtr& dataset, @@ -88,12 +84,8 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, const std::string& quantization_str, int buckets_count, const std::string& train_type, -<<<<<<< HEAD bool use_residual, int buckets_per_data) { -======= - bool use_residual) { ->>>>>>> upstream/main std::string build_parameters_str; constexpr auto parameter_temp = R"( @@ -108,12 +100,8 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, "use_reorder": {}, "base_pq_dim": {}, "precise_quantization_type": "{}", -<<<<<<< HEAD "use_residual": {}, "buckets_per_data": {} -======= - "use_residual": {} ->>>>>>> upstream/main }} }} )"; @@ -139,7 +127,6 @@ IVFTestIndex::GenerateIVFBuildParametersString(const std::string& metric_type, use_reorder, pq_dim, precise_quantizer_str, -<<<<<<< HEAD use_residual, buckets_per_data); @@ -202,9 +189,6 @@ IVFTestIndex::GenerateGNOIMIBuildParametersString(const std::string& metric_type precise_quantizer_str, use_residual, buckets_per_data); -======= - use_residual); ->>>>>>> upstream/main INFO(build_parameters_str); return build_parameters_str;