Skip to content

Commit cebe4aa

Browse files
Refactor condition API (#168815)
Co-authored-by: Artur Pragacz <artur@pragacz.com>
1 parent 32b9a21 commit cebe4aa

5 files changed

Lines changed: 230 additions & 163 deletions

File tree

homeassistant/components/device_automation/condition.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from homeassistant.helpers import config_validation as cv
1212
from homeassistant.helpers.condition import (
1313
Condition,
14-
ConditionChecker,
1514
ConditionCheckerType,
1615
ConditionConfig,
1716
)
@@ -54,6 +53,7 @@ class DeviceCondition(Condition):
5453
"""Device condition."""
5554

5655
_config: ConfigType
56+
_platform_checker: ConditionCheckerType
5757

5858
@classmethod
5959
async def async_validate_complete_config(
@@ -87,20 +87,19 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
8787
assert config.options is not None
8888
self._config = config.options
8989

90-
async def async_get_checker(self) -> ConditionChecker:
91-
"""Test a device condition."""
90+
async def async_setup(self) -> None:
91+
"""Set up a device condition."""
9292
platform = await async_get_device_automation_platform(
9393
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
9494
)
95-
platform_checker = platform.async_condition_from_config(
95+
self._platform_checker = platform.async_condition_from_config(
9696
self._hass, self._config
9797
)
9898

99-
def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool:
100-
result = platform_checker(self._hass, variables)
101-
return result is not False
102-
103-
return checker
99+
def _async_check(self, variables: TemplateVarsType = None, **kwargs: Any) -> bool:
100+
"""Check the condition."""
101+
result = self._platform_checker(self._hass, variables)
102+
return result is not False
104103

105104

106105
CONDITIONS: dict[str, type[Condition]] = {

homeassistant/components/sun/condition.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
1414
from homeassistant.helpers.condition import (
1515
Condition,
16-
ConditionChecker,
1716
ConditionCheckParams,
1817
ConditionConfig,
1918
condition_trace_set_result,
@@ -151,19 +150,20 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
151150
super().__init__(hass, config)
152151
assert config.options is not None
153152
self._options = config.options
154-
155-
async def async_get_checker(self) -> ConditionChecker:
156-
"""Wrap action method with sun based condition."""
157-
before = self._options.get("before")
158-
after = self._options.get("after")
159-
before_offset = self._options.get("before_offset")
160-
after_offset = self._options.get("after_offset")
161-
162-
def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool:
163-
"""Validate time based if-condition."""
164-
return sun(self._hass, before, after, before_offset, after_offset)
165-
166-
return sun_if
153+
self._before = self._options.get("before")
154+
self._after = self._options.get("after")
155+
self._before_offset = self._options.get("before_offset")
156+
self._after_offset = self._options.get("after_offset")
157+
158+
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
159+
"""Check the condition."""
160+
return sun(
161+
self._hass,
162+
self._before,
163+
self._after,
164+
self._before_offset,
165+
self._after_offset,
166+
)
167167

168168

169169
CONDITIONS: dict[str, type[Condition]] = {

homeassistant/components/zone/condition.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
2323
from homeassistant.helpers.condition import (
2424
Condition,
25-
ConditionChecker,
2625
ConditionCheckParams,
2726
ConditionConfig,
2827
)
@@ -117,44 +116,39 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
117116
super().__init__(hass, config)
118117
assert config.options is not None
119118
self._options = config.options
120-
121-
async def async_get_checker(self) -> ConditionChecker:
122-
"""Wrap action method with zone based condition."""
123-
entity_ids = self._options.get(CONF_ENTITY_ID, [])
124-
zone_entity_ids = self._options.get(CONF_ZONE, [])
125-
126-
def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool:
127-
"""Test if condition."""
128-
errors = []
129-
130-
all_ok = True
131-
for entity_id in entity_ids:
132-
entity_ok = False
133-
for zone_entity_id in zone_entity_ids:
134-
try:
135-
if zone(self._hass, zone_entity_id, entity_id):
136-
entity_ok = True
137-
except ConditionErrorMessage as ex:
138-
errors.append(
139-
ConditionErrorMessage(
140-
"zone",
141-
(
142-
f"error matching {entity_id} with {zone_entity_id}:"
143-
f" {ex.message}"
144-
),
145-
)
119+
self._entity_ids = self._options.get(CONF_ENTITY_ID, [])
120+
self._zone_entity_ids = self._options.get(CONF_ZONE, [])
121+
122+
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
123+
"""Test if condition."""
124+
errors = []
125+
126+
all_ok = True
127+
for entity_id in self._entity_ids:
128+
entity_ok = False
129+
for zone_entity_id in self._zone_entity_ids:
130+
try:
131+
if zone(self._hass, zone_entity_id, entity_id):
132+
entity_ok = True
133+
except ConditionErrorMessage as ex:
134+
errors.append(
135+
ConditionErrorMessage(
136+
"zone",
137+
(
138+
f"error matching {entity_id} with {zone_entity_id}:"
139+
f" {ex.message}"
140+
),
146141
)
142+
)
147143

148-
if not entity_ok:
149-
all_ok = False
150-
151-
# Raise the errors only if no definitive result was found
152-
if errors and not all_ok:
153-
raise ConditionErrorContainer("zone", errors=errors)
144+
if not entity_ok:
145+
all_ok = False
154146

155-
return all_ok
147+
# Raise the errors only if no definitive result was found
148+
if errors and not all_ok:
149+
raise ConditionErrorContainer("zone", errors=errors)
156150

157-
return if_in_zone
151+
return all_ok
158152

159153

160154
CONDITIONS: dict[str, type[Condition]] = {

0 commit comments

Comments
 (0)