Skip to content

Commit 4def025

Browse files
authored
enhance: [GoSDK] Support limit for search iterator (#43732)
Related to #37548 Add `WithIteratorLimit` option to limit search iterator overall return limit. --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
1 parent d4abb58 commit 4def025

File tree

4 files changed

+173
-4
lines changed

4 files changed

+173
-4
lines changed

client/milvusclient/iterator.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ const (
3838
IteratorSearchLastBoundKey = "search_iter_last_bound"
3939
IteratorSearchIDKey = "search_iter_id"
4040
CollectionIDKey = `collection_id`
41+
42+
// Unlimited
43+
Unlimited int64 = -1
4144
)
4245

4346
var ErrServerVersionIncompatible = errors.New("server version incompatible")
@@ -53,10 +56,29 @@ type searchIteratorV2 struct {
5356
client *Client
5457
option SearchIteratorOption
5558
schema *entity.Schema
59+
limit int64
5660
}
5761

5862
func (it *searchIteratorV2) Next(ctx context.Context) (ResultSet, error) {
59-
return it.next(ctx)
63+
// limit reached, return EOF
64+
if it.limit == 0 {
65+
return ResultSet{}, io.EOF
66+
}
67+
68+
rs, err := it.next(ctx)
69+
if err != nil {
70+
return rs, err
71+
}
72+
73+
if it.limit == Unlimited {
74+
return rs, err
75+
}
76+
77+
if int64(rs.Len()) > it.limit {
78+
rs = rs.Slice(0, int(it.limit))
79+
}
80+
it.limit -= int64(rs.Len())
81+
return rs, nil
6082
}
6183

6284
func (it *searchIteratorV2) next(ctx context.Context) (ResultSet, error) {
@@ -144,6 +166,7 @@ func newSearchIteratorV2(ctx context.Context, client *Client, option SearchItera
144166
iter := &searchIteratorV2{
145167
client: client,
146168
option: option,
169+
limit: option.Limit(),
147170
}
148171
if err := iter.setupCollectionID(ctx); err != nil {
149172
return nil, err

client/milvusclient/iterator_option.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@ import (
2424
)
2525

2626
type SearchIteratorOption interface {
27+
// SearchOption returns the search option when iterate search
2728
SearchOption() *searchOption
29+
// Limit returns the overall limit of entries to iterate
30+
Limit() int64
2831
}
2932

3033
type searchIteratorOption struct {
3134
*searchOption
32-
batchSize int
35+
batchSize int
36+
iteratorLimit int64
3337
}
3438

3539
func (opt *searchIteratorOption) SearchOption() *searchOption {
@@ -38,6 +42,10 @@ func (opt *searchIteratorOption) SearchOption() *searchOption {
3842
return opt.searchOption
3943
}
4044

45+
func (opt *searchIteratorOption) Limit() int64 {
46+
return opt.iteratorLimit
47+
}
48+
4149
func (opt *searchIteratorOption) WithBatchSize(batchSize int) *searchIteratorOption {
4250
opt.batchSize = batchSize
4351
return opt
@@ -109,11 +117,17 @@ func (opt *searchIteratorOption) WithSearchParam(key, value string) *searchItera
109117
return opt
110118
}
111119

120+
func (opt *searchIteratorOption) WithIteratorLimit(limit int64) *searchIteratorOption {
121+
opt.iteratorLimit = limit
122+
return opt
123+
}
124+
112125
func NewSearchIteratorOption(collectionName string, vector entity.Vector) *searchIteratorOption {
113126
return &searchIteratorOption{
114-
searchOption: NewSearchOption(collectionName, 100, []entity.Vector{vector}).
127+
searchOption: NewSearchOption(collectionName, 1000, []entity.Vector{vector}).
115128
WithSearchParam(IteratorKey, "true").
116129
WithSearchParam(IteratorSearchV2Key, "true"),
117-
batchSize: 1000,
130+
batchSize: 1000,
131+
iteratorLimit: Unlimited,
118132
}
119133
}

client/milvusclient/iterator_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,112 @@ func (s *SearchIteratorSuite) TestNext() {
298298
s.ErrorIs(err, io.EOF)
299299
}
300300

301+
func (s *SearchIteratorSuite) TestNextWithLimit() {
302+
ctx := context.Background()
303+
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
304+
305+
token := fmt.Sprintf("iter_token_%s", s.randString(8))
306+
307+
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
308+
CollectionID: 1,
309+
Schema: s.schema.ProtoMessage(),
310+
}, nil).Once()
311+
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
312+
s.Equal(collectionName, sr.GetCollectionName())
313+
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
314+
for _, kv := range kvs {
315+
if kv.GetKey() == key && kv.GetValue() == value {
316+
return true
317+
}
318+
}
319+
return false
320+
}
321+
322+
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
323+
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
324+
return &milvuspb.SearchResults{
325+
Status: merr.Success(),
326+
Results: &schemapb.SearchResultData{
327+
NumQueries: 1,
328+
TopK: 1,
329+
FieldsData: []*schemapb.FieldData{
330+
s.getInt64FieldData("ID", []int64{1}),
331+
},
332+
Ids: &schemapb.IDs{
333+
IdField: &schemapb.IDs_IntId{
334+
IntId: &schemapb.LongArray{
335+
Data: []int64{1},
336+
},
337+
},
338+
},
339+
Scores: make([]float32, 1),
340+
Topks: []int64{5},
341+
Recalls: []float32{1},
342+
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
343+
Token: token,
344+
},
345+
},
346+
}, nil
347+
}).Once()
348+
349+
iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
350+
return rand.Float32()
351+
}))).WithIteratorLimit(6).WithBatchSize(5))
352+
s.Require().NoError(err)
353+
s.Require().NotNil(iter)
354+
355+
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
356+
s.Equal(collectionName, sr.GetCollectionName())
357+
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
358+
for _, kv := range kvs {
359+
if kv.GetKey() == key && kv.GetValue() == value {
360+
return true
361+
}
362+
}
363+
return false
364+
}
365+
366+
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
367+
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
368+
return &milvuspb.SearchResults{
369+
Status: merr.Success(),
370+
Results: &schemapb.SearchResultData{
371+
NumQueries: 1,
372+
TopK: 1,
373+
FieldsData: []*schemapb.FieldData{
374+
s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5}),
375+
},
376+
Ids: &schemapb.IDs{
377+
IdField: &schemapb.IDs_IntId{
378+
IntId: &schemapb.LongArray{
379+
Data: []int64{1, 2, 3, 4, 5},
380+
},
381+
},
382+
},
383+
Scores: []float32{0.5, 0.4, 0.3, 0.2, 0.1},
384+
Topks: []int64{5},
385+
Recalls: []float32{1},
386+
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
387+
Token: token,
388+
LastBound: 0.5,
389+
},
390+
},
391+
}, nil
392+
}).Times(2)
393+
394+
rs, err := iter.Next(ctx)
395+
s.NoError(err)
396+
s.EqualValues(5, rs.IDs.Len(), "first batch, return all results")
397+
398+
rs, err = iter.Next(ctx)
399+
s.NoError(err)
400+
s.EqualValues(1, rs.IDs.Len(), "second batch, return sliced results")
401+
402+
_, err = iter.Next(ctx)
403+
s.Error(err)
404+
s.ErrorIs(err, io.EOF, "limit reached, return EOF")
405+
}
406+
301407
func TestSearchIterator(t *testing.T) {
302408
suite.Run(t, new(SearchIteratorSuite))
303409
}

