Skip to content

Commit 8aa30b0

Browse files
authored
Migrate VoIP to use Assist Pipeline TTS tokens (#139671)
* Migrate VoIP to use pipeline token * migrate announcements to use TTS token
1 parent 871a7c8 commit 8aa30b0

File tree

2 files changed

+64
-92
lines changed

2 files changed

+64
-92
lines changed

homeassistant/components/voip/assist_satellite.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,18 @@ async def _play_announcement(
408408
"""Play an announcement once."""
409409
_LOGGER.debug("Playing announcement")
410410

411-
try:
412-
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
413-
await self._send_tts(announcement.original_media_id, wait_for_tone=False)
411+
if announcement.tts_token is None:
412+
_LOGGER.error("Only TTS announcements are supported")
413+
return
414+
415+
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
416+
stream = tts.async_get_stream(self.hass, announcement.tts_token)
417+
if stream is None:
418+
_LOGGER.error("TTS stream no longer available")
419+
return
414420

421+
try:
422+
await self._send_tts(stream, wait_for_tone=False)
415423
if not self._run_pipeline_after_announce:
416424
# Delay before looping announcement
417425
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
@@ -442,11 +450,14 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
442450
)
443451
elif event.type == PipelineEventType.TTS_END:
444452
# Send TTS audio to caller over RTP
445-
if event.data and (tts_output := event.data["tts_output"]):
446-
media_id = tts_output["media_id"]
453+
if (
454+
event.data
455+
and (tts_output := event.data["tts_output"])
456+
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
457+
):
447458
self.config_entry.async_create_background_task(
448459
self.hass,
449-
self._send_tts(media_id),
460+
self._send_tts(tts_stream=stream),
450461
"voip_pipeline_tts",
451462
)
452463
else:
@@ -457,19 +468,22 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
457468
self._pipeline_had_error = True
458469
_LOGGER.warning(event)
459470

460-
async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None:
471+
async def _send_tts(
472+
self,
473+
tts_stream: tts.ResultStream,
474+
wait_for_tone: bool = True,
475+
) -> None:
461476
"""Send TTS audio to caller via RTP."""
462477
try:
463478
if self.transport is None:
464479
return # not connected
465480

466-
extension, data = await tts.async_get_media_source_audio(
467-
self.hass,
468-
media_id,
469-
)
481+
data = b"".join([chunk async for chunk in tts_stream.async_stream_result()])
470482

471-
if extension != "wav":
472-
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
483+
if tts_stream.extension != "wav":
484+
raise ValueError(
485+
f"Only TTS WAV audio can be streamed, got {tts_stream.extension}"
486+
)
473487

474488
if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
475489
# Don't overlap TTS and processing beep

tests/components/voip/test_voip.py

+37-79
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
3838
"""Mock the TTS cache dir with empty dir."""
3939

4040

41-
def _empty_wav() -> bytes:
41+
def _empty_wav(framerate=16000) -> bytes:
4242
"""Return bytes of an empty WAV file."""
4343
with io.BytesIO() as wav_io:
4444
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
4545
with wav_file:
46-
wav_file.setframerate(16000)
46+
wav_file.setframerate(framerate)
4747
wav_file.setsampwidth(2)
4848
wav_file.setnchannels(1)
4949

@@ -307,10 +307,11 @@ async def async_pipeline_from_audio_stream(
307307
assert satellite.state == AssistSatelliteState.RESPONDING
308308

309309
# Proceed with media output
310+
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
310311
event_callback(
311312
assist_pipeline.PipelineEvent(
312313
type=assist_pipeline.PipelineEventType.TTS_END,
313-
data={"tts_output": {"media_id": _MEDIA_ID}},
314+
data={"tts_output": {"token": mock_tts_result_stream.token}},
314315
)
315316
)
316317

@@ -326,22 +327,11 @@ def tts_response_finished():
326327
original_tts_response_finished()
327328
done.set()
328329

329-
async def async_get_media_source_audio(
330-
hass: HomeAssistant,
331-
media_source_id: str,
332-
) -> tuple[str, bytes]:
333-
assert media_source_id == _MEDIA_ID
334-
return ("wav", _empty_wav())
335-
336330
with (
337331
patch(
338332
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
339333
new=async_pipeline_from_audio_stream,
340334
),
341-
patch(
342-
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
343-
new=async_get_media_source_audio,
344-
),
345335
patch.object(satellite, "tts_response_finished", tts_response_finished),
346336
):
347337
satellite._tones = Tones(0)
@@ -457,10 +447,11 @@ async def async_pipeline_from_audio_stream(*args, **kwargs):
457447
)
458448

459449
# Proceed with media output
450+
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
460451
event_callback(
461452
assist_pipeline.PipelineEvent(
462453
type=assist_pipeline.PipelineEventType.TTS_END,
463-
data={"tts_output": {"media_id": _MEDIA_ID}},
454+
data={"tts_output": {"token": mock_tts_result_stream.token}},
464455
)
465456
)
466457

@@ -474,22 +465,9 @@ async def async_send_audio(audio_bytes: bytes, **kwargs):
474465
# Block here to force a timeout in _send_tts
475466
await asyncio.sleep(2)
476467

477-
async def async_get_media_source_audio(
478-
hass: HomeAssistant,
479-
media_source_id: str,
480-
) -> tuple[str, bytes]:
481-
# Should time out immediately
482-
return ("wav", _empty_wav())
483-
484-
with (
485-
patch(
486-
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
487-
new=async_pipeline_from_audio_stream,
488-
),
489-
patch(
490-
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
491-
new=async_get_media_source_audio,
492-
),
468+
with patch(
469+
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
470+
new=async_pipeline_from_audio_stream,
493471
):
494472
satellite._tts_extra_timeout = 0.001
495473
for tone in Tones:
@@ -568,29 +546,18 @@ async def async_pipeline_from_audio_stream(*args, **kwargs):
568546
)
569547

570548
# Proceed with media output
549+
# Should fail because it's not "wav"
550+
mock_tts_result_stream = MockResultStream(hass, "mp3", b"")
571551
event_callback(
572552
assist_pipeline.PipelineEvent(
573553
type=assist_pipeline.PipelineEventType.TTS_END,
574-
data={"tts_output": {"media_id": _MEDIA_ID}},
554+
data={"tts_output": {"token": mock_tts_result_stream.token}},
575555
)
576556
)
577557

578-
async def async_get_media_source_audio(
579-
hass: HomeAssistant,
580-
media_source_id: str,
581-
) -> tuple[str, bytes]:
582-
# Should fail because it's not "wav"
583-
return ("mp3", b"")
584-
585-
with (
586-
patch(
587-
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
588-
new=async_pipeline_from_audio_stream,
589-
),
590-
patch(
591-
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
592-
new=async_get_media_source_audio,
593-
),
558+
with patch(
559+
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
560+
new=async_pipeline_from_audio_stream,
594561
):
595562
satellite.transport = Mock()
596563

@@ -663,36 +630,18 @@ async def async_pipeline_from_audio_stream(*args, **kwargs):
663630
)
664631

665632
# Proceed with media output
633+
# Should fail because it's not 16Khz
634+
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav(22050))
666635
event_callback(
667636
assist_pipeline.PipelineEvent(
668637
type=assist_pipeline.PipelineEventType.TTS_END,
669-
data={"tts_output": {"media_id": _MEDIA_ID}},
638+
data={"tts_output": {"token": mock_tts_result_stream.token}},
670639
)
671640
)
672641

