Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 51 additions & 24 deletions homeassistant/components/websocket_api/automation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from homeassistant.const import CONF_TARGET
from homeassistant.core import HomeAssistant
from homeassistant.helpers import target as target_helpers
from homeassistant.helpers import entity_registry as er, target as target_helpers
from homeassistant.helpers.condition import (
async_get_all_descriptions as async_get_all_condition_descriptions,
)
Expand Down Expand Up @@ -92,12 +92,14 @@ class _AutomationComponentLookupData:

component: str
filters: list[_EntityFilter]
primary_entities_only: bool = True

@classmethod
def create(cls, component: str, target_description: dict[str, Any]) -> Self:
"""Build automation component lookup data from target description."""
filters: list[_EntityFilter] = []

primary_entities_only = target_description.get("primary_entities_only", True)
entity_filters_config = target_description.get("entity", [])
for entity_filter_config in entity_filters_config:
entity_filter = _EntityFilter(
Expand All @@ -110,14 +112,29 @@ def create(cls, component: str, target_description: dict[str, Any]) -> Self:
)
filters.append(entity_filter)

return cls(component=component, filters=filters)
return cls(
component=component,
filters=filters,
primary_entities_only=primary_entities_only,
)

def matches(
self, hass: HomeAssistant, entity_id: str, domain: str, integration: str
self,
hass: HomeAssistant,
entity_id: str,
domain: str,
integration: str,
check_entity_category: bool,
) -> bool:
"""Return if entity matches ANY of the filters."""
if check_entity_category and self.primary_entities_only:
entry = er.async_get(hass).async_get(entity_id)
if entry is not None and entry.entity_category is not None:
Comment thread
abmantis marked this conversation as resolved.
return False

if not self.filters:
return True

return any(
f.matches(hass, entity_id, domain, integration) for f in self.filters
)
Expand Down Expand Up @@ -220,6 +237,7 @@ def _async_get_automation_components_for_target(
hass,
target_helpers.TargetSelection(target_selection),
expand_group=expand_group,
primary_entities_only=False,
)
_LOGGER.debug("Extracted entities for lookup: %s", extracted)

Expand All @@ -232,30 +250,39 @@ def _async_get_automation_components_for_target(

entity_infos = entity_sources(hass)
matched_components: set[str] = set()
for entity_id in extracted.referenced | extracted.indirectly_referenced:
if lookup_table.component_count == len(matched_components):
# All automation components matched already, so we don't need to iterate further
break

entity_info = entity_infos.get(entity_id)
if entity_info is None:
_LOGGER.debug("No entity source found for %s", entity_id)
continue

entity_domain = entity_id.split(".")[0]
entity_integration = entity_info["domain"]
for domain in (entity_domain, entity_integration, None):
if not (
domain_component_data := lookup_table.domain_components.get(domain)
):
def _match_components(entities: set[str], check_entity_category: bool) -> None:
for entity_id in entities:
if lookup_table.component_count == len(matched_components):
# All automation components matched already, so we don't need to iterate further
break

entity_info = entity_infos.get(entity_id)
if entity_info is None:
_LOGGER.debug("No entity source found for %s", entity_id)
continue
for component_data in domain_component_data:
if component_data.component in matched_components:
continue
if component_data.matches(
hass, entity_id, entity_domain, entity_integration

entity_domain = entity_id.split(".")[0]
entity_integration = entity_info["domain"]
for domain in (entity_domain, entity_integration, None):
if not (
domain_component_data := lookup_table.domain_components.get(domain)
):
matched_components.add(component_data.component)
continue
for component_data in domain_component_data:
if component_data.component in matched_components:
continue
if component_data.matches(
hass,
entity_id,
entity_domain,
entity_integration,
check_entity_category,
):
matched_components.add(component_data.component)

_match_components(extracted.referenced, check_entity_category=False)
_match_components(extracted.indirectly_referenced, check_entity_category=True)

return matched_components

Expand Down
Loading