diff --git a/docs/design-docs/design_docs/20260602-struct_hybrid_search.md b/docs/design-docs/design_docs/20260602-struct_hybrid_search.md new file mode 100644 index 0000000000000..6f613d3a4197c --- /dev/null +++ b/docs/design-docs/design_docs/20260602-struct_hybrid_search.md @@ -0,0 +1,324 @@ +# Struct Element-Level Hybrid Search + +This document describes the intended end state for hybrid search when a vector +sub-field inside a struct array field is searched at element level. + +This document does not change embedding-list search semantics. Embedding-list +search on a struct-array vector sub-field is treated like normal row-level +vector search. + +## Concepts + +A struct array field stores multiple struct elements per row. A vector sub-field +inside that struct array can be searched in two forms: + +```text +element-level search One query vector is matched against individual struct elements. +embedding-list search A list of query vectors is matched as one row-level request. +``` + +Only element-level search produces element-level candidates. + +For example: + +```text +structA: array +normal_vector: float_vector +``` + +Element-level search on `structA[image_vec]` produces hits identified by: + +```text +(primary_key, parent_struct_field, element_index) +``` + +Embedding-list search on `structA[image_vec]` and normal vector search on +`normal_vector` both produce row-level hits identified by: + +```text +(primary_key) +``` + +Hybrid search must decide whether element-level hits from element-level +struct-array search remain element-level for rerank, or whether they are +collapsed to row-level candidates before rerank. + +## Request Model + +Row-level collapse behavior is configured per sub-search request, not on the +top-level hybrid search request. + +This is required because each sub-search has its own `anns_field`, metric, +filter, limit, and collapse behavior. A single hybrid request can search +multiple struct sub-fields with different row-level collapse strategies. + +User-facing row-collapse API example: + +```python +AnnSearchRequest( + data=[query_image], + anns_field="structA[image_vec]", + param={ + "metric_type": "COSINE", + "params": { + "ef": 100, + "element_scope": { + "collapse": { + "strategy": "topk_sum", + "topk": 3, + }, + }, + }, + }, + limit=100, +) +``` + +Equivalent SDKs may expose typed options, but they should still serialize to the +sub-search request: + +```go +annReq := client.NewAnnRequest("structA[image_vec]", limit, vectors). + WithElementCollapse(client.ElementCollapseTopKSum, client.WithTopK(3)) +``` + +The top-level hybrid request still owns only hybrid-level options such as final +`limit`, `offset`, output fields, consistency, and reranker configuration. + +Embedding-list search on `structA[image_vec]` must not use `element_scope`; it is +already row-level and follows the same hybrid behavior as `normal_vector`. + +If `element_scope` is missing, the row-level collapse strategy defaults to `max` +whenever row-level collapse is needed. + +## Candidate Scope + +Hybrid search infers final candidate scope from the sub-search types. + +```text +all sub-searches are element-level and use the same parent struct array + -> element-level hybrid, no collapse + +otherwise + -> row-level hybrid + -> every element-level sub-search is collapsed to row candidates + -> collapse strategy defaults to max unless element_scope.collapse overrides it +``` + +Element-level hybrid example: + +```python +image_req = AnnSearchRequest( + data=[query_image], + anns_field="structA[image_vec]", + param={ + "metric_type": "COSINE", + "params": {"ef": 100}, + }, + limit=100, +) + +text_req = AnnSearchRequest( + data=[query_text], + anns_field="structA[text_vec]", + param={ + "metric_type": "COSINE", + "params": {"ef": 100}, + }, + limit=100, +) + +client.hybrid_search( + collection_name, + [image_req, text_req], + ranker=RRFRanker(), + limit=20, +) +``` + +Both sub-searches are element-level and use sub-fields of `structA`, so final +results are element-level. + +## Compatibility Matrix + +Hybrid search can combine row-level and element-level sub-searches only when the +candidate identity is well-defined. + +Sub-search types: + +```text +normal vector A top-level vector field, such as normal_vector. +struct emb-list Embedding-list search on a struct-array vector sub-field. +struct element Element-level search on a struct-array vector sub-field. +``` + +Compatibility: + +```text +left \ right normal vector struct emb-list struct element +normal vector row-level row-level row-level +struct emb-list row-level row-level row-level +struct element row-level row-level element-level if same parent, else row-level +``` + +Behavior: + +```text +row-level + Final candidates are keyed by primary key. + Element-level sub-searches are collapsed before rerank. + +element-level if same parent + Allowed only when all element-level sub-searches use sub-fields of the same + parent struct array. Final candidates are keyed by + (primary_key, parent_struct_field, element_index). +``` + +For two `struct element` sub-searches with different parent struct arrays, +element offsets do not share identity. The request is still valid, but the final +candidate scope is row-level and both element-level sub-searches are collapsed. + +## Row-Level Collapse + +When inferred candidate scope is row-level, all element hits from the same row +are aggregated into one row-level candidate before hybrid rerank. + +The collapse strategy is provided in that same sub-search request: + +```json +{ + "element_scope": { + "collapse": { + "strategy": "max" + } + } +} +``` + +Supported initial strategies: + +```text +max +sum +avg +topk_sum +topk_avg +``` + +Strategy behavior: + +```text +max Keep the best element score for the row. +sum Sum all returned element scores for the row. +avg Average all returned element scores for the row. +topk_sum Sum the best K returned element scores for the row. +topk_avg Average the best K returned element scores for the row. +``` + +`topk` is required for `topk_sum` and `topk_avg`, and invalid for strategies that +do not use it. + +Collapse operates on the returned element hits from that sub-search. It does not +scan every element in a row after ANN search. Therefore, the sub-search `limit` +controls both recall and the number of elements available for aggregation. + +Metric direction must be respected: + +```text +positively related metrics: larger score is better +negatively related metrics: smaller score is better +``` + +## Element-Level Hybrid Rerank + +Element-level hybrid rerank is used only when every sub-search is element-level +and all sub-searches refer to vector sub-fields under the same parent struct +array. + +Valid: + +```text +structA[image_vec] + structA[text_vec] +``` + +These two sub-fields share the same element identity: + +```text +(primary_key, "structA", element_index) +``` + +The hybrid reranker should rank element candidates using that key. Final results +may remain element-level and expose the matched `element_index`. + +Row-level fallback: + +```text +structA[image_vec] + structB[text_vec] +``` + +Even if both hits have `element_index = 3`, those offsets refer to different +arrays. They must not be treated as the same element. The hybrid search falls +back to row-level scope and collapses both element-level sub-searches before +rerank. + +## Validation Rules + +1. `element_scope.collapse` is valid only on element-level search over + struct-array vector sub-fields when the inferred candidate scope is row-level. +2. Normal vector fields are always row-level. +3. Embedding-list search on struct-array vector sub-fields is always row-level. +4. Normal vector sub-searches and embedding-list sub-searches must reject + non-default element collapse settings. +5. If row-level scope requires collapsing element-level hits and collapse config + is omitted, use `max`. +6. If inferred candidate scope is element-level, reject `element_scope.collapse` + because no row-level collapse is performed. +7. Hybrid search supports only plain top-K for struct-array vector sub-searches. + Element-level and embedding-list sub-searches reject group-by, range search, + and search iterator. +8. `sum` and `topk_sum` collapse strategies are valid only for positively + related metrics such as `IP` and `COSINE`. Negative distance metrics such as + `L2` must use `max`, `avg`, or `topk_avg`. + +## Result Semantics + +For row-level hybrid search: + +```text +result key: primary_key +duplicates: no duplicate primary keys in final results +element_index: not returned +``` + +For element-level hybrid search: + +```text +result key: (primary_key, parent_struct_field, element_index) +duplicates: no duplicate element keys in final results +element_index: returned +``` + +## Execution Order + +The intended pipeline is: + +```text +1. Execute each sub-search. +2. Reduce each sub-search result. +3. Infer final candidate scope from all sub-searches. +4. If scope is row-level, collapse every element-level sub-search to row + candidates using that sub-search's collapse strategy. + Normal vector sub-searches and embedding-list sub-searches are already + row-level. + If scope is element-level, keep element candidates. +5. Apply hybrid rerank. +6. Assemble output fields according to the final result level. +``` + +This keeps collapse local to the sub-search that produced element-level hits, +while keeping the hybrid reranker responsible only for combining already +normalized candidate lists. diff --git a/internal/proxy/search_pipeline.go b/internal/proxy/search_pipeline.go index 16206769cd877..b0cbc4bbee380 100644 --- a/internal/proxy/search_pipeline.go +++ b/internal/proxy/search_pipeline.go @@ -115,6 +115,7 @@ const ( requeryOp = "requery" organizeOp = "organize" elementBestCollapseOp = "element_best_collapse" + elementKeyRestoreOp = "element_key_restore" hybridAssembleOp = "hybrid_assemble" endOp = "end" lambdaOp = "lambda" @@ -133,6 +134,7 @@ var opFactory = map[string]func(t *searchTask, params map[string]any) (operator, rerankOp: newRerankOperator, organizeOp: newOrganizeOperator, elementBestCollapseOp: newElementBestCollapseOperator, + elementKeyRestoreOp: newElementKeyRestoreOperator, hybridAssembleOp: newHybridAssembleOperator, requeryOp: newRequeryOperator, lambdaOp: newLambdaOperator, @@ -266,15 +268,36 @@ func (op *hybridSearchReduceOperator) run(ctx context.Context, span trace.Span, return []any{multipleMilvusResults, searchMetrics}, nil } -type elementBestCollapseOperator struct{} +type elementBestCollapseOperator struct { + configs []elementCollapseConfig + elementLevelHybrid bool +} -func newElementBestCollapseOperator(_ *searchTask, _ map[string]any) (operator, error) { - return &elementBestCollapseOperator{}, nil +func newElementBestCollapseOperator(t *searchTask, _ map[string]any) (operator, error) { + return &elementBestCollapseOperator{ + configs: t.hybridCollapseConfigs(), + elementLevelHybrid: t.hybridElementLevel, + }, nil } -// elementBestCollapseOperator normalizes element-level hybrid sub-search -// results into row-level results before rerank. For each query chunk, duplicate -// PKs keep the best element score under that sub-search metric direction. +func (t *searchTask) hybridCollapseConfigs() []elementCollapseConfig { + if len(t.hybridSubSearchInfos) == 0 { + return nil + } + configs := make([]elementCollapseConfig, len(t.hybridSubSearchInfos)) + for i, info := range t.hybridSubSearchInfos { + configs[i] = info.Collapse + if configs[i].Strategy == "" { + configs[i] = defaultElementCollapseConfig() + } + } + return configs +} + +// elementBestCollapseOperator normalizes element-level hybrid sub-search results +// into row-level results before rerank, or prepares same-struct element-level +// hybrid results with proxy-internal element keys so rerank can distinguish +// different elements from the same row. func (op *elementBestCollapseOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { if len(inputs) < 2 { return nil, merr.WrapErrServiceInternal("element best collapse: missing inputs") @@ -293,6 +316,14 @@ func (op *elementBestCollapseOperator) run(ctx context.Context, span trace.Span, collapsed := make([]*milvuspb.SearchResults, len(results)) for i, result := range results { + if op.elementLevelHybrid { + var err error + collapsed[i], err = prepareElementLevelHybridResult(result) + if err != nil { + return nil, err + } + continue + } metricType := metrics[i] if result != nil && result.GetResults() != nil && result.GetResults().GetElementIndices() != nil && strings.TrimSpace(metricType) == "" { totalRows := int64(0) @@ -304,7 +335,11 @@ func (op *elementBestCollapseOperator) run(ctx context.Context, span trace.Span, } } var err error - collapsed[i], err = collapseElementLevelResultByBestScore(result, metric.PositivelyRelated(metricType)) + config := defaultElementCollapseConfig() + if i < len(op.configs) && op.configs[i].Strategy != "" { + config = op.configs[i] + } + collapsed[i], err = collapseElementLevelResult(result, metric.PositivelyRelated(metricType), config) if err != nil { return nil, err } @@ -313,9 +348,11 @@ func (op *elementBestCollapseOperator) run(ctx context.Context, span trace.Span, } type bestElementHit struct { - rowIdx int64 - score float32 - order int + rowIdx int64 + score float32 + order int + aggregate float32 + groupCount int } type rowIdxComputeItem struct { @@ -343,6 +380,10 @@ func computeFieldIdxsByOriginalOrder(rowIdxs []int64, compute func(int64) []int6 } func collapseElementLevelResultByBestScore(result *milvuspb.SearchResults, largerScoreIsBetter bool) (*milvuspb.SearchResults, error) { + return collapseElementLevelResult(result, largerScoreIsBetter, defaultElementCollapseConfig()) +} + +func collapseElementLevelResult(result *milvuspb.SearchResults, largerScoreIsBetter bool, config elementCollapseConfig) (*milvuspb.SearchResults, error) { if result == nil || result.GetResults() == nil || result.GetResults().GetElementIndices() == nil { return result, nil } @@ -367,6 +408,12 @@ func collapseElementLevelResultByBestScore(result *milvuspb.SearchResults, large }), nil } + if isElementCollapseSumFamily(config.Strategy) && !largerScoreIsBetter { + return nil, merr.WrapErrParameterInvalidMsg( + "%s.collapse.strategy %s is only supported for positively related metrics", + elementScopeKey, config.Strategy) + } + if typeutil.GetSizeOfIDs(data.GetIds()) < int(totalRows) { return nil, merr.WrapErrServiceInternal(fmt.Sprintf("element best collapse: ids length (%d) is less than total rows (%d)", typeutil.GetSizeOfIDs(data.GetIds()), totalRows)) @@ -409,7 +456,8 @@ func collapseElementLevelResultByBestScore(result *milvuspb.SearchResults, large idxComputer := typeutil.NewFieldDataIdxComputer(data.GetFieldsData()) offset := int64(0) for _, topk := range topks { - selected := make(map[any]bestElementHit) + grouped := make(map[any][]bestElementHit) + groupOrder := make(map[any]int) for i := int64(0); i < topk; i++ { rowIdx := offset + i pk := typeutil.GetPK(data.GetIds(), rowIdx) @@ -417,23 +465,25 @@ func collapseElementLevelResultByBestScore(result *milvuspb.SearchResults, large continue } score := data.GetScores()[rowIdx] - hit, ok := selected[pk] - if !ok || isBetterElementScore(score, hit.score, largerScoreIsBetter) { - selected[pk] = bestElementHit{ - rowIdx: rowIdx, - score: score, - order: int(i), - } + if _, ok := grouped[pk]; !ok { + groupOrder[pk] = int(i) } + grouped[pk] = append(grouped[pk], bestElementHit{ + rowIdx: rowIdx, + score: score, + order: int(i), + }) } - hits := make([]bestElementHit, 0, len(selected)) - for _, hit := range selected { + hits := make([]bestElementHit, 0, len(grouped)) + for pk, pkHits := range grouped { + hit := aggregateElementHits(pkHits, config, largerScoreIsBetter) + hit.order = groupOrder[pk] hits = append(hits, hit) } sort.SliceStable(hits, func(i, j int) bool { - if hits[i].score != hits[j].score { - return isBetterElementScore(hits[i].score, hits[j].score, largerScoreIsBetter) + if hits[i].aggregate != hits[j].aggregate { + return isBetterElementScore(hits[i].aggregate, hits[j].aggregate, largerScoreIsBetter) } return hits[i].order < hits[j].order }) @@ -454,7 +504,9 @@ func collapseElementLevelResultByBestScore(result *milvuspb.SearchResults, large for i, hit := range hits { typeutil.AppendIDs(output.Ids, data.GetIds(), int(hit.rowIdx)) - output.Scores = append(output.Scores, data.GetScores()[hit.rowIdx]) + output.Scores = append(output.Scores, hit.aggregate) + // For aggregate collapse strategies, Score is the row aggregate while + // Distance/Recall keep the representative best element's values. if len(data.GetDistances()) > 0 { output.Distances = append(output.Distances, data.GetDistances()[hit.rowIdx]) } @@ -471,6 +523,230 @@ func collapseElementLevelResultByBestScore(result *milvuspb.SearchResults, large return copySearchResultsWithData(result, output), nil } +func aggregateElementHits(hits []bestElementHit, config elementCollapseConfig, largerScoreIsBetter bool) bestElementHit { + if len(hits) == 0 { + return bestElementHit{} + } + + bestHits := append([]bestElementHit(nil), hits...) + sort.SliceStable(bestHits, func(i, j int) bool { + if bestHits[i].score != bestHits[j].score { + return isBetterElementScore(bestHits[i].score, bestHits[j].score, largerScoreIsBetter) + } + return bestHits[i].order < bestHits[j].order + }) + + switch config.Strategy { + case elementCollapseSum, elementCollapseAvg: + sum := float32(0) + for _, hit := range hits { + sum += hit.score + } + selected := bestHits[0] + selected.aggregate = sum + selected.groupCount = len(hits) + if config.Strategy == elementCollapseAvg { + selected.aggregate = sum / float32(len(hits)) + } + return selected + case elementCollapseTopKSum, elementCollapseTopKAvg: + k := config.TopK + if k <= 0 || k > len(bestHits) { + k = len(bestHits) + } + sum := float32(0) + for _, hit := range bestHits[:k] { + sum += hit.score + } + selected := bestHits[0] + selected.aggregate = sum + selected.groupCount = k + if config.Strategy == elementCollapseTopKAvg { + selected.aggregate = sum / float32(k) + } + return selected + case elementCollapseMax: + fallthrough + default: + selected := bestHits[0] + selected.aggregate = selected.score + selected.groupCount = 1 + return selected + } +} + +func prepareElementLevelHybridResult(result *milvuspb.SearchResults) (*milvuspb.SearchResults, error) { + if result == nil || result.GetResults() == nil { + return result, nil + } + data := result.GetResults() + totalRows := int64(0) + for _, topk := range data.GetTopks() { + totalRows += topk + } + if totalRows == 0 { + output := &schemapb.SearchResultData{ + NumQueries: data.GetNumQueries(), + TopK: data.GetTopK(), + Topks: append([]int64(nil), data.GetTopks()...), + FieldsData: data.GetFieldsData(), + Scores: append([]float32(nil), data.GetScores()...), + Ids: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{}}}, + OutputFields: append([]string(nil), data.GetOutputFields()...), + AllSearchCount: data.GetAllSearchCount(), + PrimaryFieldName: data.GetPrimaryFieldName(), + ElementIndices: &schemapb.LongArray{}, + SearchIteratorV2Results: data.GetSearchIteratorV2Results(), + } + if len(data.GetDistances()) > 0 { + output.Distances = append([]float32(nil), data.GetDistances()...) + } + if len(data.GetRecalls()) > 0 { + output.Recalls = append([]float32(nil), data.GetRecalls()...) + } + return copySearchResultsWithData(result, output), nil + } + if typeutil.GetSizeOfIDs(data.GetIds()) < int(totalRows) { + return nil, merr.WrapErrServiceInternal(fmt.Sprintf("element-level hybrid: ids length (%d) is less than total rows (%d)", + typeutil.GetSizeOfIDs(data.GetIds()), totalRows)) + } + if data.GetElementIndices() == nil { + return nil, merr.WrapErrServiceInternal("element-level hybrid: missing element_indices") + } + if int64(len(data.GetElementIndices().GetData())) < totalRows { + return nil, merr.WrapErrServiceInternal(fmt.Sprintf("element-level hybrid: element_indices length (%d) is less than total rows (%d)", + len(data.GetElementIndices().GetData()), totalRows)) + } + + keys := make([]string, 0, totalRows) + for i := int64(0); i < totalRows; i++ { + keys = append(keys, makeHybridElementKey(typeutil.GetPK(data.GetIds(), i), data.GetElementIndices().GetData()[i])) + } + output := &schemapb.SearchResultData{ + NumQueries: data.GetNumQueries(), + TopK: data.GetTopK(), + Topks: append([]int64(nil), data.GetTopks()...), + FieldsData: data.GetFieldsData(), + Scores: append([]float32(nil), data.GetScores()...), + Ids: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: keys}}}, + OutputFields: append([]string(nil), data.GetOutputFields()...), + AllSearchCount: data.GetAllSearchCount(), + PrimaryFieldName: data.GetPrimaryFieldName(), + ElementIndices: data.GetElementIndices(), + SearchIteratorV2Results: data.GetSearchIteratorV2Results(), + } + if len(data.GetDistances()) > 0 { + output.Distances = append([]float32(nil), data.GetDistances()...) + } + if len(data.GetRecalls()) > 0 { + output.Recalls = append([]float32(nil), data.GetRecalls()...) + } + return copySearchResultsWithData(result, output), nil +} + +type elementKeyRestoreOperator struct { + enabled bool +} + +func newElementKeyRestoreOperator(t *searchTask, _ map[string]any) (operator, error) { + return &elementKeyRestoreOperator{enabled: t.hybridElementLevel}, nil +} + +func (op *elementKeyRestoreOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { + if len(inputs) < 1 { + return nil, merr.WrapErrServiceInternal("element key restore: missing inputs") + } + + target := inputs[len(inputs)-1] + if !op.enabled { + return []any{target}, nil + } + + switch v := target.(type) { + case *milvuspb.SearchResults: + if v == nil || v.GetResults() == nil { + return []any{v}, nil + } + restored, err := restoreElementLevelHybridRankResult(v) + if err != nil { + return nil, err + } + return []any{restored}, nil + case []*milvuspb.SearchResults: + restored := make([]*milvuspb.SearchResults, len(v)) + for i, result := range v { + if result == nil || result.GetResults() == nil { + restored[i] = result + continue + } + var err error + restored[i], err = restoreElementLevelHybridRankResult(result) + if err != nil { + return nil, err + } + } + return []any{restored}, nil + default: + return nil, merr.WrapErrParameterInvalidMsg("element key restore: input must be *SearchResults or []*SearchResults, got %T", target) + } +} + +func restoreElementLevelHybridRankResult(rankResult *milvuspb.SearchResults) (*milvuspb.SearchResults, error) { + data := rankResult.GetResults() + size := typeutil.GetSizeOfIDs(data.GetIds()) + outputIDs := &schemapb.IDs{} + elementIndices := make([]int64, 0, size) + for i := 0; i < size; i++ { + rawKey := typeutil.GetPK(data.GetIds(), int64(i)) + key, ok := rawKey.(string) + if !ok { + return nil, merr.WrapErrServiceInternal(fmt.Sprintf("element key restore: expected string element key, got %T", rawKey)) + } + pk, elementIndex, ok := parseHybridElementKey(key) + if !ok { + return nil, merr.WrapErrServiceInternal(fmt.Sprintf("element key restore: invalid element key %q", key)) + } + appendPK(outputIDs, pk) + elementIndices = append(elementIndices, elementIndex) + } + + output := &schemapb.SearchResultData{ + NumQueries: data.GetNumQueries(), + TopK: data.GetTopK(), + Topks: append([]int64(nil), data.GetTopks()...), + FieldsData: data.GetFieldsData(), + Scores: append([]float32(nil), data.GetScores()...), + Ids: outputIDs, + OutputFields: append([]string(nil), data.GetOutputFields()...), + AllSearchCount: data.GetAllSearchCount(), + PrimaryFieldName: data.GetPrimaryFieldName(), + ElementIndices: &schemapb.LongArray{Data: elementIndices}, + SearchIteratorV2Results: data.GetSearchIteratorV2Results(), + } + if len(data.GetDistances()) > 0 { + output.Distances = append([]float32(nil), data.GetDistances()...) + } + if len(data.GetRecalls()) > 0 { + output.Recalls = append([]float32(nil), data.GetRecalls()...) + } + return copySearchResultsWithData(rankResult, output), nil +} + +func appendPK(ids *schemapb.IDs, pk any) { + switch v := pk.(type) { + case int64: + if ids.GetIntId() == nil { + ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{}} + } + ids.GetIntId().Data = append(ids.GetIntId().Data, v) + case string: + if ids.GetStrId() == nil { + ids.IdField = &schemapb.IDs_StrId{StrId: &schemapb.StringArray{}} + } + ids.GetStrId().Data = append(ids.GetStrId().Data, v) + } +} + func isBetterElementScore(candidate, current float32, largerScoreIsBetter bool) bool { if largerScoreIsBetter { return candidate > current @@ -814,12 +1090,14 @@ func pickFieldData(ids *schemapb.IDs, pkOffset map[any]int, fields []*schemapb.F // hybridAssembleOperator picks field data for reranked IDs directly from // multiple sub-search results, avoiding a full field-data merge. type hybridAssembleOperator struct { - collectionID int64 + collectionID int64 + elementLevelHybrid bool } func newHybridAssembleOperator(t *searchTask, _ map[string]any) (operator, error) { return &hybridAssembleOperator{ - collectionID: t.GetCollectionID(), + collectionID: t.GetCollectionID(), + elementLevelHybrid: t.hybridElementLevel, }, nil } @@ -841,11 +1119,19 @@ func (op *hybridAssembleOperator) run(ctx context.Context, span trace.Span, inpu type pkLoc struct{ resultIdx, rowIdx int } + // Build candidate-key -> (resultIdx, rowIdx) index across all sub-search results. + // Row-level hybrid keys by PK; element-level hybrid keys by (PK, element_index). pkIndex := make(map[any]pkLoc) for rIdx, result := range reducedResults { ids := result.GetResults().GetIds() for i := 0; i < typeutil.GetSizeOfIDs(ids); i++ { - pkIndex[typeutil.GetPK(ids, int64(i))] = pkLoc{rIdx, i} + key := typeutil.GetPK(ids, int64(i)) + if op.elementLevelHybrid { + if rawKey, ok := key.(string); ok { + key = rawKey + } + } + pkIndex[key] = pkLoc{rIdx, i} } } @@ -870,16 +1156,24 @@ func (op *hybridAssembleOperator) run(ctx context.Context, span trace.Span, inpu itemsByResult := make([][]rowIdxComputeItem, len(reducedResults)) for i := 0; i < numReranked; i++ { - pk := typeutil.GetPK(rerankedIDs, int64(i)) - loc, ok := pkIndex[pk] + candidateKey := typeutil.GetPK(rerankedIDs, int64(i)) + if op.elementLevelHybrid { + elementIndices := rankResult.GetResults().GetElementIndices().GetData() + if i >= len(elementIndices) { + return nil, merr.WrapErrServiceInternal(fmt.Sprintf("hybrid assemble: missing element index for reranked row %d, collection=%d", i, op.collectionID)) + } + candidateKey = makeHybridElementKey(candidateKey, elementIndices[i]) + } + loc, ok := pkIndex[candidateKey] if !ok { return nil, merr.WrapErrInconsistentRequery( - fmt.Sprintf("hybrid assemble: missing id %v, collection=%d", pk, op.collectionID)) + fmt.Sprintf("hybrid assemble: missing id %v, collection=%d", candidateKey, op.collectionID)) } if computers[loc.resultIdx] == nil { return nil, merr.WrapErrServiceInternal(fmt.Sprintf( - "hybrid assemble: sub-result[%d] has empty FieldsData but contributed reranked id %v; collection=%d", - loc.resultIdx, pk, op.collectionID)) + "hybrid assemble: sub-result[%d] has empty FieldsData but contributed reranked id %v; "+ + "all sub-results that contribute ids must share the same FieldsData layout, "+ + "collection=%d", loc.resultIdx, candidateKey, op.collectionID)) } locs[i] = loc itemsByResult[loc.resultIdx] = append(itemsByResult[loc.resultIdx], rowIdxComputeItem{ @@ -1300,6 +1594,12 @@ var hybridSearchPipe = &pipelineDef{ outputs: []string{"rank_result"}, opName: rerankOp, }, + { + name: "restore_element_keys", + inputs: []string{"collapsed", "rank_result"}, + outputs: []string{"rank_result"}, + opName: elementKeyRestoreOp, + }, { name: "assemble", inputs: []string{"collapsed", "rank_result"}, @@ -1325,8 +1625,14 @@ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{ opName: elementBestCollapseOp, }, { - name: "merge_ids", + name: "restore_element_keys_for_requery", inputs: []string{"collapsed"}, + outputs: []string{"requery_data"}, + opName: elementKeyRestoreOp, + }, + { + name: "merge_ids", + inputs: []string{"requery_data"}, outputs: []string{"ids"}, opName: lambdaOp, params: map[string]any{ @@ -1341,7 +1647,7 @@ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{ }, { name: "parse_ids", - inputs: []string{"collapsed"}, + inputs: []string{"requery_data"}, outputs: []string{"id_list"}, opName: lambdaOp, params: map[string]any{ @@ -1382,6 +1688,12 @@ var hybridSearchWithRequeryAndRerankByFieldDataPipe = &pipelineDef{ outputs: []string{"rank_result"}, opName: rerankOp, }, + { + name: "restore_element_keys", + inputs: []string{"rank_data", "rank_result"}, + outputs: []string{"rank_result"}, + opName: elementKeyRestoreOp, + }, { name: "pick_ids", inputs: []string{"rank_result"}, @@ -1437,6 +1749,12 @@ var hybridSearchWithRequeryPipe = &pipelineDef{ outputs: []string{"rank_result"}, opName: rerankOp, }, + { + name: "restore_element_keys", + inputs: []string{"collapsed", "rank_result"}, + outputs: []string{"rank_result"}, + opName: elementKeyRestoreOp, + }, { name: "pick_ids", inputs: []string{"rank_result"}, diff --git a/internal/proxy/search_pipeline_test.go b/internal/proxy/search_pipeline_test.go index 9d03a6622adec..53fb983655c54 100644 --- a/internal/proxy/search_pipeline_test.go +++ b/internal/proxy/search_pipeline_test.go @@ -318,6 +318,359 @@ func (s *SearchPipelineSuite) TestElementBestCollapseOpPassesRowLevelResultsThro s.Same(input, results[0]) } +func (s *SearchPipelineSuite) TestElementBestCollapseOpAllowsEmptyElementLevelResultWithoutMetric() { + input := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 0, + Topks: []int64{0}, + ElementIndices: &schemapb.LongArray{}, + AllSearchCount: 10, + }, + } + + tests := []struct { + name string + config elementCollapseConfig + }{ + {name: "default max"}, + {name: "topk sum", config: elementCollapseConfig{Strategy: elementCollapseTopKSum, TopK: 2}}, + } + for _, test := range tests { + s.Run(test.name, func() { + op := &elementBestCollapseOperator{} + if test.config.Strategy != "" { + op.configs = []elementCollapseConfig{test.config} + } + out, err := op.run(context.Background(), s.span, []*milvuspb.SearchResults{input}, []string{""}) + s.Require().NoError(err) + + result := out[0].([]*milvuspb.SearchResults)[0].GetResults() + s.Nil(result.GetElementIndices()) + s.Equal(int64(1), result.GetNumQueries()) + s.Equal(int64(0), result.GetTopK()) + s.Equal([]int64{0}, result.GetTopks()) + s.Empty(result.GetScores()) + s.Equal(int64(10), result.GetAllSearchCount()) + }) + } +} + +func (s *SearchPipelineSuite) TestElementBestCollapseOpDeduplicatesEqualScoreElementsByRowID() { + input := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 3, + Topks: []int64{3}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{1, 1, 2}}, + }, + }, + Scores: []float32{0.50, 0.50, 0.40}, + ElementIndices: &schemapb.LongArray{Data: []int64{0, 1, 0}}, + }, + } + + op := &elementBestCollapseOperator{} + out, err := op.run(context.Background(), s.span, []*milvuspb.SearchResults{input}, []string{"IP"}) + s.Require().NoError(err) + + result := out[0].([]*milvuspb.SearchResults)[0].GetResults() + s.Nil(result.GetElementIndices()) + s.Equal([]int64{1, 2}, result.GetIds().GetIntId().GetData()) + s.Equal([]float32{0.50, 0.40}, result.GetScores()) +} + +func (s *SearchPipelineSuite) TestElementBestCollapseOpUsesMetricDirection() { + input := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 3, + Topks: []int64{3}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{1, 1, 2}}, + }, + }, + Scores: []float32{0.8, 0.2, 0.5}, + ElementIndices: &schemapb.LongArray{Data: []int64{0, 3, 1}}, + }, + } + + op := &elementBestCollapseOperator{} + out, err := op.run(context.Background(), s.span, []*milvuspb.SearchResults{input}, []string{"L2"}) + s.Require().NoError(err) + + result := out[0].([]*milvuspb.SearchResults)[0].GetResults() + s.Nil(result.GetElementIndices()) + s.Equal([]int64{1, 2}, result.GetIds().GetIntId().GetData()) + s.Equal([]float32{0.2, 0.5}, result.GetScores()) +} + +func (s *SearchPipelineSuite) TestElementBestCollapseOpUsesConfiguredCollapseStrategies() { + makeInput := func() *milvuspb.SearchResults { + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 6, + Topks: []int64{6}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{1, 1, 1, 2, 2, 3}}, + }, + }, + Scores: []float32{0.9, 0.6, 0.3, 0.5, 0.1, 0.55}, + Distances: []float32{0.9, 0.6, 0.3, 0.5, 0.1, 0.55}, + ElementIndices: &schemapb.LongArray{Data: []int64{0, 1, 2, 0, 1, 0}}, + }, + } + } + + tests := []struct { + name string + config elementCollapseConfig + expectedIDs []int64 + expectedScores []float32 + expectedDists []float32 + }{ + {name: "sum", config: elementCollapseConfig{Strategy: elementCollapseSum}, expectedIDs: []int64{1, 2, 3}, expectedScores: []float32{1.8, 0.6, 0.55}, expectedDists: []float32{0.9, 0.5, 0.55}}, + {name: "avg", config: elementCollapseConfig{Strategy: elementCollapseAvg}, expectedIDs: []int64{1, 3, 2}, expectedScores: []float32{0.6, 0.55, 0.3}, expectedDists: []float32{0.9, 0.55, 0.5}}, + {name: "topk_sum", config: elementCollapseConfig{Strategy: elementCollapseTopKSum, TopK: 2}, expectedIDs: []int64{1, 2, 3}, expectedScores: []float32{1.5, 0.6, 0.55}, expectedDists: []float32{0.9, 0.5, 0.55}}, + {name: "topk_avg", config: elementCollapseConfig{Strategy: elementCollapseTopKAvg, TopK: 2}, expectedIDs: []int64{1, 3, 2}, expectedScores: []float32{0.75, 0.55, 0.3}, expectedDists: []float32{0.9, 0.55, 0.5}}, + } + + for _, test := range tests { + s.Run(test.name, func() { + op := &elementBestCollapseOperator{configs: []elementCollapseConfig{test.config}} + out, err := op.run(context.Background(), s.span, []*milvuspb.SearchResults{makeInput()}, []string{"IP"}) + s.Require().NoError(err) + + result := out[0].([]*milvuspb.SearchResults)[0].GetResults() + s.Nil(result.GetElementIndices()) + s.Equal(test.expectedIDs, result.GetIds().GetIntId().GetData()) + s.InDeltaSlice(test.expectedScores, result.GetScores(), 0.00001) + s.InDeltaSlice(test.expectedDists, result.GetDistances(), 0.00001) + }) + } +} + +func (s *SearchPipelineSuite) TestElementBestCollapseOpRejectsSumCollapseForNegativeMetrics() { + input := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 2, + Topks: []int64{2}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{1, 1}}, + }, + }, + Scores: []float32{0.8, 0.2}, + ElementIndices: &schemapb.LongArray{Data: []int64{0, 1}}, + }, + } + + op := &elementBestCollapseOperator{configs: []elementCollapseConfig{{Strategy: elementCollapseTopKSum, TopK: 2}}} + _, err := op.run(context.Background(), s.span, []*milvuspb.SearchResults{input}, []string{"L2"}) + s.Require().Error(err) + s.ErrorIs(err, merr.ErrParameterInvalid) + s.Contains(err.Error(), "only supported for positively related metrics") +} + +func (s *SearchPipelineSuite) TestElementLevelHybridPrepareAndRestoreKeys() { + input := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 3, + Topks: []int64{3}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{10, 10, 20}}, + }, + }, + Scores: []float32{0.8, 0.9, 0.7}, + ElementIndices: &schemapb.LongArray{Data: []int64{0, 2, 1}}, + }, + } + + prepared, err := prepareElementLevelHybridResult(input) + s.Require().NoError(err) + preparedIDs := prepared.GetResults().GetIds().GetStrId().GetData() + s.Require().Len(preparedIDs, 3) + + rankResult := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 2, + Topks: []int64{2}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{Data: []string{preparedIDs[1], preparedIDs[2]}}, + }, + }, + Scores: []float32{0.99, 0.88}, + }, + } + + restored, err := restoreElementLevelHybridRankResult(rankResult) + s.Require().NoError(err) + s.Equal([]int64{10, 20}, restored.GetResults().GetIds().GetIntId().GetData()) + s.Equal([]int64{2, 1}, restored.GetResults().GetElementIndices().GetData()) + s.Equal([]float32{0.99, 0.88}, restored.GetResults().GetScores()) +} + +func (s *SearchPipelineSuite) TestElementLevelHybridRerankAcceptsSyntheticStringKeysForInt64PK() { + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, + } + functionScore, err := rerank.NewFunctionScoreWithlegacyAndPKType(schema, nil, schemapb.DataType_VarChar) + s.Require().NoError(err) + + input := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 3, + Topks: []int64{3}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{10, 10, 20}}, + }, + }, + Scores: []float32{0.8, 0.9, 0.7}, + ElementIndices: &schemapb.LongArray{Data: []int64{0, 2, 1}}, + }, + } + prepared, err := prepareElementLevelHybridResult(input) + s.Require().NoError(err) + + op := &rerankOperator{ + nq: 1, + topK: 2, + offset: 0, + roundDecimal: -1, + functionScore: functionScore, + } + out, err := op.run(context.Background(), s.span, []*milvuspb.SearchResults{prepared}, []string{"IP"}) + s.Require().NoError(err) + + rankResult := out[0].(*milvuspb.SearchResults) + s.Require().NotNil(rankResult.GetResults().GetIds().GetStrId()) + + restored, err := restoreElementLevelHybridRankResult(rankResult) + s.Require().NoError(err) + s.Equal([]int64{10, 10}, restored.GetResults().GetIds().GetIntId().GetData()) + s.Equal([]int64{0, 2}, restored.GetResults().GetElementIndices().GetData()) +} + +func (s *SearchPipelineSuite) TestElementLevelHybridDecayRerankAcceptsInt64PKInputField() { + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, + } + functionSchema := &schemapb.FunctionSchema{ + Name: "test_decay", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{"pk"}, + OutputFieldNames: []string{}, + Params: []*commonpb.KeyValuePair{ + {Key: "reranker", Value: "decay"}, + {Key: "origin", Value: "10"}, + {Key: "scale", Value: "4"}, + {Key: "offset", Value: "0"}, + {Key: "decay", Value: "0.5"}, + {Key: "function", Value: "gauss"}, + }, + } + functionScore, err := rerank.NewFunctionScoreWithPKType(schema, &schemapb.FunctionScore{ + Functions: []*schemapb.FunctionSchema{functionSchema}, + }, &models.ModelExtraInfo{ClusterID: "test-cluster", DBName: "test-db"}, schemapb.DataType_VarChar) + s.Require().NoError(err) + + input := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 3, + Topks: []int64{3}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{10, 10, 20}}, + }, + }, + FieldsData: []*schemapb.FieldData{{ + FieldName: "pk", + FieldId: 100, + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{Data: []int64{10, 10, 20}}, + }, + }, + }, + }}, + Scores: []float32{0.8, 0.9, 0.7}, + ElementIndices: &schemapb.LongArray{Data: []int64{0, 2, 1}}, + }, + } + prepared, err := prepareElementLevelHybridResult(input) + s.Require().NoError(err) + + op := &rerankOperator{ + nq: 1, + topK: 2, + offset: 0, + roundDecimal: -1, + functionScore: functionScore, + } + out, err := op.run(context.Background(), s.span, []*milvuspb.SearchResults{prepared}, []string{"IP"}) + s.Require().NoError(err) + + rankResult := out[0].(*milvuspb.SearchResults) + s.Require().NotNil(rankResult.GetResults().GetIds().GetStrId()) + _, err = restoreElementLevelHybridRankResult(rankResult) + s.Require().NoError(err) +} + +func (s *SearchPipelineSuite) TestElementBestCollapseOpRejectsEmptyMetricForElementLevelResult() { + input := &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + Topks: []int64{1}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{1}}, + }, + }, + Scores: []float32{0.8}, + ElementIndices: &schemapb.LongArray{Data: []int64{0}}, + }, + } + + op := &elementBestCollapseOperator{} + _, err := op.run(context.Background(), s.span, []*milvuspb.SearchResults{input}, []string{""}) + s.Error(err) + s.Contains(err.Error(), "missing metric type") +} + func (s *SearchPipelineSuite) TestHybridAssembleOpPicksFieldsFromCollapsedSubResults() { reduced := []*milvuspb.SearchResults{ { @@ -368,6 +721,64 @@ func (s *SearchPipelineSuite) TestHybridAssembleOpPicksFieldsFromCollapsedSubRes s.Equal([]int64{30, 10}, result.GetFieldsData()[0].GetScalars().GetLongData().GetData()) } +func (s *SearchPipelineSuite) TestHybridAssembleOpElementLevelHybridUsesElementKey() { + reduced := &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 3, + Topks: []int64{3}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{Data: []string{ + makeHybridElementKey(int64(10), 0), + makeHybridElementKey(int64(10), 2), + makeHybridElementKey(int64(20), 1), + }}, + }, + }, + Scores: []float32{0.8, 0.9, 0.7}, + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: "value", + FieldId: 101, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{Data: []int64{100, 200, 300}}, + }, + }, + }, + }, + }, + }, + } + rankResult := &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 2, + Topks: []int64{2}, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{10, 20}}, + }, + }, + Scores: []float32{0.99, 0.88}, + ElementIndices: &schemapb.LongArray{Data: []int64{2, 1}}, + }, + } + + op := &hybridAssembleOperator{collectionID: 12345, elementLevelHybrid: true} + out, err := op.run(context.Background(), s.span, []*milvuspb.SearchResults{reduced}, rankResult) + s.Require().NoError(err) + + result := out[0].(*milvuspb.SearchResults).GetResults() + s.Equal([]int64{10, 20}, result.GetIds().GetIntId().GetData()) + s.Equal([]int64{2, 1}, result.GetElementIndices().GetData()) + s.Equal([]float32{0.99, 0.88}, result.GetScores()) + s.Equal([]int64{200, 300}, result.GetFieldsData()[0].GetScalars().GetLongData().GetData()) +} + func (s *SearchPipelineSuite) TestComputeFieldIdxsByOriginalOrderUsesAscendingRowsAndPreservesOutputOrder() { rowIdxs := []int64{5, 1, 4, 2} calls := make([]int64, 0, len(rowIdxs)) diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 297ed6d0f6a11..1fc1a95b7350f 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -296,26 +296,8 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb return nil, fmt.Errorf("parse iterator v2 info failed: %w", err) } - // 7. check search for embedding list - annsFieldName, _ := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, searchParamsPair) - if annsFieldName != "" { - annField := typeutil.GetFieldByName(schema, annsFieldName) - if annField != nil && annField.GetDataType() == schemapb.DataType_ArrayOfVector { - if strings.Contains(searchParamStr, radiusKey) { - return nil, merr.WrapErrParameterInvalid("", "", - "range search is not supported for vector array (embedding list) fields, fieldName:", annsFieldName) - } - - if groupByFieldId > 0 { - return nil, merr.WrapErrParameterInvalid("", "", - "group by search is not supported for vector array (embedding list) fields, fieldName:", annsFieldName) - } - - if isIterator { - return nil, merr.WrapErrParameterInvalid("", "", - "search iterator is not supported for vector array (embedding list) fields, fieldName:", annsFieldName) - } - } + if err := validateVectorArraySearchInfo(searchParamsPair, schema, searchParamStr, groupByFieldId, isIterator || planSearchIteratorV2Info != nil); err != nil { + return nil, err } return &SearchInfo{ @@ -339,6 +321,36 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb }, nil } +func validateVectorArraySearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, searchParamStr string, groupByFieldId int64, isIterator bool) error { + if schema == nil { + return nil + } + + annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, searchParamsPair) + if err != nil || annsFieldName == "" { + return nil + } + + annsField := typeutil.GetFieldByName(schema, annsFieldName) + if annsField == nil || annsField.GetDataType() != schemapb.DataType_ArrayOfVector { + return nil + } + + if strings.Contains(searchParamStr, radiusKey) { + return merr.WrapErrParameterInvalid("", "", + "range search is not supported for vector array fields, fieldName:"+annsField.GetName()) + } + if groupByFieldId > 0 { + return merr.WrapErrParameterInvalid("", "", + "group by search is not supported for vector array fields, fieldName:"+annsField.GetName()) + } + if isIterator { + return merr.WrapErrParameterInvalid("", "", + "search iterator is not supported for vector array fields, fieldName:"+annsField.GetName()) + } + return nil +} + func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) { outputFieldIDs = make([]UniqueID, 0, len(outputFields)) for _, name := range outputFields { diff --git a/internal/proxy/struct_hybrid_search.go b/internal/proxy/struct_hybrid_search.go new file mode 100644 index 0000000000000..29cf7e3a5f940 --- /dev/null +++ b/internal/proxy/struct_hybrid_search.go @@ -0,0 +1,289 @@ +package proxy + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/tidwall/gjson" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/common" + "github.com/milvus-io/milvus/pkg/v2/util/funcutil" + "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/metric" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +const ( + elementScopeKey = "element_scope" + elementCollapseMax = "max" + elementCollapseSum = "sum" + elementCollapseAvg = "avg" + elementCollapseTopKSum = "topk_sum" + elementCollapseTopKAvg = "topk_avg" + hybridElementKeyPrefix = "__milvus_element_key" + hybridElementKeySep = "\x1f" + hybridElementKeyIntPK = "i" + hybridElementKeyStringPK = "s" +) + +type elementCollapseConfig struct { + Strategy string + TopK int +} + +type hybridSubSearchKind int + +const ( + hybridSubSearchNormal hybridSubSearchKind = iota + hybridSubSearchStructEmbList + hybridSubSearchStructElement +) + +type hybridSubSearchInfo struct { + Kind hybridSubSearchKind + ParentStructFieldName string + ElementScopeProvided bool + Collapse elementCollapseConfig +} + +func defaultElementCollapseConfig() elementCollapseConfig { + return elementCollapseConfig{Strategy: elementCollapseMax} +} + +func parseAndRemoveElementScope(searchParamStr string) (elementCollapseConfig, bool, string, error) { + if !gjson.Get(searchParamStr, elementScopeKey).Exists() { + return elementCollapseConfig{}, false, searchParamStr, nil + } + if !gjson.Valid(searchParamStr) { + return elementCollapseConfig{}, false, "", merr.WrapErrParameterInvalidMsg("%s must be valid JSON", ParamsKey) + } + + var root map[string]json.RawMessage + if err := json.Unmarshal([]byte(searchParamStr), &root); err != nil { + return elementCollapseConfig{}, false, "", merr.WrapErrParameterInvalidMsg("%s must be a JSON object: %v", ParamsKey, err) + } + scopeRaw, ok := root[elementScopeKey] + if !ok { + return elementCollapseConfig{}, false, searchParamStr, nil + } + + cfg, err := parseElementScope(scopeRaw) + if err != nil { + return elementCollapseConfig{}, false, "", err + } + delete(root, elementScopeKey) + sanitized, err := json.Marshal(root) + if err != nil { + return elementCollapseConfig{}, false, "", merr.WrapErrServiceInternal(fmt.Sprintf("failed to rewrite search params without %s: %v", elementScopeKey, err)) + } + return cfg, true, string(sanitized), nil +} + +func parseElementScope(scopeRaw json.RawMessage) (elementCollapseConfig, error) { + var scope map[string]json.RawMessage + if err := json.Unmarshal(scopeRaw, &scope); err != nil { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("%s must be a JSON object: %v", elementScopeKey, err) + } + for key := range scope { + if key != "collapse" { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("unsupported %s key: %s", elementScopeKey, key) + } + } + collapseRaw, ok := scope["collapse"] + if !ok { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("%s.collapse is required", elementScopeKey) + } + + var collapse map[string]json.RawMessage + if err := json.Unmarshal(collapseRaw, &collapse); err != nil { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("%s.collapse must be a JSON object: %v", elementScopeKey, err) + } + for key := range collapse { + if key != "strategy" && key != "topk" { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("unsupported %s.collapse key: %s", elementScopeKey, key) + } + } + + var strategy string + if raw, ok := collapse["strategy"]; ok { + if err := json.Unmarshal(raw, &strategy); err != nil { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("%s.collapse.strategy must be a string", elementScopeKey) + } + } + strategy = strings.TrimSpace(strategy) + if strategy == "" { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("%s.collapse.strategy is required", elementScopeKey) + } + if !isSupportedElementCollapseStrategy(strategy) { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("unsupported %s.collapse.strategy: %s", elementScopeKey, strategy) + } + + cfg := elementCollapseConfig{Strategy: strategy} + if raw, ok := collapse["topk"]; ok { + if err := json.Unmarshal(raw, &cfg.TopK); err != nil { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("%s.collapse.topk must be an integer", elementScopeKey) + } + if cfg.TopK <= 0 { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("%s.collapse.topk must be positive", elementScopeKey) + } + } + + switch strategy { + case elementCollapseTopKSum, elementCollapseTopKAvg: + if cfg.TopK <= 0 { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("%s.collapse.topk is required for strategy %s", elementScopeKey, strategy) + } + default: + if cfg.TopK != 0 { + return elementCollapseConfig{}, merr.WrapErrParameterInvalidMsg("%s.collapse.topk is only valid for topk strategies", elementScopeKey) + } + } + return cfg, nil +} + +func isSupportedElementCollapseStrategy(strategy string) bool { + switch strategy { + case elementCollapseMax, elementCollapseSum, elementCollapseAvg, elementCollapseTopKSum, elementCollapseTopKAvg: + return true + default: + return false + } +} + +func isElementCollapseSumFamily(strategy string) bool { + return strategy == elementCollapseSum || strategy == elementCollapseTopKSum +} + +func validateElementCollapseMetricType(config elementCollapseConfig, metricType string) error { + if config.Strategy == "" || + !isElementCollapseSumFamily(config.Strategy) || + strings.TrimSpace(metricType) == "" || + metric.PositivelyRelated(metricType) { + return nil + } + return merr.WrapErrParameterInvalidMsg( + "%s.collapse.strategy %s is only supported for positively related metrics", + elementScopeKey, config.Strategy) +} + +func resolveElementCollapseMetricType(requestMetricType string, field *schemapb.FieldSchema) string { + if strings.TrimSpace(requestMetricType) != "" || field == nil { + return requestMetricType + } + indexMetricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, field.GetIndexParams()) + if err != nil { + return "" + } + return indexMetricType +} + +func isEmbeddingListPlaceholderType(pt commonpb.PlaceholderType) bool { + switch pt { + case commonpb.PlaceholderType_EmbListFloatVector, + commonpb.PlaceholderType_EmbListFloat16Vector, + commonpb.PlaceholderType_EmbListBFloat16Vector, + commonpb.PlaceholderType_EmbListBinaryVector, + commonpb.PlaceholderType_EmbListInt8Vector: + return true + default: + return false + } +} + +func getPlaceholderType(phgBytes []byte) commonpb.PlaceholderType { + holder := &commonpb.PlaceholderGroup{} + if err := proto.Unmarshal(phgBytes, holder); err != nil || len(holder.GetPlaceholders()) == 0 { + return 0 + } + return holder.GetPlaceholders()[0].GetType() +} + +func getStructParentFieldName(schema *schemapb.CollectionSchema, fieldID int64) (string, bool) { + for _, structField := range schema.GetStructArrayFields() { + for _, subField := range structField.GetFields() { + if subField.GetFieldID() == fieldID { + return structField.GetName(), true + } + } + } + return "", false +} + +func classifyHybridSubSearch(schema *schemapb.CollectionSchema, fieldID int64, placeholderType commonpb.PlaceholderType) hybridSubSearchInfo { + field := typeutil.GetField(schema, fieldID) + if field == nil || field.GetDataType() != schemapb.DataType_ArrayOfVector { + return hybridSubSearchInfo{Kind: hybridSubSearchNormal} + } + parent, ok := getStructParentFieldName(schema, fieldID) + if !ok { + return hybridSubSearchInfo{Kind: hybridSubSearchNormal} + } + if placeholderType == 0 || isEmbeddingListPlaceholderType(placeholderType) { + return hybridSubSearchInfo{Kind: hybridSubSearchStructEmbList, ParentStructFieldName: parent} + } + return hybridSubSearchInfo{Kind: hybridSubSearchStructElement, ParentStructFieldName: parent} +} + +func inferElementLevelHybrid(infos []hybridSubSearchInfo) bool { + if len(infos) == 0 { + return false + } + parent := "" + for _, info := range infos { + if info.Kind != hybridSubSearchStructElement { + return false + } + if parent == "" { + parent = info.ParentStructFieldName + continue + } + if info.ParentStructFieldName != parent { + return false + } + } + return true +} + +func makeHybridElementKey(pk any, elementIndex int64) string { + switch v := pk.(type) { + case int64: + return fmt.Sprintf("%s%s%s%s%d%s%d", hybridElementKeyPrefix, hybridElementKeySep, hybridElementKeyIntPK, hybridElementKeySep, v, hybridElementKeySep, elementIndex) + case string: + return fmt.Sprintf("%s%s%s%s%s%s%d", hybridElementKeyPrefix, hybridElementKeySep, hybridElementKeyStringPK, hybridElementKeySep, base64.RawStdEncoding.EncodeToString([]byte(v)), hybridElementKeySep, elementIndex) + default: + return fmt.Sprintf("%s%s%T%s%v%s%d", hybridElementKeyPrefix, hybridElementKeySep, pk, hybridElementKeySep, pk, hybridElementKeySep, elementIndex) + } +} + +func parseHybridElementKey(key string) (any, int64, bool) { + parts := strings.Split(key, hybridElementKeySep) + if len(parts) != 4 || parts[0] != hybridElementKeyPrefix { + return nil, 0, false + } + elementIndex, err := strconv.ParseInt(parts[3], 10, 64) + if err != nil { + return nil, 0, false + } + switch parts[1] { + case hybridElementKeyIntPK: + pk, err := strconv.ParseInt(parts[2], 10, 64) + if err != nil { + return nil, 0, false + } + return pk, elementIndex, true + case hybridElementKeyStringPK: + decoded, err := base64.RawStdEncoding.DecodeString(parts[2]) + if err != nil { + return nil, 0, false + } + return string(decoded), elementIndex, true + default: + return nil, 0, false + } +} diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index ff3339baaa657..0fc694e73173e 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -110,6 +110,9 @@ type searchTask struct { userRequestedPkFieldExplicitly bool storageCost segcore.StorageCost + + hybridSubSearchInfos []hybridSubSearchInfo + hybridElementLevel bool } func (t *searchTask) CanSkipAllocTimestamp() bool { @@ -442,6 +445,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { t.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs())) t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs())) + t.hybridSubSearchInfos = make([]hybridSubSearchInfo, len(t.request.GetSubReqs())) + t.hybridElementLevel = false queryFieldIDs := []int64{} for index, subReq := range t.request.GetSubReqs() { plan, queryInfo, offset, subIsIterator, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl(), subReq.GetExprTemplateValues()) @@ -449,23 +454,50 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { return err } - // Hybrid search supports plain top-K on ArrayOfVector fields. Advanced - // element-level features are intentionally not enabled in this 2.6 path. + placeholderType := getPlaceholderType(subReq.GetPlaceholderGroup()) + subSearchInfo := classifyHybridSubSearch(t.schema.CollectionSchema, queryInfo.GetQueryFieldId(), placeholderType) annsField := typeutil.GetField(t.schema.CollectionSchema, queryInfo.GetQueryFieldId()) - if annsField != nil && annsField.GetDataType() == schemapb.DataType_ArrayOfVector { - if gjson.Get(queryInfo.GetSearchParams(), radiusKey).Exists() { - return merr.WrapErrParameterInvalid("", "", - "range search is not supported for vector array (embedding list) fields in hybrid search, fieldName:"+annsField.GetName()) + collapseConfig, elementScopeProvided, sanitizedSearchParams, err := parseAndRemoveElementScope(queryInfo.GetSearchParams()) + if err != nil { + return err + } + if elementScopeProvided { + if subSearchInfo.Kind != hybridSubSearchStructElement { + return merr.WrapErrParameterInvalidMsg("%s is only supported for element-level search on struct array vector sub-fields", elementScopeKey) } - if t.rankParams.GetGroupByFieldId() > 0 { - return merr.WrapErrParameterInvalid("", "", - "group by search is not supported for vector array (embedding list) fields in hybrid search, fieldName:"+annsField.GetName()) + if err := validateElementCollapseMetricType(collapseConfig, resolveElementCollapseMetricType(queryInfo.GetMetricType(), annsField)); err != nil { + return err } - if subIsIterator { - return merr.WrapErrParameterInvalid("", "", - "search iterator is not supported for vector array (embedding list) fields in hybrid search, fieldName:"+annsField.GetName()) + queryInfo.SearchParams = sanitizedSearchParams + subSearchInfo.ElementScopeProvided = true + subSearchInfo.Collapse = collapseConfig + } + + // Hybrid search only supports plain top-K on ArrayOfVector fields. Both + // element-level and embedding-list searches reject advanced controls here. + if annsField != nil && annsField.GetDataType() == schemapb.DataType_ArrayOfVector { + isStructElementSubSearch := subSearchInfo.Kind == hybridSubSearchStructElement + isStructEmbListSubSearch := subSearchInfo.Kind == hybridSubSearchStructEmbList + if isStructElementSubSearch || isStructEmbListSubSearch { + searchKind := "element-level" + if isStructEmbListSubSearch { + searchKind = "embedding-list" + } + if gjson.Get(queryInfo.GetSearchParams(), radiusKey).Exists() { + return merr.WrapErrParameterInvalid("", "", + "range search is not supported for vector array ("+searchKind+") fields in hybrid search, fieldName:"+annsField.GetName()) + } + if t.rankParams.GetGroupByFieldId() > 0 { + return merr.WrapErrParameterInvalid("", "", + "group by search is not supported for vector array ("+searchKind+") fields in hybrid search, fieldName:"+annsField.GetName()) + } + if subIsIterator { + return merr.WrapErrParameterInvalid("", "", + "search iterator is not supported for vector array ("+searchKind+") fields in hybrid search, fieldName:"+annsField.GetName()) + } } } + t.hybridSubSearchInfos[index] = subSearchInfo ignoreGrowing := t.IgnoreGrowing if !ignoreGrowing { @@ -551,6 +583,21 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { zap.Stringer("plan", plan)) // may be very large if large term passed. } + t.hybridElementLevel = inferElementLevelHybrid(t.hybridSubSearchInfos) + for index, info := range t.hybridSubSearchInfos { + if t.hybridElementLevel && info.ElementScopeProvided { + return merr.WrapErrParameterInvalidMsg("%s is not allowed for same-struct element-level hybrid search", elementScopeKey) + } + if !t.hybridElementLevel && info.Kind == hybridSubSearchStructElement && !info.ElementScopeProvided { + t.hybridSubSearchInfos[index].Collapse = defaultElementCollapseConfig() + } + } + if t.hybridElementLevel { + if err := t.useElementLevelHybridFunctionScore(); err != nil { + return err + } + } + if embedding.HasNonBM25Functions(t.schema.Functions, queryFieldIDs) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf") defer sp.End() @@ -575,6 +622,24 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { return nil } +func (t *searchTask) useElementLevelHybridFunctionScore() error { + extraInfo := &models.ModelExtraInfo{ + ClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), + DBName: t.request.GetDbName(), + } + + var err error + if t.request.FunctionScore != nil { + t.functionScore, err = rerank.NewFunctionScoreWithPKType(t.schema.CollectionSchema, t.request.FunctionScore, extraInfo, schemapb.DataType_VarChar) + } else { + t.functionScore, err = rerank.NewFunctionScoreWithlegacyAndPKType(t.schema.CollectionSchema, t.request.GetSearchParams(), schemapb.DataType_VarChar) + } + if err != nil { + return err + } + return nil +} + func (t *searchTask) fillResult() { limit := t.GetTopk() - t.GetOffset() resultSizeInsufficient := false @@ -755,10 +820,46 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { metrics.ProxySearchSparseNumNonZeros.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), t.collectionName, metrics.SearchLabel, strconv.FormatInt(t.FieldId, 10)).Observe(float64(typeutil.EstimateSparseVectorNNZFromPlaceholderGroup(t.request.GetPlaceholderGroup(), int(t.request.GetNq())))) } // Convert placeholder group vector type if needed (e.g., fp32 -> fp16/bf16) + placeholderType := getPlaceholderType(t.request.GetPlaceholderGroup()) t.PlaceholderGroup, err = t.convertPlaceholderIfNeeded(t.request.GetPlaceholderGroup(), t.FieldId) if err != nil { return err } + + annsField := typeutil.GetField(t.schema.CollectionSchema, t.FieldId) + if annsField != nil && annsField.GetDataType() == schemapb.DataType_ArrayOfVector { + isEmbList := isEmbeddingListPlaceholderType(placeholderType) + if gjson.Get(queryInfo.GetSearchParams(), radiusKey).Exists() { + if isEmbList { + return merr.WrapErrParameterInvalid("", "", + "range search is not supported for multi-search-multi on embedding list fields") + } + return merr.WrapErrParameterInvalid("", "", + "range search is not supported for vector array fields, fieldName:"+annsField.GetName()) + } + if t.isIterator { + if isEmbList { + return merr.WrapErrParameterInvalid("", "", + "search iterator is not supported for multi-search-multi on embedding list fields") + } + return merr.WrapErrParameterInvalid("", "", + "search iterator is not supported for vector array fields, fieldName:"+annsField.GetName()) + } + + var groupByFieldIDs []int64 + if queryInfo.GetGroupByFieldId() > 0 { + groupByFieldIDs = []int64{queryInfo.GetGroupByFieldId()} + } + if len(groupByFieldIDs) > 0 { + if isEmbList { + return merr.WrapErrParameterInvalid("", "", + "group by is not supported for multi-search-multi on embedding list fields") + } + return merr.WrapErrParameterInvalid("", "", + "group by search is not supported for vector array fields, fieldName:"+annsField.GetName()) + } + } + t.Topk = queryInfo.GetTopk() t.MetricType = queryInfo.GetMetricType() t.queryInfos = append(t.queryInfos, queryInfo) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 8b4ee567ee802..0709f4b51eff9 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -3700,7 +3700,7 @@ func TestSearchTask_parseSearchInfo(t *testing.T) { assert.Nil(t, searchInfo) assert.ErrorIs(t, err, merr.ErrParameterInvalid) fmt.Println(err.Error()) - assert.Contains(t, err.Error(), "range search is not supported for vector array (embedding list) fields") + assert.Contains(t, err.Error(), "range search is not supported for vector array fields") }) t.Run("vector array with group by", func(t *testing.T) { @@ -3717,7 +3717,7 @@ func TestSearchTask_parseSearchInfo(t *testing.T) { assert.Error(t, err) assert.Nil(t, searchInfo) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - assert.Contains(t, err.Error(), "group by search is not supported for vector array (embedding list) fields") + assert.Contains(t, err.Error(), "group by search is not supported for vector array fields") assert.Contains(t, err.Error(), "embeddings_list") }) @@ -3735,7 +3735,7 @@ func TestSearchTask_parseSearchInfo(t *testing.T) { assert.Error(t, err) assert.Nil(t, searchInfo) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - assert.Contains(t, err.Error(), "search iterator is not supported for vector array (embedding list) fields") + assert.Contains(t, err.Error(), "search iterator is not supported for vector array fields") assert.Contains(t, err.Error(), "embeddings_list") }) @@ -3763,7 +3763,7 @@ func TestSearchTask_parseSearchInfo(t *testing.T) { assert.Error(t, err) assert.Nil(t, searchInfo) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - assert.Contains(t, err.Error(), "search iterator is not supported for vector array (embedding list) fields") + assert.Contains(t, err.Error(), "search iterator is not supported for vector array fields") assert.Contains(t, err.Error(), "embeddings_list") }) @@ -3832,7 +3832,7 @@ func TestSearchTask_parseSearchInfo(t *testing.T) { assert.Error(t, err) assert.Nil(t, searchInfo) // Should fail on range search first - assert.Contains(t, err.Error(), "range search is not supported for vector array (embedding list) fields") + assert.Contains(t, err.Error(), "range search is not supported for vector array fields") }) t.Run("no anns field specified", func(t *testing.T) { @@ -3881,7 +3881,7 @@ func TestSearchTask_parseSearchInfo(t *testing.T) { assert.Error(t, err) assert.Nil(t, searchInfo) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - assert.Contains(t, err.Error(), "group by search is not supported for vector array (embedding list) fields") + assert.Contains(t, err.Error(), "group by search is not supported for vector array fields") }) }) @@ -5451,6 +5451,570 @@ func TestSearchTask_InitSearchRequestWithHighlighter(t *testing.T) { }) } +func TestIsEmbeddingListPlaceholderType(t *testing.T) { + embListTypes := []commonpb.PlaceholderType{ + commonpb.PlaceholderType_EmbListFloatVector, + commonpb.PlaceholderType_EmbListFloat16Vector, + commonpb.PlaceholderType_EmbListBFloat16Vector, + commonpb.PlaceholderType_EmbListBinaryVector, + commonpb.PlaceholderType_EmbListInt8Vector, + } + for _, pt := range embListTypes { + assert.True(t, isEmbeddingListPlaceholderType(pt), "expected true for %s", pt.String()) + } + + nonEmbListTypes := []commonpb.PlaceholderType{ + commonpb.PlaceholderType_FloatVector, + commonpb.PlaceholderType_BinaryVector, + commonpb.PlaceholderType_Float16Vector, + commonpb.PlaceholderType_BFloat16Vector, + commonpb.PlaceholderType_SparseFloatVector, + commonpb.PlaceholderType_Int8Vector, + commonpb.PlaceholderType_VarChar, + commonpb.PlaceholderType(0), + } + for _, pt := range nonEmbListTypes { + assert.False(t, isEmbeddingListPlaceholderType(pt), "expected false for %s", pt.String()) + } +} + +func TestSearchTask_ArrayOfVectorSimpleSearch(t *testing.T) { + paramtable.Init() + ctx := context.Background() + + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "regular_vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "4"}}}, + {FieldID: 102, Name: "scalar_field", DataType: schemapb.DataType_VarChar}, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + FieldID: 103, + Name: "struct_array", + Fields: []*schemapb.FieldSchema{ + {FieldID: 104, Name: "emb_vec", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "4"}}}, + }, + }, + }, + } + schemaInfo := newSchemaInfo(schema) + + makePlaceholderGroup := func(phType commonpb.PlaceholderType) []byte { + phg := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{{ + Tag: "$0", + Type: phType, + Values: [][]byte{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + }}, + } + bs, _ := proto.Marshal(phg) + return bs + } + + makeTask := func(annsField string, phType commonpb.PlaceholderType, paramsJSON string, withIterator bool, withIteratorV2 bool, groupByField string) *searchTask { + params := []*commonpb.KeyValuePair{ + {Key: AnnsFieldKey, Value: annsField}, + {Key: TopKKey, Value: "10"}, + {Key: common.MetricTypeKey, Value: metric.L2}, + {Key: ParamsKey, Value: paramsJSON}, + } + if withIterator { + params = append(params, &commonpb.KeyValuePair{Key: IteratorField, Value: "True"}) + } + if withIteratorV2 { + params = append(params, + &commonpb.KeyValuePair{Key: SearchIterV2Key, Value: "True"}, + &commonpb.KeyValuePair{Key: SearchIterBatchSizeKey, Value: "10"}, + ) + } + if groupByField != "" { + params = append(params, &commonpb.KeyValuePair{Key: GroupByFieldKey, Value: groupByField}) + } + + return &searchTask{ + ctx: ctx, + collectionName: "test_collection", + SearchRequest: &internalpb.SearchRequest{ + CollectionID: 1, + PartitionIDs: []int64{1}, + OutputFieldsId: []int64{100}, + DslType: commonpb.DslType_BoolExprV1, + }, + request: &milvuspb.SearchRequest{ + CollectionName: "test_collection", + OutputFields: []string{"pk"}, + SearchParams: params, + SearchInput: &milvuspb.SearchRequest_PlaceholderGroup{ + PlaceholderGroup: makePlaceholderGroup(phType), + }, + Nq: 1, + ConsistencyLevel: commonpb.ConsistencyLevel_Session, + }, + schema: schemaInfo, + translatedOutputFields: []string{"pk"}, + tr: timerecord.NewTimeRecorder("test"), + queryInfos: []*planpb.QueryInfo{{}}, + } + } + + const rangeParams = `{"nprobe": 10, "radius": 0.2}` + const plainParams = `{"nprobe": 10}` + + t.Run("element-level range search should fail", func(t *testing.T) { + task := makeTask("emb_vec", commonpb.PlaceholderType_FloatVector, rangeParams, false, false, "") + err := task.initSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "range search is not supported for vector array fields") + }) + + t.Run("element-level iterator should fail", func(t *testing.T) { + task := makeTask("emb_vec", commonpb.PlaceholderType_FloatVector, plainParams, true, false, "") + err := task.initSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "search iterator is not supported for vector array fields") + }) + + t.Run("element-level iterator v2 should fail", func(t *testing.T) { + task := makeTask("emb_vec", commonpb.PlaceholderType_FloatVector, plainParams, true, true, "") + err := task.initSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "search iterator is not supported for vector array fields") + }) + + t.Run("element-level group by pk should fail", func(t *testing.T) { + task := makeTask("emb_vec", commonpb.PlaceholderType_FloatVector, plainParams, false, false, "pk") + err := task.initSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "group by search is not supported for vector array fields") + }) + + t.Run("element-level group by non-pk should fail", func(t *testing.T) { + task := makeTask("emb_vec", commonpb.PlaceholderType_FloatVector, plainParams, false, false, "scalar_field") + err := task.initSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "group by search is not supported for vector array fields") + }) + + t.Run("emblist range search should fail", func(t *testing.T) { + task := makeTask("emb_vec", commonpb.PlaceholderType_EmbListFloatVector, rangeParams, false, false, "") + err := task.initSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "range search is not supported for vector array fields") + }) + + t.Run("emblist iterator should fail", func(t *testing.T) { + task := makeTask("emb_vec", commonpb.PlaceholderType_EmbListFloatVector, plainParams, true, false, "") + err := task.initSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "search iterator is not supported for vector array fields") + }) + + t.Run("emblist group by should fail", func(t *testing.T) { + task := makeTask("emb_vec", commonpb.PlaceholderType_EmbListFloatVector, plainParams, false, false, "pk") + err := task.initSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "group by search is not supported for vector array fields") + }) + + t.Run("regular vector advanced controls should succeed", func(t *testing.T) { + tests := []struct { + name string + paramsJSON string + withIterator bool + withIteratorV2 bool + groupByField string + }{ + {name: "range", paramsJSON: rangeParams}, + {name: "iterator", paramsJSON: plainParams, withIterator: true}, + {name: "group by", paramsJSON: plainParams, groupByField: "scalar_field"}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + task := makeTask("regular_vec", commonpb.PlaceholderType_FloatVector, test.paramsJSON, test.withIterator, test.withIteratorV2, test.groupByField) + assert.NoError(t, task.initSearchRequest(ctx)) + }) + } + }) +} + +func TestSearchTask_ArrayOfVectorHybridSearch(t *testing.T) { + paramtable.Init() + ctx := context.Background() + + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "regular_vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "4"}}}, + {FieldID: 102, Name: "scalar_field", DataType: schemapb.DataType_VarChar}, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + FieldID: 103, + Name: "struct_array", + Fields: []*schemapb.FieldSchema{ + {FieldID: 104, Name: "emb_vec", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "4"}}}, + }, + }, + }, + } + schemaInfo := newSchemaInfo(schema) + + makePlaceholderGroup := func(phType commonpb.PlaceholderType) []byte { + phg := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{{ + Tag: "$0", + Type: phType, + Values: [][]byte{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + }}, + } + bs, _ := proto.Marshal(phg) + return bs + } + + buildHybridTaskWithMetric := func(annsField string, metricType string, phType commonpb.PlaceholderType, rangeRadius string, withIterator bool, groupByField string) *searchTask { + paramsJSON := `{"nprobe": 10}` + if rangeRadius != "" { + paramsJSON = `{"nprobe": 10, "radius": ` + rangeRadius + `}` + } + subParams := []*commonpb.KeyValuePair{ + {Key: common.MetricTypeKey, Value: metricType}, + {Key: ParamsKey, Value: paramsJSON}, + {Key: AnnsFieldKey, Value: annsField}, + {Key: TopKKey, Value: "10"}, + } + if withIterator { + subParams = append(subParams, &commonpb.KeyValuePair{Key: IteratorField, Value: "True"}) + } + + outerParams := []*commonpb.KeyValuePair{{Key: LimitKey, Value: "10"}} + if groupByField != "" { + outerParams = append(outerParams, &commonpb.KeyValuePair{Key: GroupByFieldKey, Value: groupByField}) + } + + return &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_Search, Timestamp: uint64(time.Now().UnixNano())}, + }, + request: &milvuspb.SearchRequest{ + CollectionName: "test_collection", + SearchParams: outerParams, + SubReqs: []*milvuspb.SubSearchRequest{ + {PlaceholderGroup: makePlaceholderGroup(phType), SearchParams: subParams}, + }, + }, + schema: schemaInfo, + tr: timerecord.NewTimeRecorder("test"), + } + } + + buildEmbListHybridTask := func(annsField string, rangeRadius string, withIterator bool, groupByField string) *searchTask { + return buildHybridTaskWithMetric(annsField, metric.MaxSimL2, commonpb.PlaceholderType_EmbListFloatVector, rangeRadius, withIterator, groupByField) + } + buildElementHybridTask := func(annsField string, rangeRadius string, withIterator bool, groupByField string) *searchTask { + return buildHybridTaskWithMetric(annsField, metric.L2, commonpb.PlaceholderType_FloatVector, rangeRadius, withIterator, groupByField) + } + + t.Run("hybrid with ArrayOfVector EmbList metric plain topK should succeed", func(t *testing.T) { + qt := buildHybridTaskWithMetric("emb_vec", metric.MaxSimCosine, commonpb.PlaceholderType_EmbListFloatVector, "", false, "") + assert.NoError(t, qt.initAdvancedSearchRequest(ctx)) + }) + + t.Run("hybrid with ArrayOfVector range search should fail", func(t *testing.T) { + qt := buildEmbListHybridTask("emb_vec", "0.2", false, "") + err := qt.initAdvancedSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "range search is not supported for vector array fields") + }) + + t.Run("hybrid with ArrayOfVector iterator should fail", func(t *testing.T) { + qt := buildEmbListHybridTask("emb_vec", "", true, "") + err := qt.initAdvancedSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "search iterator is not supported for vector array fields") + }) + + t.Run("hybrid with ArrayOfVector group by should fail", func(t *testing.T) { + qt := buildEmbListHybridTask("emb_vec", "", false, "scalar_field") + err := qt.initAdvancedSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "group by search is not supported for vector array fields") + }) + + t.Run("hybrid with element-level ArrayOfVector plain topK should succeed", func(t *testing.T) { + qt := buildElementHybridTask("emb_vec", "", false, "") + assert.NoError(t, qt.initAdvancedSearchRequest(ctx)) + }) + + t.Run("hybrid with element-level ArrayOfVector advanced controls should fail", func(t *testing.T) { + tests := []struct { + name string + rangeRadius string + withIterator bool + groupByField string + errMsg string + }{ + {name: "range", rangeRadius: "0.2", errMsg: "range search is not supported for vector array fields"}, + {name: "iterator", withIterator: true, errMsg: "search iterator is not supported for vector array fields"}, + {name: "group by", groupByField: "pk", errMsg: "group by search is not supported for vector array fields"}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + qt := buildElementHybridTask("emb_vec", test.rangeRadius, test.withIterator, test.groupByField) + err := qt.initAdvancedSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), test.errMsg) + }) + } + }) + + t.Run("hybrid with normal vector advanced controls should succeed", func(t *testing.T) { + tests := []struct { + name string + rangeRadius string + withIterator bool + groupByField string + }{ + {name: "range", rangeRadius: "0.2"}, + {name: "iterator", withIterator: true}, + {name: "group by", groupByField: "scalar_field"}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + qt := buildElementHybridTask("regular_vec", test.rangeRadius, test.withIterator, test.groupByField) + assert.NoError(t, qt.initAdvancedSearchRequest(ctx)) + }) + } + }) +} + +func TestSearchTask_StructHybridElementScopeValidation(t *testing.T) { + paramtable.Init() + ctx := context.Background() + + indexMetricParams := func(metricType string) []*commonpb.KeyValuePair { + return []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: metricType}} + } + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "regular_vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "4"}}}, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + FieldID: 200, + Name: "struct_a", + Fields: []*schemapb.FieldSchema{ + {FieldID: 201, Name: "a_vec", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "4"}}, IndexParams: indexMetricParams(metric.IP)}, + {FieldID: 202, Name: "a_text_vec", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "4"}}, IndexParams: indexMetricParams(metric.IP)}, + }, + }, + { + FieldID: 300, + Name: "struct_b", + Fields: []*schemapb.FieldSchema{ + {FieldID: 301, Name: "b_vec", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "4"}}, IndexParams: indexMetricParams(metric.L2)}, + }, + }, + }, + } + schemaInfo := newSchemaInfo(schema) + + makePlaceholderGroup := func(phType commonpb.PlaceholderType) []byte { + phg := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{{ + Tag: "$0", + Type: phType, + Values: [][]byte{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + }}, + } + bs, _ := proto.Marshal(phg) + return bs + } + + type subSpec struct { + annsField string + params string + phType commonpb.PlaceholderType + metricType string + omitMetric bool + } + makeTask := func(specs ...subSpec) *searchTask { + subReqs := make([]*milvuspb.SubSearchRequest, 0, len(specs)) + for _, spec := range specs { + metricType := spec.metricType + if metricType == "" { + metricType = metric.L2 + } + searchParams := []*commonpb.KeyValuePair{ + {Key: ParamsKey, Value: spec.params}, + {Key: AnnsFieldKey, Value: spec.annsField}, + {Key: TopKKey, Value: "10"}, + } + if !spec.omitMetric { + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.MetricTypeKey, Value: metricType}) + } + subReqs = append(subReqs, &milvuspb.SubSearchRequest{ + PlaceholderGroup: makePlaceholderGroup(spec.phType), + Nq: 1, + SearchParams: searchParams, + }) + } + return &searchTask{ + ctx: ctx, + collectionName: "test_collection", + SearchRequest: &internalpb.SearchRequest{ + CollectionID: 1, + PartitionIDs: []int64{1}, + OutputFieldsId: []int64{100}, + }, + request: &milvuspb.SearchRequest{ + CollectionName: "test_collection", + OutputFields: []string{"pk"}, + SearchParams: []*commonpb.KeyValuePair{{Key: LimitKey, Value: "10"}}, + SubReqs: subReqs, + }, + schema: schemaInfo, + translatedOutputFields: []string{"pk"}, + tr: timerecord.NewTimeRecorder("test"), + } + } + + const noScope = `{"nprobe": 10}` + const maxScope = `{"nprobe": 10, "element_scope": {"collapse": {"strategy": "max"}}}` + const topKScope = `{"nprobe": 10, "element_scope": {"collapse": {"strategy": "topk_sum", "topk": 2}}}` + + t.Run("rejects element_scope on normal vector sub request", func(t *testing.T) { + task := makeTask(subSpec{annsField: "regular_vec", params: maxScope, phType: commonpb.PlaceholderType_FloatVector}) + err := task.initAdvancedSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "element_scope is only supported") + }) + + t.Run("rejects element_scope on struct embedding-list sub request", func(t *testing.T) { + task := makeTask(subSpec{annsField: "a_vec", params: maxScope, phType: commonpb.PlaceholderType_EmbListFloatVector}) + err := task.initAdvancedSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "element_scope is only supported") + }) + + t.Run("rejects element_scope for same-struct element-level hybrid", func(t *testing.T) { + task := makeTask( + subSpec{annsField: "a_vec", params: maxScope, phType: commonpb.PlaceholderType_FloatVector}, + subSpec{annsField: "a_text_vec", params: noScope, phType: commonpb.PlaceholderType_FloatVector}, + ) + err := task.initAdvancedSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "same-struct element-level hybrid") + }) + + t.Run("accepts element_scope for row-level hybrid with normal vector", func(t *testing.T) { + task := makeTask( + subSpec{annsField: "a_vec", params: topKScope, phType: commonpb.PlaceholderType_FloatVector, metricType: metric.IP}, + subSpec{annsField: "regular_vec", params: noScope, phType: commonpb.PlaceholderType_FloatVector}, + ) + require.NoError(t, task.initAdvancedSearchRequest(ctx)) + require.Len(t, task.hybridSubSearchInfos, 2) + assert.False(t, task.hybridElementLevel) + assert.Equal(t, elementCollapseTopKSum, task.hybridSubSearchInfos[0].Collapse.Strategy) + assert.Equal(t, 2, task.hybridSubSearchInfos[0].Collapse.TopK) + }) + + t.Run("accepts sum collapse with omitted metric type", func(t *testing.T) { + task := makeTask( + subSpec{annsField: "a_vec", params: topKScope, phType: commonpb.PlaceholderType_FloatVector, omitMetric: true}, + subSpec{annsField: "regular_vec", params: noScope, phType: commonpb.PlaceholderType_FloatVector}, + ) + require.NoError(t, task.initAdvancedSearchRequest(ctx)) + require.Len(t, task.hybridSubSearchInfos, 2) + assert.False(t, task.hybridElementLevel) + assert.Equal(t, elementCollapseTopKSum, task.hybridSubSearchInfos[0].Collapse.Strategy) + assert.Empty(t, task.queryInfos[0].GetMetricType()) + }) + + t.Run("rejects sum collapse with omitted negative index metric", func(t *testing.T) { + task := makeTask( + subSpec{annsField: "b_vec", params: topKScope, phType: commonpb.PlaceholderType_FloatVector, omitMetric: true}, + subSpec{annsField: "regular_vec", params: noScope, phType: commonpb.PlaceholderType_FloatVector}, + ) + err := task.initAdvancedSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "only supported for positively related metrics") + }) + + t.Run("rejects sum collapse on negative metric", func(t *testing.T) { + task := makeTask( + subSpec{annsField: "a_vec", params: topKScope, phType: commonpb.PlaceholderType_FloatVector, metricType: metric.L2}, + subSpec{annsField: "regular_vec", params: noScope, phType: commonpb.PlaceholderType_FloatVector}, + ) + err := task.initAdvancedSearchRequest(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "only supported for positively related metrics") + }) + + t.Run("accepts element_scope for row-level hybrid across different structs", func(t *testing.T) { + task := makeTask( + subSpec{annsField: "a_vec", params: maxScope, phType: commonpb.PlaceholderType_FloatVector}, + subSpec{annsField: "b_vec", params: noScope, phType: commonpb.PlaceholderType_FloatVector}, + ) + require.NoError(t, task.initAdvancedSearchRequest(ctx)) + assert.False(t, task.hybridElementLevel) + assert.Equal(t, elementCollapseMax, task.hybridSubSearchInfos[0].Collapse.Strategy) + assert.Equal(t, elementCollapseMax, task.hybridSubSearchInfos[1].Collapse.Strategy) + }) +} + +func TestParseElementScope(t *testing.T) { + tests := []struct { + name string + params string + errMsg string + }{ + {name: "unknown strategy", params: `{"element_scope": {"collapse": {"strategy": "median"}}}`, errMsg: "unsupported element_scope.collapse.strategy"}, + {name: "topk strategy requires topk", params: `{"element_scope": {"collapse": {"strategy": "topk_avg"}}}`, errMsg: "topk is required"}, + {name: "topk must be positive", params: `{"element_scope": {"collapse": {"strategy": "topk_sum", "topk": 0}}}`, errMsg: "topk must be positive"}, + {name: "topk invalid for max", params: `{"element_scope": {"collapse": {"strategy": "max", "topk": 2}}}`, errMsg: "topk is only valid"}, + {name: "unknown scope key", params: `{"element_scope": {"collapse": {"strategy": "max"}, "mode": "row"}}`, errMsg: "unsupported element_scope key"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, _, _, err := parseAndRemoveElementScope(test.params) + require.Error(t, err) + assert.Contains(t, err.Error(), test.errMsg) + }) + } + + cfg, provided, sanitized, err := parseAndRemoveElementScope(`{"nprobe": 10, "element_scope": {"collapse": {"strategy": "sum"}}}`) + require.NoError(t, err) + assert.True(t, provided) + assert.Equal(t, elementCollapseSum, cfg.Strategy) + assert.Equal(t, 0, cfg.TopK) + assert.NotContains(t, sanitized, elementScopeKey) + assert.Contains(t, sanitized, "nprobe") +} + func TestSearchTask_SearchRequeryPolicy(t *testing.T) { paramtable.Init() ctx := context.Background() diff --git a/internal/util/function/rerank/decay_function.go b/internal/util/function/rerank/decay_function.go index 2468733dcdb21..66b6dd60ea8e4 100644 --- a/internal/util/function/rerank/decay_function.go +++ b/internal/util/function/rerank/decay_function.go @@ -58,8 +58,8 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct { reScorer decayReScorer } -func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, DecayFunctionName, true) +func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, pkTypeOverride ...schemapb.DataType) (Reranker, error) { + base, err := newRerankBase(collSchema, funcSchema, DecayFunctionName, true, pkTypeOverride...) if err != nil { return nil, err } diff --git a/internal/util/function/rerank/function_score.go b/internal/util/function/rerank/function_score.go index 83e45c205acba..27af4687ce325 100644 --- a/internal/util/function/rerank/function_score.go +++ b/internal/util/function/rerank/function_score.go @@ -128,7 +128,7 @@ type FunctionScore struct { reranker Reranker } -func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, extraInfo *models.ModelExtraInfo) (Reranker, error) { +func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, extraInfo *models.ModelExtraInfo, pkTypeOverride ...schemapb.DataType) (Reranker, error) { if funcSchema.GetType() != schemapb.FunctionType_Rerank { return nil, fmt.Errorf("%s is not rerank function", funcSchema.GetType().String()) } @@ -141,13 +141,13 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb. var newRerankErr error switch rerankerName { case DecayFunctionName: - rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema) + rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema, pkTypeOverride...) case ModelFunctionName: - rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema, extraInfo) + rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema, extraInfo, pkTypeOverride...) case RRFName: - rerankFunc, newRerankErr = newRRFFunction(collSchema, funcSchema) + rerankFunc, newRerankErr = newRRFFunction(collSchema, funcSchema, pkTypeOverride...) case WeightedName: - rerankFunc, newRerankErr = newWeightedFunction(collSchema, funcSchema) + rerankFunc, newRerankErr = newWeightedFunction(collSchema, funcSchema, pkTypeOverride...) case BoostName: return nil, nil default: @@ -161,10 +161,18 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb. } func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore, extraInfo *models.ModelExtraInfo) (*FunctionScore, error) { + return newFunctionScore(collSchema, funcScoreSchema, extraInfo) +} + +func NewFunctionScoreWithPKType(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore, extraInfo *models.ModelExtraInfo, pkType schemapb.DataType) (*FunctionScore, error) { + return newFunctionScore(collSchema, funcScoreSchema, extraInfo, pkType) +} + +func newFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *schemapb.FunctionScore, extraInfo *models.ModelExtraInfo, pkTypeOverride ...schemapb.DataType) (*FunctionScore, error) { funcScore := &FunctionScore{} for _, function := range funcScoreSchema.Functions { - reranker, err := createFunction(collSchema, function, extraInfo) + reranker, err := createFunction(collSchema, function, extraInfo, pkTypeOverride...) if err != nil { return nil, err } @@ -186,6 +194,14 @@ func NewFunctionScore(collSchema *schemapb.CollectionSchema, funcScoreSchema *sc } func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParams []*commonpb.KeyValuePair) (*FunctionScore, error) { + return newFunctionScoreWithlegacy(collSchema, rankParams) +} + +func NewFunctionScoreWithlegacyAndPKType(collSchema *schemapb.CollectionSchema, rankParams []*commonpb.KeyValuePair, pkType schemapb.DataType) (*FunctionScore, error) { + return newFunctionScoreWithlegacy(collSchema, rankParams, pkType) +} + +func newFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParams []*commonpb.KeyValuePair, pkTypeOverride ...schemapb.DataType) (*FunctionScore, error) { var params map[string]interface{} rankTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(legacyRankTypeKey, rankParams) if err != nil { @@ -241,7 +257,7 @@ func NewFunctionScoreWithlegacy(collSchema *schemapb.CollectionSchema, rankParam return nil, fmt.Errorf("unsupported rank type %s", rankTypeStr) } funcScore := &FunctionScore{} - if funcScore.reranker, err = createFunction(collSchema, &fSchema, &models.ModelExtraInfo{}); err != nil { + if funcScore.reranker, err = createFunction(collSchema, &fSchema, &models.ModelExtraInfo{}, pkTypeOverride...); err != nil { return nil, err } return funcScore, nil diff --git a/internal/util/function/rerank/model_function.go b/internal/util/function/rerank/model_function.go index fcb6fe7ccca6e..a1ed02731995c 100644 --- a/internal/util/function/rerank/model_function.go +++ b/internal/util/function/rerank/model_function.go @@ -109,8 +109,8 @@ type ModelFunction[T PKType] struct { queries []string } -func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, extraInfo *models.ModelExtraInfo) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, DecayFunctionName, true) +func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, extraInfo *models.ModelExtraInfo, pkTypeOverride ...schemapb.DataType) (Reranker, error) { + base, err := newRerankBase(collSchema, funcSchema, DecayFunctionName, true, pkTypeOverride...) if err != nil { return nil, err } diff --git a/internal/util/function/rerank/rerank_base.go b/internal/util/function/rerank/rerank_base.go index 99e127bce9451..58244dadb4ece 100644 --- a/internal/util/function/rerank/rerank_base.go +++ b/internal/util/function/rerank/rerank_base.go @@ -52,11 +52,17 @@ type RerankBase struct { searchParams *searchParams } -func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool) (*RerankBase, error) { +func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool, pkTypeOverride ...schemapb.DataType) (*RerankBase, error) { pkType, err := getPKType(coll) if err != nil { return nil, err } + if len(pkTypeOverride) > 0 && pkTypeOverride[0] != schemapb.DataType_None { + if pkTypeOverride[0] != schemapb.DataType_Int64 && pkTypeOverride[0] != schemapb.DataType_VarChar { + return nil, fmt.Errorf("unsupported pk type override: %s", pkTypeOverride[0].String()) + } + pkType = pkTypeOverride[0] + } base := RerankBase{ inputFieldNames: funcSchema.InputFieldNames, diff --git a/internal/util/function/rerank/rrf_function.go b/internal/util/function/rerank/rrf_function.go index 7e4169a14f8eb..73c9634538481 100644 --- a/internal/util/function/rerank/rrf_function.go +++ b/internal/util/function/rerank/rrf_function.go @@ -39,8 +39,8 @@ type RRFFunction[T PKType] struct { k float32 } -func newRRFFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, RRFName, true) +func newRRFFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, pkTypeOverride ...schemapb.DataType) (Reranker, error) { + base, err := newRerankBase(collSchema, funcSchema, RRFName, true, pkTypeOverride...) if err != nil { return nil, err } diff --git a/internal/util/function/rerank/weighted_function.go b/internal/util/function/rerank/weighted_function.go index 95d1c7302632a..b1cce9f6d8838 100644 --- a/internal/util/function/rerank/weighted_function.go +++ b/internal/util/function/rerank/weighted_function.go @@ -41,8 +41,8 @@ type WeightedFunction[T PKType] struct { needNorm bool } -func newWeightedFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - base, err := newRerankBase(collSchema, funcSchema, WeightedName, true) +func newWeightedFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, pkTypeOverride ...schemapb.DataType) (Reranker, error) { + base, err := newRerankBase(collSchema, funcSchema, WeightedName, true, pkTypeOverride...) if err != nil { return nil, err } diff --git a/tests/python_client/milvus_client/test_milvus_client_struct_array.py b/tests/python_client/milvus_client/test_milvus_client_struct_array.py index 4832ad10514f3..86d78110d606b 100644 --- a/tests/python_client/milvus_client/test_milvus_client_struct_array.py +++ b/tests/python_client/milvus_client/test_milvus_client_struct_array.py @@ -3895,8 +3895,8 @@ def test_struct_array_range_search_not_supported(self): # Step 2: Perform range search with selected radius and range_filter # For COSINE: radius < distance <= range_filter error = { - ct.err_code: 65535, - ct.err_msg: "range search is not supported for vector array", + ct.err_code: 1100, + ct.err_msg: "range search is not supported for vector array fields", } self.search( client,