Skip to content

Commit ffa9f12

Browse files
committed
add vllm and mistral unit tests
1 parent 6710ba9 commit ffa9f12

11 files changed

Lines changed: 566 additions & 53 deletions

File tree

api/endpoints/chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ async def chat_completions(
5858
request_context: ContextVar[RequestContext] = Depends(get_request_context),
5959
) -> JSONResponse | StreamingResponseWithStatusCode:
6060
"""Creates a model response for the given chat conversation."""
61+
6162
model_provider = await model_registry.get_model_provider(
6263
model=body.model,
6364
endpoint=EndpointRoute.CHAT_COMPLETIONS,

api/helpers/models/_modelregistry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ async def setup(self, models: list[ModelConfiguration], postgres_session: AsyncS
129129
)
130130
logger.info(f"Router {model.name} are created (id: {router_id})")
131131
except RouterAlreadyExistsException:
132-
continue
132+
pass
133133
except RouterAliasAlreadyExistsException:
134134
continue
135135
except Exception as e:

api/infrastructure/http/model/_mistralmodelhttpclient.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from mistralai.client.models import AudioChunk, ChatCompletionRequest, TextChunk, UserMessage
44

5-
from api.infrastructure.fastapi.schemas.models import ModelsResponse
5+
from api.infrastructure.fastapi.schemas.models import ModelResponse, ModelsResponse
66
from api.infrastructure.http.model._modelhttpclient import ModelHttpClient, ModelHttpClientEndpoints
77
from api.schemas.audio import AudioTranscription
88
from api.schemas.core.models import RequestContent
@@ -14,31 +14,28 @@ class MistralModelHttpClient(ModelHttpClient):
1414
# request formatting
1515
@staticmethod
1616
def format_chat_completion_request(request_content: RequestContent) -> RequestContent:
17-
try:
18-
request_content.body = ChatCompletionRequest(**request_content.body).model_dump(by_alias=True)
19-
except Exception:
20-
# apply a minimal formatting and ignore error to let the provider raise the correct 422 error
21-
# see https://docs.mistral.ai/api#operation-chat_completion_v1_chat_completions_post
22-
request_content.body = {
23-
"frequency_penalty": request_content.body.get("frequency_penalty") or 0.0,
24-
"max_tokens": request_content.body.get("max_tokens"),
25-
"messages": request_content.body.get("messages"),
26-
"model": request_content.body.get("model"),
27-
"n": request_content.body.get("n"),
28-
"parallel_tool_calls": request_content.body.get("parallel_tool_calls") or False,
29-
"prediction": request_content.body.get("prediction") or {},
30-
"presence_penalty": request_content.body.get("presence_penalty") or 0.0,
31-
"prompt_mode": request_content.body.get("prompt_mode"),
32-
"random_seed": request_content.body.get("random_seed") or request_content.body.get("seed"),
33-
"response_format": request_content.body.get("response_format") or {"type": "text"},
34-
"safe_prompt": request_content.body.get("safe_prompt") or False,
35-
"stop": request_content.body.get("stop") or [],
36-
"stream": request_content.body.get("stream") or False,
37-
"temperature": request_content.body.get("temperature"),
38-
"tool_choice": request_content.body.get("tool_choice"),
39-
"tools": request_content.body.get("tools"),
40-
"top_p": request_content.body.get("top_p") or 1.0,
41-
}
17+
# see https://docs.mistral.ai/api#operation-chat_completion_v1_chat_completions_post
18+
request_content.body = {
19+
"frequency_penalty": request_content.body.get("frequency_penalty") or 0.0,
20+
"max_tokens": request_content.body.get("max_tokens"),
21+
"messages": request_content.body.get("messages"),
22+
"model": request_content.body.get("model"),
23+
"n": request_content.body.get("n"),
24+
"parallel_tool_calls": request_content.body.get("parallel_tool_calls") or False,
25+
"prediction": request_content.body.get("prediction") or {},
26+
"presence_penalty": request_content.body.get("presence_penalty") or 0.0,
27+
"prompt_mode": request_content.body.get("prompt_mode"),
28+
"random_seed": request_content.body.get("random_seed") or request_content.body.get("seed"),
29+
"response_format": request_content.body.get("response_format") or {"type": "text"},
30+
"safe_prompt": request_content.body.get("safe_prompt") or False,
31+
"stop": request_content.body.get("stop") or [],
32+
"stream": request_content.body.get("stream") or False,
33+
"temperature": request_content.body.get("temperature"),
34+
"tool_choice": request_content.body.get("tool_choice"),
35+
"tools": request_content.body.get("tools"),
36+
"top_p": request_content.body.get("top_p") or 1.0,
37+
}
38+
4239
return request_content
4340

4441
@staticmethod
@@ -57,16 +54,30 @@ def format_audio_transcription_request(request_content: RequestContent) -> Reque
5754
).model_dump()
5855
request_content.files = {}
5956
request_content.form = {}
57+
6058
return request_content
6159

