1313import logging
1414import re
1515import sys
16- from typing import TYPE_CHECKING , Any , Protocol , cast
16+ from typing import TYPE_CHECKING , Any , Protocol , TypedDict , Unpack , cast , overload
1717
1818import voluptuous as vol
1919
@@ -298,7 +298,7 @@ def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
298298 self ._hass = hass
299299
300300 @abc .abstractmethod
301- async def async_get_checker (self ) -> ConditionCheckerType :
301+ async def async_get_checker (self ) -> ConditionChecker :
302302 """Get the condition checker."""
303303
304304
@@ -319,7 +319,23 @@ class ConditionConfig:
319319 target : dict [str , Any ] | None = None
320320
321321
322- type ConditionCheckerType = Callable [[HomeAssistant , TemplateVarsType ], bool | None ]
322+ class ConditionCheckParams (TypedDict , total = False ):
323+ """Condition check params."""
324+
325+ variables : TemplateVarsType
326+
327+
328+ class ConditionChecker (Protocol ):
329+ """Protocol for condition checker callable with typed kwargs."""
330+
331+ def __call__ (self , ** kwargs : Unpack [ConditionCheckParams ]) -> bool :
332+ """Check the condition."""
333+
334+
335+ type ConditionCheckerType = Callable [[HomeAssistant , TemplateVarsType ], bool ]
336+ type ConditionCheckerTypeOptional = Callable [
337+ [HomeAssistant , TemplateVarsType ], bool | None
338+ ]
323339
324340
325341def condition_trace_append (variables : TemplateVarsType , path : str ) -> TraceElement :
@@ -374,7 +390,21 @@ def trace_condition(variables: TemplateVarsType) -> Generator[TraceElement]:
374390 trace_stack_pop (trace_stack_cv )
375391
376392
377- def trace_condition_function (condition : ConditionCheckerType ) -> ConditionCheckerType :
393+ @overload
394+ def trace_condition_function (
395+ condition : ConditionCheckerType ,
396+ ) -> ConditionCheckerType : ...
397+
398+
399+ @overload
400+ def trace_condition_function (
401+ condition : ConditionCheckerTypeOptional ,
402+ ) -> ConditionCheckerTypeOptional : ...
403+
404+
405+ def trace_condition_function (
406+ condition : ConditionCheckerType | ConditionCheckerTypeOptional ,
407+ ) -> ConditionCheckerType | ConditionCheckerTypeOptional :
378408 """Wrap a condition function to enable basic tracing."""
379409
380410 @ft .wraps (condition )
@@ -420,10 +450,20 @@ async def _async_get_condition_platform(
420450 ) from None
421451
422452
453+ async def _async_get_checker (condition : Condition ) -> ConditionCheckerType :
454+ new_checker = await condition .async_get_checker ()
455+
456+ @trace_condition_function
457+ def checker (hass : HomeAssistant , variables : TemplateVarsType = None ) -> bool :
458+ return new_checker (variables = variables )
459+
460+ return checker
461+
462+
423463async def async_from_config (
424464 hass : HomeAssistant ,
425465 config : ConfigType ,
426- ) -> ConditionCheckerType :
466+ ) -> ConditionCheckerTypeOptional :
427467 """Turn a condition configuration into a method.
428468
429469 Should be run on the event loop.
@@ -466,7 +506,7 @@ def disabled_condition(
466506 target = config .get (CONF_TARGET ),
467507 ),
468508 )
469- return await condition . async_get_checker ( )
509+ return await _async_get_checker ( condition )
470510
471511 for fmt in (ASYNC_FROM_CONFIG_FORMAT , FROM_CONFIG_FORMAT ):
472512 factory = getattr (sys .modules [__name__ ], fmt .format (condition_key ), None )
@@ -1131,7 +1171,7 @@ async def async_conditions_from_config(
11311171 name : str ,
11321172) -> Callable [[TemplateVarsType ], bool ]:
11331173 """AND all conditions."""
1134- checks : list [ ConditionCheckerType ] = [
1174+ checks = [
11351175 await async_from_config (hass , condition_config )
11361176 for condition_config in condition_configs
11371177 ]
@@ -1330,7 +1370,6 @@ async def async_get_all_descriptions(
13301370 continue
13311371
13321372 description = {"fields" : yaml_description .get ("fields" , {})}
1333-
13341373 if (target := yaml_description .get ("target" )) is not None :
13351374 description ["target" ] = target
13361375
0 commit comments