Skip to content

Commit 9194a47

Browse files
authored
Merge pull request #739 from ftnext/test-whisper
2 parents 4924857 + 8da6e42 commit 9194a47

File tree

4 files changed

+80
-18
lines changed

4 files changed

+80
-18
lines changed

.github/workflows/unittests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- name: Install Python dependencies
3232
run: |
3333
python -m pip install 'pocketsphinx<5'
34-
python -m pip install git+https://github.com/openai/whisper.git soundfile
34+
python -m pip install openai-whisper soundfile
3535
python -m pip install openai
3636
python -m pip install .
3737
- name: Test with unittest

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ Whisper (for Whisper users)
169169
~~~~~~~~~~~~~~~~~~~~~~~~~~~
170170
Whisper is **required if and only if you want to use whisper** (``recognizer_instance.recognize_whisper``).
171171

172-
You can install it with ``python3 -m pip install git+https://github.com/openai/whisper.git soundfile``.
172+
You can install it with ``python3 -m pip install openai-whisper soundfile``.
173173

174174
Whisper API (for Whisper API users)
175175
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

tests/test_recognition.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def setUp(self):
1414
self.AUDIO_FILE_EN = os.path.join(os.path.dirname(os.path.realpath(__file__)), "english.wav")
1515
self.AUDIO_FILE_FR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "french.aiff")
1616
self.AUDIO_FILE_ZH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "chinese.flac")
17-
self.WHISPER_CONFIG = {"temperature": 0}
1817

1918
def test_recognizer_attributes(self):
2019
r = sr.Recognizer()
@@ -81,21 +80,6 @@ def test_ibm_chinese(self):
8180
with sr.AudioFile(self.AUDIO_FILE_ZH) as source: audio = r.record(source)
8281
self.assertEqual(r.recognize_ibm(audio, username=os.environ["IBM_USERNAME"], password=os.environ["IBM_PASSWORD"], language="zh-CN"), u"砸 自己 的 脚 ")
8382

84-
def test_whisper_english(self):
85-
r = sr.Recognizer()
86-
with sr.AudioFile(self.AUDIO_FILE_EN) as source: audio = r.record(source)
87-
self.assertEqual(r.recognize_whisper(audio, language="english", **self.WHISPER_CONFIG), " 1, 2, 3")
88-
89-
def test_whisper_french(self):
90-
r = sr.Recognizer()
91-
with sr.AudioFile(self.AUDIO_FILE_FR) as source: audio = r.record(source)
92-
self.assertEqual(r.recognize_whisper(audio, language="french", **self.WHISPER_CONFIG), " et c'est la dictée numéro 1.")
93-
94-
def test_whisper_chinese(self):
95-
r = sr.Recognizer()
96-
with sr.AudioFile(self.AUDIO_FILE_ZH) as source: audio = r.record(source)
97-
self.assertEqual(r.recognize_whisper(audio, model="small", language="chinese", **self.WHISPER_CONFIG), u"砸自己的腳")
98-
9983

10084
if __name__ == "__main__":
10185
unittest.main()

tests/test_whisper_recognition.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from unittest import TestCase
2+
from unittest.mock import MagicMock, patch
3+
4+
import numpy as np
5+
6+
from speech_recognition import AudioData, Recognizer
7+
8+
9+
@patch("speech_recognition.io.BytesIO")
10+
@patch("soundfile.read")
11+
@patch("torch.cuda.is_available")
12+
@patch("whisper.load_model")
13+
class RecognizeWhisperTestCase(TestCase):
14+
def test_default_parameters(
15+
self, load_model, is_available, sf_read, BytesIO
16+
):
17+
whisper_model = load_model.return_value
18+
transcript = whisper_model.transcribe.return_value
19+
audio_array = MagicMock()
20+
dummy_sampling_rate = 99_999
21+
sf_read.return_value = (audio_array, dummy_sampling_rate)
22+
23+
recognizer = Recognizer()
24+
audio_data = MagicMock(spec=AudioData)
25+
actual = recognizer.recognize_whisper(audio_data)
26+
27+
self.assertEqual(actual, transcript.__getitem__.return_value)
28+
load_model.assert_called_once_with("base")
29+
audio_data.get_wav_data.assert_called_once_with(convert_rate=16000)
30+
BytesIO.assert_called_once_with(audio_data.get_wav_data.return_value)
31+
sf_read.assert_called_once_with(BytesIO.return_value)
32+
audio_array.astype.assert_called_once_with(np.float32)
33+
whisper_model.transcribe.assert_called_once_with(
34+
audio_array.astype.return_value,
35+
language=None,
36+
task=None,
37+
fp16=is_available.return_value,
38+
)
39+
transcript.__getitem__.assert_called_once_with("text")
40+
41+
def test_return_as_dict(self, load_model, is_available, sf_read, BytesIO):
42+
whisper_model = load_model.return_value
43+
audio_array = MagicMock()
44+
dummy_sampling_rate = 99_999
45+
sf_read.return_value = (audio_array, dummy_sampling_rate)
46+
47+
recognizer = Recognizer()
48+
audio_data = MagicMock(spec=AudioData)
49+
actual = recognizer.recognize_whisper(audio_data, show_dict=True)
50+
51+
self.assertEqual(actual, whisper_model.transcribe.return_value)
52+
53+
def test_pass_parameters(self, load_model, is_available, sf_read, BytesIO):
54+
whisper_model = load_model.return_value
55+
transcript = whisper_model.transcribe.return_value
56+
audio_array = MagicMock()
57+
dummy_sampling_rate = 99_999
58+
sf_read.return_value = (audio_array, dummy_sampling_rate)
59+
60+
recognizer = Recognizer()
61+
audio_data = MagicMock(spec=AudioData)
62+
actual = recognizer.recognize_whisper(
63+
audio_data,
64+
model="small",
65+
language="english",
66+
translate=True,
67+
temperature=0,
68+
)
69+
70+
self.assertEqual(actual, transcript.__getitem__.return_value)
71+
load_model.assert_called_once_with("small")
72+
whisper_model.transcribe.assert_called_once_with(
73+
audio_array.astype.return_value,
74+
language="english",
75+
task="translate",
76+
fp16=is_available.return_value,
77+
temperature=0,
78+
)

0 commit comments

Comments
 (0)