Skip to content

Commit 0824fb9

Browse files
author
Cyril LAY
committed
Use large-v3-turbo as default model,simplify tests, fix Dockerfile and docs
1 parent 236c146 commit 0824fb9

7 files changed

Lines changed: 87 additions & 227 deletions

File tree

README.md

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,36 @@ export LOGGING_CONFIG=logging-config.yaml
3333
python app/main.py
3434
```
3535

36+
## Deployment
37+
38+
### Build the image
39+
40+
```bash
41+
docker build -f app/Dockerfile -t whisperx-openai-api .
42+
```
43+
44+
### Run
45+
46+
```bash
47+
docker run -d \
48+
--gpus all \
49+
-p 8000:8000 \
50+
-e API_KEY=your-api-key \
51+
-e HF_TOKEN=your-hf-token \
52+
-v /path/to/models:/data/models \
53+
whisperx-openai-api
54+
```
55+
56+
Models are downloaded on first startup and cached in `/data/models`. Mount a persistent volume to avoid re-downloading on restart.
57+
58+
**1 worker is recommended.** GPU inference is serialized internally : multiple workers each load a full model copy in VRAM, and it doesn't improve throughput unless you have multiple GPUs.
59+
60+
To scale workers (each worker loads its own model in VRAM):
61+
62+
```bash
63+
docker run -d --gpus all ... -e WORKERS=2 whisperx-openai-api
64+
```
65+
3666
## Testing
3767

3868
Tests mock actual inference and can be run locally:
@@ -52,7 +82,7 @@ Check the [documentation to run integration tests](docs/testing_with_gpu.md) on
5282
| -------- | ----------- | ------- |
5383
| API_KEY | API key for API access | Required |
5484
| HF_TOKEN | Hugging Face token (required for diarization) | Required |
55-
| TRANSCRIBE_MODEL | WhisperX model to load | `large-v2` |
85+
| TRANSCRIBE_MODEL | WhisperX model to load | `large-v3-turbo` |
5686
| BATCH_SIZE | Transcription batch size | `16` |
5787
| DIARIZE_MODEL | Pyannote diarization model | `pyannote/speaker-diarization-community-1` |
5888
| PRELOADED_ALIGN_MODEL_LANGUAGES | Languages to pre-load alignment models for | `["en", "fr", "nl", "de"]` |
@@ -64,6 +94,6 @@ Check the [documentation to run integration tests](docs/testing_with_gpu.md) on
6494
| WORKERS | Number of uvicorn workers (each loads its own model in VRAM) | `1` |
6595
| RELOAD | Enable auto-reload | `false` |
6696
| ROOT_PATH | API root path | `None` |
67-
| LOGGING_CONFIG | Path to logging config file | `None` |
97+
| LOGGING_CONFIG | Path to logging config file | `logging-config.yaml` |
6898
| DEBUG | Enable debug logging | `false` |
6999

app/Dockerfile

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,4 @@ COPY --chown=whisperuser:whisperuser ./logging-config.yaml /app/logging-config.y
7575

7676
USER whisperuser
7777

78-
CMD gunicorn main:app \
79-
--workers ${WORKERS:-1} \
80-
--worker-class uvicorn.workers.UvicornWorker \
81-
--timeout 120 \
82-
--bind 0.0.0.0:${PORT:-8000}
78+
ENTRYPOINT ["sh", "-c", "gunicorn main:app --workers ${WORKERS:-1} --worker-class uvicorn.workers.UvicornWorker --timeout 120 --bind 0.0.0.0:${PORT:-8000}"]

app/tests/conftest.py

Lines changed: 23 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,115 +2,58 @@
22
import sys
33
from unittest.mock import MagicMock, patch
44

5-
# ---------------------------------------------------------------------------
6-
# 1. Mock heavy third-party libraries in sys.modules BEFORE any app import.
7-
# app modules import whisperx, torch, numpy at the top level.
8-
# ---------------------------------------------------------------------------
9-
_mock_np = MagicMock()
5+
import pytest
6+
7+
# Mock libs before import
108
_mock_torch = MagicMock()
119
_mock_torch.cuda.is_available.return_value = False
1210
_mock_torch.float32 = "float32"
1311

1412
_mock_whisperx = MagicMock()
15-
FAKE_LANGUAGES = {"en": "english", "fr": "french", "cz": "czech"}
16-
FAKE_ALIGN_MODELS_HF = {"en": "WAV2VEC2_ASR_BASE_960H", "fr": "some-fr-model"}
17-
FAKE_ALIGN_MODELS_TORCH = {"en": "WAV2VEC2_ASR_BASE_960H"}
18-
19-
_mock_whisperx.utils.LANGUAGES = FAKE_LANGUAGES
20-
_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_HF = FAKE_ALIGN_MODELS_HF
21-
_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH = FAKE_ALIGN_MODELS_TORCH
13+
_mock_whisperx.utils.LANGUAGES = {"en": "english", "fr": "french", "cz": "czech"}
14+
_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_HF = {"en": "...", "fr": "..."}
15+
_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH = {}
2216

23-
for mod_name, mock_obj in {
24-
"numpy": _mock_np,
25-
"np": _mock_np,
17+
for name, mock in {
18+
"numpy": MagicMock(),
2619
"torch": _mock_torch,
2720
"whisperx": _mock_whisperx,
2821
"whisperx.utils": _mock_whisperx.utils,
2922
"whisperx.alignment": _mock_whisperx.alignment,
3023
"whisperx.asr": _mock_whisperx.asr,
3124
"whisperx.diarize": _mock_whisperx.diarize,
3225
}.items():
33-
sys.modules.setdefault(mod_name, mock_obj)
26+
sys.modules.setdefault(name, mock)
3427

35-
# ---------------------------------------------------------------------------
36-
# 2. Set required env vars BEFORE any app module is imported.
37-
# config.py and security.py evaluate settings at module level.
38-
# ---------------------------------------------------------------------------
39-
TEST_API_KEY = "test-api-key"
40-
TEST_HF_TOKEN = "test-hf-token"
28+
os.environ.setdefault("API_KEY", "test-key")
29+
os.environ.setdefault("HF_TOKEN", "test-token")
4130

42-
os.environ.setdefault("API_KEY", TEST_API_KEY)
43-
os.environ.setdefault("HF_TOKEN", TEST_HF_TOKEN)
31+
# App imports after mocking)
4432

45-
# ---------------------------------------------------------------------------
46-
# 3. Now it is safe to import app modules.
47-
# ---------------------------------------------------------------------------
4833
from fastapi import FastAPI # noqa: E402
4934
from fastapi.testclient import TestClient # noqa: E402
50-
import pytest # noqa: E402
5135

5236
from endpoints import audio # noqa: E402
5337
from utils.config import Settings, get_settings # noqa: E402
5438
from utils.security import check_api_key # noqa: E402
5539

56-
_FAKE_WORDS = [
57-
{"word": "Hello", "start": 0.0, "end": 0.7, "score": 0.95},
58-
{"word": "world.", "start": 0.8, "end": 1.5, "score": 0.92},
59-
]
60-
61-
MOCK_TRANSCRIPTION_RESULT = {
62-
"segments": [
63-
{"start": 0.0, "end": 1.5, "text": "Hello world.", "words": _FAKE_WORDS}
64-
],
65-
"word_segments": _FAKE_WORDS,
40+
FAKE_TRANSCRIPTION = {
41+
"segments": [{"start": 0.0, "end": 1.5, "text": "Hello world.", "speaker": "SPEAKER_00"}],
6642
}
6743

68-
FAKE_AUDIO = MagicMock(name="fake_audio_array")
69-
70-
71-
def _test_settings() -> Settings:
72-
return Settings(
73-
api_key=TEST_API_KEY,
74-
hf_token=TEST_HF_TOKEN,
75-
transcribe_model="large-v2",
76-
batch_size=4,
77-
)
78-
7944

8045
@pytest.fixture()
8146
def client():
82-
"""TestClient with dependency overrides and lifespan disabled."""
8347
app = FastAPI()
8448
app.include_router(audio.router, prefix="/v1")
49+
app.dependency_overrides[get_settings] = lambda: Settings(
50+
api_key="test-key", hf_token="test-token", transcribe_model="large-v3-turbo"
51+
)
52+
app.dependency_overrides[check_api_key] = lambda: "test-key"
8553

86-
app.dependency_overrides[get_settings] = _test_settings
87-
app.dependency_overrides[check_api_key] = lambda: TEST_API_KEY
88-
89-
with TestClient(app) as c:
54+
with (
55+
patch.object(audio.whisperx, "load_audio", return_value=MagicMock()),
56+
patch.object(audio, "transcribe", return_value=FAKE_TRANSCRIPTION),
57+
TestClient(app) as c,
58+
):
9059
yield c
91-
92-
93-
@pytest.fixture()
94-
def mock_whisperx():
95-
"""Patch whisperx symbols used directly in the endpoint module."""
96-
with patch.object(
97-
audio.whisperx, "load_audio", return_value=FAKE_AUDIO
98-
) as load_audio:
99-
yield {"load_audio": load_audio, "fake_audio": FAKE_AUDIO}
100-
101-
102-
@pytest.fixture()
103-
def mock_transcribe():
104-
"""Patch the transcribe service function as imported in the endpoint module."""
105-
with patch.object(
106-
audio,
107-
"transcribe",
108-
return_value=MOCK_TRANSCRIPTION_RESULT,
109-
) as mock:
110-
yield mock
111-
112-
113-
@pytest.fixture()
114-
def sample_audio_bytes() -> bytes:
115-
"""Minimal bytes to simulate an uploaded audio file."""
116-
return b"\x00" * 1024
Lines changed: 27 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,153 +1,43 @@
1-
"""Tests for the POST /v1/audio/transcriptions endpoint."""
2-
3-
import os
4-
5-
from utils.config import Settings
1+
"""Tests for POST /v1/audio/transcriptions."""
62

73
ENDPOINT = "/v1/audio/transcriptions"
4+
AUDIO = b"\x00" * 512
85

96

10-
def _post_to_transcribe_endpoint(client, audio_bytes, **form_fields):
11-
"""Helper: POST a file upload to the transcription endpoint."""
12-
return client.post(
13-
ENDPOINT,
14-
files={"file": ("test.wav", audio_bytes, "audio/wav")},
15-
data=form_fields,
16-
)
17-
18-
19-
# Test Success
20-
21-
22-
class TestTranscribeSuccess:
23-
def test_default_params(
24-
self, client, mock_whisperx, mock_transcribe, sample_audio_bytes
25-
):
26-
"""Successful transcription with no explicit model or language."""
27-
response = _post_to_transcribe_endpoint(client, sample_audio_bytes)
28-
29-
assert response.status_code == 200
30-
body = response.json()
31-
assert "segments" in body
32-
assert body["segments"][0]["words"][0]["word"] == "Hello"
33-
34-
def test_with_language(
35-
self, client, mock_whisperx, mock_transcribe, sample_audio_bytes
36-
):
37-
"""Explicit language is forwarded to the transcribe service."""
38-
response = _post_to_transcribe_endpoint(
39-
client, sample_audio_bytes, language="en"
40-
)
41-
42-
assert response.status_code == 200
43-
mock_transcribe.assert_called_once()
44-
assert mock_transcribe.call_args.args[2] == "en"
45-
46-
def test_with_matching_model(
47-
self, client, mock_whisperx, mock_transcribe, sample_audio_bytes
48-
):
49-
"""Explicit model that matches the configured model succeeds."""
50-
response = _post_to_transcribe_endpoint(
51-
client, sample_audio_bytes, model="large-v2"
52-
)
53-
54-
assert response.status_code == 200
55-
56-
57-
# Test Validation Errors
58-
59-
60-
class TestTranscribeValidation:
61-
def test_unsupported_transcribe_language(
62-
self, client, mock_whisperx, mock_transcribe, sample_audio_bytes
63-
):
64-
"""Language not in whisperx.utils.LANGUAGES returns 400."""
65-
response = _post_to_transcribe_endpoint(
66-
client, sample_audio_bytes, language="xx"
67-
)
68-
69-
assert response.status_code == 400
70-
assert "Unsupported language" in response.json()["detail"]
71-
assert "for transcription" in response.json()["detail"]
72-
73-
def test_unsupported_align_language(
74-
self, client, mock_whisperx, mock_transcribe, sample_audio_bytes
75-
):
76-
"""Language in LANGUAGES but missing from alignment dicts returns 400."""
77-
78-
# NB: "cz" is supported for transcribe but not align
79-
response = _post_to_transcribe_endpoint(
80-
client, sample_audio_bytes, language="cz"
81-
)
82-
83-
assert response.status_code == 400
84-
assert "Unsupported language" in response.json()["detail"]
85-
assert "for alignment" in response.json()["detail"]
86-
87-
def test_wrong_model(
88-
self, client, mock_whisperx, mock_transcribe, sample_audio_bytes
89-
):
90-
"""Model that differs from configured model returns 404."""
91-
response = _post_to_transcribe_endpoint(
92-
client, sample_audio_bytes, model="tiny"
93-
)
7+
def post(client, model=None, language=None, response_format=None):
8+
data = {}
9+
if model:
10+
data["model"] = model
11+
if language:
12+
data["language"] = language
13+
if response_format:
14+
data["response_format"] = response_format
15+
return client.post(ENDPOINT, files={"file": ("test.wav", AUDIO, "audio/wav")}, data=data)
9416

95-
assert response.status_code == 404
96-
assert "Model not found" in response.json()["detail"]
9717

98-
def test_missing_file(self, client, mock_whisperx, mock_transcribe):
99-
"""Request without a file upload returns 422."""
100-
response = client.post(ENDPOINT)
18+
def test_transcription(client):
19+
body = post(client).json()
20+
assert body["text"] == "Hello world."
21+
assert body["segments"] is None
10122

102-
assert response.status_code == 422
10323

24+
def test_diarized(client):
25+
body = post(client, response_format="diarized_json").json()
26+
assert body["segments"][0]["text"] == "Hello world."
27+
assert body["segments"][0]["speaker"] == "SPEAKER_00"
10428

105-
# Behavior tests
10629

30+
def test_wrong_model_returns_404(client):
31+
assert post(client, model="tiny").status_code == 404
10732

108-
class TestTranscribeBehaviour:
109-
def test_load_audio_called_with_temp_path(
110-
self, client, mock_whisperx, mock_transcribe, sample_audio_bytes
111-
):
112-
"""whisperx.load_audio is called with a temp file path that is cleaned up."""
113-
response = _post_to_transcribe_endpoint(client, sample_audio_bytes)
11433

115-
assert response.status_code == 200
116-
mock_whisperx["load_audio"].assert_called_once()
117-
temp_path = mock_whisperx["load_audio"].call_args.args[0]
118-
assert isinstance(temp_path, str)
119-
# Temp file should have been deleted by the endpoint
120-
assert not os.path.exists(temp_path)
34+
def test_english(client):
35+
assert post(client, language="en").status_code == 200
12136

122-
def test_temp_file_extension_preserved(
123-
self, client, mock_whisperx, mock_transcribe
124-
):
125-
"""Temp file preserves the original upload extension."""
126-
response = client.post(
127-
ENDPOINT,
128-
files={"file": ("interview.ogg", b"\x00" * 512, "audio/mpeg")},
129-
)
13037

131-
assert response.status_code == 200
132-
temp_path = mock_whisperx["load_audio"].call_args.args[0]
133-
assert temp_path.endswith(".ogg")
38+
def test_french(client):
39+
assert post(client, language="fr").status_code == 200
13440

135-
def test_transcribe_called_with_correct_args(
136-
self, client, mock_whisperx, mock_transcribe, sample_audio_bytes
137-
):
138-
"""The transcribe service receives (audio_array, settings, language)."""
139-
response = _post_to_transcribe_endpoint(
140-
client, sample_audio_bytes, language="fr"
141-
)
14241

143-
assert response.status_code == 200
144-
mock_transcribe.assert_called_once()
145-
args = mock_transcribe.call_args.args
146-
# 0: Numpy audio array returned by load_audio
147-
assert args[0] is mock_whisperx["fake_audio"]
148-
# 1: Settings instance with expected values
149-
assert isinstance(args[1], Settings)
150-
assert args[1].transcribe_model == "large-v2"
151-
assert args[1].batch_size == 4
152-
# 2: language
153-
assert args[2] == "fr"
42+
def test_missing_file_returns_422(client):
43+
assert client.post(ENDPOINT).status_code == 422

app/tests_with_gpu/test_transcribe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ def test_real_audio_output(self, integration_client, sample_ogg, expected_output
3737
response = integration_client.post(
3838
ENDPOINT,
3939
files={"file": ("sample_en_1.ogg", sample_ogg, "audio/ogg")},
40+
data={"response_format": "diarized_json"},
4041
)
4142

4243
assert response.status_code == 200
4344
body = response.json()
4445

46+
assert body["text"] == expected_output["text"]
4547
assert body["segments"] == expected_output["segments"]

0 commit comments

Comments
 (0)