@@ -298,6 +298,112 @@ func (s *SearchIteratorSuite) TestNext() {
298298 s .ErrorIs (err , io .EOF )
299299}
300300
301+ func (s * SearchIteratorSuite ) TestNextWithLimit () {
302+ ctx := context .Background ()
303+ collectionName := fmt .Sprintf ("coll_%s" , s .randString (6 ))
304+
305+ token := fmt .Sprintf ("iter_token_%s" , s .randString (8 ))
306+
307+ s .mock .EXPECT ().DescribeCollection (mock .Anything , mock .Anything ).Return (& milvuspb.DescribeCollectionResponse {
308+ CollectionID : 1 ,
309+ Schema : s .schema .ProtoMessage (),
310+ }, nil ).Once ()
311+ s .mock .EXPECT ().Search (mock .Anything , mock .Anything ).RunAndReturn (func (ctx context.Context , sr * milvuspb.SearchRequest ) (* milvuspb.SearchResults , error ) {
312+ s .Equal (collectionName , sr .GetCollectionName ())
313+ checkSearchParam := func (kvs []* commonpb.KeyValuePair , key string , value string ) bool {
314+ for _ , kv := range kvs {
315+ if kv .GetKey () == key && kv .GetValue () == value {
316+ return true
317+ }
318+ }
319+ return false
320+ }
321+
322+ s .True (checkSearchParam (sr .GetSearchParams (), IteratorKey , "true" ))
323+ s .True (checkSearchParam (sr .GetSearchParams (), IteratorSearchV2Key , "true" ))
324+ return & milvuspb.SearchResults {
325+ Status : merr .Success (),
326+ Results : & schemapb.SearchResultData {
327+ NumQueries : 1 ,
328+ TopK : 1 ,
329+ FieldsData : []* schemapb.FieldData {
330+ s .getInt64FieldData ("ID" , []int64 {1 }),
331+ },
332+ Ids : & schemapb.IDs {
333+ IdField : & schemapb.IDs_IntId {
334+ IntId : & schemapb.LongArray {
335+ Data : []int64 {1 },
336+ },
337+ },
338+ },
339+ Scores : make ([]float32 , 1 ),
340+ Topks : []int64 {5 },
341+ Recalls : []float32 {1 },
342+ SearchIteratorV2Results : & schemapb.SearchIteratorV2Results {
343+ Token : token ,
344+ },
345+ },
346+ }, nil
347+ }).Once ()
348+
349+ iter , err := s .client .SearchIterator (ctx , NewSearchIteratorOption (collectionName , entity .FloatVector (lo .RepeatBy (128 , func (_ int ) float32 {
350+ return rand .Float32 ()
351+ }))).WithIteratorLimit (6 ).WithBatchSize (5 ))
352+ s .Require ().NoError (err )
353+ s .Require ().NotNil (iter )
354+
355+ s .mock .EXPECT ().Search (mock .Anything , mock .Anything ).RunAndReturn (func (ctx context.Context , sr * milvuspb.SearchRequest ) (* milvuspb.SearchResults , error ) {
356+ s .Equal (collectionName , sr .GetCollectionName ())
357+ checkSearchParam := func (kvs []* commonpb.KeyValuePair , key string , value string ) bool {
358+ for _ , kv := range kvs {
359+ if kv .GetKey () == key && kv .GetValue () == value {
360+ return true
361+ }
362+ }
363+ return false
364+ }
365+
366+ s .True (checkSearchParam (sr .GetSearchParams (), IteratorKey , "true" ))
367+ s .True (checkSearchParam (sr .GetSearchParams (), IteratorSearchV2Key , "true" ))
368+ return & milvuspb.SearchResults {
369+ Status : merr .Success (),
370+ Results : & schemapb.SearchResultData {
371+ NumQueries : 1 ,
372+ TopK : 1 ,
373+ FieldsData : []* schemapb.FieldData {
374+ s .getInt64FieldData ("ID" , []int64 {1 , 2 , 3 , 4 , 5 }),
375+ },
376+ Ids : & schemapb.IDs {
377+ IdField : & schemapb.IDs_IntId {
378+ IntId : & schemapb.LongArray {
379+ Data : []int64 {1 , 2 , 3 , 4 , 5 },
380+ },
381+ },
382+ },
383+ Scores : []float32 {0.5 , 0.4 , 0.3 , 0.2 , 0.1 },
384+ Topks : []int64 {5 },
385+ Recalls : []float32 {1 },
386+ SearchIteratorV2Results : & schemapb.SearchIteratorV2Results {
387+ Token : token ,
388+ LastBound : 0.5 ,
389+ },
390+ },
391+ }, nil
392+ }).Times (2 )
393+
394+ rs , err := iter .Next (ctx )
395+ s .NoError (err )
396+ s .EqualValues (5 , rs .IDs .Len (), "first batch, return all results" )
397+
398+ rs , err = iter .Next (ctx )
399+ s .NoError (err )
400+ s .EqualValues (1 , rs .IDs .Len (), "second batch, return sliced results" )
401+
402+ _ , err = iter .Next (ctx )
403+ s .Error (err )
404+ s .ErrorIs (err , io .EOF , "limit reached, return EOF" )
405+ }
406+
301407func TestSearchIterator (t * testing.T ) {
302408 suite .Run (t , new (SearchIteratorSuite ))
303409}
0 commit comments