|
2 | 2 | import sys |
3 | 3 | from unittest.mock import MagicMock, patch |
4 | 4 |
|
5 | | -# --------------------------------------------------------------------------- |
6 | | -# 1. Mock heavy third-party libraries in sys.modules BEFORE any app import. |
7 | | -# app modules import whisperx, torch, numpy at the top level. |
8 | | -# --------------------------------------------------------------------------- |
9 | | -_mock_np = MagicMock() |
| 5 | +import pytest |
| 6 | + |
| 7 | +# Mock libs before import |
10 | 8 | _mock_torch = MagicMock() |
11 | 9 | _mock_torch.cuda.is_available.return_value = False |
12 | 10 | _mock_torch.float32 = "float32" |
13 | 11 |
|
14 | 12 | _mock_whisperx = MagicMock() |
15 | | -FAKE_LANGUAGES = {"en": "english", "fr": "french", "cz": "czech"} |
16 | | -FAKE_ALIGN_MODELS_HF = {"en": "WAV2VEC2_ASR_BASE_960H", "fr": "some-fr-model"} |
17 | | -FAKE_ALIGN_MODELS_TORCH = {"en": "WAV2VEC2_ASR_BASE_960H"} |
18 | | - |
19 | | -_mock_whisperx.utils.LANGUAGES = FAKE_LANGUAGES |
20 | | -_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_HF = FAKE_ALIGN_MODELS_HF |
21 | | -_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH = FAKE_ALIGN_MODELS_TORCH |
| 13 | +_mock_whisperx.utils.LANGUAGES = {"en": "english", "fr": "french", "cz": "czech"} |
| 14 | +_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_HF = {"en": "...", "fr": "..."} |
| 15 | +_mock_whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH = {} |
22 | 16 |
|
23 | | -for mod_name, mock_obj in { |
24 | | - "numpy": _mock_np, |
25 | | - "np": _mock_np, |
| 17 | +for name, mock in { |
| 18 | + "numpy": MagicMock(), |
26 | 19 | "torch": _mock_torch, |
27 | 20 | "whisperx": _mock_whisperx, |
28 | 21 | "whisperx.utils": _mock_whisperx.utils, |
29 | 22 | "whisperx.alignment": _mock_whisperx.alignment, |
30 | 23 | "whisperx.asr": _mock_whisperx.asr, |
31 | 24 | "whisperx.diarize": _mock_whisperx.diarize, |
32 | 25 | }.items(): |
33 | | - sys.modules.setdefault(mod_name, mock_obj) |
| 26 | + sys.modules.setdefault(name, mock) |
34 | 27 |
|
35 | | -# --------------------------------------------------------------------------- |
36 | | -# 2. Set required env vars BEFORE any app module is imported. |
37 | | -# config.py and security.py evaluate settings at module level. |
38 | | -# --------------------------------------------------------------------------- |
39 | | -TEST_API_KEY = "test-api-key" |
40 | | -TEST_HF_TOKEN = "test-hf-token" |
| 28 | +os.environ.setdefault("API_KEY", "test-key") |
| 29 | +os.environ.setdefault("HF_TOKEN", "test-token") |
41 | 30 |
|
42 | | -os.environ.setdefault("API_KEY", TEST_API_KEY) |
43 | | -os.environ.setdefault("HF_TOKEN", TEST_HF_TOKEN) |
| 31 | +# App imports after mocking) |
44 | 32 |
|
45 | | -# --------------------------------------------------------------------------- |
46 | | -# 3. Now it is safe to import app modules. |
47 | | -# --------------------------------------------------------------------------- |
48 | 33 | from fastapi import FastAPI # noqa: E402 |
49 | 34 | from fastapi.testclient import TestClient # noqa: E402 |
50 | | -import pytest # noqa: E402 |
51 | 35 |
|
52 | 36 | from endpoints import audio # noqa: E402 |
53 | 37 | from utils.config import Settings, get_settings # noqa: E402 |
54 | 38 | from utils.security import check_api_key # noqa: E402 |
55 | 39 |
|
56 | | -_FAKE_WORDS = [ |
57 | | - {"word": "Hello", "start": 0.0, "end": 0.7, "score": 0.95}, |
58 | | - {"word": "world.", "start": 0.8, "end": 1.5, "score": 0.92}, |
59 | | -] |
60 | | - |
61 | | -MOCK_TRANSCRIPTION_RESULT = { |
62 | | - "segments": [ |
63 | | - {"start": 0.0, "end": 1.5, "text": "Hello world.", "words": _FAKE_WORDS} |
64 | | - ], |
65 | | - "word_segments": _FAKE_WORDS, |
| 40 | +FAKE_TRANSCRIPTION = { |
| 41 | + "segments": [{"start": 0.0, "end": 1.5, "text": "Hello world.", "speaker": "SPEAKER_00"}], |
66 | 42 | } |
67 | 43 |
|
68 | | -FAKE_AUDIO = MagicMock(name="fake_audio_array") |
69 | | - |
70 | | - |
71 | | -def _test_settings() -> Settings: |
72 | | - return Settings( |
73 | | - api_key=TEST_API_KEY, |
74 | | - hf_token=TEST_HF_TOKEN, |
75 | | - transcribe_model="large-v2", |
76 | | - batch_size=4, |
77 | | - ) |
78 | | - |
79 | 44 |
|
80 | 45 | @pytest.fixture() |
81 | 46 | def client(): |
82 | | - """TestClient with dependency overrides and lifespan disabled.""" |
83 | 47 | app = FastAPI() |
84 | 48 | app.include_router(audio.router, prefix="/v1") |
| 49 | + app.dependency_overrides[get_settings] = lambda: Settings( |
| 50 | + api_key="test-key", hf_token="test-token", transcribe_model="large-v3-turbo" |
| 51 | + ) |
| 52 | + app.dependency_overrides[check_api_key] = lambda: "test-key" |
85 | 53 |
|
86 | | - app.dependency_overrides[get_settings] = _test_settings |
87 | | - app.dependency_overrides[check_api_key] = lambda: TEST_API_KEY |
88 | | - |
89 | | - with TestClient(app) as c: |
| 54 | + with ( |
| 55 | + patch.object(audio.whisperx, "load_audio", return_value=MagicMock()), |
| 56 | + patch.object(audio, "transcribe", return_value=FAKE_TRANSCRIPTION), |
| 57 | + TestClient(app) as c, |
| 58 | + ): |
90 | 59 | yield c |
91 | | - |
92 | | - |
93 | | -@pytest.fixture() |
94 | | -def mock_whisperx(): |
95 | | - """Patch whisperx symbols used directly in the endpoint module.""" |
96 | | - with patch.object( |
97 | | - audio.whisperx, "load_audio", return_value=FAKE_AUDIO |
98 | | - ) as load_audio: |
99 | | - yield {"load_audio": load_audio, "fake_audio": FAKE_AUDIO} |
100 | | - |
101 | | - |
102 | | -@pytest.fixture() |
103 | | -def mock_transcribe(): |
104 | | - """Patch the transcribe service function as imported in the endpoint module.""" |
105 | | - with patch.object( |
106 | | - audio, |
107 | | - "transcribe", |
108 | | - return_value=MOCK_TRANSCRIPTION_RESULT, |
109 | | - ) as mock: |
110 | | - yield mock |
111 | | - |
112 | | - |
113 | | -@pytest.fixture() |
114 | | -def sample_audio_bytes() -> bytes: |
115 | | - """Minimal bytes to simulate an uploaded audio file.""" |
116 | | - return b"\x00" * 1024 |
0 commit comments