Skip to content

Commit 6b82f3f

Browse files
committed
pass entities
1 parent c1b9b79 commit 6b82f3f

File tree

3 files changed

+165
-0
lines changed

3 files changed

+165
-0
lines changed

hnswlib/ats_dummy.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
#include <iostream>
3+
4+
class ATSDummy {
5+
public:
6+
static void ping() {
7+
// std::cout << "[ATS] ats_dummy.h included successfully" << std::endl;
8+
}
9+
};

hnswlib/hnswalg.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <list>
1111
#include <memory>
1212

13+
#include "ats_dummy.h"
14+
1315
namespace hnswlib {
1416
typedef unsigned int tableint;
1517
typedef unsigned int linklistsizeint;
@@ -70,6 +72,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
7072
std::mutex deleted_elements_lock; // lock for deleted_elements
7173
std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
7274

75+
std::vector<std::vector<tableint>> node_entities_;
76+
7377

7478
HierarchicalNSW(SpaceInterface<dist_t> *s) {
7579
}
@@ -96,6 +100,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
96100
: label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
97101
link_list_locks_(max_elements),
98102
element_levels_(max_elements),
103+
node_entities_(max_elements),
99104
allow_replace_deleted_(allow_replace_deleted) {
100105
max_elements_ = max_elements;
101106
num_deleted_ = 0;
@@ -169,6 +174,18 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
169174
}
170175
};
171176

177+
double getJaccardSimilarity(const std::unordered_set<tableint>& a,
178+
const std::unordered_set<tableint>& b) {
179+
180+
size_t intersection = 0;
181+
for (const auto& x : a) {
182+
if (b.count(x)) intersection++;
183+
}
184+
185+
size_t union_count = a.size() + b.size() - intersection;
186+
return union_count == 0 ? 0.0 : static_cast<double>(intersection) / union_count;
187+
}
188+
172189

173190
void setEf(size_t ef) {
174191
ef_ = ef;
@@ -224,6 +241,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
224241

225242
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
226243
searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
244+
ATSDummy::ping();
245+
// std::unordered_set<tableint> setA = {1, 2, 3, 4};
246+
// std::unordered_set<tableint> setB = {3, 4, 5, 6};
247+
// double sim = getJaccardSimilarity(setA, setB);
248+
// std::cout << "Jaccard similarity: " << sim << std::endl;
249+
227250
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
228251
vl_type *visited_array = vl->mass;
229252
vl_type visited_array_tag = vl->curV;
@@ -683,6 +706,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
683706
}
684707

685708
void saveIndex(const std::string &location) {
709+
std::cout<< "======================SAVING=============================\n";
686710
std::ofstream output(location, std::ios::binary);
687711
std::streampos position;
688712

@@ -714,7 +738,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
714738

715739

716740
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i = 0) {
741+
std::cout<< "======================LOADING=============================\n";
717742
std::ifstream input(location, std::ios::binary);
743+
node_entities_.resize(cur_element_count);
718744

719745
if (!input.is_open())
720746
throw std::runtime_error("Cannot open file");
@@ -1265,6 +1291,24 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
12651291
}
12661292
return cur_c;
12671293
}
1294+
1295+
1296+
tableint addPointWithEntities(
1297+
const void* data_point,
1298+
labeltype label,
1299+
const std::vector<tableint>& entities,
1300+
int level = -1) {
1301+
tableint id = addPoint(data_point, label, level);
1302+
1303+
// Ensure capacity
1304+
if (id >= node_entities_.size()) {
1305+
node_entities_.resize(max_elements_);
1306+
}
1307+
1308+
node_entities_[id] = entities;
1309+
return id;
1310+
}
1311+
12681312

12691313

12701314
std::priority_queue<std::pair<dist_t, labeltype >>

python_bindings/bindings.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,17 @@ class Index {
248248
}
249249

250250

251+
// std::vector<tableint> getEntities(tableint internal_id) {
252+
// if (!appr_alg) {
253+
// throw std::runtime_error("Index not initialized");
254+
// }
255+
// if (internal_id >= appr_alg->getCurrentElementCount()) {
256+
// throw std::out_of_range("Invalid internal id");
257+
// }
258+
// return appr_alg->node_entities_[internal_id]
259+
// }
260+
261+
251262
void addItems(py::object input, py::object ids_ = py::none(), int num_threads = -1, bool replace_deleted = false) {
252263
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
253264
auto buffer = items.request();
@@ -303,6 +314,100 @@ class Index {
303314
}
304315
}
305316

