|
1 | | -"""Tests for the POST /v1/audio/transcriptions endpoint.""" |
2 | | - |
3 | | -import os |
4 | | - |
5 | | -from utils.config import Settings |
| 1 | +"""Tests for POST /v1/audio/transcriptions.""" |
6 | 2 |
|
7 | 3 | ENDPOINT = "/v1/audio/transcriptions" |
| 4 | +AUDIO = b"\x00" * 512 |
8 | 5 |
|
9 | 6 |
|
10 | | -def _post_to_transcribe_endpoint(client, audio_bytes, **form_fields): |
11 | | - """Helper: POST a file upload to the transcription endpoint.""" |
12 | | - return client.post( |
13 | | - ENDPOINT, |
14 | | - files={"file": ("test.wav", audio_bytes, "audio/wav")}, |
15 | | - data=form_fields, |
16 | | - ) |
17 | | - |
18 | | - |
19 | | -# Test Success |
20 | | - |
21 | | - |
22 | | -class TestTranscribeSuccess: |
23 | | - def test_default_params( |
24 | | - self, client, mock_whisperx, mock_transcribe, sample_audio_bytes |
25 | | - ): |
26 | | - """Successful transcription with no explicit model or language.""" |
27 | | - response = _post_to_transcribe_endpoint(client, sample_audio_bytes) |
28 | | - |
29 | | - assert response.status_code == 200 |
30 | | - body = response.json() |
31 | | - assert "segments" in body |
32 | | - assert body["segments"][0]["words"][0]["word"] == "Hello" |
33 | | - |
34 | | - def test_with_language( |
35 | | - self, client, mock_whisperx, mock_transcribe, sample_audio_bytes |
36 | | - ): |
37 | | - """Explicit language is forwarded to the transcribe service.""" |
38 | | - response = _post_to_transcribe_endpoint( |
39 | | - client, sample_audio_bytes, language="en" |
40 | | - ) |
41 | | - |
42 | | - assert response.status_code == 200 |
43 | | - mock_transcribe.assert_called_once() |
44 | | - assert mock_transcribe.call_args.args[2] == "en" |
45 | | - |
46 | | - def test_with_matching_model( |
47 | | - self, client, mock_whisperx, mock_transcribe, sample_audio_bytes |
48 | | - ): |
49 | | - """Explicit model that matches the configured model succeeds.""" |
50 | | - response = _post_to_transcribe_endpoint( |
51 | | - client, sample_audio_bytes, model="large-v2" |
52 | | - ) |
53 | | - |
54 | | - assert response.status_code == 200 |
55 | | - |
56 | | - |
57 | | -# Test Validation Errors |
58 | | - |
59 | | - |
60 | | -class TestTranscribeValidation: |
61 | | - def test_unsupported_transcribe_language( |
62 | | - self, client, mock_whisperx, mock_transcribe, sample_audio_bytes |
63 | | - ): |
64 | | - """Language not in whisperx.utils.LANGUAGES returns 400.""" |
65 | | - response = _post_to_transcribe_endpoint( |
66 | | - client, sample_audio_bytes, language="xx" |
67 | | - ) |
68 | | - |
69 | | - assert response.status_code == 400 |
70 | | - assert "Unsupported language" in response.json()["detail"] |
71 | | - assert "for transcription" in response.json()["detail"] |
72 | | - |
73 | | - def test_unsupported_align_language( |
74 | | - self, client, mock_whisperx, mock_transcribe, sample_audio_bytes |
75 | | - ): |
76 | | - """Language in LANGUAGES but missing from alignment dicts returns 400.""" |
77 | | - |
78 | | - # NB: "cz" is supported for transcribe but not align |
79 | | - response = _post_to_transcribe_endpoint( |
80 | | - client, sample_audio_bytes, language="cz" |
81 | | - ) |
82 | | - |
83 | | - assert response.status_code == 400 |
84 | | - assert "Unsupported language" in response.json()["detail"] |
85 | | - assert "for alignment" in response.json()["detail"] |
86 | | - |
87 | | - def test_wrong_model( |
88 | | - self, client, mock_whisperx, mock_transcribe, sample_audio_bytes |
89 | | - ): |
90 | | - """Model that differs from configured model returns 404.""" |
91 | | - response = _post_to_transcribe_endpoint( |
92 | | - client, sample_audio_bytes, model="tiny" |
93 | | - ) |
| 7 | +def post(client, model=None, language=None, response_format=None): |
| 8 | + data = {} |
| 9 | + if model: |
| 10 | + data["model"] = model |
| 11 | + if language: |
| 12 | + data["language"] = language |
| 13 | + if response_format: |
| 14 | + data["response_format"] = response_format |
| 15 | + return client.post(ENDPOINT, files={"file": ("test.wav", AUDIO, "audio/wav")}, data=data) |
94 | 16 |
|
95 | | - assert response.status_code == 404 |
96 | | - assert "Model not found" in response.json()["detail"] |
97 | 17 |
|
98 | | - def test_missing_file(self, client, mock_whisperx, mock_transcribe): |
99 | | - """Request without a file upload returns 422.""" |
100 | | - response = client.post(ENDPOINT) |
| 18 | +def test_transcription(client): |
| 19 | + body = post(client).json() |
| 20 | + assert body["text"] == "Hello world." |
| 21 | + assert body["segments"] is None |
101 | 22 |
|
102 | | - assert response.status_code == 422 |
103 | 23 |
|
| 24 | +def test_diarized(client): |
| 25 | + body = post(client, response_format="diarized_json").json() |
| 26 | + assert body["segments"][0]["text"] == "Hello world." |
| 27 | + assert body["segments"][0]["speaker"] == "SPEAKER_00" |
104 | 28 |
|
105 | | -# Behavior tests |
106 | 29 |
|
| 30 | +def test_wrong_model_returns_404(client): |
| 31 | + assert post(client, model="tiny").status_code == 404 |
107 | 32 |
|
108 | | -class TestTranscribeBehaviour: |
109 | | - def test_load_audio_called_with_temp_path( |
110 | | - self, client, mock_whisperx, mock_transcribe, sample_audio_bytes |
111 | | - ): |
112 | | - """whisperx.load_audio is called with a temp file path that is cleaned up.""" |
113 | | - response = _post_to_transcribe_endpoint(client, sample_audio_bytes) |
114 | 33 |
|
115 | | - assert response.status_code == 200 |
116 | | - mock_whisperx["load_audio"].assert_called_once() |
117 | | - temp_path = mock_whisperx["load_audio"].call_args.args[0] |
118 | | - assert isinstance(temp_path, str) |
119 | | - # Temp file should have been deleted by the endpoint |
120 | | - assert not os.path.exists(temp_path) |
| 34 | +def test_english(client): |
| 35 | + assert post(client, language="en").status_code == 200 |
121 | 36 |
|
122 | | - def test_temp_file_extension_preserved( |
123 | | - self, client, mock_whisperx, mock_transcribe |
124 | | - ): |
125 | | - """Temp file preserves the original upload extension.""" |
126 | | - response = client.post( |
127 | | - ENDPOINT, |
128 | | - files={"file": ("interview.ogg", b"\x00" * 512, "audio/mpeg")}, |
129 | | - ) |
130 | 37 |
|
131 | | - assert response.status_code == 200 |
132 | | - temp_path = mock_whisperx["load_audio"].call_args.args[0] |
133 | | - assert temp_path.endswith(".ogg") |
| 38 | +def test_french(client): |
| 39 | + assert post(client, language="fr").status_code == 200 |
134 | 40 |
|
135 | | - def test_transcribe_called_with_correct_args( |
136 | | - self, client, mock_whisperx, mock_transcribe, sample_audio_bytes |
137 | | - ): |
138 | | - """The transcribe service receives (audio_array, settings, language).""" |
139 | | - response = _post_to_transcribe_endpoint( |
140 | | - client, sample_audio_bytes, language="fr" |
141 | | - ) |
142 | 41 |
|
143 | | - assert response.status_code == 200 |
144 | | - mock_transcribe.assert_called_once() |
145 | | - args = mock_transcribe.call_args.args |
146 | | - # 0: Numpy audio array returned by load_audio |
147 | | - assert args[0] is mock_whisperx["fake_audio"] |
148 | | - # 1: Settings instance with expected values |
149 | | - assert isinstance(args[1], Settings) |
150 | | - assert args[1].transcribe_model == "large-v2" |
151 | | - assert args[1].batch_size == 4 |
152 | | - # 2: language |
153 | | - assert args[2] == "fr" |
| 42 | +def test_missing_file_returns_422(client): |
| 43 | + assert client.post(ENDPOINT).status_code == 422 |
0 commit comments