Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 110 additions & 15 deletions app/endpoints/audio.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import asyncio
import logging
import os
import tempfile
import time
from functools import partial
from typing import Annotated, Optional

import numpy as np
import tiktoken
from fastapi import (
APIRouter,
Depends,
File,
Form,
HTTPException,
Request,
Security,
UploadFile,
)
from fastapi.responses import PlainTextResponse
from services.transcription import transcribe
from utils.lifespan import gpu_executor
import whisperx

from schemas.audio import AudioTranscription, AudioTranscriptionVerbose
from schemas.audio import AudioTranscription, InputTokenDetails, Segment, Usage
from utils.config import Settings, get_settings
from utils.exceptions import ModelNotFoundException
from utils.security import check_api_key
Expand All @@ -26,30 +31,108 @@

router = APIRouter()

WHISPERX_SAMPLE_RATE = 16_000
AUDIO_TOKENS_PER_SECOND = 10

@router.post("/audio/transcriptions")

SUPPORTED_RESPONSE_FORMATS = {"json", "text", "verbose_json", "diarized_json", "srt", "vtt"}


def _format_timestamp(seconds: float, separator: str) -> str:
ms = round(seconds * 1000)
hours, ms = divmod(ms, 3_600_000)
minutes, ms = divmod(ms, 60_000)
secs, ms = divmod(ms, 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d}{separator}{ms:03d}"


def _format_srt(segments: list[dict]) -> str:
blocks = []
for i, seg in enumerate(segments, start=1):
start = _format_timestamp(seg["start"], ",")
end = _format_timestamp(seg["end"], ",")
blocks.append(f"{i}\n{start} --> {end}\n{seg['text'].strip()}")
return "\n\n".join(blocks) + "\n"


def _format_vtt(segments: list[dict]) -> str:
blocks = ["WEBVTT"]
for seg in segments:
start = _format_timestamp(seg["start"], ".")
end = _format_timestamp(seg["end"], ".")
blocks.append(f"{start} --> {end}\n{seg['text'].strip()}")
return "\n\n".join(blocks) + "\n"


def _build_response(
result: dict, audio: np.ndarray, is_diarize: bool, is_verbose: bool
) -> AudioTranscription:
raw_segments = result.get("segments", [])
text = "".join(seg["text"] for seg in raw_segments)

audio_tokens = round(len(audio) / WHISPERX_SAMPLE_RATE * AUDIO_TOKENS_PER_SECOND)
output_tokens = len(tiktoken.get_encoding("o200k_base").encode(text))

segments = None
if is_diarize or is_verbose:
segments = [
Segment(
id=i,
text=seg["text"],
start=seg["start"],
end=seg["end"],
speaker=seg.get("speaker"),
)
for i, seg in enumerate(raw_segments)
]

return AudioTranscription(
task="transcribe" if is_verbose else None,
language=result.get("language") if is_verbose else None,
duration=len(audio) / WHISPERX_SAMPLE_RATE if is_verbose else None,
text=text,
segments=segments,
usage=Usage(
input_tokens=audio_tokens,
input_token_details=InputTokenDetails(audio_tokens=audio_tokens),
output_tokens=output_tokens,
total_tokens=audio_tokens + output_tokens,
),
)


@router.post("/audio/transcriptions", dependencies=[Security(check_api_key)])
async def audio_transcriptions(
request: Request,
settings: Annotated[Settings, Depends(get_settings)],
file: UploadFile = File(...),
model: Optional[str] = Form(None),
api_key=Security(check_api_key),
language: Optional[str] = Form(None),
) -> AudioTranscription | AudioTranscriptionVerbose:
response_format: str = Form("json"),

Check warning on line 110 in app/endpoints/audio.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use "Annotated" type hints for FastAPI dependency injection

See more on https://sonarcloud.io/project/issues?id=suitenumerique_meet-whisperx&issues=AZ5PtT0DjNg73_MR9fGC&open=AZ5PtT0DjNg73_MR9fGC&pullRequest=53
) -> AudioTranscription:
"""
Audio transcription API (custom implementation).

/!\ Note: This endpoint is **not** OpenAI API compatible.
The response format does not follow the OpenAI specification.
Audio transcription API compatible with the OpenAI transcription response format.
Supported response_format values: "json" (default), "text", "verbose_json", "diarized_json", "srt", "vtt".
"""
logger.info("Request received. transcribe model: %s, language: %s", model, language)
logger.info("Request received. Transcribe model: %s, language: %s", model, language)

