|
| 1 | +from unittest import TestCase |
| 2 | +from unittest.mock import MagicMock, patch |
| 3 | + |
| 4 | +from speech_recognition import AudioData, Recognizer |
| 5 | +from speech_recognition.recognizers import whisper |
| 6 | + |
| 7 | + |
| 8 | +@patch("speech_recognition.recognizers.whisper.os.environ") |
| 9 | +@patch("speech_recognition.recognizers.whisper.BytesIO") |
| 10 | +@patch("openai.OpenAI") |
| 11 | +class RecognizeWhisperApiTestCase(TestCase): |
| 12 | + def test_recognize_default_arguments(self, OpenAI, BytesIO, environ): |
| 13 | + client = OpenAI.return_value |
| 14 | + transcript = client.audio.transcriptions.create.return_value |
| 15 | + |
| 16 | + recognizer = MagicMock(spec=Recognizer) |
| 17 | + audio_data = MagicMock(spec=AudioData) |
| 18 | + |
| 19 | + actual = whisper.recognize_whisper_api(recognizer, audio_data) |
| 20 | + |
| 21 | + self.assertEqual(actual, transcript.text) |
| 22 | + audio_data.get_wav_data.assert_called_once_with() |
| 23 | + BytesIO.assert_called_once_with(audio_data.get_wav_data.return_value) |
| 24 | + OpenAI.assert_called_once_with(api_key=None) |
| 25 | + client.audio.transcriptions.create.assert_called_once_with( |
| 26 | + file=BytesIO.return_value, model="whisper-1" |
| 27 | + ) |
| 28 | + |
| 29 | + def test_recognize_pass_arguments(self, OpenAI, BytesIO, environ): |
| 30 | + client = OpenAI.return_value |
| 31 | + |
| 32 | + recognizer = MagicMock(spec=Recognizer) |
| 33 | + audio_data = MagicMock(spec=AudioData) |
| 34 | + |
| 35 | + actual = whisper.recognize_whisper_api( |
| 36 | + recognizer, audio_data, model="x-whisper", api_key="OPENAI_API_KEY" |
| 37 | + ) |
| 38 | + |
| 39 | + OpenAI.assert_called_once_with(api_key="OPENAI_API_KEY") |
| 40 | + client.audio.transcriptions.create.assert_called_once_with( |
| 41 | + file=BytesIO.return_value, model="x-whisper" |
| 42 | + ) |
0 commit comments