Skip to content

Commit 281fb83

Browse files
committed
fix(cloud_asr): usage of old API
1 parent c6aa4a4 commit 281fb83

2 files changed

Lines changed: 51 additions & 50 deletions

File tree

src/arduino/app_bricks/cloud_asr/cloud_asr.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
import threading
1010
import time
1111
from contextlib import contextmanager
12-
from typing import Generator, Optional, Union, Iterator, Generator, cast
12+
from typing import Generator, Union, Iterator, Generator, cast
1313

1414
import numpy as np
1515

1616
from arduino.app_peripherals.microphone import Microphone
17+
from arduino.app_peripherals.microphone.base_microphone import BaseMicrophone
1718
from arduino.app_utils import Logger, brick
1819

1920
from .providers import ASRProvider, CloudProvider, DEFAULT_PROVIDER, provider_factory
@@ -45,25 +46,36 @@ def __init__(
4546
self,
4647
api_key: str = os.getenv("API_KEY", ""),
4748
provider: CloudProvider = DEFAULT_PROVIDER,
48-
mic: Optional[Microphone] = None,
49+
mic: BaseMicrophone | None = None,
4950
language: str = os.getenv("LANGUAGE", ""),
5051
silence_timeout: float = 10.0,
5152
):
52-
if mic:
53+
if mic is not None:
5354
logger.info(f"[{self.__class__.__name__}] Using provided microphone: {mic}")
5455
self._mic = mic
56+
self._owns_mic = False
5557
else:
5658
self._mic = Microphone()
59+
self._owns_mic = True
5760

5861
self._language = language
5962
self.silence_timeout = silence_timeout
60-
self._mic_lock = threading.Lock()
6163
self._provider: ASRProvider = provider_factory(
6264
api_key=api_key,
6365
name=provider,
6466
language=self._language,
6567
sample_rate=self._mic.sample_rate,
6668
)
69+
70+
def start(self):
71+
"""Start the ASR service by initializing the microphone."""
72+
if self._owns_mic:
73+
self._mic.start()
74+
75+
def stop(self):
76+
"""Stop the ASR service by releasing the microphone."""
77+
if self._owns_mic:
78+
self._mic.stop()
6779

