Skip to content

Commit 6da48a2

Browse files
Carrot-77zourunxin.zrx
and
zourunxin.zrx
authored
add search allocator (#716)
Signed-off-by: zourunxin.zrx <[email protected]> Co-authored-by: zourunxin.zrx <[email protected]>
1 parent 69f889d commit 6da48a2

26 files changed

+716
-46
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#include <iostream>
17+
18+
#include "nlohmann/json.hpp"
19+
#include "vsag/logger.h"
20+
#include "vsag/search_param.h"
21+
#include "vsag/vsag.h"
22+
23+
class ExampleAllocator : public vsag::Allocator {
24+
public:
25+
std::string
26+
Name() override {
27+
return "example-allocator";
28+
}
29+
30+
void*
31+
Allocate(size_t size) override {
32+
vsag::Options::Instance().logger()->Debug("allocate " + std::to_string(size) + " bytes.");
33+
auto addr = (void*)malloc(size);
34+
sizes_[addr] = size;
35+
return addr;
36+
}
37+
38+
void
39+
Deallocate(void* p) override {
40+
if (sizes_.find(p) == sizes_.end())
41+
return;
42+
vsag::Options::Instance().logger()->Debug("deallocate " + std::to_string(sizes_[p]) +
43+
" bytes.");
44+
sizes_.erase(p);
45+
return free(p);
46+
}
47+
48+
void*
49+
Reallocate(void* p, size_t size) override {
50+
vsag::Options::Instance().logger()->Debug("reallocate " + std::to_string(size) + " bytes.");
51+
auto addr = (void*)realloc(p, size);
52+
sizes_.erase(p);
53+
sizes_[addr] = size;
54+
return addr;
55+
}
56+
57+
private:
58+
std::unordered_map<void*, size_t> sizes_;
59+
};
60+
61+
int
62+
main() {
63+
vsag::Options::Instance().logger()->SetLevel(vsag::Logger::kINFO);
64+
65+
ExampleAllocator allocator;
66+
vsag::Resource resource(&allocator, nullptr);
67+
vsag::Engine engine(&resource);
68+
69+
auto paramesters = R"(
70+
{
71+
"dtype": "float32",
72+
"metric_type": "l2",
73+
"dim": 4,
74+
"hnsw": {
75+
"max_degree": 5,
76+
"ef_construction": 20
77+
}
78+
}
79+
)";
80+
std::cout << "create index" << std::endl;
81+
auto index = engine.CreateIndex("hnsw", paramesters).value();
82+
83+
std::cout << "prepare data" << std::endl;
84+
int64_t num_vectors = 100;
85+
int64_t dim = 4;
86+
87+
// prepare ids and vectors
88+
std::vector<int64_t> ids(num_vectors);
89+
std::vector<float> vectors(num_vectors * dim);
90+
91+
std::mt19937 rng;
92+
rng.seed(47);
93+
std::uniform_real_distribution<> distrib_real;
94+
for (int64_t i = 0; i < num_vectors; ++i) {
95+
ids[i] = i;
96+
}
97+
for (int64_t i = 0; i < dim * num_vectors; ++i) {
98+
vectors[i] = distrib_real(rng);
99+
}
100+
auto base = vsag::Dataset::Make();
101+
base->NumElements(num_vectors)
102+
->Dim(dim)
103+
->Ids(ids.data())
104+
->Float32Vectors(vectors.data())
105+
->Owner(false);
106+
index->Build(base);
107+
108+
// search on the index
109+
auto query_vector = new float[dim]; // memory will be released by query the dataset
110+
for (int64_t i = 0; i < dim; ++i) {
111+
query_vector[i] = distrib_real(rng);
112+
}
113+
114+
int64_t topk = 10;
115+
auto query = vsag::Dataset::Make();
116+
query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector)->Owner(true);
117+
118+
/******************* HNSW Search *****************/
119+
{
120+
nlohmann::json search_parameters = {
121+
{"hnsw", {{"ef_search", 100}, {"skip_ratio", 0.7f}}},
122+
};
123+
int64_t topk = 10;
124+
auto query = vsag::Dataset::Make();
125+
query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector)->Owner(true);
126+
127+
std::string param_str = search_parameters.dump();
128+
vsag::SearchParam search_param(false, param_str, nullptr, &allocator);
129+
auto result = index->KnnSearch(query, topk, search_param).value();
130+
131+
// print the results
132+
std::cout << "results: " << std::endl;
133+
for (int64_t i = 0; i < result->GetDim(); ++i) {
134+
std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl;
135+
}
136+
137+
allocator.Deallocate((void*)result->GetIds());
138+
allocator.Deallocate((void*)result->GetDistances());
139+
}
140+
141+
/******************* HNSW Iterator Filter *****************/
142+
{
143+
vsag::IteratorContext* iter_ctx = nullptr;
144+
nlohmann::json search_parameters = {
145+
{"hnsw", {{"ef_search", 100}, {"skip_ratio", 0.7f}}},
146+
};
147+
std::string param_str = search_parameters.dump();
148+
vsag::SearchParam search_param(true, param_str, nullptr, &allocator, iter_ctx, false);
149+
150+
/* first search */
151+
{
152+
auto result = index->KnnSearch(query, topk, search_param).value();
153+
154+
// print the results
155+
std::cout << "results: " << std::endl;
156+
for (int64_t i = 0; i < result->GetDim(); ++i) {
157+
std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl;
158+
}
159+
160+
allocator.Deallocate((void*)result->GetIds());
161+
allocator.Deallocate((void*)result->GetDistances());
162+
}
163+
164+
/* last search */
165+
{
166+
search_param.is_last_search = true;
167+
auto result = index->KnnSearch(query, topk, search_param).value();
168+
169+
// print the results
170+
std::cout << "results: " << std::endl;
171+
for (int64_t i = 0; i < result->GetDim(); ++i) {
172+
std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl;
173+
}
174+
175+
allocator.Deallocate((void*)result->GetIds());
176+
allocator.Deallocate((void*)result->GetDistances());
177+
}
178+
}
179+
180+
std::cout << "delete index" << std::endl;
181+
index = nullptr;
182+
engine.Shutdown();
183+
184+
return 0;
185+
}
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#include <iostream>
17+
18+
#include "nlohmann/json.hpp"
19+
#include "vsag/logger.h"
20+
#include "vsag/search_param.h"
21+
#include "vsag/vsag.h"
22+
23+
class ExampleAllocator : public vsag::Allocator {
24+
public:
25+
std::string
26+
Name() override {
27+
return "example-allocator";
28+
}
29+
30+
void*
31+
Allocate(size_t size) override {
32+
vsag::Options::Instance().logger()->Debug("allocate " + std::to_string(size) + " bytes.");
33+
auto addr = (void*)malloc(size);
34+
sizes_[addr] = size;
35+
return addr;
36+
}
37+
38+
void
39+
Deallocate(void* p) override {
40+
if (sizes_.find(p) == sizes_.end())
41+
return;
42+
vsag::Options::Instance().logger()->Debug("deallocate " + std::to_string(sizes_[p]) +
43+
" bytes.");
44+
sizes_.erase(p);
45+
return free(p);
46+
}
47+
48+
void*
49+
Reallocate(void* p, size_t size) override {
50+
vsag::Options::Instance().logger()->Debug("reallocate " + std::to_string(size) + " bytes.");
51+
auto addr = (void*)realloc(p, size);
52+
sizes_.erase(p);
53+
sizes_[addr] = size;
54+
return addr;
55+
}
56+
57+
private:
58+
std::unordered_map<void*, size_t> sizes_;
59+
};
60+
61+
int
62+
main() {
63+
vsag::init();
64+
std::cout << "hgraph index example" << std::endl;
65+
66+
/******************* Prepare Base Dataset *****************/
67+
int64_t num_vectors = 1000;
68+
int64_t dim = 128;
69+
std::vector<int64_t> ids(num_vectors);
70+
std::vector<float> datas(num_vectors * dim);
71+
std::mt19937 rng(47);
72+
std::uniform_real_distribution<float> distrib_real;
73+
for (int64_t i = 0; i < num_vectors; ++i) {
74+
ids[i] = i;
75+
}
76+
for (int64_t i = 0; i < dim * num_vectors; ++i) {
77+
datas[i] = distrib_real(rng);
78+
}
79+
auto base = vsag::Dataset::Make();
80+
base->NumElements(num_vectors)
81+
->Dim(dim)
82+
->Ids(ids.data())
83+
->Float32Vectors(datas.data())
84+
->Owner(false);
85+
86+
/******************* Create HGraph Index *****************/
87+
std::string hgraph_build_parameters = R"(
88+
{
89+
"dtype": "float32",
90+
"metric_type": "l2",
91+
"dim": 128,
92+
"index_param": {
93+
"base_quantization_type": "sq8",
94+
"max_degree": 26,
95+
"ef_construction": 100
96+
}
97+
}
98+
)";
99+
vsag::Resource resource(vsag::Engine::CreateDefaultAllocator(), nullptr);
100+
vsag::Engine engine(&resource);
101+
std::cout << "create index" << std::endl;
102+
auto index = engine.CreateIndex("hgraph", hgraph_build_parameters).value();
103+
104+
ExampleAllocator allocator;
105+
106+
/******************* Build HGraph Index *****************/
107+
if (auto build_result = index->Build(base); build_result.has_value()) {
108+
std::cout << "After Build(), Index HGraph contains: " << index->GetNumElements()
109+
<< std::endl;
110+
} else if (build_result.error().type == vsag::ErrorType::INTERNAL_ERROR) {
111+
std::cerr << "Failed to build index: internalError" << std::endl;
112+
exit(-1);
113+
}
114+
115+
/******************* Prepare Query Dataset *****************/
116+
std::cout << "prepare index" << std::endl;
117+
std::vector<float> query_vector(dim);
118+
for (int64_t i = 0; i < dim; ++i) {
119+
query_vector[i] = distrib_real(rng);
120+
}
121+
auto query = vsag::Dataset::Make();
122+
query->NumElements(1)->Dim(dim)->Float32Vectors(query_vector.data())->Owner(false);
123+
124+
/******************* KnnSearch For HGraph Index *****************/
125+
auto hgraph_search_parameters = R"(
126+
{
127+
"hgraph": {
128+
"ef_search": 100
129+
}
130+
}
131+
)";
132+
int64_t topk = 10;
133+
134+
/******************* Hgraph sq8 Search *****************/
135+
{
136+
nlohmann::json search_parameters = {
137+
{"hgraph", {{"ef_search", 100}, {"skip_ratio", 0.7f}}},
138+
};
139+
std::string param_str = search_parameters.dump();
140+
vsag::SearchParam search_param(false, param_str, nullptr, &allocator);
141+
auto result = index->KnnSearch(query, topk, search_param).value();
142+
143+
// print the results
144+
std::cout << "results: " << std::endl;
145+
for (int64_t i = 0; i < result->GetDim(); ++i) {
146+
std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl;
147+
}
148+
149+
allocator.Deallocate((void*)result->GetIds());
150+
allocator.Deallocate((void*)result->GetDistances());
151+
}
152+
153+
/******************* Hgraph sq8 Iterator Filter *****************/
154+
{
155+
vsag::IteratorContext* iter_ctx = nullptr;
156+
nlohmann::json search_parameters = {
157+
{"hgraph", {{"ef_search", 100}, {"skip_ratio", 0.7f}}},
158+
};
159+
std::string param_str = search_parameters.dump();
160+
vsag::SearchParam search_param(true, param_str, nullptr, &allocator, iter_ctx, false);
161+
162+
/* first search */
163+
{
164+
auto result = index->KnnSearch(query, topk, search_param).value();
165+
166+
// print the results
167+
std::cout << "results: " << std::endl;
168+
for (int64_t i = 0; i < result->GetDim(); ++i) {
169+
std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl;
170+
}
171+
172+
allocator.Deallocate((void*)result->GetIds());
173+
allocator.Deallocate((void*)result->GetDistances());
174+
}
175+
176+
/* last search */
177+
{
178+
search_param.is_last_search = true;
179+
auto result = index->KnnSearch(query, topk, search_param).value();
180+
181+
// print the results
182+
std::cout << "results: " << std::endl;
183+
for (int64_t i = 0; i < result->GetDim(); ++i) {
184+
std::cout << result->GetIds()[i] << ": " << result->GetDistances()[i] << std::endl;
185+
}
186+
187+
allocator.Deallocate((void*)result->GetIds());
188+
allocator.Deallocate((void*)result->GetDistances());
189+
}
190+
}
191+
192+
engine.Shutdown();
193+
return 0;
194+
}

examples/cpp/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ target_link_libraries(310_feature_export_model vsag)
6161
add_executable(311_feature_train 311_feature_train.cpp)
6262
target_link_libraries(311_feature_train vsag)
6363

64+
add_executable(313_feature_search_allocator 313_feature_search_allocator.cpp)
65+
target_link_libraries(313_feature_search_allocator vsag)
66+
67+
add_executable(314_feature_hgraph_search_allocator 314_feature_hgraph_search_allocator.cpp)
68+
target_link_libraries(314_feature_hgraph_search_allocator vsag)
69+
6470
add_executable (401_persistent_kv 401_persistent_kv.cpp)
6571
target_link_libraries (401_persistent_kv vsag)
6672

0 commit comments

Comments
 (0)