@@ -248,6 +248,17 @@ class Index {
248248 }
249249
250250
251+ // std::vector<tableint> getEntities(tableint internal_id) {
252+ // if (!appr_alg) {
253+ // throw std::runtime_error("Index not initialized");
254+ // }
255+ // if (internal_id >= appr_alg->getCurrentElementCount()) {
256+ // throw std::out_of_range("Invalid internal id");
257+ // }
258+ // return appr_alg->node_entities_[internal_id]
259+ // }
260+
261+
251262 void addItems (py::object input, py::object ids_ = py::none(), int num_threads = -1, bool replace_deleted = false) {
252263 py::array_t < dist_t , py::array::c_style | py::array::forcecast > items (input);
253264 auto buffer = items.request ();
@@ -303,6 +314,100 @@ class Index {
303314 }
304315 }
305316
317+ void addItemsWithEntities (py::object input, py::object ids_ = py::none(), py::object entities_ = py::none(),
318+ int num_threads = -1, bool replace_deleted = false)
319+ {
320+
321+ std::cout << " CUSTOM FUNCTION CALLED" << std::endl;
322+ if (!entities_.is_none ()) {
323+ std::cout << " Entities passed:" << std::endl;
324+
325+ // Convert Python object to iterable
326+ py::list entity_list = entities_;
327+ for (size_t i = 0 ; i < entity_list.size (); i++) {
328+ py::object ent = entity_list[i];
329+ // Convert to string for printing
330+ std::string ent_str = py::str (ent);
331+ std::cout << " " << i << " : " << ent_str << std::endl;
332+ }
333+ } else {
334+ std::cout << " No entities provided." << std::endl;
335+ }
336+
337+ py::array_t < dist_t , py::array::c_style | py::array::forcecast > items (input);
338+ auto buffer = items.request ();
339+ if (num_threads <= 0 )
340+ num_threads = num_threads_default;
341+
342+ size_t rows, features;
343+ get_input_array_shapes (buffer, &rows, &features);
344+
345+ if (features != dim)
346+ throw std::runtime_error (" Wrong dimensionality of the vectors" );
347+
348+ // avoid using threads when the number of additions is small:
349+ if (rows <= num_threads * 4 ) {
350+ num_threads = 1 ;
351+ }
352+
353+ std::vector<size_t > ids = get_input_ids_and_check_shapes (ids_, rows);
354+
355+ py::array_t <hnswlib::tableint, py::array::c_style | py::array::forcecast> entities_arr =
356+ entities_.cast <py::array_t <hnswlib::tableint>>();
357+
358+ auto buf = entities_arr.request ();
359+
360+ size_t entity_rows = buf.shape [0 ];
361+ size_t cols = buf.shape [1 ];
362+
363+ auto * data = static_cast <hnswlib::tableint*>(buf.ptr );
364+
365+ for (size_t id : ids) {
366+ std::cout << " id is: " << id << " \n " ;
367+ std::cout << " entities[" << id << " ]: " ;
368+
369+ for (size_t j = 0 ; j < cols; j++) {
370+ std::cout << data[id * cols + j] << " " ;
371+ }
372+ std::cout << " \n " ;
373+ }
374+
375+ {
376+ int start = 0 ;
377+ if (!ep_added) {
378+ size_t id = ids.size () ? ids.at (0 ) : (cur_l);
379+ float * vector_data = (float *)items.data (0 );
380+ std::vector<float > norm_array (dim);
381+ if (normalize) {
382+ normalize_vector (vector_data, norm_array.data ());
383+ vector_data = norm_array.data ();
384+ }
385+ appr_alg->addPoint ((void *)vector_data, (size_t )id, replace_deleted);
386+ start = 1 ;
387+ ep_added = true ;
388+ }
389+
390+ py::gil_scoped_release l;
391+ if (normalize == false ) {
392+ ParallelFor (start, rows, num_threads, [&](size_t row, size_t threadId) {
393+ size_t id = ids.size () ? ids.at (row) : (cur_l + row);
394+ appr_alg->addPoint ((void *)items.data (row), (size_t )id, replace_deleted);
395+ });
396+ } else {
397+ std::vector<float > norm_array (num_threads * dim);
398+ ParallelFor (start, rows, num_threads, [&](size_t row, size_t threadId) {
399+ // normalize vector:
400+ size_t start_idx = threadId * dim;
401+ normalize_vector ((float *)items.data (row), (norm_array.data () + start_idx));
402+
403+ size_t id = ids.size () ? ids.at (row) : (cur_l + row);
404+ appr_alg->addPoint ((void *)(norm_array.data () + start_idx), (size_t )id, replace_deleted);
405+ });
406+ }
407+ cur_l += rows;
408+ }
409+ }
410+
306411
307412 py::object getData (py::object ids_ = py::none(), std::string return_type = "numpy") {
308413 std::vector<std::string> return_types{" numpy" , " list" };
@@ -934,6 +1039,13 @@ PYBIND11_PLUGIN(hnswlib) {
9341039 py::arg (" ids" ) = py::none (),
9351040 py::arg (" num_threads" ) = -1 ,
9361041 py::arg (" replace_deleted" ) = false )
1042+ .def (" add_items_with_entities" ,
1043+ &Index<float >::addItemsWithEntities,
1044+ py::arg (" data" ),
1045+ py::arg (" ids_" ) = py::none (),
1046+ py::arg (" entities_" ) = py::none (),
1047+ py::arg (" num_threads" ) = -1 ,
1048+ py::arg (" replace_deleted" ) = false )
9371049 .def (" get_items" , &Index<float >::getData, py::arg (" ids" ) = py::none (), py::arg (" return_type" ) = " numpy" )
9381050 .def (" get_ids_list" , &Index<float >::getIdsList)
9391051 .def (" set_ef" , &Index<float >::set_ef, py::arg (" ef" ))
0 commit comments