Skip to content

Commit 1208467

Browse files
Support vllm and tei rerank
Signed-off-by: junjie.jiang <[email protected]>
1 parent 2e35393 commit 1208467

File tree

13 files changed

+994
-40
lines changed

13 files changed

+994
-40
lines changed

configs/milvus.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,3 +1316,10 @@ function:
13161316
voyageai:
13171317
credential: # The name in the crendential configuration item
13181318
url: # Your voyageai embedding url, Default is the official embedding url
1319+
rerank:
1320+
model:
1321+
providers:
1322+
tei:
1323+
enable: true # Whether to enable TEI rerank service
1324+
vllm:
1325+
enable: true # Whether to enable vllm rerank service

internal/proxy/task_search.go

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -675,20 +675,26 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
675675
for i := 0; i < len(multipleMilvusResults); i++ {
676676
multipleMilvusResults[i].Results.FieldsData = fields[i]
677677
}
678-
params := rerank.NewSearchParams(
679-
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
680-
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
681-
)
682-
683-
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
684-
return err
678+
{
679+
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
680+
defer sp.End()
681+
params := rerank.NewSearchParams(
682+
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
683+
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
684+
)
685+
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
686+
return err
687+
}
685688
}
686689
if fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}); err != nil {
687690
return err
688691
} else {
689692
t.result.Results.FieldsData = fields[0]
690693
}
691694
} else {
695+
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
696+
defer sp.End()
697+
692698
params := rerank.NewSearchParams(
693699
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
694700
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
@@ -816,11 +822,15 @@ func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toR
816822
}
817823

818824
if t.functionScore != nil && len(result.Results.FieldsData) != 0 {
819-
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
820-
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
821-
// rank only returns id and score
822-
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
823-
return err
825+
{
826+
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
827+
defer sp.End()
828+
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
829+
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
830+
// rank only returns id and score
831+
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
832+
return err
833+
}
824834
}
825835
if !t.needRequery {
826836
fields, err := t.reorganizeRequeryResults(ctx, result.Results.FieldsData, []*schemapb.IDs{t.result.Results.Ids})

internal/util/function/common.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,17 @@ const (
109109
siliconflowAKEnvStr string = "MILVUSAI_SILICONFLOW_API_KEY"
110110
)
111111

112-
// TEI
112+
// TEI and vllm
113113

114114
const (
115115
ingestionPromptParamKey string = "ingestion_prompt"
116116
searchPromptParamKey string = "search_prompt"
117117
maxClientBatchSizeParamKey string = "max_client_batch_size"
118118
truncationDirectionParamKey string = "truncation_direction"
119-
endpointParamKey string = "endpoint"
119+
EndpointParamKey string = "endpoint"
120120

121-
enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
121+
EnableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
122+
EnableVllmEnvStr string = "MILVUSAI_ENABLE_VLLM"
122123
)
123124

124125
func parseAKAndURL(credentials *credentials.Credentials, params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string, error) {

internal/util/function/rerank/decay_function.go

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,13 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
5555
}
5656

5757
func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
58-
pkType := schemapb.DataType_None
59-
for _, field := range collSchema.Fields {
60-
if field.IsPrimaryKey {
61-
pkType = field.DataType
62-
}
63-
}
64-
65-
if pkType == schemapb.DataType_None {
66-
return nil, fmt.Errorf("Collection %s can not found pk field", collSchema.Name)
67-
}
68-
69-
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false, pkType)
58+
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
7059
if err != nil {
7160
return nil, err
7261
}
7362

7463
if len(base.GetInputFieldNames()) != 1 {
75-
return nil, fmt.Errorf("Decay function only supoorts single input, but gets [%s] input", base.GetInputFieldNames())
64+
return nil, fmt.Errorf("Decay function only supports single input, but gets [%s] input", base.GetInputFieldNames())
7665
}
7766

7867
var inputType schemapb.DataType
@@ -82,7 +71,7 @@ func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemap
8271
}
8372
}
8473

85-
if pkType == schemapb.DataType_Int64 {
74+
if base.pkType == schemapb.DataType_Int64 {
8675
switch inputType {
8776
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
8877
return newFunction[int64, int32](base, funcSchema)
@@ -160,7 +149,7 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
160149
}
161150

162151
if decayFunc.decay <= 0 || decayFunc.decay >= 1 {
163-
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.offset)
152+
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.decay)
164153
}
165154

166155
switch decayFunc.functionName {

internal/util/function/rerank/decay_function_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func (s *DecayFunctionSuite) TestNewDecayErrors() {
9090
{
9191
functionSchema.InputFieldNames = []string{"ts", "pk"}
9292
_, err := newDecayFunction(schema, functionSchema)
93-
s.ErrorContains(err, "Decay function only supoorts single input, but gets")
93+
s.ErrorContains(err, "Decay function only supports single input, but gets")
9494
}
9595

9696
{

internal/util/function/rerank/function_score.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232

3333
const (
3434
decayFunctionName string = "decay"
35+
modelFunctionName string = "model"
3536
)
3637

3738
type SearchParams struct {
@@ -92,8 +93,10 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.
9293
switch rerankerName {
9394
case decayFunctionName:
9495
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
96+
case modelFunctionName:
97+
rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema)
9598
default:
96-
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName)
99+
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s]", rerankerName, decayFunctionName, modelFunctionName)
97100
}
98101

99102
if newRerankErr != nil {

0 commit comments

Comments
 (0)