Skip to content

Commit 919111b

Browse files
committed
use universal heap in Reader for IVFFlat
1 parent 2377651 commit 919111b

File tree

5 files changed

+88
-118
lines changed

5 files changed

+88
-118
lines changed

pkg/objectio/types.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,35 @@ func (f BlockReadFilter) DecideSearchFunc(isSortedBlk bool) ReadFilterSearchFunc
6464
return nil
6565
}
6666

67+
type Float64Heap []float64
68+
69+
func (h Float64Heap) Len() int { return len(h) }
70+
func (h Float64Heap) Less(i, j int) bool { return h[i] > h[j] }
71+
func (h Float64Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
72+
73+
func (h *Float64Heap) Push(x any) {
74+
*h = append(*h, x.(float64))
75+
}
76+
77+
func (h *Float64Heap) Pop() any {
78+
old := *h
79+
n := len(old)
80+
x := old[n-1]
81+
*h = old[0 : n-1]
82+
return x
83+
}
84+
6785
type BlockReadTopOp struct {
6886
Typ types.T
6987
Metric metric.MetricType
7088
ColPos int32
7189
NumVec []byte
7290
Limit uint64
91+
92+
//LowerBound float64
93+
//UpperBound float64
94+
95+
DistHeap Float64Heap
7396
}
7497

7598
type WriteOptions struct {

pkg/sql/compile/operator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ func constructTableFunction(n *plan.Node, qry *plan.Query) *table_function.Table
10241024
arg.FuncName = n.TableDef.TblFunc.Name
10251025
arg.Params = n.TableDef.TblFunc.Param
10261026
arg.IsSingle = n.TableDef.TblFunc.IsSingle
1027-
arg.Limit = n.Limit
1027+
arg.Limit = n.BlockLimit
10281028
// probe side runtime filter specs
10291029
arg.RuntimeFilterSpecs = n.RuntimeFilterProbeList
10301030
return arg

pkg/sql/plan/apply_indices_ivfflat.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ func (builder *QueryBuilder) applyIndicesForSortUsingIvfflat(nodeID int32, projN
182182
overFetchFactor := calculatePostFilterOverFetchFactor(originalLimit)
183183

184184
newLimit := max(uint64(float64(originalLimit)*overFetchFactor), originalLimit+10)
185-
tableFuncNode.Limit = &Expr{
185+
tableFuncNode.BlockLimit = &Expr{
186186
Typ: limit.Typ,
187187
Expr: &plan.Expr_Lit{
188188
Lit: &plan.Literal{
@@ -195,11 +195,11 @@ func (builder *QueryBuilder) applyIndicesForSortUsingIvfflat(nodeID int32, projN
195195
}
196196
} else {
197197
// If limit is not a constant, just copy it
198-
tableFuncNode.Limit = DeepCopyExpr(limit)
198+
tableFuncNode.BlockLimit = DeepCopyExpr(limit)
199199
}
200200
} else {
201201
// No filters, use original limit
202-
tableFuncNode.Limit = DeepCopyExpr(limit)
202+
tableFuncNode.BlockLimit = DeepCopyExpr(limit)
203203
}
204204

205205
// Determine join structure based on rankOption.mode:

pkg/vm/engine/readutil/reader.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,11 @@ func (r *reader) SetBlockTop(orderBy []*plan.OrderBySpec, limit uint64) {
476476
r.orderByLimit.ColPos = col.ColPos
477477
r.orderByLimit.NumVec = []byte(numVec)
478478
r.orderByLimit.Limit = limit
479+
480+
//r.orderByLimit.LowerBound = math.Inf(-1)
481+
//r.orderByLimit.UpperBound = math.Inf(1)
482+
483+
r.orderByLimit.DistHeap = make(objectio.Float64Heap, 0, limit)
479484
}
480485

481486
func (r *reader) GetOrderBy() []*plan.OrderBySpec {

pkg/vm/engine/tae/blockio/read.go

Lines changed: 56 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
"container/heap"
1919
"context"
2020
"fmt"
21-
"sort"
21+
"slices"
2222
"time"
2323

2424
"github.com/matrixorigin/matrixone/pkg/common/moerr"
@@ -390,148 +390,90 @@ func HandleOrderByLimitOnIVFFlatIndex(
390390
}
391391
}
392392

393-
var sels []int64
394-
var dists []float64
395-
396393
nullsBm := vecCol.GetNulls()
394+
selectRows = slices.DeleteFunc(selectRows, func(row int64) bool {
395+
return nullsBm.Contains(uint64(row))
396+
})
397397

398-
if int(orderByLimit.Limit) < len(selectRows) {
399-
// apply topn if needed
400-
hp := make(vectorindex.SearchResultMaxHeap, 0, orderByLimit.Limit)
401-
402-
switch orderByLimit.Typ {
403-
case types.T_array_float32:
404-
distFunc, err := metric.ResolveDistanceFn[float32](orderByLimit.Metric)
405-
if err != nil {
406-
return nil, nil, err
407-
}
408-
409-
rhs := types.BytesToArray[float32](orderByLimit.NumVec)
398+
searchResults := make([]vectorindex.SearchResult, 0, len(selectRows))
410399

411-
for _, row := range selectRows {
412-
if nullsBm.Contains(uint64(row)) {
413-
continue
414-
}
415-
dist, err := distFunc(types.BytesToArray[float32](vecCol.GetBytesAt(int(row))), rhs)
416-
if err != nil {
417-
return nil, nil, err
418-
}
419-
dist64 := float64(dist)
400+
switch orderByLimit.Typ {
401+
case types.T_array_float32:
402+
distFunc, err := metric.ResolveDistanceFn[float32](orderByLimit.Metric)
403+
if err != nil {
404+
return nil, nil, err
405+
}
420406

421-
heapItem := &vectorindex.SearchResult{
422-
Id: row,
423-
Distance: dist64,
424-
}
425-
if len(hp) >= int(orderByLimit.Limit) {
426-
if dist64 < hp[0].GetDistance() {
427-
hp[0] = heapItem
428-
heap.Fix(&hp, 0)
429-
}
430-
} else {
431-
heap.Push(&hp, heapItem)
432-
}
433-
}
407+
rhs := types.BytesToArray[float32](orderByLimit.NumVec)
434408

435-
case types.T_array_float64:
436-
distFunc, err := metric.ResolveDistanceFn[float64](orderByLimit.Metric)
409+
for _, row := range selectRows {
410+
dist, err := distFunc(types.BytesToArray[float32](vecCol.GetBytesAt(int(row))), rhs)
437411
if err != nil {
438412
return nil, nil, err
439413
}
414+
dist64 := float64(dist)
440415

441-
rhs := types.BytesToArray[float64](orderByLimit.NumVec)
442-
443-
for _, row := range selectRows {
444-
if nullsBm.Contains(uint64(row)) {
445-
continue
446-
}
447-
dist, err := distFunc(types.BytesToArray[float64](vecCol.GetBytesAt(int(row))), rhs)
448-
if err != nil {
449-
return nil, nil, err
450-
}
451-
452-
heapItem := &vectorindex.SearchResult{
453-
Id: row,
454-
Distance: dist,
455-
}
456-
if len(hp) >= int(orderByLimit.Limit) {
457-
if dist < hp[0].GetDistance() {
458-
hp[0] = heapItem
459-
heap.Fix(&hp, 0)
460-
}
416+
if len(orderByLimit.DistHeap) >= int(orderByLimit.Limit) {
417+
if dist64 < orderByLimit.DistHeap[0] {
418+
orderByLimit.DistHeap[0] = dist64
419+
heap.Fix(&orderByLimit.DistHeap, 0)
461420
} else {
462-
heap.Push(&hp, heapItem)
421+
continue
463422
}
423+
} else {
424+
heap.Push(&orderByLimit.DistHeap, dist64)
464425
}
465426

466-
default:
467-
return nil, nil, moerr.NewInternalError(ctx, fmt.Sprintf("only support float32/float64 type for topn: %s", orderByLimit.Typ))
427+
searchResults = append(searchResults, vectorindex.SearchResult{
428+
Id: row,
429+
Distance: dist64,
430+
})
468431
}
469432

470-
sRes := make([]vectorindex.SearchResult, len(hp))
471-
for i := range sRes {
472-
sRes[i] = *hp[i].(*vectorindex.SearchResult)
433+
case types.T_array_float64:
434+
distFunc, err := metric.ResolveDistanceFn[float64](orderByLimit.Metric)
435+
if err != nil {
436+
return nil, nil, err
473437
}
474-
sort.Slice(sRes, func(i, j int) bool {
475-
return sRes[i].Id < sRes[j].Id
476-
})
477438

478-
sels = make([]int64, len(sRes))
479-
dists = make([]float64, len(sRes))
480-
481-
for i := range sRes {
482-
sels[i] = sRes[i].Id
483-
dists[i] = sRes[i].Distance
484-
}
485-
} else {
486-
sels = make([]int64, 0, len(selectRows))
487-
dists = make([]float64, 0, len(selectRows))
439+
rhs := types.BytesToArray[float64](orderByLimit.NumVec)
488440

489-
switch orderByLimit.Typ {
490-
case types.T_array_float32:
491-
distFunc, err := metric.ResolveDistanceFn[float32](orderByLimit.Metric)
441+
for _, row := range selectRows {
442+
dist64, err := distFunc(types.BytesToArray[float64](vecCol.GetBytesAt(int(row))), rhs)
492443
if err != nil {
493444
return nil, nil, err
494445
}
495446

496-
rhs := types.BytesToArray[float32](orderByLimit.NumVec)
497-
498-
for _, row := range selectRows {
499-
if nullsBm.Contains(uint64(row)) {
447+
if len(orderByLimit.DistHeap) >= int(orderByLimit.Limit) {
448+
if dist64 < orderByLimit.DistHeap[0] {
449+
orderByLimit.DistHeap[0] = dist64
450+
heap.Fix(&orderByLimit.DistHeap, 0)
451+
} else {
500452
continue
501453
}
502-
dist, err := distFunc(types.BytesToArray[float32](vecCol.GetBytesAt(int(row))), rhs)
503-
if err != nil {
504-
return nil, nil, err
505-
}
506-
507-
sels = append(sels, row)
508-
dists = append(dists, float64(dist))
509-
}
510-
511-
case types.T_array_float64:
512-
distFunc, err := metric.ResolveDistanceFn[float64](orderByLimit.Metric)
513-
if err != nil {
514-
return nil, nil, err
454+
} else {
455+
heap.Push(&orderByLimit.DistHeap, dist64)
515456
}
516457

517-
rhs := types.BytesToArray[float64](orderByLimit.NumVec)
458+
searchResults = append(searchResults, vectorindex.SearchResult{
459+
Id: row,
460+
Distance: dist64,
461+
})
462+
}
518463

519-
for _, row := range selectRows {
520-
if nullsBm.Contains(uint64(row)) {
521-
continue
522-
}
523-
dist, err := distFunc(types.BytesToArray[float64](vecCol.GetBytesAt(int(row))), rhs)
524-
if err != nil {
525-
return nil, nil, err
526-
}
464+
default:
465+
return nil, nil, moerr.NewInternalError(ctx, fmt.Sprintf("only support float32/float64 type for topn: %s", orderByLimit.Typ))
466+
}
527467

528-
sels = append(sels, row)
529-
dists = append(dists, float64(dist))
530-
}
468+
searchResults = slices.DeleteFunc(searchResults, func(res vectorindex.SearchResult) bool {
469+
return res.Distance > orderByLimit.DistHeap[0]
470+
})
531471

532-
default:
533-
return nil, nil, moerr.NewInternalError(ctx, fmt.Sprintf("only support float32/float64 type for topn: %s", orderByLimit.Typ))
534-
}
472+
sels := make([]int64, len(searchResults))
473+
dists := make([]float64, len(searchResults))
474+
for i, res := range searchResults {
475+
sels[i] = res.Id
476+
dists[i] = res.Distance
535477
}
536478

537479
return sels, dists, nil

0 commit comments

Comments
 (0)