Skip to content

feat: Support vllm and tei rerank #41947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions configs/milvus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1327,3 +1327,10 @@ function:
voyageai:
credential: # The name in the crendential configuration item
url: # Your voyageai embedding url, Default is the official embedding url
rerank:
model:
providers:
tei:
enable: true # Whether to enable TEI rerank service
vllm:
enable: true # Whether to enable vllm rerank service
38 changes: 22 additions & 16 deletions internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,17 @@

func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults, searchMetrics []string) error {
var err error
processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
defer sp.End()

params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
)
return t.functionScore.Process(ctx, params, results)
}

// The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery.
// At this time, outputFields and rerank input_fields will be recalled.
// If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step
Expand Down Expand Up @@ -682,12 +693,7 @@
for i := 0; i < len(multipleMilvusResults); i++ {
multipleMilvusResults[i].Results.FieldsData = fields[i]
}
params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
)

if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
return err
}
if fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}); err != nil {
Expand All @@ -696,11 +702,7 @@
t.result.Results.FieldsData = fields[0]
}
} else {
params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
)
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {

Check warning on line 705 in internal/proxy/task_search.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/task_search.go#L705

Added line #L705 was not covered by tests
return err
}
}
Expand Down Expand Up @@ -823,11 +825,15 @@
}

if t.functionScore != nil && len(result.Results.FieldsData) != 0 {
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
// rank only returns id and score
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
return err
{
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
defer sp.End()
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
// rank only returns id and score
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
return err
}

Check warning on line 836 in internal/proxy/task_search.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/task_search.go#L835-L836

Added lines #L835 - L836 were not covered by tests
}
if !t.needRequery {
fields, err := t.reorganizeRequeryResults(ctx, result.Results.FieldsData, []*schemapb.IDs{t.result.Results.Ids})
Expand Down
7 changes: 4 additions & 3 deletions internal/util/function/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,17 @@ const (
siliconflowAKEnvStr string = "MILVUSAI_SILICONFLOW_API_KEY"
)

// TEI
// TEI and vllm

const (
ingestionPromptParamKey string = "ingestion_prompt"
searchPromptParamKey string = "search_prompt"
maxClientBatchSizeParamKey string = "max_client_batch_size"
truncationDirectionParamKey string = "truncation_direction"
endpointParamKey string = "endpoint"
EndpointParamKey string = "endpoint"

enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
EnableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
EnableVllmEnvStr string = "MILVUSAI_ENABLE_VLLM"
)

func parseAKAndURL(credentials *credentials.Credentials, params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (c *CohereEmbedding) Embedding(modelName string, texts []string, inputType
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("bearer %s", c.apiKey),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/util/function/models/openai/openai_embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (c *openAIBase) embedding(url string, headers map[string]string, modelName

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/util/function/models/openai/openai_embedding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ func TestEmbeddingFailed(t *testing.T) {
func TestTimeout(t *testing.T) {
var st int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// (Timeout 1s + Wait 1s) * Retry 3
time.Sleep(6 * time.Second)
// (Timeout 1s + 2s + 4s + Wait 1s * 3)
time.Sleep(11 * time.Second)
atomic.AddInt32(&st, 1)
w.WriteHeader(http.StatusUnauthorized)
}))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (c *SiliconflowEmbedding) Embedding(modelName string, texts []string, encod
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/util/function/models/tei/tei.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (c *TEIEmbedding) Embedding(texts []string, truncate bool, truncationDirect
if c.apiKey != "" {
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.apiKey)
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}
Expand Down
7 changes: 5 additions & 2 deletions internal/util/function/models/utils/embedding_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"context"
"fmt"
"io"
"math/rand"
"net/http"
"time"
)
Expand All @@ -45,7 +46,7 @@ func send(req *http.Request) ([]byte, error) {
return body, nil
}

func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int, retryDelay int) ([]byte, error) {
func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int) ([]byte, error) {
var err error
var body []byte
for i := 0; i < maxRetries; i++ {
Expand All @@ -60,7 +61,9 @@ func RetrySend(ctx context.Context, data []byte, httpMethod string, url string,
if err == nil {
return body, nil
}
time.Sleep(time.Duration(retryDelay) * time.Second)
backoffDelay := 1 << uint(i) * time.Second
jitter := time.Duration(rand.Int63n(int64(backoffDelay / 4)))
time.Sleep(backoffDelay + jitter)
}
return nil, err
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", token),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int,
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}
Expand Down
27 changes: 5 additions & 22 deletions internal/util/function/rerank/decay_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,34 +55,17 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
}

func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
pkType := schemapb.DataType_None
for _, field := range collSchema.Fields {
if field.IsPrimaryKey {
pkType = field.DataType
}
}

if pkType == schemapb.DataType_None {
return nil, fmt.Errorf("Collection %s can not found pk field", collSchema.Name)
}

base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false, pkType)
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
if err != nil {
return nil, err
}

if len(base.GetInputFieldNames()) != 1 {
return nil, fmt.Errorf("Decay function only supoorts single input, but gets [%s] input", base.GetInputFieldNames())
}

var inputType schemapb.DataType
for _, field := range collSchema.Fields {
if field.Name == base.GetInputFieldNames()[0] {
inputType = field.DataType
}
return nil, fmt.Errorf("Decay function only supports single input, but gets [%s] input", base.GetInputFieldNames())
}

if pkType == schemapb.DataType_Int64 {
inputType := base.GetInputFieldTypes()[0]
if base.pkType == schemapb.DataType_Int64 {
switch inputType {
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return newFunction[int64, int32](base, funcSchema)
Expand Down Expand Up @@ -160,7 +143,7 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
}

if decayFunc.decay <= 0 || decayFunc.decay >= 1 {
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.offset)
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.decay)
}

switch decayFunc.functionName {
Expand Down
2 changes: 1 addition & 1 deletion internal/util/function/rerank/decay_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (s *DecayFunctionSuite) TestNewDecayErrors() {
{
functionSchema.InputFieldNames = []string{"ts", "pk"}
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Decay function only supoorts single input, but gets")
s.ErrorContains(err, "Decay function only supports single input, but gets")
}

{
Expand Down
5 changes: 4 additions & 1 deletion internal/util/function/rerank/function_score.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

const (
decayFunctionName string = "decay"
modelFunctionName string = "model"
)

type SearchParams struct {
Expand Down Expand Up @@ -92,8 +93,10 @@
switch rerankerName {
case decayFunctionName:
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
case modelFunctionName:
rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema)

Check warning on line 97 in internal/util/function/rerank/function_score.go

View check run for this annotation

Codecov / codecov/patch

internal/util/function/rerank/function_score.go#L96-L97

Added lines #L96 - L97 were not covered by tests
default:
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName)
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s]", rerankerName, decayFunctionName, modelFunctionName)
}

if newRerankErr != nil {
Expand Down
Loading
Loading