Skip to content

Commit e99fe4a

Browse files
committed
Merge: Filtered Search for GoLang
2 parents 4f02166 + 55c1ea8 commit e99fe4a

File tree

4 files changed

+221
-2
lines changed

4 files changed

+221
-2
lines changed

c/lib.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ USEARCH_EXPORT size_t usearch_search(
413413
USEARCH_EXPORT size_t usearch_filtered_search( //
414414
usearch_index_t index, //
415415
void const* query, usearch_scalar_kind_t query_kind, size_t results_limit, //
416-
int (*filter)(usearch_key_t key, void* filter_state), void* filter_state, //
416+
usearch_filtered_search_callback_t filter, void* filter_state, //
417417
usearch_key_t* found_keys, usearch_distance_t* found_distances, usearch_error_t* error) {
418418

419419
USEARCH_ASSERT(index && query && filter && error && "Missing arguments");

c/usearch.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ USEARCH_EXPORT typedef struct usearch_init_options_t {
109109
bool multi;
110110
} usearch_init_options_t;
111111

112+
extern int goFilteredSearchCallback(usearch_key_t, void*);
113+
114+
USEARCH_EXPORT typedef int (*usearch_filtered_search_callback_t)(usearch_key_t, void*);
115+
112116
/**
113117
* @brief Retrieves the version of the library.
114118
* @return The version of the library.
@@ -391,7 +395,7 @@ USEARCH_EXPORT size_t usearch_search( //
391395
USEARCH_EXPORT size_t usearch_filtered_search( //
392396
usearch_index_t index, //
393397
void const* query_vector, usearch_scalar_kind_t query_kind, size_t count, //
394-
int (*filter)(usearch_key_t key, void* filter_state), void* filter_state, //
398+
usearch_filtered_search_callback_t filter, void* filter_state, //
395399
usearch_key_t* keys, usearch_distance_t* distances, usearch_error_t* error);
396400

397401
/**

golang/lib.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ func DefaultConfig(dimensions uint) IndexConfig {
206206
return c
207207
}
208208

209+
// FilteredSearchHandler include the callback functiona and user data
210+
type FilteredSearchHandler struct {
211+
Callback func(key Key, handler *FilteredSearchHandler) int
212+
Data any
213+
}
214+
209215
// Index represents a USearch approximate nearest neighbor index.
210216
// It implements io.Closer for idiomatic resource cleanup.
211217
//
@@ -638,6 +644,56 @@ func (index *Index) Search(query []float32, limit uint) (keys []Key, distances [
638644
return keys, distances, nil
639645
}
640646

647+
// Search finds the k nearest neighbors to the query vector.
648+
//
649+
// Parameters:
650+
// - query: Must have exactly Dimensions() elements
651+
// - limit: Maximum number of results to return
652+
//
653+
// Returns:
654+
// - keys: IDs of the nearest vectors (up to limit)
655+
// - distances: Distance to each result (same length as keys)
656+
// - err: Error if query is invalid or search fails
657+
//
658+
// The actual number of results may be less than limit if the index
659+
// contains fewer vectors.
660+
func (index *Index) FilteredSearch(query []float32, limit uint, handler *FilteredSearchHandler) (keys []Key, distances []float32, err error) {
661+
if index.handle == nil {
662+
panic("index is uninitialized")
663+
}
664+
665+
if len(query) == 0 {
666+
return nil, nil, errors.New("query vector cannot be empty")
667+
}
668+
if uint(len(query)) != index.config.Dimensions {
669+
return nil, nil, fmt.Errorf("query dimension mismatch: got %d, expected %d", len(query), index.config.Dimensions)
670+
}
671+
if handler == nil {
672+
return nil, nil, errors.New("filtered search handler cannot be nil")
673+
}
674+
if limit == 0 {
675+
return []Key{}, []float32{}, nil
676+
}
677+
678+
keys = make([]Key, limit)
679+
distances = make([]float32, limit)
680+
var errorMessage *C.char
681+
resultCount := uint(C.usearch_filtered_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_f32_k, (C.size_t)(limit),
682+
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
683+
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
684+
runtime.KeepAlive(query)
685+
runtime.KeepAlive(keys)
686+
runtime.KeepAlive(distances)
687+
runtime.KeepAlive(handler)
688+
if errorMessage != nil {
689+
return nil, nil, errors.New(C.GoString(errorMessage))
690+
}
691+
692+
keys = keys[:resultCount]
693+
distances = distances[:resultCount]
694+
return keys, distances, nil
695+
}
696+
641697
// SearchUnsafe performs k-Approximate Nearest Neighbors Search using an unsafe pointer.
642698
//
643699
// SAFETY REQUIREMENTS:
@@ -675,6 +731,48 @@ func (index *Index) SearchUnsafe(query unsafe.Pointer, limit uint) (keys []Key,
675731
return keys, distances, nil
676732
}
677733

734+
//export goFilteredSearchCallback
735+
func goFilteredSearchCallback(key C.usearch_key_t, ptr unsafe.Pointer) C.int {
736+
handler := (*FilteredSearchHandler)(ptr)
737+
return C.int(handler.Callback(Key(key), handler))
738+
}
739+
740+
// Filtred Search performs k-Approximate Nearest Neighbors Search for the closest vectors to the query vector with filtering.
741+
func (index *Index) FilteredSearchUnsafe(query unsafe.Pointer, limit uint, handler *FilteredSearchHandler) (keys []Key, distances []float32, err error) {
742+
if index.handle == nil {
743+
panic("index is uninitialized")
744+
}
745+
746+
if query == nil {
747+
return nil, nil, errors.New("query pointer cannot be nil")
748+
}
749+
750+
if handler == nil {
751+
return nil, nil, errors.New("filtered search handler cannot be nil")
752+
}
753+
754+
if limit == 0 {
755+
return []Key{}, []float32{}, nil
756+
}
757+
758+
keys = make([]Key, limit)
759+
distances = make([]float32, limit)
760+
var errorMessage *C.char
761+
resultCount := uint(C.usearch_filtered_search(index.handle, query, index.config.Quantization.CValue(), (C.size_t)(limit),
762+
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
763+
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
764+
runtime.KeepAlive(keys)
765+
runtime.KeepAlive(distances)
766+
runtime.KeepAlive(handler)
767+
if errorMessage != nil {
768+
return nil, nil, errors.New(C.GoString(errorMessage))
769+
}
770+
771+
keys = keys[:resultCount]
772+
distances = distances[:resultCount]
773+
return keys, distances, nil
774+
}
775+
678776
// ExactSearch performs multithreaded exact nearest neighbors search.
679777
// Unlike the index-based search, this computes distances to all vectors in the dataset.
680778
//
@@ -823,6 +921,43 @@ func (index *Index) SearchI8(query []int8, limit uint) (keys []Key, distances []
823921
return keys, distances, nil
824922
}
825923

924+
func (index *Index) FilteredSearchI8(query []int8, limit uint, handler *FilteredSearchHandler) (keys []Key, distances []float32, err error) {
925+
if index.handle == nil {
926+
panic("index is uninitialized")
927+
}
928+
929+
if len(query) == 0 {
930+
return nil, nil, errors.New("query vector cannot be empty")
931+
}
932+
if uint(len(query)) != index.config.Dimensions {
933+
return nil, nil, fmt.Errorf("query dimension mismatch: got %d, expected %d", len(query), index.config.Dimensions)
934+
}
935+
if handler == nil {
936+
return nil, nil, errors.New("filtered search handler cannot be nil")
937+
}
938+
if limit == 0 {
939+
return []Key{}, []float32{}, nil
940+
}
941+
942+
keys = make([]Key, limit)
943+
distances = make([]float32, limit)
944+
var errorMessage *C.char
945+
resultCount := uint(C.usearch_filtered_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_i8_k, (C.size_t)(limit),
946+
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
947+
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
948+
runtime.KeepAlive(query)
949+
runtime.KeepAlive(keys)
950+
runtime.KeepAlive(distances)
951+
runtime.KeepAlive(handler)
952+
if errorMessage != nil {
953+
return nil, nil, errors.New(C.GoString(errorMessage))
954+
}
955+
956+
keys = keys[:resultCount]
957+
distances = distances[:resultCount]
958+
return keys, distances, nil
959+
}
960+
826961
// DistanceI8 computes the distance between two int8 vectors.
827962
//
828963
// Example:

golang/lib_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,26 @@ func TestQuantizationTypes(t *testing.T) {
538538
if len(keys) == 0 || keys[0] != 1 {
539539
t.Fatalf("F32 search results incorrect")
540540
}
541+
542+
// Test FilteredSearch
543+
handler := &FilteredSearchHandler{
544+
Callback: func(key Key, handler *FilteredSearchHandler) int {
545+
if key%2 == 0 {
546+
return 1
547+
}
548+
return 0
549+
},
550+
Data: int64(1),
551+
}
552+
553+
keys, _, err = index.FilteredSearch(vector, 1, handler)
554+
if err != nil {
555+
t.Fatalf("F32 FilteredSearch failed: %v", err)
556+
}
557+
558+
if len(keys) > 0 {
559+
t.Fatalf("F32 FilteredSearch returned incorrect results")
560+
}
541561
})
542562

543563
t.Run("F64 operations", func(t *testing.T) {
@@ -569,6 +589,26 @@ func TestQuantizationTypes(t *testing.T) {
569589
if len(keys) == 0 || keys[0] != 1 {
570590
t.Fatalf("F64 search results incorrect")
571591
}
592+
593+
// Test F64 FilteredSearchUnsafe
594+
handler := &FilteredSearchHandler{
595+
Callback: func(key Key, handler *FilteredSearchHandler) int {
596+
if key%2 == 0 {
597+
return 1
598+
}
599+
return 0
600+
},
601+
Data: int64(1),
602+
}
603+
604+
keys, _, err = index.FilteredSearchUnsafe(unsafe.Pointer(&vector[0]), 5, handler)
605+
if err != nil {
606+
t.Fatalf("F64 FilteredSearchUnsafe failed: %v", err)
607+
}
608+
609+
if len(keys) > 0 {
610+
t.Fatalf("F64 FilteredSearchUnsafe returned incorrect results")
611+
}
572612
})
573613

574614
t.Run("I8 operations", func(t *testing.T) {
@@ -596,6 +636,26 @@ func TestQuantizationTypes(t *testing.T) {
596636
if len(keys) == 0 || keys[0] != 1 {
597637
t.Fatalf("I8 search results incorrect")
598638
}
639+
640+
// Test FilteredSearchI8
641+
handler := &FilteredSearchHandler{
642+
Callback: func(key Key, handler *FilteredSearchHandler) int {
643+
if key%2 == 0 {
644+
return 1
645+
}
646+
return 0
647+
},
648+
Data: int64(1),
649+
}
650+
651+
keys, _, err = index.FilteredSearchI8(vector, 1, handler)
652+
if err != nil {
653+
t.Fatalf("FilteredSearchI8 failed: %v", err)
654+
}
655+
656+
if len(keys) > 0 {
657+
t.Fatalf("FilteredSearchI8 returned incorrect results")
658+
}
599659
})
600660
}
601661

@@ -645,6 +705,26 @@ func TestUnsafeOperations(t *testing.T) {
645705
if math.Abs(float64(distances[0])) > distanceTolerance {
646706
t.Fatalf("Expected near-zero distance for exact match, got %f", distances[0])
647707
}
708+
709+
// Test FilteredSearchUnsafe
710+
handler := &FilteredSearchHandler{
711+
Callback: func(key Key, handler *FilteredSearchHandler) int {
712+
if key%2 == 0 {
713+
return 0
714+
}
715+
return 1
716+
},
717+
Data: int64(1),
718+
}
719+
720+
keys, _, err = index.FilteredSearchUnsafe(ptr, 5, handler)
721+
if err != nil {
722+
t.Fatalf("FilteredSearchUnsafe failed: %v", err)
723+
}
724+
725+
if len(keys) > 0 {
726+
t.Fatalf("FilteredSearchUnsafe returned incorrect results")
727+
}
648728
})
649729
}
650730

0 commit comments

Comments
 (0)