Skip to content

Commit 83ea6cd

Browse files
Rerank support for vLLM in HuggingFace Serving Runtime (kserve#4376)
Signed-off-by: ayush <ayush.sawant@nutanix.com> Signed-off-by: tmbochenski <31276485+tmbochenski@users.noreply.github.com> Co-authored-by: tmbochenski <31276485+tmbochenski@users.noreply.github.com>
1 parent db60f59 commit 83ea6cd

14 files changed

Lines changed: 490 additions & 46 deletions

File tree

python/huggingfaceserver/huggingfaceserver/encoder_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
EmbeddingRequest,
6161
EmbeddingResponseData,
6262
ErrorResponse,
63+
Rerank,
64+
RerankRequest,
6365
UsageInfo,
6466
)
6567

@@ -439,3 +441,13 @@ async def create_embedding(
439441

440442
except Exception as e:
441443
raise OpenAIError(f"Error during embedding creation: {e}") from e
444+
445+
async def create_rerank(
446+
self,
447+
request: RerankRequest,
448+
raw_request: Optional[Request] = None,
449+
context: Optional[Dict[str, Any]] = None,
450+
) -> Union[AsyncGenerator[str, None], Rerank, ErrorResponse]:
451+
raise OpenAIError(
452+
"Rerank is not implemented for Encoder model with huggingface backend"
453+
)

python/huggingfaceserver/huggingfaceserver/vllm/vllm_model.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,43 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, Optional, Union, AsyncGenerator
16-
import torch
1715
from argparse import Namespace
18-
from fastapi import Request
16+
from typing import Any, Dict, Optional, Union, AsyncGenerator
1917
from http import HTTPStatus
2018

21-
from kserve.protocol.rest.openai.errors import create_error_response
22-
from kserve.protocol.rest.openai import OpenAIEncoderModel, OpenAIGenerativeModel
23-
from kserve.protocol.rest.openai.types import (
24-
Completion,
25-
ChatCompletion,
26-
CompletionRequest,
27-
ChatCompletionRequest,
28-
EmbeddingRequest,
29-
Embedding,
30-
ErrorResponse,
31-
)
32-
33-
import vllm.envs as envs
19+
import torch
20+
from fastapi import Request
3421
from vllm import AsyncEngineArgs
22+
import vllm.envs as envs
3523
from vllm.entrypoints.logger import RequestLogger
3624
from vllm.engine.protocol import EngineClient
3725
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
3826
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
3927
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
28+
from vllm.entrypoints.openai.serving_score import ServingScores
4029
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
4130
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
4231
from vllm.entrypoints.openai.cli_args import validate_parsed_serve_args
4332
from vllm.entrypoints.chat_utils import load_chat_template
4433
from vllm.entrypoints.openai.protocol import ErrorResponse as engineError
4534
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
35+
36+
from kserve.protocol.rest.openai.errors import create_error_response
37+
from kserve.protocol.rest.openai import (
38+
OpenAIEncoderModel,
39+
OpenAIGenerativeModel,
40+
)
41+
from kserve.protocol.rest.openai.types import (
42+
Completion,
43+
ChatCompletion,
44+
CompletionRequest,
45+
ChatCompletionRequest,
46+
EmbeddingRequest,
47+
Embedding,
48+
ErrorResponse,
49+
RerankRequest,
50+
Rerank,
51+
)
4652
from .utils import build_async_engine_client_from_engine_args, build_vllm_engine_args
4753

4854

@@ -53,7 +59,11 @@ class VLLMModel(
5359
vllm_engine_args: AsyncEngineArgs = None
5460
args: Namespace = None
5561
ready: bool = False
62+
openai_serving_models: Optional[OpenAIServingModels] = None
5663
openai_serving_completion: Optional[OpenAIServingCompletion] = None
64+
openai_serving_chat: Optional[OpenAIServingChat] = None
65+
openai_serving_embedding: Optional[OpenAIServingEmbedding] = None
66+
serving_reranking: Optional[ServingScores] = None
5767

5868
def __init__(
5969
self,
@@ -68,6 +78,9 @@ def __init__(
6878
self.vllm_engine_args = engine_args
6979
self.request_logger = request_logger
7080
self.model_name = model_name
81+
self.base_model_paths = []
82+
self.log_stats = True
83+
self.model_config = None
7184

7285
async def start_engine(self):
7386
if self.args.tool_parser_plugin and len(self.args.tool_parser_plugin) > 3:
@@ -169,6 +182,17 @@ async def start_engine(self):
169182
else None
170183
)
171184

185+
self.serving_reranking = (
186+
ServingScores(
187+
self.engine_client,
188+
self.model_config,
189+
self.openai_serving_models,
190+
request_logger=self.request_logger,
191+
)
192+
if self.model_config.task == "score"
193+
else None
194+
)
195+
172196
self.ready = True
173197
return self.ready
174198

@@ -201,6 +225,11 @@ async def create_completion(
201225
raw_request: Optional[Request] = None,
202226
context: Optional[Dict[str, Any]] = None,
203227
) -> Union[AsyncGenerator[str, None], Completion, ErrorResponse]:
228+
if self.openai_serving_completion is None:
229+
return create_error_response(
230+
message="The model does not support Completions API",
231+
status_code=HTTPStatus.BAD_REQUEST,
232+
)
204233
response = await self.openai_serving_completion.create_completion(
205234
request, raw_request
206235
)
@@ -221,6 +250,11 @@ async def create_chat_completion(
221250
raw_request: Optional[Request] = None,
222251
context: Optional[Dict[str, Any]] = None,
223252
) -> Union[AsyncGenerator[str, None], ChatCompletion, ErrorResponse]:
253+
if self.openai_serving_chat is None:
254+
return create_error_response(
255+
message="The model does not support Chat Completions API",
256+
status_code=HTTPStatus.BAD_REQUEST,
257+
)
224258
response = await self.openai_serving_chat.create_chat_completion(
225259
request, raw_request
226260
)
@@ -241,6 +275,11 @@ async def create_embedding(
241275
raw_request: Optional[Request] = None,
242276
context: Optional[Dict[str, Any]] = None,
243277
) -> Union[AsyncGenerator[str, None], Embedding, ErrorResponse]:
278+
if self.openai_serving_embedding is None:
279+
return create_error_response(
280+
message="The model does not support Embeddings API",
281+
status_code=HTTPStatus.BAD_REQUEST,
282+
)
244283
response = await self.openai_serving_embedding.create_embedding(
245284
request, raw_request
246285
)
@@ -254,3 +293,26 @@ async def create_embedding(
254293
)
255294

256295
return response
296+
297+
async def create_rerank(
298+
self,
299+
request: RerankRequest,
300+
raw_request: Optional[Request] = None,
301+
context: Optional[Dict[str, Any]] = None,
302+
) -> Union[AsyncGenerator[str, None], Rerank, ErrorResponse]:
303+
if self.serving_reranking is None:
304+
return create_error_response(
305+
message="The model does not support Rerank API",
306+
status_code=HTTPStatus.BAD_REQUEST,
307+
)
308+
response = await self.serving_reranking.do_rerank(request, raw_request)
309+
310+
if isinstance(response, engineError):
311+
return create_error_response(
312+
message=response.message,
313+
err_type=response.type,
314+
param=response.param,
315+
status_code=HTTPStatus(response.code),
316+
)
317+
318+
return response
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2025 The KServe Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import pytest
17+
import requests
18+
19+
from kserve.protocol.rest.openai.types import Rerank
20+
from server import RemoteOpenAIServer
21+
22+
23+
MODEL = "BAAI/bge-reranker-base"
24+
MODEL_NAME = "test-model"
25+
26+
27+
@pytest.fixture(scope="module")
28+
def server(): # noqa: F811
29+
args = [
30+
# use half precision for speed and memory savings in CI environment
31+
"--dtype",
32+
"bfloat16",
33+
"--max-model-len",
34+
"100",
35+
"--enforce-eager",
36+
]
37+
38+
with RemoteOpenAIServer(MODEL, MODEL_NAME, args) as remote_server:
39+
yield remote_server
40+
41+
42+
@pytest.mark.asyncio
43+
@pytest.mark.parametrize(
44+
"model_name",
45+
[MODEL_NAME],
46+
)
47+
async def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
48+
query = "What is the capital of France?"
49+
documents = [
50+
"The capital of Brazil is Brasilia.",
51+
"The capital of France is Paris.",
52+
]
53+
54+
rerank_response = requests.post(
55+
server.url_for("openai/v1", "rerank"),
56+
json={
57+
"model": model_name,
58+
"query": query,
59+
"documents": documents,
60+
},
61+
)
62+
rerank_response.raise_for_status()
63+
rerank = Rerank.model_validate(rerank_response.json())
64+
65+
assert rerank.id is not None
66+
assert rerank.results is not None
67+
assert len(rerank.results) == 2
68+
assert rerank.results[0].relevance_score >= 0.9
69+
assert rerank.results[1].relevance_score <= 0.01
70+
71+
72+
@pytest.mark.asyncio
73+
@pytest.mark.parametrize(
74+
"model_name",
75+
[MODEL_NAME],
76+
)
77+
async def test_top_n(server: RemoteOpenAIServer, model_name: str):
78+
query = "What is the capital of France?"
79+
documents = [
80+
"The capital of Brazil is Brasilia.",
81+
"The capital of France is Paris.",
82+
"Cross-encoder models are neat",
83+
]
84+
85+
rerank_response = requests.post(
86+
server.url_for("openai/v1", "rerank"),
87+
json={"model": model_name, "query": query, "documents": documents, "top_n": 2},
88+
)
89+
rerank_response.raise_for_status()
90+
rerank = Rerank.model_validate(rerank_response.json())
91+
92+
assert rerank.id is not None
93+
assert rerank.results is not None
94+
assert len(rerank.results) == 2
95+
assert rerank.results[0].relevance_score >= 0.9
96+
assert rerank.results[1].relevance_score <= 0.01
97+
98+
99+
@pytest.mark.asyncio
100+
@pytest.mark.parametrize(
101+
"model_name",
102+
[MODEL_NAME],
103+
)
104+
async def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
105+
query = "What is the capital of France?" * 100
106+
documents = [
107+
"The capital of Brazil is Brasilia.",
108+
"The capital of France is Paris.",
109+
]
110+
111+
rerank_response = requests.post(
112+
server.url_for("openai/v1", "rerank"),
113+
json={"model": model_name, "query": query, "documents": documents},
114+
)
115+
assert rerank_response.status_code == 400
116+
# Assert just a small fragments of the response
117+
assert "Please reduce the length of the input." in rerank_response.text

0 commit comments

Comments
 (0)