client/milvusclient/results.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"runtime/debug"
2222

2323
"github.com/cockroachdb/errors"
24+
"github.com/samber/lo"
2425

2526
"github.com/milvus-io/milvus/client/v2/column"
2627
"github.com/milvus-io/milvus/client/v2/entity"
@@ -51,6 +52,31 @@ func (rs *ResultSet) GetColumn(fieldName string) column.Column {
5152
return nil
5253
}
5354

55+
func (rs ResultSet) Len() int {
56+
return rs.ResultCount
57+
}
58+
59+
func (rs ResultSet) Slice(start, end int) ResultSet {
60+
result := ResultSet{
61+
sch: rs.sch,
62+
IDs: rs.IDs.Slice(start, end),
63+
Fields: lo.Map(rs.Fields, func(column column.Column, _ int) column.Column {
64+
return column.Slice(start, end)
65+
}),
66+
// Recall will not be sliced
67+
Err: rs.Err,
68+
}
69+
70+
if rs.GroupByValue != nil {
71+
result.GroupByValue = rs.GroupByValue.Slice(start, end)
72+
}
73+
74+
result.ResultCount = result.IDs.Len()
75+
result.Scores = rs.Scores[start : start+result.ResultCount]
76+
77+
return result
78+
}
79+
5480
// Unmarshal puts dataset into receiver in row based way.
5581
// `receiver` shall be a slice of pointer of model struct
5682
// eg, []*Records, in which type `Record` defines the row data.

0 commit comments

Comments
 (0)