Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions homeassistant/components/device_automation/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol
from typing import Any, Protocol

import voluptuous as vol

Expand All @@ -11,18 +11,15 @@
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import (
Condition,
ConditionChecker,
ConditionCheckerType,
ConditionConfig,
trace_condition_function,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, TemplateVarsType

from . import DeviceAutomationType, async_get_device_automation_platform
from .helpers import async_validate_device_automation_config

if TYPE_CHECKING:
from homeassistant.helpers import condition


class DeviceAutomationConditionProtocol(Protocol):
"""Define the format of device_condition modules.
Expand Down Expand Up @@ -90,15 +87,21 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
assert config.options is not None
self._config = config.options

async def async_get_checker(self) -> condition.ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Test a device condition."""
platform = await async_get_device_automation_platform(
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return trace_condition_function(
platform.async_condition_from_config(self._hass, self._config)
platform_checker = platform.async_condition_from_config(
self._hass, self._config
)

def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool:
Comment thread
arturpragacz marked this conversation as resolved.
result = platform_checker(self._hass, variables)
return result is not False
Comment thread
arturpragacz marked this conversation as resolved.

return checker


CONDITIONS: dict[str, type[Condition]] = {
"_device": DeviceCondition,
Expand Down
17 changes: 8 additions & 9 deletions homeassistant/components/light/condition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Provides conditions for lights."""

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Final, override
from typing import TYPE_CHECKING, Any, Final, Unpack, override

import voluptuous as vol

Expand All @@ -10,11 +10,11 @@
from homeassistant.helpers import config_validation as cv, target
from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType,
ConditionChecker,
ConditionCheckParams,
ConditionConfig,
trace_condition_function,
)
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.helpers.typing import ConfigType

from .const import DOMAIN

Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
self._state = state

@override
async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Get the condition checker."""

def check_any_match_state(states: list[str]) -> bool:
Expand All @@ -78,12 +78,11 @@ def check_all_match_state(states: list[str]) -> bool:
elif self._behavior == BEHAVIOR_ALL:
matcher = check_all_match_state

@trace_condition_function
def test_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test state condition."""
target_selection = target.TargetSelection(self._target)
targeted_entities = target.async_extract_referenced_entity_ids(
hass, target_selection, expand_group=False
self._hass, target_selection, expand_group=False
)
referenced_entity_ids = targeted_entities.referenced.union(
targeted_entities.indirectly_referenced
Expand All @@ -96,7 +95,7 @@ def test_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
light_entity_states = [
state.state
for entity_id in light_entity_ids
if (state := hass.states.get(entity_id))
if (state := self._hass.states.get(entity_id))
and state.state in STATE_CONDITION_VALID_STATES
]
return matcher(light_entity_states)
Expand Down
15 changes: 7 additions & 8 deletions homeassistant/components/sun/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from datetime import datetime, timedelta
from typing import Any, cast
from typing import Any, Unpack, cast

import voluptuous as vol

Expand All @@ -13,14 +13,14 @@
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType,
ConditionChecker,
ConditionCheckParams,
ConditionConfig,
condition_trace_set_result,
condition_trace_update_result,
trace_condition_function,
)
from homeassistant.helpers.sun import get_astral_event_date
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import dt as dt_util

_OPTIONS_SCHEMA_DICT: dict[vol.Marker, Any] = {
Expand Down Expand Up @@ -154,17 +154,16 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
assert config.options is not None
self._options = config.options

async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Wrap action method with sun based condition."""
before = self._options.get("before")
after = self._options.get("after")
before_offset = self._options.get("before_offset")
after_offset = self._options.get("after_offset")

@trace_condition_function
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Validate time based if-condition."""
return sun(hass, before, after, before_offset, after_offset)
return sun(self._hass, before, after, before_offset, after_offset)

return sun_if

Expand Down
15 changes: 7 additions & 8 deletions homeassistant/components/zone/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, cast
from typing import Any, Unpack, cast

import voluptuous as vol

Expand All @@ -22,11 +22,11 @@
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType,
ConditionChecker,
ConditionCheckParams,
ConditionConfig,
trace_condition_function,
)
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.helpers.typing import ConfigType

from . import in_zone

Expand Down Expand Up @@ -118,13 +118,12 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
assert config.options is not None
self._options = config.options

async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Wrap action method with zone based condition."""
entity_ids = self._options.get(CONF_ENTITY_ID, [])
zone_entity_ids = self._options.get(CONF_ZONE, [])

@trace_condition_function
def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test if condition."""
errors = []

Expand All @@ -133,7 +132,7 @@ def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
entity_ok = False
for zone_entity_id in zone_entity_ids:
try:
if zone(hass, zone_entity_id, entity_id):
if zone(self._hass, zone_entity_id, entity_id):
entity_ok = True
except ConditionErrorMessage as ex:
errors.append(
Expand Down
55 changes: 47 additions & 8 deletions homeassistant/helpers/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import logging
import re
import sys
from typing import TYPE_CHECKING, Any, Protocol, cast
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, Unpack, cast, overload

import voluptuous as vol

Expand Down Expand Up @@ -298,7 +298,7 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
self._hass = hass

@abc.abstractmethod
async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Get the condition checker."""


Expand All @@ -319,7 +319,23 @@ class ConditionConfig:
target: dict[str, Any] | None = None


type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
class ConditionCheckParams(TypedDict, total=False):
"""Condition check params."""

variables: TemplateVarsType


class ConditionChecker(Protocol):
"""Protocol for condition checker callable with typed kwargs."""

def __call__(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Check the condition."""


type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
type ConditionCheckerTypeOptional = Callable[
[HomeAssistant, TemplateVarsType], bool | None
]


def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
Expand Down Expand Up @@ -374,7 +390,21 @@ def trace_condition(variables: TemplateVarsType) -> Generator[TraceElement]:
trace_stack_pop(trace_stack_cv)


def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType:
@overload
def trace_condition_function(
condition: ConditionCheckerType,
) -> ConditionCheckerType: ...


@overload
def trace_condition_function(
condition: ConditionCheckerTypeOptional,
) -> ConditionCheckerTypeOptional: ...


def trace_condition_function(
condition: ConditionCheckerType | ConditionCheckerTypeOptional,
) -> ConditionCheckerType | ConditionCheckerTypeOptional:
"""Wrap a condition function to enable basic tracing."""

@ft.wraps(condition)
Expand Down Expand Up @@ -420,10 +450,20 @@ async def _async_get_condition_platform(
) from None


async def _async_get_checker(condition: Condition) -> ConditionCheckerType:
new_checker = await condition.async_get_checker()

@trace_condition_function
def checker(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
return new_checker(variables=variables)

return checker


async def async_from_config(
hass: HomeAssistant,
config: ConfigType,
) -> ConditionCheckerType:
) -> ConditionCheckerTypeOptional:
"""Turn a condition configuration into a method.

Should be run on the event loop.
Expand Down Expand Up @@ -466,7 +506,7 @@ def disabled_condition(
target=config.get(CONF_TARGET),
),
)
return await condition.async_get_checker()
return await _async_get_checker(condition)

for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition_key), None)
Expand Down Expand Up @@ -1131,7 +1171,7 @@ async def async_conditions_from_config(
name: str,
) -> Callable[[TemplateVarsType], bool]:
"""AND all conditions."""
checks: list[ConditionCheckerType] = [
checks = [
await async_from_config(hass, condition_config)
for condition_config in condition_configs
]
Expand Down Expand Up @@ -1330,7 +1370,6 @@ async def async_get_all_descriptions(
continue

description = {"fields": yaml_description.get("fields", {})}

if (target := yaml_description.get("target")) is not None:
description["target"] = target

Expand Down
20 changes: 13 additions & 7 deletions homeassistant/helpers/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from homeassistant.util.signal_type import SignalType, SignalTypeFormat

from . import condition, config_validation as cv, service, template
from .condition import ConditionCheckerType, trace_condition_function
from .condition import ConditionCheckerTypeOptional, trace_condition_function
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
from .event import async_call_later, async_track_template
from .script_variables import ScriptRunVariables, ScriptVariables
Expand Down Expand Up @@ -675,12 +675,14 @@ async def _async_step_sequence(self) -> None:

### Condition actions ###

async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
async def _async_get_condition(
self, config: ConfigType
) -> ConditionCheckerTypeOptional:
return await self._script._async_get_condition(config) # noqa: SLF001

def _test_conditions(
self,
conditions: list[ConditionCheckerType],
conditions: list[ConditionCheckerTypeOptional],
name: str,
condition_path: str | None = None,
) -> bool | None:
Expand Down Expand Up @@ -1404,12 +1406,12 @@ def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:


class _ChooseData(TypedDict):
choices: list[tuple[list[ConditionCheckerType], Script]]
choices: list[tuple[list[ConditionCheckerTypeOptional], Script]]
default: Script | None


class _IfData(TypedDict):
if_conditions: list[ConditionCheckerType]
if_conditions: list[ConditionCheckerTypeOptional]
if_then: Script
if_else: Script | None

Expand Down Expand Up @@ -1486,7 +1488,9 @@ def __init__(
self._max_exceeded = max_exceeded
if script_mode == SCRIPT_MODE_QUEUED:
self._queue_lck = asyncio.Lock()
self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {}
self._config_cache: dict[
frozenset[tuple[str, str]], ConditionCheckerTypeOptional
] = {}
self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, _ChooseData] = {}
self._if_data: dict[int, _IfData] = {}
Expand Down Expand Up @@ -1857,7 +1861,9 @@ async def async_stop(
return
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))

async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
async def _async_get_condition(
self, config: ConfigType
) -> ConditionCheckerTypeOptional:
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
if not (cond := self._config_cache.get(config_cache_key)):
cond = await condition.async_from_config(self._hass, config)
Expand Down
Loading
Loading