@@ -38,12 +38,12 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
38
38
"""Mock the TTS cache dir with empty dir."""
39
39
40
40
41
- def _empty_wav () -> bytes :
41
+ def _empty_wav (framerate = 16000 ) -> bytes :
42
42
"""Return bytes of an empty WAV file."""
43
43
with io .BytesIO () as wav_io :
44
44
wav_file : wave .Wave_write = wave .open (wav_io , "wb" )
45
45
with wav_file :
46
- wav_file .setframerate (16000 )
46
+ wav_file .setframerate (framerate )
47
47
wav_file .setsampwidth (2 )
48
48
wav_file .setnchannels (1 )
49
49
@@ -307,10 +307,11 @@ async def async_pipeline_from_audio_stream(
307
307
assert satellite .state == AssistSatelliteState .RESPONDING
308
308
309
309
# Proceed with media output
310
+ mock_tts_result_stream = MockResultStream (hass , "wav" , _empty_wav ())
310
311
event_callback (
311
312
assist_pipeline .PipelineEvent (
312
313
type = assist_pipeline .PipelineEventType .TTS_END ,
313
- data = {"tts_output" : {"media_id " : _MEDIA_ID }},
314
+ data = {"tts_output" : {"token " : mock_tts_result_stream . token }},
314
315
)
315
316
)
316
317
@@ -326,22 +327,11 @@ def tts_response_finished():
326
327
original_tts_response_finished ()
327
328
done .set ()
328
329
329
- async def async_get_media_source_audio (
330
- hass : HomeAssistant ,
331
- media_source_id : str ,
332
- ) -> tuple [str , bytes ]:
333
- assert media_source_id == _MEDIA_ID
334
- return ("wav" , _empty_wav ())
335
-
336
330
with (
337
331
patch (
338
332
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" ,
339
333
new = async_pipeline_from_audio_stream ,
340
334
),
341
- patch (
342
- "homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio" ,
343
- new = async_get_media_source_audio ,
344
- ),
345
335
patch .object (satellite , "tts_response_finished" , tts_response_finished ),
346
336
):
347
337
satellite ._tones = Tones (0 )
@@ -457,10 +447,11 @@ async def async_pipeline_from_audio_stream(*args, **kwargs):
457
447
)
458
448
459
449
# Proceed with media output
450
+ mock_tts_result_stream = MockResultStream (hass , "wav" , _empty_wav ())
460
451
event_callback (
461
452
assist_pipeline .PipelineEvent (
462
453
type = assist_pipeline .PipelineEventType .TTS_END ,
463
- data = {"tts_output" : {"media_id " : _MEDIA_ID }},
454
+ data = {"tts_output" : {"token " : mock_tts_result_stream . token }},
464
455
)
465
456
)
466
457
@@ -474,22 +465,9 @@ async def async_send_audio(audio_bytes: bytes, **kwargs):
474
465
# Block here to force a timeout in _send_tts
475
466
await asyncio .sleep (2 )
476
467
477
- async def async_get_media_source_audio (
478
- hass : HomeAssistant ,
479
- media_source_id : str ,
480
- ) -> tuple [str , bytes ]:
481
- # Should time out immediately
482
- return ("wav" , _empty_wav ())
483
-
484
- with (
485
- patch (
486
- "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" ,
487
- new = async_pipeline_from_audio_stream ,
488
- ),
489
- patch (
490
- "homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio" ,
491
- new = async_get_media_source_audio ,
492
- ),
468
+ with patch (
469
+ "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" ,
470
+ new = async_pipeline_from_audio_stream ,
493
471
):
494
472
satellite ._tts_extra_timeout = 0.001
495
473
for tone in Tones :
@@ -568,29 +546,18 @@ async def async_pipeline_from_audio_stream(*args, **kwargs):
568
546
)
569
547
570
548
# Proceed with media output
549
+ # Should fail because it's not "wav"
550
+ mock_tts_result_stream = MockResultStream (hass , "mp3" , b"" )
571
551
event_callback (
572
552
assist_pipeline .PipelineEvent (
573
553
type = assist_pipeline .PipelineEventType .TTS_END ,
574
- data = {"tts_output" : {"media_id " : _MEDIA_ID }},
554
+ data = {"tts_output" : {"token " : mock_tts_result_stream . token }},
575
555
)
576
556
)
577
557
578
- async def async_get_media_source_audio (
579
- hass : HomeAssistant ,
580
- media_source_id : str ,
581
- ) -> tuple [str , bytes ]:
582
- # Should fail because it's not "wav"
583
- return ("mp3" , b"" )
584
-
585
- with (
586
- patch (
587
- "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" ,
588
- new = async_pipeline_from_audio_stream ,
589
- ),
590
- patch (
591
- "homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio" ,
592
- new = async_get_media_source_audio ,
593
- ),
558
+ with patch (
559
+ "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" ,
560
+ new = async_pipeline_from_audio_stream ,
594
561
):
595
562
satellite .transport = Mock ()
596
563
@@ -663,36 +630,18 @@ async def async_pipeline_from_audio_stream(*args, **kwargs):
663
630
)
664
631
665
632
# Proceed with media output
633
+ # Should fail because it's not 16Khz
634
+ mock_tts_result_stream = MockResultStream (hass , "wav" , _empty_wav (22050 ))
666
635
event_callback (
667
636
assist_pipeline .PipelineEvent (
668
637
type = assist_pipeline .PipelineEventType .TTS_END ,
669
- data = {"tts_output" : {"media_id " : _MEDIA_ID }},
638
+ data = {"tts_output" : {"token " : mock_tts_result_stream . token }},
670
639
)
671
640
)
672
641
673
- async def async_get_media_source_audio (
674
- hass : HomeAssistant ,
675
- media_source_id : str ,
676
- ) -> tuple [str , bytes ]:
677
- # Should fail because it's not 16Khz, 16-bit mono
678
- with io .BytesIO () as wav_io :
679
- wav_file : wave .Wave_write = wave .open (wav_io , "wb" )
680
- with wav_file :
681
- wav_file .setframerate (22050 )
682
- wav_file .setsampwidth (2 )
683
- wav_file .setnchannels (2 )
684
-
685
- return ("wav" , wav_io .getvalue ())
686
-
687
- with (
688
- patch (
689
- "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" ,
690
- new = async_pipeline_from_audio_stream ,
691
- ),
692
- patch (
693
- "homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio" ,
694
- new = async_get_media_source_audio ,
695
- ),
642
+ with patch (
643
+ "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" ,
644
+ new = async_pipeline_from_audio_stream ,
696
645
):
697
646
satellite .transport = Mock ()
698
647
@@ -878,10 +827,11 @@ async def test_announce(
878
827
assert err .value .translation_domain == "voip"
879
828
assert err .value .translation_key == "non_tts_announcement"
880
829
830
+ mock_tts_result_stream = MockResultStream (hass , "wav" , _empty_wav ())
881
831
announcement = assist_satellite .AssistSatelliteAnnouncement (
882
832
message = "test announcement" ,
883
833
media_id = _MEDIA_ID ,
884
- tts_token = "test- token" ,
834
+ tts_token = mock_tts_result_stream . token ,
885
835
original_media_id = _MEDIA_ID ,
886
836
media_id_source = "tts" ,
887
837
)
@@ -907,7 +857,9 @@ async def test_announce(
907
857
async with asyncio .timeout (1 ):
908
858
await announce_task
909
859
910
- mock_send_tts .assert_called_once_with (_MEDIA_ID , wait_for_tone = False )
860
+ mock_send_tts .assert_called_once_with (
861
+ mock_tts_result_stream , wait_for_tone = False
862
+ )
911
863
912
864
913
865
@pytest .mark .usefixtures ("socket_enabled" )
@@ -926,10 +878,11 @@ async def test_voip_id_is_ip_address(
926
878
& assist_satellite .AssistSatelliteEntityFeature .ANNOUNCE
927
879
)
928
880
881
+ mock_tts_result_stream = MockResultStream (hass , "wav" , _empty_wav ())
929
882
announcement = assist_satellite .AssistSatelliteAnnouncement (
930
883
message = "test announcement" ,
931
884
media_id = _MEDIA_ID ,
932
- tts_token = "test- token" ,
885
+ tts_token = mock_tts_result_stream . token ,
933
886
original_media_id = _MEDIA_ID ,
934
887
media_id_source = "tts" ,
935
888
)
@@ -960,7 +913,9 @@ async def test_voip_id_is_ip_address(
960
913
async with asyncio .timeout (1 ):
961
914
await announce_task
962
915
963
- mock_send_tts .assert_called_once_with (_MEDIA_ID , wait_for_tone = False )
916
+ mock_send_tts .assert_called_once_with (
917
+ mock_tts_result_stream , wait_for_tone = False
918
+ )
964
919
965
920
966
921
@pytest .mark .usefixtures ("socket_enabled" )
@@ -979,10 +934,11 @@ async def test_announce_timeout(
979
934
& assist_satellite .AssistSatelliteEntityFeature .ANNOUNCE
980
935
)
981
936
937
+ mock_tts_result_stream = MockResultStream (hass , "wav" , _empty_wav ())
982
938
announcement = assist_satellite .AssistSatelliteAnnouncement (
983
939
message = "test announcement" ,
984
940
media_id = _MEDIA_ID ,
985
- tts_token = "test- token" ,
941
+ tts_token = mock_tts_result_stream . token ,
986
942
original_media_id = _MEDIA_ID ,
987
943
media_id_source = "tts" ,
988
944
)
@@ -1020,10 +976,11 @@ async def test_start_conversation(
1020
976
& assist_satellite .AssistSatelliteEntityFeature .START_CONVERSATION
1021
977
)
1022
978
979
+ mock_tts_result_stream = MockResultStream (hass , "wav" , _empty_wav ())
1023
980
announcement = assist_satellite .AssistSatelliteAnnouncement (
1024
981
message = "test announcement" ,
1025
982
media_id = _MEDIA_ID ,
1026
- tts_token = "test- token" ,
983
+ tts_token = mock_tts_result_stream . token ,
1027
984
original_media_id = _MEDIA_ID ,
1028
985
media_id_source = "tts" ,
1029
986
)
@@ -1061,10 +1018,11 @@ async def async_pipeline_from_audio_stream(
1061
1018
)
1062
1019
1063
1020
# Proceed with media output
1021
+ mock_tts_result_stream = MockResultStream (hass , "wav" , _empty_wav ())
1064
1022
event_callback (
1065
1023
assist_pipeline .PipelineEvent (
1066
1024
type = assist_pipeline .PipelineEventType .TTS_END ,
1067
- data = {"tts_output" : {"media_id " : _MEDIA_ID }},
1025
+ data = {"tts_output" : {"token " : mock_tts_result_stream . token }},
1068
1026
)
1069
1027
)
1070
1028
0 commit comments