Skip to content

Commit 6dc574f

Browse files
committed
Add: String-search example
1 parent da84fa2 commit 6dc574f

File tree

3 files changed

+82
-5
lines changed

3 files changed

+82
-5
lines changed

cpp/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ if (USEARCH_BUILD_TEST_CPP)
44
include(CTest)
55
enable_testing()
66
add_test(NAME test_cpp COMMAND test_cpp)
7+
8+
target_include_directories(test_cpp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../stringzilla/include)
9+
10+
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
11+
target_compile_options(test_cpp PRIVATE -Wno-vla -Wno-unused-function -Wno-cast-function-type)
12+
endif ()
713
endif ()
814

915
if (USEARCH_BUILD_BENCH_CPP)

cpp/test.cpp

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include <unordered_map> // `std::unordered_map`
1212
#include <vector> // `std::vector`
1313

14+
#define SZ_USE_X86_AVX512 0 // Sanitizers hate AVX512
15+
#include <stringzilla/stringzilla.hpp> // Levenshtein distance implementation
16+
1417
#include <usearch/index.hpp>
1518
#include <usearch/index_dense.hpp>
1619
#include <usearch/index_plugins.hpp>
@@ -559,6 +562,73 @@ void test_sets(std::size_t collection_size, std::size_t min_set_length, std::siz
559562
}
560563
}
561564

565+
/**
566+
* Tests similarity search over strings using Levenshtein distances
567+
* implementation from StringZilla.
568+
*
569+
* Adds a predefined number of long strings, comparing them.
570+
*
571+
* @param index A reference to the index instance to be tested.
572+
* @tparam index_at Type of the index being tested.
573+
*/
574+
template <typename key_at, typename slot_at> void test_strings() {
575+
576+
namespace sz = ashvardanian::stringzilla;
577+
578+
/// Levenshtein distance is an integer
579+
using levenshtein_distance_t = std::uint64_t;
580+
581+
// Aliasis for the index overload
582+
using vector_key_t = key_at;
583+
using slot_t = slot_at;
584+
using index_t = index_gt<levenshtein_distance_t, vector_key_t, slot_t>;
585+
586+
std::string_view str0 = "ACGTACGTACGTACGTACGTACGTACGTACGTACGT";
587+
std::string_view str1 = "ACG_ACTC_TAC-TACGTA_GTACACG_ACGT";
588+
std::string_view str2 = "A_GTACTACGTA-GTAC_TACGTACGTA-GTAGT";
589+
std::string_view str3 = "GTACGTAGT-ACGTACGACGTACGTACG-TACGTAC";
590+
std::vector<std::string_view> strings({str0, str1, str2, str3});
591+
592+
// Wrap the data into a proxy object
593+
struct metric_t {
594+
using member_cref_t = typename index_t::member_cref_t;
595+
using member_citerator_t = typename index_t::member_citerator_t;
596+
597+
std::vector<std::string_view> const* strings_ptr = nullptr;
598+
599+
std::string_view str_at(std::size_t i) const noexcept { return (*strings_ptr)[i]; }
600+
levenshtein_distance_t between(std::string_view a, std::string_view b) const {
601+
return sz::edit_distance(sz::string_view(a), sz::string_view(b));
602+
}
603+
604+
levenshtein_distance_t operator()(member_cref_t const& a, member_cref_t const& b) const {
605+
return between(str_at(get_slot(b)), str_at(get_slot(a)));
606+
}
607+
levenshtein_distance_t operator()(std::string_view some_vector, member_cref_t const& member) const {
608+
return between(some_vector, str_at(get_slot(member)));
609+
}
610+
levenshtein_distance_t operator()(member_citerator_t const& a, member_citerator_t const& b) const {
611+
return between(str_at(get_slot(b)), str_at(get_slot(a)));
612+
}
613+
levenshtein_distance_t operator()(std::string_view some_vector, member_citerator_t const& member) const {
614+
return between(some_vector, str_at(get_slot(member)));
615+
}
616+
};
617+
618+
// Perform indexing
619+
aligned_wrapper_gt<index_t> aligned_index;
620+
aligned_index.index->reserve(strings.size());
621+
for (std::size_t i = 0; i < strings.size(); i++)
622+
aligned_index.index->add(i, strings[i], metric_t{&strings});
623+
expect(aligned_index.index->size() == strings.size());
624+
625+
// Perform the search queries
626+
for (std::size_t i = 0; i < strings.size(); i++) {
627+
auto results = aligned_index.index->search(strings[i], 5, metric_t{&strings});
628+
expect(results.size() > 0);
629+
}
630+
}
631+
562632
int main(int, char**) {
563633

564634
// Exact search without constructing indexes.
@@ -570,10 +640,10 @@ int main(int, char**) {
570640

571641
// Make sure the initializers and the algorithms can work with inadequately small values.
572642
// Be warned - this combinatorial explosion of tests produces close to __500'000__ tests!
573-
for (std::size_t connectivity : {0, 1, 2, 3, 16})
574-
for (std::size_t dimensions : {1, 2, 3, 16}) // TODO: Add zero
575-
for (std::size_t expansion_add : {0, 1, 2, 3, 16})
576-
for (std::size_t expansion_search : {0, 1, 2, 3, 16})
643+
for (std::size_t connectivity : {0, 1, 2, 3})
644+
for (std::size_t dimensions : {1, 2, 3}) // TODO: Add zero?
645+
for (std::size_t expansion_add : {0, 1, 2, 3})
646+
for (std::size_t expansion_search : {0, 1, 2, 3})
577647
for (std::size_t count_vectors : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
578648
for (std::size_t count_wanted : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
579649
test_absurd<std::int64_t, std::uint32_t>(dimensions, connectivity, expansion_add,
@@ -596,6 +666,7 @@ int main(int, char**) {
596666
// Beyond dense equi-dimensional vectors - integer sets
597667
for (std::size_t set_size : {1, 100, 1000})
598668
test_sets<std::int64_t, std::uint32_t>(set_size, 20, 30);
669+
test_strings<std::int64_t, std::uint32_t>();
599670

600671
return 0;
601672
}

0 commit comments

Comments
 (0)