11
11
#include < unordered_map> // `std::unordered_map`
12
12
#include < vector> // `std::vector`
13
13
14
+ #define SZ_USE_X86_AVX512 0 // Sanitizers hate AVX512
15
+ #include < stringzilla/stringzilla.hpp> // Levenshtein distance implementation
16
+
14
17
#include < usearch/index.hpp>
15
18
#include < usearch/index_dense.hpp>
16
19
#include < usearch/index_plugins.hpp>
@@ -559,6 +562,73 @@ void test_sets(std::size_t collection_size, std::size_t min_set_length, std::siz
559
562
}
560
563
}
561
564
565
+ /* *
566
+ * Tests similarity search over strings using Levenshtein distances
567
+ * implementation from StringZilla.
568
+ *
569
+ * Adds a predefined number of long strings, comparing them.
570
+ *
571
+ * @param index A reference to the index instance to be tested.
572
+ * @tparam index_at Type of the index being tested.
573
+ */
574
+ template <typename key_at, typename slot_at> void test_strings () {
575
+
576
+ namespace sz = ashvardanian::stringzilla;
577
+
578
+ // / Levenshtein distance is an integer
579
+ using levenshtein_distance_t = std::uint64_t ;
580
+
581
+ // Aliasis for the index overload
582
+ using vector_key_t = key_at;
583
+ using slot_t = slot_at;
584
+ using index_t = index_gt<levenshtein_distance_t , vector_key_t , slot_t >;
585
+
586
+ std::string_view str0 = " ACGTACGTACGTACGTACGTACGTACGTACGTACGT" ;
587
+ std::string_view str1 = " ACG_ACTC_TAC-TACGTA_GTACACG_ACGT" ;
588
+ std::string_view str2 = " A_GTACTACGTA-GTAC_TACGTACGTA-GTAGT" ;
589
+ std::string_view str3 = " GTACGTAGT-ACGTACGACGTACGTACG-TACGTAC" ;
590
+ std::vector<std::string_view> strings ({str0, str1, str2, str3});
591
+
592
+ // Wrap the data into a proxy object
593
+ struct metric_t {
594
+ using member_cref_t = typename index_t ::member_cref_t ;
595
+ using member_citerator_t = typename index_t ::member_citerator_t ;
596
+
597
+ std::vector<std::string_view> const * strings_ptr = nullptr ;
598
+
599
+ std::string_view str_at (std::size_t i) const noexcept { return (*strings_ptr)[i]; }
600
+ levenshtein_distance_t between (std::string_view a, std::string_view b) const {
601
+ return sz::edit_distance (sz::string_view (a), sz::string_view (b));
602
+ }
603
+
604
+ levenshtein_distance_t operator ()(member_cref_t const & a, member_cref_t const & b) const {
605
+ return between (str_at (get_slot (b)), str_at (get_slot (a)));
606
+ }
607
+ levenshtein_distance_t operator ()(std::string_view some_vector, member_cref_t const & member) const {
608
+ return between (some_vector, str_at (get_slot (member)));
609
+ }
610
+ levenshtein_distance_t operator ()(member_citerator_t const & a, member_citerator_t const & b) const {
611
+ return between (str_at (get_slot (b)), str_at (get_slot (a)));
612
+ }
613
+ levenshtein_distance_t operator ()(std::string_view some_vector, member_citerator_t const & member) const {
614
+ return between (some_vector, str_at (get_slot (member)));
615
+ }
616
+ };
617
+
618
+ // Perform indexing
619
+ aligned_wrapper_gt<index_t > aligned_index;
620
+ aligned_index.index ->reserve (strings.size ());
621
+ for (std::size_t i = 0 ; i < strings.size (); i++)
622
+ aligned_index.index ->add (i, strings[i], metric_t {&strings});
623
+ expect (aligned_index.index ->size () == strings.size ());
624
+
625
+ // Perform the search queries
626
+ for (std::size_t i = 0 ; i < strings.size (); i++) {
627
+ auto results = aligned_index.index ->search (strings[i], 5 , metric_t {&strings});
628
+ expect (results.size () > 0 );
629
+ }
630
+ }
631
+
562
632
int main (int , char **) {
563
633
564
634
// Exact search without constructing indexes.
@@ -570,10 +640,10 @@ int main(int, char**) {
570
640
571
641
// Make sure the initializers and the algorithms can work with inadequately small values.
572
642
// Be warned - this combinatorial explosion of tests produces close to __500'000__ tests!
573
- for (std::size_t connectivity : {0 , 1 , 2 , 3 , 16 })
574
- for (std::size_t dimensions : {1 , 2 , 3 , 16 }) // TODO: Add zero
575
- for (std::size_t expansion_add : {0 , 1 , 2 , 3 , 16 })
576
- for (std::size_t expansion_search : {0 , 1 , 2 , 3 , 16 })
643
+ for (std::size_t connectivity : {0 , 1 , 2 , 3 })
644
+ for (std::size_t dimensions : {1 , 2 , 3 }) // TODO: Add zero?
645
+ for (std::size_t expansion_add : {0 , 1 , 2 , 3 })
646
+ for (std::size_t expansion_search : {0 , 1 , 2 , 3 })
577
647
for (std::size_t count_vectors : {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 })
578
648
for (std::size_t count_wanted : {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 })
579
649
test_absurd<std::int64_t , std::uint32_t >(dimensions, connectivity, expansion_add,
@@ -596,6 +666,7 @@ int main(int, char**) {
596
666
// Beyond dense equi-dimensional vectors - integer sets
597
667
for (std::size_t set_size : {1 , 100 , 1000 })
598
668
test_sets<std::int64_t , std::uint32_t >(set_size, 20 , 30 );
669
+ test_strings<std::int64_t , std::uint32_t >();
599
670
600
671
return 0 ;
601
672
}
0 commit comments