27
27
28
28
from homeassistant .components .file_upload import process_uploaded_file
29
29
from homeassistant .components .hassio import AddonError , AddonManager , AddonState
30
+ from homeassistant .components .sensor import (
31
+ CONF_STATE_CLASS ,
32
+ DEVICE_CLASS_UNITS ,
33
+ SensorDeviceClass ,
34
+ SensorStateClass ,
35
+ )
30
36
from homeassistant .config_entries import (
31
37
SOURCE_RECONFIGURE ,
32
38
ConfigEntry ,
45
51
ATTR_SW_VERSION ,
46
52
CONF_CLIENT_ID ,
47
53
CONF_DEVICE ,
54
+ CONF_DEVICE_CLASS ,
48
55
CONF_DISCOVERY ,
49
56
CONF_HOST ,
50
57
CONF_NAME ,
53
60
CONF_PLATFORM ,
54
61
CONF_PORT ,
55
62
CONF_PROTOCOL ,
63
+ CONF_UNIT_OF_MEASUREMENT ,
56
64
CONF_USERNAME ,
65
+ CONF_VALUE_TEMPLATE ,
57
66
)
58
67
from homeassistant .core import HomeAssistant , callback
59
68
from homeassistant .data_entry_flow import AbortFlow
97
106
CONF_COMMAND_TOPIC ,
98
107
CONF_DISCOVERY_PREFIX ,
99
108
CONF_ENTITY_PICTURE ,
109
+ CONF_EXPIRE_AFTER ,
100
110
CONF_KEEPALIVE ,
111
+ CONF_LAST_RESET_VALUE_TEMPLATE ,
112
+ CONF_OPTIONS ,
101
113
CONF_QOS ,
102
114
CONF_RETAIN ,
115
+ CONF_STATE_TOPIC ,
116
+ CONF_SUGGESTED_DISPLAY_PRECISION ,
103
117
CONF_TLS_INSECURE ,
104
118
CONF_TRANSPORT ,
105
119
CONF_WILL_MESSAGE ,
127
141
from .util import (
128
142
async_create_certificate_temp_files ,
129
143
get_file_path ,
144
+ learn_more_url ,
130
145
valid_birth_will ,
131
146
valid_publish_topic ,
132
147
valid_qos_schema ,
133
148
valid_subscribe_topic ,
134
149
valid_subscribe_topic_template ,
150
+ validate_sensor_state_and_device_class_config ,
135
151
)
136
152
137
153
_LOGGER = logging .getLogger (__name__ )
211
227
)
212
228
213
229
# Subentry selectors
214
- SUBENTRY_PLATFORMS = [Platform .NOTIFY ]
230
+ RESET_IF_EMPTY = {CONF_OPTIONS }
231
+ SUBENTRY_PLATFORMS = [Platform .NOTIFY , Platform .SENSOR ]
215
232
SUBENTRY_PLATFORM_SELECTOR = SelectSelector (
216
233
SelectSelectorConfig (
217
234
options = [platform .value for platform in SUBENTRY_PLATFORMS ],
221
238
)
222
239
TEMPLATE_SELECTOR = TemplateSelector (TemplateSelectorConfig ())
223
240
241
+ # Sensor specific selectors
242
+ SENSOR_DEVICE_CLASS_SELECTOR = SelectSelector (
243
+ SelectSelectorConfig (
244
+ options = [device_class .value for device_class in SensorDeviceClass ],
245
+ mode = SelectSelectorMode .DROPDOWN ,
246
+ translation_key = CONF_DEVICE_CLASS ,
247
+ )
248
+ )
249
+ SENSOR_STATE_CLASS_SELECTOR = SelectSelector (
250
+ SelectSelectorConfig (
251
+ options = [device_class .value for device_class in SensorStateClass ],
252
+ mode = SelectSelectorMode .DROPDOWN ,
253
+ translation_key = CONF_STATE_CLASS ,
254
+ )
255
+ )
256
+ OPTIONS_SELECTOR = SelectSelector (
257
+ SelectSelectorConfig (
258
+ options = [],
259
+ custom_value = True ,
260
+ multiple = True ,
261
+ )
262
+ )
263
+ SUGGESTED_DISPLAY_PRECISION_SELECTOR = NumberSelector (
264
+ NumberSelectorConfig (mode = NumberSelectorMode .BOX , min = 0 , max = 9 )
265
+ )
266
+ EXIRE_AFTER_SELECTOR = NumberSelector (
267
+ NumberSelectorConfig (mode = NumberSelectorMode .BOX , min = 0 )
268
+ )
269
+
224
270
225
271
@dataclass (frozen = True )
226
272
class PlatformField :
227
273
"""Stores a platform config field schema, required flag and validator."""
228
274
229
- selector : Selector
275
+ selector : Selector [ Any ] | Callable [..., Selector [ Any ]]
230
276
required : bool
231
277
validator : Callable [..., Any ]
232
278
error : str | None = None
279
+ custom_filtering : bool = False
280
+
281
+
282
+ @callback
283
+ def unit_of_measurement_selector (user_data : dict [str , Any | None ]) -> Selector :
284
+ """Return a context based unit of measurement selector."""
285
+ if (
286
+ user_data is None
287
+ or (device_class := user_data .get (CONF_DEVICE_CLASS )) is None
288
+ or device_class not in DEVICE_CLASS_UNITS
289
+ ):
290
+ return TEXT_SELECTOR
291
+ return SelectSelector (
292
+ SelectSelectorConfig (
293
+ options = [str (uom ) for uom in DEVICE_CLASS_UNITS [device_class ]],
294
+ custom_value = True ,
295
+ )
296
+ )
233
297
234
298
235
299
COMMON_ENTITY_FIELDS = {
@@ -240,7 +304,20 @@ class PlatformField:
240
304
241
305
COMMON_MQTT_FIELDS = {
242
306
CONF_QOS : PlatformField (QOS_SELECTOR , False , valid_qos_schema ),
243
- CONF_RETAIN : PlatformField (BOOLEAN_SELECTOR , False , bool ),
307
+ }
308
+ PLATFORM_ENTITY_FIELDS = {
309
+ Platform .NOTIFY .value : {},
310
+ Platform .SENSOR .value : {
311
+ CONF_DEVICE_CLASS : PlatformField (SENSOR_DEVICE_CLASS_SELECTOR , False , str ),
312
+ CONF_STATE_CLASS : PlatformField (SENSOR_STATE_CLASS_SELECTOR , False , str ),
313
+ CONF_UNIT_OF_MEASUREMENT : PlatformField (
314
+ unit_of_measurement_selector , False , str , custom_filtering = True
315
+ ),
316
+ CONF_SUGGESTED_DISPLAY_PRECISION : PlatformField (
317
+ SUGGESTED_DISPLAY_PRECISION_SELECTOR , False , cv .positive_int
318
+ ),
319
+ CONF_OPTIONS : PlatformField (OPTIONS_SELECTOR , False , cv .ensure_list ),
320
+ },
244
321
}
245
322
PLATFORM_MQTT_FIELDS = {
246
323
Platform .NOTIFY .value : {
@@ -250,8 +327,27 @@ class PlatformField:
250
327
CONF_COMMAND_TEMPLATE : PlatformField (
251
328
TEMPLATE_SELECTOR , False , cv .template , "invalid_template"
252
329
),
330
+ CONF_RETAIN : PlatformField (BOOLEAN_SELECTOR , False , bool ),
331
+ },
332
+ Platform .SENSOR .value : {
333
+ CONF_STATE_TOPIC : PlatformField (
334
+ TEXT_SELECTOR , True , valid_subscribe_topic , "invalid_subscribe_topic"
335
+ ),
336
+ CONF_VALUE_TEMPLATE : PlatformField (
337
+ TEMPLATE_SELECTOR , False , cv .template , "invalid_template"
338
+ ),
339
+ CONF_LAST_RESET_VALUE_TEMPLATE : PlatformField (
340
+ TEMPLATE_SELECTOR , False , cv .template , "invalid_template"
341
+ ),
342
+ CONF_EXPIRE_AFTER : PlatformField (EXIRE_AFTER_SELECTOR , False , cv .positive_int ),
253
343
},
254
344
}
345
+ ENTITY_CONFIG_VALIDATOR : dict [
346
+ str , Callable [[dict [str , Any ], dict [str , str ]], dict [str , Any ]] | None
347
+ ] = {
348
+ Platform .NOTIFY .value : None ,
349
+ Platform .SENSOR .value : validate_sensor_state_and_device_class_config ,
350
+ }
255
351
256
352
MQTT_DEVICE_SCHEMA = vol .Schema (
257
353
{
@@ -318,6 +414,9 @@ def validate_user_input(
318
414
user_input : dict [str , Any ],
319
415
data_schema_fields : dict [str , PlatformField ],
320
416
errors : dict [str , str ],
417
+ config_validator : Callable [[dict [str , Any ], dict [str , str ]], dict [str , str ]]
418
+ | None = None ,
419
+ component_data : dict [str , Any ] | None = None ,
321
420
) -> None :
322
421
"""Validate user input."""
323
422
for field , value in user_input .items ():
@@ -327,15 +426,30 @@ def validate_user_input(
327
426
except (ValueError , vol .Invalid ):
328
427
errors [field ] = data_schema_fields [field ].error or "invalid_input"
329
428
429
+ if config_validator is not None :
430
+ config = user_input
431
+ if component_data is not None :
432
+ config |= component_data
433
+ config_validator (config , errors )
434
+
330
435
331
436
@callback
332
- def data_schema_from_fields (data_schema_fields : dict [str , PlatformField ]) -> vol .Schema :
333
- """Generate data schema from platform fields."""
437
+ def data_schema_from_fields (
438
+ data_schema_fields : dict [str , PlatformField ],
439
+ component : dict [str , Any ] | None = None ,
440
+ user_input : dict [str , Any ] | None = None ,
441
+ ) -> vol .Schema :
442
+ """Generate custom data schema from platform fields."""
443
+ user_data = component
444
+ if user_data is not None and user_input is not None :
445
+ user_data |= user_input
334
446
return vol .Schema (
335
447
{
336
448
vol .Required (field_name )
337
449
if field_details .required
338
- else vol .Optional (field_name ): field_details .selector
450
+ else vol .Optional (field_name ): field_details .selector (user_data ) # type: ignore[operator]
451
+ if field_details .custom_filtering
452
+ else field_details .selector
339
453
for field_name , field_details in data_schema_fields .items ()
340
454
}
341
455
)
@@ -878,6 +992,34 @@ def update_component_fields(
878
992
component_data .pop (field )
879
993
component_data .update (user_input )
880
994
995
+ @callback
996
+ def reset_if_empty (self , user_input : dict [str , Any ]) -> None :
997
+ """Reset fields in componment config that are not in the user_input."""
998
+ if TYPE_CHECKING :
999
+ assert self ._component_id is not None
1000
+ for field in [
1001
+ form_field
1002
+ for form_field in user_input
1003
+ if form_field in RESET_IF_EMPTY
1004
+ and form_field in RESET_IF_EMPTY
1005
+ and not user_input [form_field ]
1006
+ ]:
1007
+ user_input .pop (field )
1008
+
1009
+ @callback
1010
+ def generate_names (self ) -> tuple [str , str ]:
1011
+ """Generate the device and full entity name."""
1012
+ if TYPE_CHECKING :
1013
+ assert self ._component_id is not None
1014
+ device_name = self ._subentry_data [CONF_DEVICE ][CONF_NAME ]
1015
+ if entity_name := self ._subentry_data ["components" ][self ._component_id ].get (
1016
+ CONF_NAME
1017
+ ):
1018
+ full_entity_name : str = f"{ device_name } { entity_name } "
1019
+ else :
1020
+ full_entity_name = device_name
1021
+ return device_name , full_entity_name
1022
+
881
1023
async def async_step_user (
882
1024
self , user_input : dict [str , Any ] | None = None
883
1025
) -> SubentryFlowResult :
@@ -933,7 +1075,7 @@ async def async_step_entity(
933
1075
self ._component_id = uuid4 ().hex
934
1076
self ._subentry_data ["components" ].setdefault (self ._component_id , {})
935
1077
self .update_component_fields (data_schema , user_input )
936
- return await self .async_step_mqtt_platform_config ()
1078
+ return await self .async_step_entity_platform_config ()
937
1079
data_schema = self .add_suggested_values_to_schema (data_schema , user_input )
938
1080
elif self .source == SOURCE_RECONFIGURE and self ._component_id is not None :
939
1081
data_schema = self .add_suggested_values_to_schema (
@@ -993,6 +1135,52 @@ async def async_step_delete_entity(
993
1135
return await self .async_step_summary_menu ()
994
1136
return self ._show_update_or_delete_form ("delete_entity" )
995
1137
1138
+ async def async_step_entity_platform_config (
1139
+ self , user_input : dict [str , Any ] | None = None
1140
+ ) -> SubentryFlowResult :
1141
+ """Configure platform entity details."""
1142
+ if TYPE_CHECKING :
1143
+ assert self ._component_id is not None
1144
+ component = self ._subentry_data ["components" ][self ._component_id ]
1145
+ platform = component [CONF_PLATFORM ]
1146
+ if not (data_schema_fields := PLATFORM_ENTITY_FIELDS [platform ]):
1147
+ return await self .async_step_mqtt_platform_config ()
1148
+ errors : dict [str , str ] = {}
1149
+ device_name , full_entity_name = self .generate_names ()
1150
+
1151
+ data_schema = data_schema_from_fields (data_schema_fields , component , user_input )
1152
+ if user_input is not None :
1153
+ # Test entity fields against the validator
1154
+ self .reset_if_empty (user_input )
1155
+ validate_user_input (
1156
+ user_input ,
1157
+ data_schema_fields ,
1158
+ errors ,
1159
+ ENTITY_CONFIG_VALIDATOR [platform ],
1160
+ )
1161
+ if not errors :
1162
+ self .update_component_fields (data_schema , user_input )
1163
+ return await self .async_step_mqtt_platform_config ()
1164
+
1165
+ data_schema = self .add_suggested_values_to_schema (data_schema , user_input )
1166
+ else :
1167
+ data_schema = self .add_suggested_values_to_schema (
1168
+ data_schema , self ._subentry_data ["components" ][self ._component_id ]
1169
+ )
1170
+
1171
+ return self .async_show_form (
1172
+ step_id = "entity_platform_config" ,
1173
+ data_schema = data_schema ,
1174
+ description_placeholders = {
1175
+ "mqtt_device" : device_name ,
1176
+ CONF_PLATFORM : platform ,
1177
+ "entity" : full_entity_name ,
1178
+ "url" : learn_more_url (platform ),
1179
+ },
1180
+ errors = errors ,
1181
+ last_step = False ,
1182
+ )
1183
+
996
1184
async def async_step_mqtt_platform_config (
997
1185
self , user_input : dict [str , Any ] | None = None
998
1186
) -> SubentryFlowResult :
@@ -1002,19 +1190,20 @@ async def async_step_mqtt_platform_config(
1002
1190
assert self ._component_id is not None
1003
1191
device_name = self ._subentry_data [CONF_DEVICE ][CONF_NAME ]
1004
1192
platform = self ._subentry_data ["components" ][self ._component_id ][CONF_PLATFORM ]
1005
- entity_name : str | None
1006
- if entity_name := self ._subentry_data ["components" ][self ._component_id ].get (
1007
- CONF_NAME
1008
- ):
1009
- full_entity_name : str = f"{ device_name } { entity_name } "
1010
- else :
1011
- full_entity_name = device_name
1193
+ device_name , full_entity_name = self .generate_names ()
1012
1194
1013
1195
data_schema_fields = PLATFORM_MQTT_FIELDS [platform ] | COMMON_MQTT_FIELDS
1014
1196
data_schema = data_schema_from_fields (data_schema_fields )
1015
1197
if user_input is not None :
1016
1198
# Test entity fields against the validator
1017
- validate_user_input (user_input , data_schema_fields , errors )
1199
+ self .reset_if_empty (user_input )
1200
+ validate_user_input (
1201
+ user_input ,
1202
+ data_schema_fields ,
1203
+ errors ,
1204
+ ENTITY_CONFIG_VALIDATOR [platform ],
1205
+ self ._subentry_data ["components" ][self ._component_id ],
1206
+ )
1018
1207
if not errors :
1019
1208
self .update_component_fields (data_schema , user_input )
1020
1209
self ._component_id = None
@@ -1035,6 +1224,7 @@ async def async_step_mqtt_platform_config(
1035
1224
"mqtt_device" : device_name ,
1036
1225
CONF_PLATFORM : platform ,
1037
1226
"entity" : full_entity_name ,
1227
+ "url" : learn_more_url (platform ),
1038
1228
},
1039
1229
errors = errors ,
1040
1230
last_step = False ,
0 commit comments