673-
async def async_get_media_source_audio(
674-
hass: HomeAssistant,
675-
media_source_id: str,
676-
) -> tuple[str, bytes]:
677-
# Should fail because it's not 16Khz, 16-bit mono
678-
with io.BytesIO() as wav_io:
679-
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
680-
with wav_file:
681-
wav_file.setframerate(22050)
682-
wav_file.setsampwidth(2)
683-
wav_file.setnchannels(2)
684-
685-
return ("wav", wav_io.getvalue())
686-
687-
with (
688-
patch(
689-
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
690-
new=async_pipeline_from_audio_stream,
691-
),
692-
patch(
693-
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
694-
new=async_get_media_source_audio,
695-
),
642+
with patch(
643+
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
644+
new=async_pipeline_from_audio_stream,
696645
):
697646
satellite.transport = Mock()
698647

@@ -878,10 +827,11 @@ async def test_announce(
878827
assert err.value.translation_domain == "voip"
879828
assert err.value.translation_key == "non_tts_announcement"
880829

830+
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
881831
announcement = assist_satellite.AssistSatelliteAnnouncement(
882832
message="test announcement",
883833
media_id=_MEDIA_ID,
884-
tts_token="test-token",
834+
tts_token=mock_tts_result_stream.token,
885835
original_media_id=_MEDIA_ID,
886836
media_id_source="tts",
887837
)
@@ -907,7 +857,9 @@ async def test_announce(
907857
async with asyncio.timeout(1):
908858
await announce_task
909859

910-
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
860+
mock_send_tts.assert_called_once_with(
861+
mock_tts_result_stream, wait_for_tone=False
862+
)
911863

912864

913865
@pytest.mark.usefixtures("socket_enabled")
@@ -926,10 +878,11 @@ async def test_voip_id_is_ip_address(
926878
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
927879
)
928880

881+
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
929882
announcement = assist_satellite.AssistSatelliteAnnouncement(
930883
message="test announcement",
931884
media_id=_MEDIA_ID,
932-
tts_token="test-token",
885+
tts_token=mock_tts_result_stream.token,
933886
original_media_id=_MEDIA_ID,
934887
media_id_source="tts",
935888
)
@@ -960,7 +913,9 @@ async def test_voip_id_is_ip_address(
960913
async with asyncio.timeout(1):
961914
await announce_task
962915

963-
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
916+
mock_send_tts.assert_called_once_with(
917+
mock_tts_result_stream, wait_for_tone=False
918+
)
964919

965920

966921
@pytest.mark.usefixtures("socket_enabled")
@@ -979,10 +934,11 @@ async def test_announce_timeout(
979934
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
980935
)
981936

937+
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
982938
announcement = assist_satellite.AssistSatelliteAnnouncement(
983939
message="test announcement",
984940
media_id=_MEDIA_ID,
985-
tts_token="test-token",
941+
tts_token=mock_tts_result_stream.token,
986942
original_media_id=_MEDIA_ID,
987943
media_id_source="tts",
988944
)
@@ -1020,10 +976,11 @@ async def test_start_conversation(
1020976
& assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
1021977
)
1022978

979+
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
1023980
announcement = assist_satellite.AssistSatelliteAnnouncement(
1024981
message="test announcement",
1025982
media_id=_MEDIA_ID,
1026-
tts_token="test-token",
983+
tts_token=mock_tts_result_stream.token,
1027984
original_media_id=_MEDIA_ID,
1028985
media_id_source="tts",
1029986
)
@@ -1061,10 +1018,11 @@ async def async_pipeline_from_audio_stream(
10611018
)
10621019

10631020
# Proceed with media output
1021+
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
10641022
event_callback(
10651023
assist_pipeline.PipelineEvent(
10661024
type=assist_pipeline.PipelineEventType.TTS_END,
1067-
data={"tts_output": {"media_id": _MEDIA_ID}},
1025+
data={"tts_output": {"token": mock_tts_result_stream.token}},
10681026
)
10691027
)
10701028

0 commit comments

Comments
 (0)