317+
void addItemsWithEntities(py::object input, py::object ids_ = py::none(), py::object entities_ = py::none(),
318+
int num_threads = -1, bool replace_deleted = false)
319+
{
320+
321+
std::cout << "CUSTOM FUNCTION CALLED" << std::endl;
322+
if (!entities_.is_none()) {
323+
std::cout << "Entities passed:" << std::endl;
324+
325+
// Convert Python object to iterable
326+
py::list entity_list = entities_;
327+
for (size_t i = 0; i < entity_list.size(); i++) {
328+
py::object ent = entity_list[i];
329+
// Convert to string for printing
330+
std::string ent_str = py::str(ent);
331+
std::cout << " " << i << ": " << ent_str << std::endl;
332+
}
333+
} else {
334+
std::cout << "No entities provided." << std::endl;
335+
}
336+
337+
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
338+
auto buffer = items.request();
339+
if (num_threads <= 0)
340+
num_threads = num_threads_default;
341+
342+
size_t rows, features;
343+
get_input_array_shapes(buffer, &rows, &features);
344+
345+
if (features != dim)
346+
throw std::runtime_error("Wrong dimensionality of the vectors");
347+
348+
// avoid using threads when the number of additions is small:
349+
if (rows <= num_threads * 4) {
350+
num_threads = 1;
351+
}
352+
353+
std::vector<size_t> ids = get_input_ids_and_check_shapes(ids_, rows);
354+
355+
py::array_t<hnswlib::tableint, py::array::c_style | py::array::forcecast> entities_arr =
356+
entities_.cast<py::array_t<hnswlib::tableint>>();
357+
358+
auto buf = entities_arr.request();
359+
360+
size_t entity_rows = buf.shape[0];
361+
size_t cols = buf.shape[1];
362+
363+
auto* data = static_cast<hnswlib::tableint*>(buf.ptr);
364+
365+
for (size_t id : ids) {
366+
std::cout << "id is: " << id << "\n";
367+
std::cout << "entities[" << id << "]: ";
368+
369+
for (size_t j = 0; j < cols; j++) {
370+
std::cout << data[id * cols + j] << " ";
371+
}
372+
std::cout << "\n";
373+
}
374+
375+
{
376+
int start = 0;
377+
if (!ep_added) {
378+
size_t id = ids.size() ? ids.at(0) : (cur_l);
379+
float* vector_data = (float*)items.data(0);
380+
std::vector<float> norm_array(dim);
381+
if (normalize) {
382+
normalize_vector(vector_data, norm_array.data());
383+
vector_data = norm_array.data();
384+
}
385+
appr_alg->addPoint((void*)vector_data, (size_t)id, replace_deleted);
386+
start = 1;
387+
ep_added = true;
388+
}
389+
390+
py::gil_scoped_release l;
391+
if (normalize == false) {
392+
ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) {
393+
size_t id = ids.size() ? ids.at(row) : (cur_l + row);
394+
appr_alg->addPoint((void*)items.data(row), (size_t)id, replace_deleted);
395+
});
396+
} else {
397+
std::vector<float> norm_array(num_threads * dim);
398+
ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) {
399+
// normalize vector:
400+
size_t start_idx = threadId * dim;
401+
normalize_vector((float*)items.data(row), (norm_array.data() + start_idx));
402+
403+
size_t id = ids.size() ? ids.at(row) : (cur_l + row);
404+
appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id, replace_deleted);
405+
});
406+
}
407+
cur_l += rows;
408+
}
409+
}
410+
306411

307412
py::object getData(py::object ids_ = py::none(), std::string return_type = "numpy") {
308413
std::vector<std::string> return_types{"numpy", "list"};
@@ -934,6 +1039,13 @@ PYBIND11_PLUGIN(hnswlib) {
9341039
py::arg("ids") = py::none(),
9351040
py::arg("num_threads") = -1,
9361041
py::arg("replace_deleted") = false)
1042+
.def("add_items_with_entities",
1043+
&Index<float>::addItemsWithEntities,
1044+
py::arg("data"),
1045+
py::arg("ids_") = py::none(),
1046+
py::arg("entities_") = py::none(),
1047+
py::arg("num_threads") = -1,
1048+
py::arg("replace_deleted") = false)
9371049
.def("get_items", &Index<float>::getData, py::arg("ids") = py::none(), py::arg("return_type") = "numpy")
9381050
.def("get_ids_list", &Index<float>::getIdsList)
9391051
.def("set_ef", &Index<float>::set_ef, py::arg("ef"))

0 commit comments

Comments
 (0)