Check warning on line 116 in app/endpoints/audio.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Change this code to not log user-controlled data.

See more on https://sonarcloud.io/project/issues?id=suitenumerique_meet-whisperx&issues=AZ5PtT0EjNg73_MR9fGE&open=AZ5PtT0EjNg73_MR9fGE&pullRequest=53

if language is not None and (language not in whisperx.utils.LANGUAGES):
if response_format not in SUPPORTED_RESPONSE_FORMATS:
raise HTTPException(

Check failure on line 119 in app/endpoints/audio.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Document this HTTPException with status code 400 in the "responses" parameter.

See more on https://sonarcloud.io/project/issues?id=suitenumerique_meet-whisperx&issues=AZ5PtT0DjNg73_MR9fGD&open=AZ5PtT0DjNg73_MR9fGD&pullRequest=53
status_code=400,
detail=f"Unsupported response_format '{response_format}'. Must be one of: {sorted(SUPPORTED_RESPONSE_FORMATS)}.",
)

is_diarize = response_format == "diarized_json"
is_verbose = response_format == "verbose_json"
is_text = response_format == "text"
is_srt = response_format == "srt"
is_vtt = response_format == "vtt"

if language is not None and language not in whisperx.utils.LANGUAGES:
raise HTTPException(
status_code=400,
detail=f"Unsupported language '{language}' for transcription.",
)
if language is not None and language not in (
if is_diarize and language is not None and language not in (
whisperx.alignment.DEFAULT_ALIGN_MODELS_HF
| whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH
):
Expand Down Expand Up @@ -78,6 +161,18 @@
audio = whisperx.load_audio(temp_file_path)
os.remove(temp_file_path)

result = transcribe(audio, settings, language)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
gpu_executor, partial(transcribe, audio, settings, language, is_diarize=is_diarize)
)

segments = result.get("segments", [])

if is_text:
return PlainTextResponse(content="".join(seg["text"] for seg in segments))
if is_srt:
return PlainTextResponse(content=_format_srt(segments), media_type="application/x-subrip")
if is_vtt:
return PlainTextResponse(content=_format_vtt(segments), media_type="text/vtt")

return AudioTranscription(**result)
return _build_response(result, audio, is_diarize=is_diarize, is_verbose=is_verbose)
43 changes: 19 additions & 24 deletions app/schemas/audio.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,33 @@
from typing import List, Optional, Dict, Any
from typing import List, Optional
from pydantic import BaseModel


class Word(BaseModel):
"""Represents a single word in the transcription"""

word: str
start: Optional[float] = None
end: Optional[float] = None
score: Optional[float] = None
speaker: Optional[str] = None


class Segment(BaseModel):
"""Represents a segment of transcribed audio"""

id: int
type: str = "transcript.text.segment"
text: str
start: float
end: float
text: str
words: List[Word]
speaker: Optional[str] = None


class AudioTranscription(BaseModel):
"""Base audio transcription model"""
class InputTokenDetails(BaseModel):
text_tokens: int = 0
audio_tokens: int

segments: List[Segment]
word_segments: Optional[List[Dict[str, Any]]] = None

class Usage(BaseModel):
type: str = "tokens"
input_tokens: int
input_token_details: InputTokenDetails
output_tokens: int
total_tokens: int

class AudioTranscriptionVerbose(AudioTranscription):
"""Extended audio transcription model with additional details"""

language: str
duration: float
class AudioTranscription(BaseModel):
task: Optional[str] = None
language: Optional[str] = None
duration: Optional[float] = None
text: str
words: List[Word]
segments: Optional[List[Segment]] = None
usage: Usage
8 changes: 5 additions & 3 deletions app/services/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ def transcribe(
audio: np.ndarray,
settings: Settings,
language: str | None = None,
is_diarize: bool = False,
) -> dict:
"""Run the full transcription pipeline: transcribe, align, and diarize."""
"""Run the transcription pipeline: transcribe, and optionally diarize."""
result = _transcribe_audio(audio, settings, language)
result = _align_transcription(audio, result, settings)
result = _diarize_and_assign_speakers(audio, result, settings)
if is_diarize:
result = _align_transcription(audio, result, settings)
result = _diarize_and_assign_speakers(audio, result, settings)
return result


Expand Down