Skip to content

Commit 68c23b0

Browse files
authored
Fix: Exact search maxResults in GoLang
2 parents 7685a67 + abcbf09 commit 68c23b0

File tree

3 files changed

+167
-79
lines changed

3 files changed

+167
-79
lines changed

c/lib.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ USEARCH_EXPORT void usearch_exact_search( //
478478

479479
metric_punned_t metric(dimensions, metric_kind_to_cpp(metric_kind), scalar_kind_to_cpp(scalar_kind));
480480
executor_default_t executor(threads);
481-
static exact_search_t search;
481+
exact_search_t search;
482482
exact_search_results_t result = search( //
483483
(byte_t const*)dataset, dataset_count, dataset_stride, //
484484
(byte_t const*)queries, queries_count, queries_stride, //

golang/lib.go

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ func (index *Index) SearchUnsafe(query unsafe.Pointer, limit uint) (keys []Key,
689689
// - numThreads: Number of threads to use (0 for auto-detection)
690690
func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCount uint,
691691
datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric,
692-
maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) {
692+
maxResults uint, numThreads uint) (keys []Key, distances []float32, err error) {
693693

694694
if len(dataset) == 0 || len(queries) == 0 {
695695
return nil, nil, errors.New("dataset and queries cannot be empty")
@@ -703,9 +703,15 @@ func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCo
703703
if (len(queries) % int(vectorDimensions)) != 0 {
704704
return nil, nil, errors.New("queries length must be a multiple of the dimensions")
705705
}
706+
if maxResults == 0 {
707+
return nil, nil, errors.New("maxResults must be greater than zero")
708+
}
709+
710+
keys = make([]Key, queryCount*maxResults)
711+
distances = make([]float32, queryCount*maxResults)
712+
resultKeysStride := uint32(maxResults * 8) // int64 - 8 bytes
713+
resultDistancesStride := uint32(maxResults * 4) // float32 - 4 bytes
706714

707-
keys = make([]Key, maxResults)
708-
distances = make([]float32, maxResults)
709715
var errorMessage *C.char
710716
C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(datasetSize), C.size_t(datasetStride), unsafe.Pointer(&queries[0]), C.size_t(queryCount), C.size_t(queryStride),
711717
C.usearch_scalar_f32_k, C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads),
@@ -718,8 +724,6 @@ func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCo
718724
return nil, nil, errors.New(C.GoString(errorMessage))
719725
}
720726

721-
keys = keys[:maxResults]
722-
distances = distances[:maxResults]
723727
return keys, distances, nil
724728
}
725729

@@ -736,17 +740,19 @@ func ExactSearch(dataset []float32, queries []float32, datasetSize uint, queryCo
736740
// For contiguous data, use vectorDimensions * sizeof(element_type).
737741
func ExactSearchUnsafe(dataset unsafe.Pointer, queries unsafe.Pointer, datasetSize uint, queryCount uint,
738742
datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric, quantization Quantization,
739-
maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) {
743+
maxResults uint, numThreads uint) (keys []Key, distances []float32, err error) {
740744

741745
if dataset == nil || queries == nil {
742746
return nil, nil, errors.New("dataset and queries pointers cannot be nil")
743747
}
744-
if vectorDimensions == 0 || datasetSize == 0 || queryCount == 0 {
745-
return nil, nil, errors.New("dimensions and sizes must be greater than zero")
748+
if vectorDimensions == 0 || datasetSize == 0 || queryCount == 0 || maxResults == 0 {
749+
return nil, nil, errors.New("dimensions, query count, max results and sizes must be greater than zero")
746750
}
747751

748-
keys = make([]Key, maxResults)
749-
distances = make([]float32, maxResults)
752+
keys = make([]Key, queryCount*maxResults)
753+
distances = make([]float32, queryCount*maxResults)
754+
resultKeysStride := uint32(maxResults * 8) // int64 - 8 bytes
755+
resultDistancesStride := uint32(maxResults * 4) // float32 - 4 bytes
750756
var errorMessage *C.char
751757
C.usearch_exact_search(dataset, C.size_t(datasetSize), C.size_t(datasetStride), queries, C.size_t(queryCount), C.size_t(queryStride),
752758
quantization.CValue(), C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads),
@@ -757,8 +763,6 @@ func ExactSearchUnsafe(dataset unsafe.Pointer, queries unsafe.Pointer, datasetSi
757763
return nil, nil, errors.New(C.GoString(errorMessage))
758764
}
759765

760-
keys = keys[:maxResults]
761-
distances = distances[:maxResults]
762766
return keys, distances, nil
763767
}
764768

@@ -850,17 +854,21 @@ func DistanceI8(vec1 []int8, vec2 []int8, vectorDimensions uint, metric Metric)
850854
// For contiguous int8 data, use vectorDimensions * 1 byte.
851855
func ExactSearchI8(dataset []int8, queries []int8, datasetSize uint, queryCount uint,
852856
datasetStride uint, queryStride uint, vectorDimensions uint, metric Metric,
853-
maxResults uint, numThreads uint, resultKeysStride uint, resultDistancesStride uint) (keys []Key, distances []float32, err error) {
857+
maxResults uint, numThreads uint) (keys []Key, distances []float32, err error) {
854858

855859
if len(dataset) == 0 || len(queries) == 0 {
856860
return nil, nil, errors.New("dataset and queries cannot be empty")
857861
}
858862
if vectorDimensions == 0 {
859863
return nil, nil, errors.New("dimensions must be greater than zero")
860864
}
861-
862-
keys = make([]Key, maxResults)
863-
distances = make([]float32, maxResults)
865+
if maxResults == 0 {
866+
return nil, nil, errors.New("maxResults must be greater than zero")
867+
}
868+
keys = make([]Key, queryCount*maxResults)
869+
distances = make([]float32, queryCount*maxResults)
870+
resultKeysStride := uint32(maxResults * 8) // int64 - 8 bytes
871+
resultDistancesStride := uint32(maxResults * 4) // float32 - 4 bytes
864872
var errorMessage *C.char
865873
C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(datasetSize), C.size_t(datasetStride), unsafe.Pointer(&queries[0]), C.size_t(queryCount), C.size_t(queryStride),
866874
C.usearch_scalar_i8_k, C.size_t(vectorDimensions), metric.CValue(), C.size_t(maxResults), C.size_t(numThreads),
@@ -872,8 +880,6 @@ func ExactSearchI8(dataset []int8, queries []int8, datasetSize uint, queryCount
872880
if errorMessage != nil {
873881
return nil, nil, errors.New(C.GoString(errorMessage))
874882
}
875-
keys = keys[:maxResults]
876-
distances = distances[:maxResults]
877883
return keys, distances, nil
878884
}
879885

0 commit comments

Comments
 (0)