Skip to content

Commit f6544fa

Browse files
authored
fix(audio): fix request_format for Albert integration (#665)
* fix(tests): audio transcription and chat tests * Update unit coverage badge --------- Co-authored-by: leoguillaume <leoguillaume@users.noreply.github.com>
1 parent daf7590 commit f6544fa

6 files changed

Lines changed: 40 additions & 35 deletions

File tree

.github/.env.ci.example

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# CELERY_EXTRA_ARGS="--loglevel DEBUG"
77

88
# Dependencies
9-
POSTGRES_PORT=5432
10-
REDIS_PORT=6379
11-
ELASTICSEARCH_PORT=9200
12-
RABBITMQ_PORT=5672
9+
POSTGRES_PORT=15432
10+
REDIS_PORT=16379
11+
ELASTICSEARCH_PORT=19200
12+
RABBITMQ_PORT=25672
13+
RABBITMQ_UI_PORT=15672
1314

1415
# Secrets (to complete)
1516
# ALBERT_API_KEY=

.github/badges/coverage.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"schemaVersion":1,"label":"coverage","message":"51.05%","color":"red"}
1+
{"schemaVersion":1,"label":"coverage","message":"51.16%","color":"red"}

.github/compose.ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ services:
3434
image: rabbitmq:3.13-management
3535
restart: always
3636
ports:
37-
- "15672:15672"
37+
- "${RABBITMQ_UI_PORT:-15672}:15672"
3838
- "${RABBITMQ_PORT:-5672}:5672"
3939
environment:
4040
- RABBITMQ_DEFAULT_USER=rabbitmq

api/endpoints/audio.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
from contextvars import ContextVar
2-
from typing import Literal
2+
from typing import Annotated
33

4-
from fastapi import APIRouter, Depends, File, Form, Request, Security, UploadFile
4+
from fastapi import APIRouter, Depends, Form, Request, Security
55
from fastapi.responses import JSONResponse, PlainTextResponse
66
from redis.asyncio import Redis as AsyncRedis
77
from sqlalchemy.ext.asyncio import AsyncSession
88

99
from api.helpers._accesscontroller import AccessController
1010
from api.helpers.models import ModelRegistry
11-
from api.schemas.audio import (
12-
AudioTranscription,
13-
AudioTranscriptionLanguage,
14-
)
11+
from api.schemas.audio import AudioTranscription, CreateAudioTranscription
1512
from api.schemas.core.context import RequestContext
1613
from api.schemas.core.models import RequestContent
1714
from api.utils.dependencies import get_model_registry, get_postgres_session, get_redis_client, get_request_context
@@ -20,16 +17,12 @@
2017
router = APIRouter(prefix="/v1", tags=[ROUTER__AUDIO.title()])
2118

2219