6260
# response formatting
6361
@staticmethod
64-
def format_client_response_to_models_response(request_content: RequestContent, response_data: dict) -> ModelsResponse:
65-
for model in response_data.get("data", []):
66-
model.update({"type": None})
67-
return ModelsResponse(**response_data)
62+
def format_response_to_models_response(request_content: RequestContent, response_data: dict) -> ModelsResponse:
63+
return ModelsResponse(
64+
data=[
65+
ModelResponse(
66+
id=model.get("id"),
67+
created=model.get("created"),
68+
owned_by=model.get("owned_by"),
69+
max_context_length=model.get("max_context_length"),
70+
aliases=model.get("aliases", []),
71+
)
72+
for model in response_data.get("data", [])
73+
]
74+
)
6875

6976
@staticmethod
70-
def format_client_response_to_audio_transcription_response(request_content: RequestContent, response_data: dict) -> AudioTranscription:
71-
text = response_data["choices"][0]["message"]["content"]
72-
return AudioTranscription(text=text)
77+
def format_response_to_audio_transcription_response(request_content: RequestContent, response_data: dict) -> AudioTranscription:
78+
return AudioTranscription(
79+
id=response_data["id"],
80+
model=response_data["model"],
81+
text=response_data["choices"][0]["message"]["content"],
82+
usage=response_data["usage"],
83+
)

api/infrastructure/http/model/_modelhttpclient.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def format_audio_transcription_request(self, request_content: RequestContent) ->
9393
@staticmethod
9494
def format_chat_completion_request(request_content: RequestContent) -> RequestContent:
9595
"""This method can be overridden by children clients to format the chat completion request."""
96+
# @TODO: setup default temperature by model (default=1.0)
97+
# @TODO: catch stream options to avoid double usage computation
9698
return request_content
9799

98100
@staticmethod

api/infrastructure/http/model/_vllmmodelhttpclient.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,13 @@ class VllmModelHttpClient(ModelHttpClient):
99
# response formatting
1010
@staticmethod
1111
def format_response_to_models_response(request_content: RequestContent, response_data: dict) -> ModelsResponse:
12-
data = [ModelResponse(max_context_length=model.get("max_model_len"), **model) for model in response_data.get("helpers/data", [])]
12+
data = [
13+
ModelResponse(
14+
id=model.get("id"),
15+
created=model.get("created"),
16+
owned_by=model.get("owned_by"),
17+
max_context_length=model.get("max_model_len"),
18+
)
19+
for model in response_data.get("data", [])
20+
]
1321
return ModelsResponse(data=data)

api/infrastructure/model/_modelprovidergateway.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,11 @@ def _build_client(provider_type, url, key, timeout, model_name) -> ModelHttpClie
3939
}
4040

4141
return provider_class[provider_type](
42-
url=url, key=key, timeout=timeout, model_name=model_name, model_active_params=None, model_total_params=None, model_hosting_zone=None
42+
url=url,
43+
key=key,
44+
timeout=timeout,
45+
model_name=model_name,
46+
model_active_params=None,
47+
model_total_params=None,
48+
model_hosting_zone=None,
4349
)

api/schemas/chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ class CreateChatCompletion(BaseModel):
3737
stop: str | list[str] | None = Field(default_factory=list, description="Up to 4 sequences where the API will stop generating further tokens.") # fmt: off
3838
stream: Literal[True, False] | None = Field(default=False, description="If set, partial message deltas will be sent. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.") # fmt: off
3939
stream_options: Any | None = Field(default=None, description="Options for streaming response. Only set this when you set `stream: true`.") # fmt: off
40-
temperature: float | None = Field(default=0.7, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.") # fmt: off
41-
top_p: float | None = Field(default=1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.<br>We generally recommend altering this or `temperature` but not both.") # fmt: off
40+
temperature: float | None = Field(default=None, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.") # fmt: off
41+
top_p: float | None = Field(default=None, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.<br>We generally recommend altering this or `temperature` but not both.") # fmt: off
4242
tools: Annotated[list[dict| SearchTool] | None, Field(description="A list of tools the model may call. Currently, only functions are supported as a tool. Support function calling and build-in tools (currently only SearchTool). Use this to provide a list of functions the model may generate JSON inputs for.")] | None = Field(default=None) # fmt: off
4343
tool_choice: Any = Field(default="none", description="Controls which (if any) tool is called by the model. `none` means the model will not call any tool and instead generates a message. `auto` means the model can pick between generating a message or calling one or more tools. `required` means the model must call one or more tools. Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool.<br>`none` is the default when no tools are present. `auto` is the default if tools are present.") # fmt: off
4444
parallel_tool_calls: bool | None = Field(default=False, description="Whether to call tools in parallel or sequentially. If true, the model will call tools in parallel. If false, the model will call tools sequentially. If None, the model will call tools in parallel if the model supports it, otherwise it will call tools sequentially.") # fmt: off

0 commit comments

Comments
 (0)