Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions custom_components/stream_assist/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from homeassistant.helpers.device_registry import DeviceEntryType
from homeassistant.helpers.entity import Entity, DeviceInfo
from homeassistant.helpers.entity_component import EntityComponent

from homeassistant.helpers.dispatcher import async_dispatcher_send
from .stream import Stream

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -122,7 +122,7 @@ def internal_event_callback(event: PipelineEvent):
if event.data
else {"timestamp": event.timestamp}
)

async_dispatcher_send(hass, "simple_state_pipeline_event", event)
if event.type == PipelineEventType.STT_START:
if player_entity_id and (media_id := data.get("stt_start_media")):
play_media(hass, player_entity_id, media_id, "audio")
Expand Down
7 changes: 3 additions & 4 deletions custom_components/stream_assist/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from homeassistant.helpers.entity_platform import AddEntitiesCallback

from .core import EVENTS, init_entity

from .simple_state import SimpleState

async def async_setup_entry(
hass: HomeAssistant,
Expand All @@ -27,10 +27,9 @@ async def async_setup_entry(
if event == "tts" and not pipeline.tts_engine:
continue
entities.append(StreamAssistSensor(config_entry, event))

simple_state_sensor = SimpleState(hass, config_entry)
async_add_entities([simple_state_sensor], update_before_add=True)
async_add_entities(entities)


class StreamAssistSensor(SensorEntity):
_attr_native_value = STATE_IDLE

Expand Down
127 changes: 127 additions & 0 deletions custom_components/stream_assist/simple_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations
import asyncio
import logging
import io
from homeassistant.components.assist_pipeline.pipeline import (
PipelineEvent,
PipelineEventType,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from mutagen.mp3 import MP3
from homeassistant.helpers.network import get_url
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.const import STATE_IDLE
from homeassistant.components.sensor import SensorEntity
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
from .core import init_entity
from .core import DOMAIN
_LOGGER = logging.getLogger(__name__)

class SimpleState(SensorEntity):
_attr_native_value = STATE_IDLE

@property
def streamassist_entity_name(self):
"""Return the entity name."""
return "simple_state"

def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Initialize the SimpleState sensor."""
self.hass = hass
self.config_entry = config_entry
self._attr_native_value = STATE_IDLE
self.schedule_update_ha_state()
self.tts_duration = 0
init_entity(self, "simple_state", config_entry)


async def async_added_to_hass(self) -> None:
"""Subscribe to pipeline events."""
self.remove_dispatcher = async_dispatcher_connect(
self.hass, "simple_state_pipeline_event", self.on_pipeline_event
)
self.async_on_remove(self.remove_dispatcher)

def on_pipeline_event(self, event: PipelineEvent):
"""Handle a pipeline event."""

def getSimpleState(t: PipelineEventType) -> str:
match t:
case PipelineEventType.ERROR:
return "error"
case PipelineEventType.WAKE_WORD_END:
_LOGGER.debug("PipelineEventType.WAKE_WORD_END TRIGGERED.")
return "detected"
case PipelineEventType.STT_START:
self.hass.loop.call_soon_threadsafe(self._handle_listening_state)
return None
case PipelineEventType.INTENT_START:
return "processing"
case PipelineEventType.TTS_END:
tts_url = event.data["tts_output"]["url"]
self.hass.loop.call_soon_threadsafe(
self._handle_tts_end, tts_url
)
return "responding"
case _:
return None

state = getSimpleState(event.type)
if state is not None:
self._update_state(state)

def _update_state(self, state: str):
"""Update the state safely from the event loop."""
self._attr_native_value = state
self.schedule_update_ha_state()

def _handle_listening_state(self):
"""Handle the delayed update for the 'listening' state."""
async def handle():
try:
await asyncio.sleep(0.5)
self._update_state("listening")
except Exception as e:
_LOGGER.error(f"Error in _handle_listening_state: {e}")

self.hass.loop.create_task(handle())

async def get_tts_duration(self, hass: HomeAssistant, tts_url: str) -> float:
try:
if tts_url.startswith('/'):
base_url = get_url(hass)
full_url = f"{base_url}{tts_url}"
else:
full_url = tts_url

session = async_get_clientsession(hass)
async with session.get(full_url) as response:
if response.status != 200:
_LOGGER.error(f"Failed to fetch TTS audio: HTTP {response.status}")
return 0

content = await response.read()

audio = MP3(io.BytesIO(content))
return audio.info.length
except Exception as e:
_LOGGER.error(f"Error getting TTS duration: {e}")
return 0

def _handle_tts_end(self, tts_url: str):
"""Handle the end of TTS and store its duration."""

async def handle():
try:
duration = await self.get_tts_duration(self.hass, tts_url)
self.tts_duration = duration
_LOGGER.debug(f"Stored TTS duration: {duration} seconds")

await asyncio.sleep(duration - 0.5)
self._update_state("finished")
except Exception as e:
_LOGGER.error(f"Error in _handle_tts_end: {e}")

self.hass.loop.create_task(handle())