23-
# fmt: off
24-
@router.post(path=ENDPOINT__AUDIO_TRANSCRIPTIONS, dependencies=[Security(dependency=AccessController())], status_code=200, response_model=AudioTranscription)
20+
@router.post(
21+
path=ENDPOINT__AUDIO_TRANSCRIPTIONS, dependencies=[Security(dependency=AccessController())], status_code=200, response_model=AudioTranscription
22+
)
2523
async def audio_transcriptions(
2624
request: Request,
27-
file: UploadFile = File(description="The audio file object (not file name) to transcribe, in one of these formats: mp3 or wav."),
28-
model: str = Form(default=..., description="ID of the model to use. Call `/v1/models` endpoint to get the list of available models, only `automatic-speech-recognition` model type is supported."),
29-
language: AudioTranscriptionLanguage = Form(default=AudioTranscriptionLanguage.EMPTY, description="The language of the output audio. If the output language is different than the audio language, the audio language will be translated into the output language. Supplying the output language in ISO-639-1 (e.g. en, fr) format will improve accuracy and latency."),
30-
prompt: str | None = Form(default=None, description="An optional text to tell the model what to do with the input audio. Default is `Transcribe this audio in this language : {language}`"),
31-
response_format: Literal["json", "text"] = Form(default="json", description="The format of the transcript output, in one of these formats: `json` or `text`."),
32-
temperature: float | None = Form(default=None, ge=0, le=1, description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit."),
25+
data: Annotated[CreateAudioTranscription, Form()],
3326
model_registry: ModelRegistry = Depends(get_model_registry),
3427
redis_client: AsyncRedis = Depends(get_redis_client),
3528
postgres_session: AsyncSession = Depends(get_postgres_session),
@@ -39,28 +32,29 @@ async def audio_transcriptions(
3932
Transcribes audio into the input language.
4033
"""
4134
model_provider = await model_registry.get_model_provider(
42-
model=model,
35+
model=data.model,
4336
endpoint=ENDPOINT__AUDIO_TRANSCRIPTIONS,
4437
postgres_session=postgres_session,
4538
redis_client=redis_client,
46-
request_context=request_context
39+
request_context=request_context,
4740
)
4841

49-
file_content = await file.read()
50-
form = {"model": model, "response_format": response_format, "temperature": temperature, "language": language.value, "prompt": prompt}
42+
file_content = await data.file.read()
43+
form = data.model_dump()
44+
form.pop("file")
5145

5246
response = await model_provider.forward_request(
5347
request_content=RequestContent(
5448
method="POST",
55-
model=model,
49+
model=data.model,
5650
endpoint=ENDPOINT__AUDIO_TRANSCRIPTIONS,
57-
files={"file": (file.filename, file_content, file.content_type)},
51+
files={"file": (data.file.filename, file_content, data.file.content_type)},
5852
form=form,
5953
),
6054
redis_client=redis_client,
6155
)
6256

63-
if response_format == "text":
57+
if data.response_format == "text":
6458
response = PlainTextResponse(content=response.json()["text"], status_code=response.status_code)
6559
else:
6660
response = JSONResponse(content=AudioTranscription(**response.json()).model_dump(), status_code=response.status_code)

api/schemas/audio.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import base64
22
from enum import Enum
3+
from typing import Literal
34

5+
from fastapi import File, UploadFile
46
from mistralai.models import AudioChunk, ChatCompletionRequest, TextChunk, UserMessage
5-
from pydantic import Field
7+
from pydantic import Field, field_validator
68

79
from api.schemas import BaseModel
810
from api.schemas.admin.providers import ProviderType
@@ -18,6 +20,17 @@
1820

1921

2022
class CreateAudioTranscription(BaseModel):
23+
file: UploadFile = File(description="The audio file object (not file name) to transcribe, in one of these formats: mp3 or wav.") # fmt: off
24+
model: str = Field(default=..., description="ID of the model to use. Call `/v1/models` endpoint to get the list of available models, only `automatic-speech-recognition` model type is supported.") # fmt: off
25+
language: AudioTranscriptionLanguage = Field(default=AudioTranscriptionLanguage.EMPTY, description="The language of the output audio. If the output language is different than the audio language, the audio language will be translated into the output language. Supplying the output language in ISO-639-1 (e.g. en, fr) format will improve accuracy and latency.") # fmt: off
26+
prompt: str | None = Field(default=None, description="An optional text to tell the model what to do with the input audio. Default is `Transcribe this audio in this language : {language}`") # fmt: off
27+
response_format: Literal["json", "text"] = Field(default="json", description="The format of the transcript output, in one of these formats: `json` or `text`.") # fmt: off
28+
temperature: float | None = Field(default=None, ge=0, le=1, description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.") # fmt: off
29+
30+
@field_validator("language", mode="after")
31+
def extract_value_language(cls, language: AudioTranscriptionLanguage) -> str:
32+
return language.value
33+
2134
@staticmethod
2235
def format_request(provider_type: ProviderType, request_content: RequestContent):
2336
match provider_type:
@@ -43,13 +56,10 @@ def format_request(provider_type: ProviderType, request_content: RequestContent)
4356
return request_content
4457

4558
case ProviderType.VLLM:
46-
if request_content.form["language"] == AudioTranscriptionLanguage.EMPTY:
47-
request_content.form.pop("language")
59+
request_content.form["language"] = "en" if request_content.form["language"] == "" else request_content.form["language"]
60+
request_content.form["temperature"] = 0 if request_content.form["temperature"] is None else request_content.form["temperature"]
61+
request_content.form["prompt"] = "" if request_content.form["prompt"] is None else request_content.form["prompt"]
4862

49-
if request_content.form.get("temperature") is None:
50-
request_content.form.pop("temperature")
51-
52-
request_content.form["response_format"] = "json"
5363
return request_content
5464

5565
case _:

api/tests/integ/test_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_chat_completions_search_no_collections(self, client: TestClient, setup)
197197
},
198198
}
199199
response = client.post_without_permissions(url=f"/v1{ENDPOINT__CHAT_COMPLETIONS}", json=params)
200-
assert response.status_code == 200, response.text
200+
assert response.status_code == 422, response.text
201201

202202
def test_chat_completions_search_template(self, client: TestClient, setup):
203203
"""Test the GET /chat/completions search template."""

0 commit comments

Comments
 (0)