Skip to content

Commit 270c046

Browse files
cyrillayCyril LAY
andauthored
feat(perf) Improve performance and production readiness (#2)
* Make worker count increasable * Try adding gunicorn workers to scale the whisperX * Try GPU async offload, and fix logs * Use large-v3-turbo as default model,simplify tests, fix Dockerfile and docs * Remove build on test branch for dev purposes --------- Co-authored-by: Cyril LAY <cyril.lay.ext@mail.numerique.gouv.fr>
1 parent 891527e commit 270c046

12 files changed

Lines changed: 129 additions & 230 deletions

File tree

.github/workflows/docker.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ on:
44
push:
55
branches:
66
- main
7-
- feat/make-diarization-optional # tmp for dev
87
workflow_dispatch:
98

109
jobs:

README.md

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,36 @@ export LOGGING_CONFIG=logging-config.yaml
3333
python app/main.py
3434
```
3535

36+
## Deployment
37+
38+
### Build the image
39+
40+
```bash
41+
docker build -f app/Dockerfile -t whisperx-openai-api .
42+
```
43+
44+
### Run
45+
46+
```bash
47+
docker run -d \
48+
--gpus all \
49+
-p 8000:8000 \
50+
-e API_KEY=your-api-key \
51+
-e HF_TOKEN=your-hf-token \
52+
-v /path/to/models:/data/models \
53+
whisperx-openai-api
54+
```
55+
56+
Models are downloaded on first startup and cached in `/data/models`. Mount a persistent volume to avoid re-downloading on restart.
57+
58+
**1 worker is recommended.** GPU inference is serialized internally : multiple workers each load a full model copy in VRAM, and it doesn't improve throughput unless you have multiple GPUs.
59+
60+
To scale workers (each worker loads its own model in VRAM):
61+
62+
```bash
63+
docker run -d --gpus all ... -e WORKERS=2 whisperx-openai-api
64+
```
65+
3666
## Testing
3767

3868
Tests mock actual inference and can be run locally:
@@ -52,7 +82,7 @@ Check the [documentation to run integration tests](docs/testing_with_gpu.md) on
5282
| -------- | ----------- | ------- |
5383
| API_KEY | API key for API access | Required |
5484
| HF_TOKEN | Hugging Face token (required for diarization) | Required |
55-
| TRANSCRIBE_MODEL | WhisperX model to load | `large-v2` |
85+
| TRANSCRIBE_MODEL | WhisperX model to load | `large-v3-turbo` |
5686
| BATCH_SIZE | Transcription batch size | `16` |
5787
| DIARIZE_MODEL | Pyannote diarization model | `pyannote/speaker-diarization-community-1` |
5888
| PRELOADED_ALIGN_MODEL_LANGUAGES | Languages to pre-load alignment models for | `["en", "fr", "nl", "de"]` |
@@ -61,8 +91,9 @@ Check the [documentation to run integration tests](docs/testing_with_gpu.md) on
6191
| FILL_NEAREST | Fill nearest gaps in speaker assignment (diarization only) | `false` |
6292
| TIMEOUT_KEEP_ALIVE | Keep-alive timeout (seconds) | `60` |
6393
| PORT | Server port | `8000` |
94+
| WORKERS | Number of uvicorn workers (each loads its own model in VRAM) | `1` |
6495
| RELOAD | Enable auto-reload | `false` |
6596
| ROOT_PATH | API root path | `None` |
66-
| LOGGING_CONFIG | Path to logging config file | `None` |
97+
| LOGGING_CONFIG | Path to logging config file | `logging-config.yaml` |
6798
| DEBUG | Enable debug logging | `false` |
6899

app/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,4 @@ COPY --chown=whisperuser:whisperuser ./logging-config.yaml /app/logging-config.y
7575

7676
USER whisperuser
7777

78-
CMD ["python3", "main.py"]
78+
ENTRYPOINT ["sh", "-c", "gunicorn main:app --workers ${WORKERS:-1} --worker-class uvicorn.workers.UvicornWorker --timeout 120 --bind 0.0.0.0:${PORT:-8000}"]

app/endpoints/audio.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import asyncio
12
import logging
23
import os
34
import tempfile
45
import time
6+
from functools import partial
57
from typing import Annotated, Optional
68

79
import numpy as np
@@ -16,6 +18,7 @@
1618
UploadFile,
1719
)
1820
from services.transcription import transcribe
21+
from utils.lifespan import gpu_executor
1922
import whisperx
2023

2124
from schemas.audio import AudioTranscription, InputTokenDetails, Segment, Usage
@@ -122,6 +125,9 @@ async def audio_transcriptions(
122125
audio = whisperx.load_audio(temp_file_path)
123126
os.remove(temp_file_path)
124127

125-
result = transcribe(audio, settings, language, is_diarize=is_diarize)
128+
loop = asyncio.get_event_loop()
129+
result = await loop.run_in_executor(
130+
gpu_executor, partial(transcribe, audio, settings, language, is_diarize=is_diarize)
131+
)
126132

127133
return _build_response(result, audio, is_diarize=is_diarize)

app/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
import logging
2+
import logging.config
13
from typing import Annotated
24

35
from fastapi import Depends, FastAPI
46
import uvicorn
7+
import yaml
58

69
from endpoints import audio, models, monitoring
710
from utils.config import Settings, get_settings, settings
811
from utils.lifespan import lifespan
912

13+
if settings.logging_config:
14+
with open(settings.logging_config) as f:
15+
logging.config.dictConfig(yaml.safe_load(f))
16+
1017
# Setup FastAPI
1118
app = FastAPI(
1219
title="LaSuite Meet WhisperX",
@@ -39,4 +46,5 @@ async def info(settings: Annotated[Settings, Depends(get_settings)]):
3946
log_config=settings.logging_config,
4047
reload=settings.reload,
4148
timeout_keep_alive=settings.timeout_keep_alive,
49+
workers=settings.workers,
4250
)

app/tests/conftest.py

Lines changed: 23 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,115 +2,58 @@
22
import sys
33
from unittest.mock import MagicMock, patch
44

5-
# ---------------------------------------------------------------------------
6-
# 1. Mock heavy third-party libraries in sys.modules BEFORE any app import.
7-
# app modules import whisperx, torch, numpy at the top level.
8-
# ---------------------------------------------------------------------------
9-
_mock_np = MagicMock()
5+
import pytest
6+
7+
# Mock libs before import
108
_mock_torch = MagicMock()
119
_mock_torch.cuda.is_available.return_value = False
1210
_mock_torch.float32 = "float32"
1311

1412
_mock_whisperx = MagicMock()
15-
FAKE_LANGUAGES = {"en": "english", "fr": "french", "cz": "czech"}
16-
FAKE_ALIGN_MODELS_HF = {"en": "WAV2VEC2_ASR_BASE_960H", "fr": "some-fr-model"}
17-
FAKE_ALIGN_MODELS_TORCH = {"en": "WAV2VEC2_ASR_BASE_960H"}
18-
19-
_mock_whisperx.utils.LANGUAGES = FAKE_LANGUAGES
20-
_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_HF = FAKE_ALIGN_MODELS_HF
21-
_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH = FAKE_ALIGN_MODELS_TORCH
13+
_mock_whisperx.utils.LANGUAGES = {"en": "english", "fr": "french", "cz": "czech"}
14+
_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_HF = {"en": "...", "fr": "..."}
15+
_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH = {}
2216

23-
for mod_name, mock_obj in {
24-
"numpy": _mock_np,
25-
"np": _mock_np,
17+
for name, mock in {
18+
"numpy": MagicMock(),
2619
"torch": _mock_torch,
2720
"whisperx": _mock_whisperx,
2821
"whisperx.utils": _mock_whisperx.utils,
2922
"whisperx.alignment": _mock_whisperx.alignment,
3023
"whisperx.asr": _mock_whisperx.asr,
3124
"whisperx.diarize": _mock_whisperx.diarize,
3225
}.items():
33-
sys.modules.setdefault(mod_name, mock_obj)
26+
sys.modules.setdefault(name, mock)
3427

35-
# ---------------------------------------------------------------------------
36-
# 2. Set required env vars BEFORE any app module is imported.
37-
# config.py and security.py evaluate settings at module level.
38-
# ---------------------------------------------------------------------------
39-
TEST_API_KEY = "test-api-key"
40-
TEST_HF_TOKEN = "test-hf-token"
28+
os.environ.setdefault("API_KEY", "test-key")
29+
os.environ.setdefault("HF_TOKEN", "test-token")
4130

42-
os.environ.setdefault("API_KEY", TEST_API_KEY)
43-
os.environ.setdefault("HF_TOKEN", TEST_HF_TOKEN)
31+
# App imports after mocking)
4432

45-
# ---------------------------------------------------------------------------
46-
# 3. Now it is safe to import app modules.
47-
# ---------------------------------------------------------------------------
4833
from fastapi import FastAPI # noqa: E402
4934
from fastapi.testclient import TestClient # noqa: E402
50-
import pytest # noqa: E402
5135

5236
from endpoints import audio # noqa: E402
5337
from utils.config import Settings, get_settings # noqa: E402
5438
from utils.security import check_api_key # noqa: E402
5539

56-
_FAKE_WORDS = [
57-
{"word": "Hello", "start": 0.0, "end": 0.7, "score": 0.95},
58-
{"word": "world.", "start": 0.8, "end": 1.5, "score": 0.92},
59-
]
60-
61-
MOCK_TRANSCRIPTION_RESULT = {
62-
"segments": [
63-
{"start": 0.0, "end": 1.5, "text": "Hello world.", "words": _FAKE_WORDS}
64-
],
65-
"word_segments": _FAKE_WORDS,
40+
FAKE_TRANSCRIPTION = {
41+
"segments": [{"start": 0.0, "end": 1.5, "text": "Hello world.", "speaker": "SPEAKER_00"}],
6642
}
6743

68-
FAKE_AUDIO = MagicMock(name="fake_audio_array")
69-
70-
71-
def _test_settings() -> Settings:
72-
return Settings(
73-
api_key=TEST_API_KEY,
74-
hf_token=TEST_HF_TOKEN,
75-
transcribe_model="large-v2",
76-
batch_size=4,
77-
)
78-
7944

8045
@pytest.fixture()
8146
def client():
82-
"""TestClient with dependency overrides and lifespan disabled."""
8347
app = FastAPI()
8448
app.include_router(audio.router, prefix="/v1")
49+
app.dependency_overrides[get_settings] = lambda: Settings(
50+
api_key="test-key", hf_token="test-token", transcribe_model="large-v3-turbo"
51+
)
52+
app.dependency_overrides[check_api_key] = lambda: "test-key"
8553

86-
app.dependency_overrides[get_settings] = _test_settings
87-
app.dependency_overrides[check_api_key] = lambda: TEST_API_KEY
88-
89-
with TestClient(app) as c:
54+
with (
55+
patch.object(audio.whisperx, "load_audio", return_value=MagicMock()),
56+
patch.object(audio, "transcribe", return_value=FAKE_TRANSCRIPTION),
57+
TestClient(app) as c,
58+
):
9059
yield c
91-
92-
93-
@pytest.fixture()
94-
def mock_whisperx():
95-
"""Patch whisperx symbols used directly in the endpoint module."""
96-
with patch.object(
97-
audio.whisperx, "load_audio", return_value=FAKE_AUDIO
98-
) as load_audio:
99-
yield {"load_audio": load_audio, "fake_audio": FAKE_AUDIO}
100-
101-
102-
@pytest.fixture()
103-
def mock_transcribe():
104-
"""Patch the transcribe service function as imported in the endpoint module."""
105-
with patch.object(
106-
audio,
107-
"transcribe",
108-
return_value=MOCK_TRANSCRIPTION_RESULT,
109-
) as mock:
110-
yield mock
111-
112-
113-
@pytest.fixture()
114-
def sample_audio_bytes() -> bytes:
115-
"""Minimal bytes to simulate an uploaded audio file."""
116-
return b"\x00" * 1024

0 commit comments

Comments
 (0)