Skip to content

Commit 4202c77

Browse files
feat: Support vllm and tei rerank (#41947)
#35856 Signed-off-by: junjie.jiang <[email protected]>
1 parent 14563ad commit 4202c77

23 files changed

+1013
-66
lines changed

configs/milvus.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,3 +1327,10 @@ function:
13271327
voyageai:
13281328
credential: # The name in the crendential configuration item
13291329
url: # Your voyageai embedding url, Default is the official embedding url
1330+
rerank:
1331+
model:
1332+
providers:
1333+
tei:
1334+
enable: true # Whether to enable TEI rerank service
1335+
vllm:
1336+
enable: true # Whether to enable vllm rerank service

internal/proxy/task_search.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,17 @@ func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) {
646646

647647
func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults, searchMetrics []string) error {
648648
var err error
649+
processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
650+
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
651+
defer sp.End()
652+
653+
params := rerank.NewSearchParams(
654+
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
655+
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
656+
)
657+
return t.functionScore.Process(ctx, params, results)
658+
}
659+
649660
// The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery.
650661
// At this time, outputFields and rerank input_fields will be recalled.
651662
// If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step
@@ -682,12 +693,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
682693
for i := 0; i < len(multipleMilvusResults); i++ {
683694
multipleMilvusResults[i].Results.FieldsData = fields[i]
684695
}
685-
params := rerank.NewSearchParams(
686-
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
687-
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
688-
)
689-
690-
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
696+
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
691697
return err
692698
}
693699
if fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}); err != nil {
@@ -696,11 +702,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
696702
t.result.Results.FieldsData = fields[0]
697703
}
698704
} else {
699-
params := rerank.NewSearchParams(
700-
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
701-
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
702-
)
703-
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
705+
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
704706
return err
705707
}
706708
}
@@ -823,11 +825,15 @@ func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toR
823825
}
824826

825827
if t.functionScore != nil && len(result.Results.FieldsData) != 0 {
826-
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
827-
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
828-
// rank only returns id and score
829-
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
830-
return err
828+
{
829+
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
830+
defer sp.End()
831+
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
832+
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
833+
// rank only returns id and score
834+
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
835+
return err
836+
}
831837
}
832838
if !t.needRequery {
833839
fields, err := t.reorganizeRequeryResults(ctx, result.Results.FieldsData, []*schemapb.IDs{t.result.Results.Ids})

internal/util/function/common.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,17 @@ const (
109109
siliconflowAKEnvStr string = "MILVUSAI_SILICONFLOW_API_KEY"
110110
)
111111

112-
// TEI
112+
// TEI and vllm
113113

114114
const (
115115
ingestionPromptParamKey string = "ingestion_prompt"
116116
searchPromptParamKey string = "search_prompt"
117117
maxClientBatchSizeParamKey string = "max_client_batch_size"
118118
truncationDirectionParamKey string = "truncation_direction"
119-
endpointParamKey string = "endpoint"
119+
EndpointParamKey string = "endpoint"
120120

121-
enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
121+
EnableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
122+
EnableVllmEnvStr string = "MILVUSAI_ENABLE_VLLM"
122123
)
123124

124125
func parseAKAndURL(credentials *credentials.Credentials, params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string, error) {

internal/util/function/models/ali/ali_dashscope_text_embedding.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim
140140
"Content-Type": "application/json",
141141
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
142142
}
143-
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
143+
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
144144
if err != nil {
145145
return nil, err
146146
}

internal/util/function/models/cohere/cohere_text_embedding.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (c *CohereEmbedding) Embedding(modelName string, texts []string, inputType
106106
"Content-Type": "application/json",
107107
"Authorization": fmt.Sprintf("bearer %s", c.apiKey),
108108
}
109-
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
109+
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
110110
if err != nil {
111111
return nil, err
112112
}

internal/util/function/models/openai/openai_embedding.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func (c *openAIBase) embedding(url string, headers map[string]string, modelName
149149

150150
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
151151
defer cancel()
152-
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3, 1)
152+
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3)
153153
if err != nil {
154154
return nil, err
155155
}

internal/util/function/models/openai/openai_embedding_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ func TestEmbeddingFailed(t *testing.T) {
218218
func TestTimeout(t *testing.T) {
219219
var st int32 = 0
220220
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
221-
// (Timeout 1s + Wait 1s) * Retry 3
222-
time.Sleep(6 * time.Second)
221+
// (Timeout 1s + 2s + 4s + Wait 1s * 3)
222+
time.Sleep(11 * time.Second)
223223
atomic.AddInt32(&st, 1)
224224
w.WriteHeader(http.StatusUnauthorized)
225225
}))

