Skip to content

Commit 811c7d3

Browse files
authored
Modernise condition checker in helper (home-assistant#159159)
1 parent 41741ac commit 811c7d3

7 files changed

Lines changed: 99 additions & 54 deletions

File tree

homeassistant/components/device_automation/condition.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Any, Protocol
5+
from typing import Any, Protocol
66

77
import voluptuous as vol
88

@@ -11,18 +11,15 @@
1111
from homeassistant.helpers import config_validation as cv
1212
from homeassistant.helpers.condition import (
1313
Condition,
14+
ConditionChecker,
1415
ConditionCheckerType,
1516
ConditionConfig,
16-
trace_condition_function,
1717
)
18-
from homeassistant.helpers.typing import ConfigType
18+
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
1919

2020
from . import DeviceAutomationType, async_get_device_automation_platform
2121
from .helpers import async_validate_device_automation_config
2222

23-
if TYPE_CHECKING:
24-
from homeassistant.helpers import condition
25-
2623

2724
class DeviceAutomationConditionProtocol(Protocol):
2825
"""Define the format of device_condition modules.
@@ -90,15 +87,21 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
9087
assert config.options is not None
9188
self._config = config.options
9289

93-
async def async_get_checker(self) -> condition.ConditionCheckerType:
90+
async def async_get_checker(self) -> ConditionChecker:
9491
"""Test a device condition."""
9592
platform = await async_get_device_automation_platform(
9693
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
9794
)
98-
return trace_condition_function(
99-
platform.async_condition_from_config(self._hass, self._config)
95+
platform_checker = platform.async_condition_from_config(
96+
self._hass, self._config
10097
)
10198

99+
def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool:
100+
result = platform_checker(self._hass, variables)
101+
return result is not False
102+
103+
return checker
104+
102105

103106
CONDITIONS: dict[str, type[Condition]] = {
104107
"_device": DeviceCondition,

homeassistant/components/light/condition.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Provides conditions for lights."""
22

33
from collections.abc import Callable
4-
from typing import TYPE_CHECKING, Any, Final, override
4+
from typing import TYPE_CHECKING, Any, Final, Unpack, override
55

66
import voluptuous as vol
77

@@ -10,11 +10,11 @@
1010
from homeassistant.helpers import config_validation as cv, target
1111
from homeassistant.helpers.condition import (
1212
Condition,
13-
ConditionCheckerType,
13+
ConditionChecker,
14+
ConditionCheckParams,
1415
ConditionConfig,
15-
trace_condition_function,
1616
)
17-
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
17+
from homeassistant.helpers.typing import ConfigType
1818

1919
from .const import DOMAIN
2020

@@ -61,7 +61,7 @@ def __init__(
6161
self._state = state
6262

6363
@override
64-
async def async_get_checker(self) -> ConditionCheckerType:
64+
async def async_get_checker(self) -> ConditionChecker:
6565
"""Get the condition checker."""
6666

6767
def check_any_match_state(states: list[str]) -> bool:
@@ -78,12 +78,11 @@ def check_all_match_state(states: list[str]) -> bool:
7878
elif self._behavior == BEHAVIOR_ALL:
7979
matcher = check_all_match_state
8080

81-
@trace_condition_function
82-
def test_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
81+
def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool:
8382
"""Test state condition."""
8483
target_selection = target.TargetSelection(self._target)
8584
targeted_entities = target.async_extract_referenced_entity_ids(
86-
hass, target_selection, expand_group=False
85+
self._hass, target_selection, expand_group=False
8786
)
8887
referenced_entity_ids = targeted_entities.referenced.union(
8988
targeted_entities.indirectly_referenced
@@ -96,7 +95,7 @@ def test_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
9695
light_entity_states = [
9796
state.state
9897
for entity_id in light_entity_ids
99-
if (state := hass.states.get(entity_id))
98+
if (state := self._hass.states.get(entity_id))
10099
and state.state in STATE_CONDITION_VALID_STATES
101100
]
102101
return matcher(light_entity_states)

homeassistant/components/sun/condition.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from datetime import datetime, timedelta
6-
from typing import Any, cast
6+
from typing import Any, Unpack, cast
77

88
import voluptuous as vol
99

@@ -13,14 +13,14 @@
1313
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
1414
from homeassistant.helpers.condition import (
1515
Condition,
16-
ConditionCheckerType,
16+
ConditionChecker,
17+
ConditionCheckParams,
1718
ConditionConfig,
1819
condition_trace_set_result,
1920
condition_trace_update_result,
20-
trace_condition_function,
2121
)
2222
from homeassistant.helpers.sun import get_astral_event_date
23-
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
23+
from homeassistant.helpers.typing import ConfigType
2424
from homeassistant.util import dt as dt_util
2525

2626
_OPTIONS_SCHEMA_DICT: dict[vol.Marker, Any] = {
@@ -154,17 +154,16 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
154154
assert config.options is not None
155155
self._options = config.options
156156

157-
async def async_get_checker(self) -> ConditionCheckerType:
157+
async def async_get_checker(self) -> ConditionChecker:
158158
"""Wrap action method with sun based condition."""
159159
before = self._options.get("before")
160160
after = self._options.get("after")
161161
before_offset = self._options.get("before_offset")
162162
after_offset = self._options.get("after_offset")
163163

164-
@trace_condition_function
165-
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
164+
def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool:
166165
"""Validate time based if-condition."""
167-
return sun(hass, before, after, before_offset, after_offset)
166+
return sun(self._hass, before, after, before_offset, after_offset)
168167

169168
return sun_if
170169

homeassistant/components/zone/condition.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, cast
5+
from typing import Any, Unpack, cast
66

77
import voluptuous as vol
88

@@ -22,11 +22,11 @@
2222
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
2323
from homeassistant.helpers.condition import (
2424
Condition,
25-
ConditionCheckerType,
25+
ConditionChecker,
26+
ConditionCheckParams,
2627
ConditionConfig,
27-
trace_condition_function,
2828
)
29-
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
29+
from homeassistant.helpers.typing import ConfigType
3030

3131
from . import in_zone
3232

@@ -118,13 +118,12 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
118118
assert config.options is not None
119119
self._options = config.options
120120

121-
async def async_get_checker(self) -> ConditionCheckerType:
121+
async def async_get_checker(self) -> ConditionChecker:
122122
"""Wrap action method with zone based condition."""
123123
entity_ids = self._options.get(CONF_ENTITY_ID, [])
124124
zone_entity_ids = self._options.get(CONF_ZONE, [])
125125

126-
@trace_condition_function
127-
def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
126+
def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool:
128127
"""Test if condition."""
129128
errors = []
130129

@@ -133,7 +132,7 @@ def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
133132
entity_ok = False
134133
for zone_entity_id in zone_entity_ids:
135134
try:
136-
if zone(hass, zone_entity_id, entity_id):
135+
if zone(self._hass, zone_entity_id, entity_id):
137136
entity_ok = True
138137
except ConditionErrorMessage as ex:
139138
errors.append(

homeassistant/helpers/condition.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import re
1515
import sys
16-
from typing import TYPE_CHECKING, Any, Protocol, cast
16+
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, Unpack, cast, overload
1717

1818
import voluptuous as vol
1919

@@ -298,7 +298,7 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
298298
self._hass = hass
299299

300300
@abc.abstractmethod
301-
async def async_get_checker(self) -> ConditionCheckerType:
301+
async def async_get_checker(self) -> ConditionChecker:
302302
"""Get the condition checker."""
303303

304304

@@ -319,7 +319,23 @@ class ConditionConfig:
319319
target: dict[str, Any] | None = None
320320

321321

322-
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
322+
class ConditionCheckParams(TypedDict, total=False):
323+
"""Condition check params."""
324+
325+
variables: TemplateVarsType
326+
327+
328+
class ConditionChecker(Protocol):
329+
"""Protocol for condition checker callable with typed kwargs."""
330+
331+
def __call__(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
332+
"""Check the condition."""
333+
334+
335+
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
336+
type ConditionCheckerTypeOptional = Callable[
337+
[HomeAssistant, TemplateVarsType], bool | None
338+
]
323339

324340

325341
def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
@@ -374,7 +390,21 @@ def trace_condition(variables: TemplateVarsType) -> Generator[TraceElement]:
374390
trace_stack_pop(trace_stack_cv)
375391

376392

377-
def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType:
393+
@overload
394+
def trace_condition_function(
395+
condition: ConditionCheckerType,
396+
) -> ConditionCheckerType: ...
397+
398+
399+
@overload
400+
def trace_condition_function(
401+
condition: ConditionCheckerTypeOptional,
402+
) -> ConditionCheckerTypeOptional: ...
403+
404+
405+
def trace_condition_function(
406+
condition: ConditionCheckerType | ConditionCheckerTypeOptional,
407+
) -> ConditionCheckerType | ConditionCheckerTypeOptional:
378408
"""Wrap a condition function to enable basic tracing."""
379409

380410
@ft.wraps(condition)
@@ -420,10 +450,20 @@ async def _async_get_condition_platform(
420450
) from None
421451

422452

453+
async def _async_get_checker(condition: Condition) -> ConditionCheckerType:
454+
new_checker = await condition.async_get_checker()
455+
456+
@trace_condition_function
457+
def checker(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
458+
return new_checker(variables=variables)
459+
460+
return checker
461+
462+
423463
async def async_from_config(
424464
hass: HomeAssistant,
425465
config: ConfigType,
426-
) -> ConditionCheckerType:
466+
) -> ConditionCheckerTypeOptional:
427467
"""Turn a condition configuration into a method.
428468
429469
Should be run on the event loop.
@@ -466,7 +506,7 @@ def disabled_condition(
466506
target=config.get(CONF_TARGET),
467507
),
468508
)
469-
return await condition.async_get_checker()
509+
return await _async_get_checker(condition)
470510

471511
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
472512
factory = getattr(sys.modules[__name__], fmt.format(condition_key), None)
@@ -1131,7 +1171,7 @@ async def async_conditions_from_config(
11311171
name: str,
11321172
) -> Callable[[TemplateVarsType], bool]:
11331173
"""AND all conditions."""
1134-
checks: list[ConditionCheckerType] = [
1174+
checks = [
11351175
await async_from_config(hass, condition_config)
11361176
for condition_config in condition_configs
11371177
]
@@ -1330,7 +1370,6 @@ async def async_get_all_descriptions(
13301370
continue
13311371

13321372
description = {"fields": yaml_description.get("fields", {})}
1333-
13341373
if (target := yaml_description.get("target")) is not None:
13351374
description["target"] = target
13361375

homeassistant/helpers/script.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
from homeassistant.util.signal_type import SignalType, SignalTypeFormat
8787

8888
from . import condition, config_validation as cv, service, template
89-
from .condition import ConditionCheckerType, trace_condition_function
89+
from .condition import ConditionCheckerTypeOptional, trace_condition_function
9090
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
9191
from .event import async_call_later, async_track_template
9292
from .script_variables import ScriptRunVariables, ScriptVariables
@@ -675,12 +675,14 @@ async def _async_step_sequence(self) -> None:
675675

676676
### Condition actions ###
677677

678-
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
678+
async def _async_get_condition(
679+
self, config: ConfigType
680+
) -> ConditionCheckerTypeOptional:
679681
return await self._script._async_get_condition(config) # noqa: SLF001
680682

681683
def _test_conditions(
682684
self,
683-
conditions: list[ConditionCheckerType],
685+
conditions: list[ConditionCheckerTypeOptional],
684686
name: str,
685687
condition_path: str | None = None,
686688
) -> bool | None:
@@ -1404,12 +1406,12 @@ def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
14041406

14051407

14061408
class _ChooseData(TypedDict):
1407-
choices: list[tuple[list[ConditionCheckerType], Script]]
1409+
choices: list[tuple[list[ConditionCheckerTypeOptional], Script]]
14081410
default: Script | None
14091411

14101412

14111413
class _IfData(TypedDict):
1412-
if_conditions: list[ConditionCheckerType]
1414+
if_conditions: list[ConditionCheckerTypeOptional]
14131415
if_then: Script
14141416
if_else: Script | None
14151417

@@ -1486,7 +1488,9 @@ def __init__(
14861488
self._max_exceeded = max_exceeded
14871489
if script_mode == SCRIPT_MODE_QUEUED:
14881490
self._queue_lck = asyncio.Lock()
1489-
self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {}
1491+
self._config_cache: dict[
1492+
frozenset[tuple[str, str]], ConditionCheckerTypeOptional
1493+
] = {}
14901494
self._repeat_script: dict[int, Script] = {}
14911495
self._choose_data: dict[int, _ChooseData] = {}
14921496
self._if_data: dict[int, _IfData] = {}
@@ -1857,7 +1861,9 @@ async def async_stop(
18571861
return
18581862
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
18591863

1860-
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
1864+
async def _async_get_condition(
1865+
self, config: ConfigType
1866+
) -> ConditionCheckerTypeOptional:
18611867
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
18621868
if not (cond := self._config_cache.get(config_cache_key)):
18631869
cond = await condition.async_from_config(self._hass, config)

0 commit comments

Comments
 (0)