6880
def _transcribe_stream(self, duration: float = 60.0) -> Generator[ASREvent, None, None]:
6981
"""Perform continuous speech-to-text recognition with detailed events.
@@ -84,12 +96,6 @@ def _transcribe_stream(self, duration: float = 60.0) -> Generator[ASREvent, None
8496
overall_deadline = time.monotonic() + duration
8597
silence_deadline = time.monotonic() + self.silence_timeout
8698

87-
with self._mic_lock:
88-
if self._mic.is_recording.is_set():
89-
raise RuntimeError("Microphone is busy.")
90-
self._mic.start()
91-
logger.info(f"[{self.__class__.__name__}] Microphone started.")
92-
9399
def _send():
94100
try:
95101
for chunk in self._mic.stream():
@@ -178,10 +184,6 @@ def _recv():
178184
finally:
179185
logger.info("Releasing ASR resources...")
180186
stop_event.set()
181-
with self._mic_lock:
182-
if self._mic.is_recording.is_set():
183-
self._mic.stop()
184-
logger.info(f"[{self.__class__.__name__}] Microphone stopped.")
185187
send_thread.join(timeout=1)
186188
recv_thread.join(timeout=1)
187189
provider.stop()

tests/arduino/app_bricks/cloud_asr/test_cloud_asr.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,35 @@
1212

1313
from arduino.app_bricks.cloud_asr import CloudASR, CloudProvider
1414
from arduino.app_bricks.cloud_asr.providers import ASRProviderEvent, ASRProviderError
15+
from arduino.app_peripherals.microphone.base_microphone import BaseMicrophone
1516
from arduino.app_utils.app import App
1617

1718

18-
class MockMicrophone:
19+
class MockMicrophone(BaseMicrophone):
1920
"""Lightweight microphone stub that yields pre-loaded chunks."""
2021

21-
def __init__(self, chunks: Iterable, sample_rate: int = 16000, delay_between_chunks: float = 0.0):
22-
self.sample_rate = sample_rate
23-
self.is_recording = threading.Event()
22+
def __init__(
23+
self,
24+
chunks: Iterable,
25+
sample_rate: int = 16000,
26+
channels: int = 1,
27+
format: type | np.dtype | str = np.int16,
28+
buffer_size: int = 1024,
29+
auto_reconnect: bool = True
30+
):
31+
super().__init__(sample_rate=sample_rate, channels=channels, format=format, buffer_size=buffer_size, auto_reconnect=auto_reconnect)
2432
self._chunks: List = list(chunks)
25-
self._delay = delay_between_chunks
26-
self.start_calls = 0
27-
self.stop_calls = 0
2833

29-
def start(self):
30-
self.start_calls += 1
31-
self.is_recording.set()
34+
def _open_microphone(self):
35+
pass
3236

33-
def stop(self):
34-
self.stop_calls += 1
35-
self.is_recording.clear()
37+
def _close_microphone(self):
38+
pass
3639

37-
def stream(self):
38-
while self.is_recording.is_set() and self._chunks:
39-
if self._delay:
40-
time.sleep(self._delay)
41-
yield self._chunks.pop(0)
40+
def _read_audio(self):
41+
if not self._chunks:
42+
return None
43+
return self._chunks.pop(0)
4244

4345

4446
class DummyProvider:
@@ -88,21 +90,19 @@ def _factory(
8890

8991
def test_transcribe_stream_use_microphone_state(make_provider):
9092
mic = MockMicrophone(chunks=[])
93+
mic.start()
9194
provider = make_provider(events=[ASRProviderEvent(type="text", data="mock")])
9295
asr = CloudASR(api_key="dummy", mic=mic, provider=CloudProvider.OPENAI_TRANSCRIBE)
9396

9497
try:
9598
with asr.transcribe_stream() as stream:
9699
next(stream)
97-
assert mic.start_calls == 1
98-
assert mic.is_recording.is_set()
99100
assert provider.start_called is True
100101

101-
assert mic.stop_calls == 1
102-
assert not mic.is_recording.is_set()
103102
assert provider.stop_called is True
104103
finally:
105-
App.unregister(asr)
104+
asr.stop()
105+
mic.stop()
106106

107107

108108
def test_transcribe_stream_aggregates_partial_text_in_append_mode(make_provider):
@@ -112,10 +112,8 @@ def test_transcribe_stream_aggregates_partial_text_in_append_mode(make_provider)
112112
ASRProviderEvent(type="text", data=None),
113113
]
114114
audio_chunks = [np.array([1, 2, 3], dtype=np.int16), None, np.array([4, 5, 6], dtype=np.int16)]
115-
mic = MockMicrophone(
116-
chunks=audio_chunks,
117-
delay_between_chunks=0.002,
118-
)
115+
mic = MockMicrophone(audio_chunks)
116+
mic.start()
119117
provider = make_provider(events=events, partial_mode="append", audio_chunks_len=sum(ch is not None for ch in audio_chunks))
120118
asr = CloudASR(api_key="dummy", mic=mic, provider=CloudProvider.OPENAI_TRANSCRIBE)
121119

@@ -127,8 +125,9 @@ def test_transcribe_stream_aggregates_partial_text_in_append_mode(make_provider)
127125
if ev.type == "text":
128126
break
129127
finally:
130-
App.unregister(asr)
131-
128+
asr.stop()
129+
mic.stop()
130+
132131
assert provider.start_called is True
133132
assert [msg.type for msg in results] == ["partial_text", "partial_text", "text"]
134133
assert [msg.data for msg in results[:2]] == ["Hel", "lo"]
@@ -149,10 +148,8 @@ def test_transcribe_stream_resets_partial_buffer_in_replace_mode(make_provider):
149148
ASRProviderEvent(type="text", data=None),
150149
]
151150
audio_chunks = [np.ones(4, dtype=np.int16) for _ in range(5)]
152-
mic = MockMicrophone(
153-
chunks=audio_chunks,
154-
delay_between_chunks=0.002,
155-
)
151+
mic = MockMicrophone(audio_chunks)
152+
mic.start()
156153
provider = make_provider(events=events, partial_mode="replace", audio_chunks_len=sum(ch is not None for ch in audio_chunks))
157154
asr = CloudASR(api_key="dummy", mic=mic, provider=CloudProvider.GOOGLE_SPEECH)
158155

@@ -167,7 +164,8 @@ def test_transcribe_stream_resets_partial_buffer_in_replace_mode(make_provider):
167164
if text_count == 2:
168165
break
169166
finally:
170-
App.unregister(asr)
167+
asr.stop()
168+
mic.stop()
171169

172170
assert provider.start_called is True
173171
assert [msg.type for msg in results] == ["partial_text", "partial_text", "text", "partial_text", "text"]
@@ -186,8 +184,8 @@ def recv(self):
186184

187185
mic = MockMicrophone(
188186
chunks=[np.array([7, 8], dtype=np.int16), np.array([9, 10], dtype=np.int16)],
189-
delay_between_chunks=0.001,
190187
)
188+
mic.start()
191189
asr = CloudASR(api_key="dummy", mic=mic, provider=CloudProvider.OPENAI_TRANSCRIBE)
192190

193191
try:
@@ -197,7 +195,8 @@ def recv(self):
197195
assert isinstance(exc, ASRProviderError)
198196
assert str(exc) == "boom"
199197
finally:
200-
App.unregister(asr)
198+
asr.stop()
199+
mic.stop()
201200

202201
assert provider.start_called is True
203202
assert provider.stop_called is True

0 commit comments

Comments
 (0)