internal/util/function/models/siliconflow/siliconflow_text_embedding.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func (c *SiliconflowEmbedding) Embedding(modelName string, texts []string, encod
121121
"Content-Type": "application/json",
122122
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
123123
}
124-
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
124+
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
125125
if err != nil {
126126
return nil, err
127127
}

internal/util/function/models/tei/tei.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (c *TEIEmbedding) Embedding(texts []string, truncate bool, truncationDirect
9393
if c.apiKey != "" {
9494
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.apiKey)
9595
}
96-
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
96+
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
9797
if err != nil {
9898
return nil, err
9999
}

internal/util/function/models/utils/embedding_util.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"context"
2222
"fmt"
2323
"io"
24+
"math/rand"
2425
"net/http"
2526
"time"
2627
)
@@ -45,7 +46,7 @@ func send(req *http.Request) ([]byte, error) {
4546
return body, nil
4647
}
4748

48-
func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int, retryDelay int) ([]byte, error) {
49+
func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int) ([]byte, error) {
4950
var err error
5051
var body []byte
5152
for i := 0; i < maxRetries; i++ {
@@ -60,7 +61,9 @@ func RetrySend(ctx context.Context, data []byte, httpMethod string, url string,
6061
if err == nil {
6162
return body, nil
6263
}
63-
time.Sleep(time.Duration(retryDelay) * time.Second)
64+
backoffDelay := 1 << uint(i) * time.Second
65+
jitter := time.Duration(rand.Int63n(int64(backoffDelay / 4)))
66+
time.Sleep(backoffDelay + jitter)
6467
}
6568
return nil, err
6669
}

internal/util/function/models/vertexai/vertexai_text_embedding.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6
148148
"Content-Type": "application/json",
149149
"Authorization": fmt.Sprintf("Bearer %s", token),
150150
}
151-
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
151+
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
152152
if err != nil {
153153
return nil, err
154154
}

internal/util/function/models/voyageai/voyageai_text_embedding.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int,
138138
"Content-Type": "application/json",
139139
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
140140
}
141-
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
141+
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
142142
if err != nil {
143143
return nil, err
144144
}

internal/util/function/rerank/decay_function.go

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,34 +55,17 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
5555
}
5656

5757
func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
58-
pkType := schemapb.DataType_None
59-
for _, field := range collSchema.Fields {
60-
if field.IsPrimaryKey {
61-
pkType = field.DataType
62-
}
63-
}
64-
65-
if pkType == schemapb.DataType_None {
66-
return nil, fmt.Errorf("Collection %s can not found pk field", collSchema.Name)
67-
}
68-
69-
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false, pkType)
58+
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
7059
if err != nil {
7160
return nil, err
7261
}
7362

7463
if len(base.GetInputFieldNames()) != 1 {
75-
return nil, fmt.Errorf("Decay function only supoorts single input, but gets [%s] input", base.GetInputFieldNames())
76-
}
77-
78-
var inputType schemapb.DataType
79-
for _, field := range collSchema.Fields {
80-
if field.Name == base.GetInputFieldNames()[0] {
81-
inputType = field.DataType
82-
}
64+
return nil, fmt.Errorf("Decay function only supports single input, but gets [%s] input", base.GetInputFieldNames())
8365
}
8466

85-
if pkType == schemapb.DataType_Int64 {
67+
inputType := base.GetInputFieldTypes()[0]
68+
if base.pkType == schemapb.DataType_Int64 {
8669
switch inputType {
8770
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
8871
return newFunction[int64, int32](base, funcSchema)
@@ -160,7 +143,7 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
160143
}
161144

162145
if decayFunc.decay <= 0 || decayFunc.decay >= 1 {
163-
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.offset)
146+
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.decay)
164147
}
165148

166149
switch decayFunc.functionName {

internal/util/function/rerank/decay_function_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func (s *DecayFunctionSuite) TestNewDecayErrors() {
9090
{
9191
functionSchema.InputFieldNames = []string{"ts", "pk"}
9292
_, err := newDecayFunction(schema, functionSchema)
93-
s.ErrorContains(err, "Decay function only supoorts single input, but gets")
93+
s.ErrorContains(err, "Decay function only supports single input, but gets")
9494
}
9595

9696
{

internal/util/function/rerank/function_score.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232

3333
const (
3434
decayFunctionName string = "decay"
35+
modelFunctionName string = "model"
3536
)
3637

3738
type SearchParams struct {
@@ -92,8 +93,10 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.
9293
switch rerankerName {
9394
case decayFunctionName:
9495
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
96+
case modelFunctionName:
97+
rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema)
9598
default:
96-
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName)
99+
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s]", rerankerName, decayFunctionName, modelFunctionName)
97100
}
98101

99102
if newRerankErr != nil {

0 commit comments

Comments
 (0)