Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 103 additions & 7 deletions homeassistant/helpers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,22 @@
SERVICE_GET_EVENTS,
)
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.homeassistant import async_should_expose
from homeassistant.components.homeassistant import (
DOMAIN as HOMEASSISTANT_DOMAIN,
async_should_expose,
)
from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
from homeassistant.components.todo import DOMAIN as TODO_DOMAIN, TodoServices
from homeassistant.components.weather import INTENT_GET_WEATHER
from homeassistant.const import (
ATTR_DOMAIN,
ATTR_ENTITY_ID,
ATTR_SERVICE,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
)
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
from homeassistant.exceptions import HomeAssistantError
Expand Down Expand Up @@ -317,6 +323,82 @@ async def async_call(
return IntentResponseDict(intent_response)


class EntityControlTool(Tool):
"""LLM tool for controlling exposed entities by exact entity_id."""

def __init__(self, service_name: str) -> None:
"""Init the class."""
self._service_name = service_name
service_label = (
unicode_slug.slugify(service_name, separator="_").title().replace("_", "")
)
self.name = f"HassEntity{service_label}"
self.description = (
f"Execute Home Assistant {service_name} for exposed entities by exact entity_id. "
"Use entity_id values from the static context when a specific device is requested."
)
self.parameters = vol.Schema({vol.Required(ATTR_ENTITY_ID): [cv.entity_id]})

async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Call the Home Assistant turn service for exposed entities."""
if ATTR_ENTITY_ID not in tool_input.tool_args:
return {
"success": False,
"error": f"required key not provided @ data['{ATTR_ENTITY_ID}']",
}

try:
entity_ids = [
cv.entity_id(entity_id)
for entity_id in cv.ensure_list(tool_input.tool_args[ATTR_ENTITY_ID])
Comment on lines +353 to +355
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard missing entity_id before reading tool arguments

EntityControlTool.async_call indexes tool_input.tool_args["entity_id"] directly, so malformed tool calls that omit entity_id raise KeyError instead of returning a structured tool error. In the conversation pipeline, chat_log only converts HomeAssistantError and vol.Invalid into tool results, so this uncaught exception can abort the tool-call flow instead of letting the assistant recover. This is especially likely with LLM-generated calls where required args are occasionally missing.

Useful? React with 👍 / 👎.

Comment on lines +353 to +355
]
except vol.Invalid as err:
return {"success": False, "error": str(err)}

valid_entity_ids: list[str] = []
failed: dict[str, str] = {}

for entity_id in entity_ids:
if hass.states.get(entity_id) is None:
failed[entity_id] = "Entity does not exist"
continue

if llm_context.assistant and not async_should_expose(
hass, llm_context.assistant, entity_id
):
failed[entity_id] = "Entity is not exposed to this assistant"
continue

domain = split_entity_id(entity_id)[0]
if not hass.services.has_service(domain, self._service_name):
failed[entity_id] = (
f"Domain {domain} does not support {self._service_name}"
)
continue

valid_entity_ids.append(entity_id)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Exclude unsupported domains from entity control results

This loop treats every exposed entity with a current state as actionable, but homeassistant.turn_on/turn_off only executes per-domain services and silently skips unsupported domains (it just logs a warning in homeassistant/components/homeassistant/__init__.py, lines 148-173). As a result, if the LLM passes an exposed non-turnable entity (for example a sensor), this tool can return success: true and include it in done even though no action was performed, causing incorrect confirmations to users.

Useful? React with 👍 / 👎.


if valid_entity_ids:
await hass.services.async_call(
HOMEASSISTANT_DOMAIN,
self._service_name,
{ATTR_ENTITY_ID: valid_entity_ids},
Comment on lines +363 to +387
context=llm_context.context,
blocking=True,
)

return cast(
JsonObjectType,
{
"success": bool(valid_entity_ids) and not failed,
"done": valid_entity_ids,
"failed": failed,
},
)


class IntentResponseDict(dict):
"""Dictionary to represent an intent response resulting from a tool call."""

Expand Down Expand Up @@ -495,7 +577,9 @@ def _async_get_preable(self, llm_context: LLMContext) -> list[str]:
"When controlling Home Assistant always call the intent tools. "
"Use HassTurnOn to lock and HassTurnOff to unlock a lock. "
"When controlling a device, prefer passing just name and domain. "
"When controlling an area, prefer passing just area name and domain."
"When controlling an area, prefer passing just area name and domain. "
"When the user requests a specific exposed entity and the static context "
"contains its entity_id, use the exact entity_id tool."
)
]
area: ar.AreaEntry | None = None
Expand Down Expand Up @@ -543,7 +627,14 @@ def _async_get_exposed_entities_prompt(
prompt.append(
"Static Context: An overview of the areas and the devices in this smart home:"
)
prompt.append(yaml_util.dump(list(exposed_entities["entities"].values())))
prompt.append(
yaml_util.dump(
[
{"entity_id": entity_id, **info}
for entity_id, info in exposed_entities["entities"].items()
]
)
)

return prompt

Expand Down Expand Up @@ -613,6 +704,14 @@ def _async_get_tools(
)

if exposed_domains:
tools.extend(
EntityControlTool(service_name)
for service_name in (SERVICE_TURN_ON, SERVICE_TURN_OFF)
if any(
self.hass.services.has_service(domain, service_name)
for domain in exposed_domains
)
)
tools.append(GetLiveContextTool())

return tools
Expand Down Expand Up @@ -686,10 +785,7 @@ def _get_exposed_entities(
area_names.append(area_entry.name)
area_names.extend(area_entry.aliases)

info: dict[str, Any] = {
"names": ", ".join(names),
"domain": state.domain,
}
info: dict[str, Any] = {"names": ", ".join(names), "domain": state.domain}

if include_state:
info["state"] = state.state
Expand Down
3 changes: 2 additions & 1 deletion tests/components/mcp_server/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
EVENT_PREFIX = "event: "
DATA_PREFIX = "data: "
EXPECTED_PROMPT_SUFFIX = """
- names: Kitchen Light
- entity_id: light.kitchen
names: Kitchen Light
domain: light
areas: Kitchen
"""
Expand Down
Loading
Loading