@@ -18,7 +18,7 @@ import (
1818 "container/heap"
1919 "context"
2020 "fmt"
21- "sort "
21+ "slices "
2222 "time"
2323
2424 "github.com/matrixorigin/matrixone/pkg/common/moerr"
@@ -390,148 +390,90 @@ func HandleOrderByLimitOnIVFFlatIndex(
390390 }
391391 }
392392
393- var sels []int64
394- var dists []float64
395-
396393 nullsBm := vecCol .GetNulls ()
394+ selectRows = slices .DeleteFunc (selectRows , func (row int64 ) bool {
395+ return nullsBm .Contains (uint64 (row ))
396+ })
397397
398- if int (orderByLimit .Limit ) < len (selectRows ) {
399- // apply topn if needed
400- hp := make (vectorindex.SearchResultMaxHeap , 0 , orderByLimit .Limit )
401-
402- switch orderByLimit .Typ {
403- case types .T_array_float32 :
404- distFunc , err := metric.ResolveDistanceFn [float32 ](orderByLimit .Metric )
405- if err != nil {
406- return nil , nil , err
407- }
408-
409- rhs := types.BytesToArray [float32 ](orderByLimit .NumVec )
398+ searchResults := make ([]vectorindex.SearchResult , 0 , len (selectRows ))
410399
411- for _ , row := range selectRows {
412- if nullsBm .Contains (uint64 (row )) {
413- continue
414- }
415- dist , err := distFunc (types.BytesToArray [float32 ](vecCol .GetBytesAt (int (row ))), rhs )
416- if err != nil {
417- return nil , nil , err
418- }
419- dist64 := float64 (dist )
400+ switch orderByLimit .Typ {
401+ case types .T_array_float32 :
402+ distFunc , err := metric.ResolveDistanceFn [float32 ](orderByLimit .Metric )
403+ if err != nil {
404+ return nil , nil , err
405+ }
420406
421- heapItem := & vectorindex.SearchResult {
422- Id : row ,
423- Distance : dist64 ,
424- }
425- if len (hp ) >= int (orderByLimit .Limit ) {
426- if dist64 < hp [0 ].GetDistance () {
427- hp [0 ] = heapItem
428- heap .Fix (& hp , 0 )
429- }
430- } else {
431- heap .Push (& hp , heapItem )
432- }
433- }
407+ rhs := types.BytesToArray [float32 ](orderByLimit .NumVec )
434408
435- case types . T_array_float64 :
436- distFunc , err := metric. ResolveDistanceFn [ float64 ]( orderByLimit . Metric )
409+ for _ , row := range selectRows {
410+ dist , err := distFunc (types. BytesToArray [ float32 ]( vecCol . GetBytesAt ( int ( row ))), rhs )
437411 if err != nil {
438412 return nil , nil , err
439413 }
414+ dist64 := float64 (dist )
440415
441- rhs := types.BytesToArray [float64 ](orderByLimit .NumVec )
442-
443- for _ , row := range selectRows {
444- if nullsBm .Contains (uint64 (row )) {
445- continue
446- }
447- dist , err := distFunc (types.BytesToArray [float64 ](vecCol .GetBytesAt (int (row ))), rhs )
448- if err != nil {
449- return nil , nil , err
450- }
451-
452- heapItem := & vectorindex.SearchResult {
453- Id : row ,
454- Distance : dist ,
455- }
456- if len (hp ) >= int (orderByLimit .Limit ) {
457- if dist < hp [0 ].GetDistance () {
458- hp [0 ] = heapItem
459- heap .Fix (& hp , 0 )
460- }
416+ if len (orderByLimit .DistHeap ) >= int (orderByLimit .Limit ) {
417+ if dist64 < orderByLimit .DistHeap [0 ] {
418+ orderByLimit .DistHeap [0 ] = dist64
419+ heap .Fix (& orderByLimit .DistHeap , 0 )
461420 } else {
462- heap . Push ( & hp , heapItem )
421+ continue
463422 }
423+ } else {
424+ heap .Push (& orderByLimit .DistHeap , dist64 )
464425 }
465426
466- default :
467- return nil , nil , moerr .NewInternalError (ctx , fmt .Sprintf ("only support float32/float64 type for topn: %s" , orderByLimit .Typ ))
427+ searchResults = append (searchResults , vectorindex.SearchResult {
428+ Id : row ,
429+ Distance : dist64 ,
430+ })
468431 }
469432
470- sRes := make ([]vectorindex.SearchResult , len (hp ))
471- for i := range sRes {
472- sRes [i ] = * hp [i ].(* vectorindex.SearchResult )
433+ case types .T_array_float64 :
434+ distFunc , err := metric.ResolveDistanceFn [float64 ](orderByLimit .Metric )
435+ if err != nil {
436+ return nil , nil , err
473437 }
474- sort .Slice (sRes , func (i , j int ) bool {
475- return sRes [i ].Id < sRes [j ].Id
476- })
477438
478- sels = make ([]int64 , len (sRes ))
479- dists = make ([]float64 , len (sRes ))
480-
481- for i := range sRes {
482- sels [i ] = sRes [i ].Id
483- dists [i ] = sRes [i ].Distance
484- }
485- } else {
486- sels = make ([]int64 , 0 , len (selectRows ))
487- dists = make ([]float64 , 0 , len (selectRows ))
439+ rhs := types.BytesToArray [float64 ](orderByLimit .NumVec )
488440
489- switch orderByLimit .Typ {
490- case types .T_array_float32 :
491- distFunc , err := metric.ResolveDistanceFn [float32 ](orderByLimit .Metric )
441+ for _ , row := range selectRows {
442+ dist64 , err := distFunc (types.BytesToArray [float64 ](vecCol .GetBytesAt (int (row ))), rhs )
492443 if err != nil {
493444 return nil , nil , err
494445 }
495446
496- rhs := types.BytesToArray [float32 ](orderByLimit .NumVec )
497-
498- for _ , row := range selectRows {
499- if nullsBm .Contains (uint64 (row )) {
447+ if len (orderByLimit .DistHeap ) >= int (orderByLimit .Limit ) {
448+ if dist64 < orderByLimit .DistHeap [0 ] {
449+ orderByLimit .DistHeap [0 ] = dist64
450+ heap .Fix (& orderByLimit .DistHeap , 0 )
451+ } else {
500452 continue
501453 }
502- dist , err := distFunc (types.BytesToArray [float32 ](vecCol .GetBytesAt (int (row ))), rhs )
503- if err != nil {
504- return nil , nil , err
505- }
506-
507- sels = append (sels , row )
508- dists = append (dists , float64 (dist ))
509- }
510-
511- case types .T_array_float64 :
512- distFunc , err := metric.ResolveDistanceFn [float64 ](orderByLimit .Metric )
513- if err != nil {
514- return nil , nil , err
454+ } else {
455+ heap .Push (& orderByLimit .DistHeap , dist64 )
515456 }
516457
517- rhs := types.BytesToArray [float64 ](orderByLimit .NumVec )
458+ searchResults = append (searchResults , vectorindex.SearchResult {
459+ Id : row ,
460+ Distance : dist64 ,
461+ })
462+ }
518463
519- for _ , row := range selectRows {
520- if nullsBm .Contains (uint64 (row )) {
521- continue
522- }
523- dist , err := distFunc (types.BytesToArray [float64 ](vecCol .GetBytesAt (int (row ))), rhs )
524- if err != nil {
525- return nil , nil , err
526- }
464+ default :
465+ return nil , nil , moerr .NewInternalError (ctx , fmt .Sprintf ("only support float32/float64 type for topn: %s" , orderByLimit .Typ ))
466+ }
527467
528- sels = append ( sels , row )
529- dists = append ( dists , float64 ( dist ))
530- }
468+ searchResults = slices . DeleteFunc ( searchResults , func ( res vectorindex. SearchResult ) bool {
469+ return res . Distance > orderByLimit . DistHeap [ 0 ]
470+ })
531471
532- default :
533- return nil , nil , moerr .NewInternalError (ctx , fmt .Sprintf ("only support float32/float64 type for topn: %s" , orderByLimit .Typ ))
534- }
472+ sels := make ([]int64 , len (searchResults ))
473+ dists := make ([]float64 , len (searchResults ))
474+ for i , res := range searchResults {
475+ sels [i ] = res .Id
476+ dists [i ] = res .Distance
535477 }
536478
537479 return sels , dists , nil
0 commit comments