Skip to content

Commit d9e45cf

Browse files
committed
Make options and last_reset_template conditional, use sections for advanced settings
1 parent 19f6cc0 commit d9e45cf

File tree

4 files changed

+202
-73
lines changed

4 files changed

+202
-73
lines changed

homeassistant/components/mqtt/config_flow.py

+75-44
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,6 @@
241241
translation_key=CONF_PLATFORM,
242242
)
243243
)
244-
RESET_IF_EMPTY = {CONF_OPTIONS}
245-
246244
TEMPLATE_SELECTOR = TemplateSelector(TemplateSelectorConfig())
247245

248246
SUBENTRY_AVAILABILITY_SCHEMA = vol.Schema(
@@ -346,7 +344,12 @@ def unit_of_measurement_selector(user_data: dict[str, Any | None]) -> Selector:
346344
cv.positive_int,
347345
section="advanced_settings",
348346
),
349-
CONF_OPTIONS: PlatformField(OPTIONS_SELECTOR, False, cv.ensure_list),
347+
CONF_OPTIONS: PlatformField(
348+
OPTIONS_SELECTOR,
349+
False,
350+
cv.ensure_list,
351+
conditions=({"device_class": "enum"},),
352+
),
350353
},
351354
}
352355
PLATFORM_MQTT_FIELDS = {
@@ -379,7 +382,8 @@ def unit_of_measurement_selector(user_data: dict[str, Any | None]) -> Selector:
379382
},
380383
}
381384
ENTITY_CONFIG_VALIDATOR: dict[
382-
str, Callable[[dict[str, Any], dict[str, str]], dict[str, Any]] | None
385+
str,
386+
Callable[[dict[str, Any], dict[str, str], list[str] | None], dict[str, Any]] | None,
383387
] = {
384388
Platform.NOTIFY.value: None,
385389
Platform.SENSOR.value: validate_sensor_state_and_device_class_config,
@@ -445,22 +449,36 @@ def validate_field(
445449
errors[field] = error
446450

447451

452+
@callback
453+
def _check_conditions(
454+
platform_field: PlatformField, component: dict[str, Any] | None = None
455+
) -> bool:
456+
"""Only include field if one of conditions match, or no conditions are set."""
457+
if platform_field.conditions is None or component is None:
458+
return True
459+
return any(
460+
all(component.get(key) == value for key, value in condition.items())
461+
for condition in platform_field.conditions
462+
)
463+
464+
448465
@callback
449466
def validate_user_input(
450467
user_input: dict[str, Any],
451468
data_schema_fields: dict[str, PlatformField],
452469
errors: dict[str, str],
453-
config_validator: Callable[[dict[str, Any], dict[str, str]], dict[str, str]]
470+
component_data: dict[str, Any] | None,
471+
config_validator: Callable[
472+
[dict[str, Any], dict[str, str], list[str] | None], dict[str, str]
473+
]
454474
| None = None,
455-
component_data: dict[str, Any] | None = None,
456475
) -> dict[str, Any]:
457476
"""Validate user input."""
458477
# Merge sections
478+
reset_fields: list[str] = []
459479
merged_user_input: dict[str, Any] = {}
460480
for key, value in user_input.items():
461481
# Omit empty lists that are not allowed to be empty
462-
if key in RESET_IF_EMPTY and not value:
463-
continue
464482
if isinstance(value, dict):
465483
merged_user_input.update(value)
466484
else:
@@ -474,10 +492,30 @@ def validate_user_input(
474492
errors[field] = data_schema_fields[field].error or "invalid_input"
475493

476494
if config_validator is not None:
477-
config = merged_user_input
478-
if component_data is not None:
479-
config |= component_data
480-
config_validator(config, errors)
495+
if TYPE_CHECKING:
496+
assert component_data is not None
497+
schema_fields = tuple(
498+
{
499+
key
500+
for key, platform_field in data_schema_fields.items()
501+
if _check_conditions(platform_field, component_data)
502+
}
503+
- set(merged_user_input)
504+
)
505+
config_validator(
506+
{
507+
key: value
508+
for key, value in component_data.items()
509+
if key not in schema_fields
510+
}
511+
| merged_user_input,
512+
errors,
513+
reset_fields,
514+
)
515+
516+
for field in reset_fields:
517+
if component_data and field in component_data:
518+
del component_data[field]
481519

482520
return merged_user_input
483521

@@ -490,25 +528,14 @@ def data_schema_from_fields(
490528
user_input: dict[str, Any] | None = None,
491529
) -> vol.Schema:
492530
"""Generate custom data schema from platform fields."""
493-
494-
def _check_conditions(
495-
platform_field: PlatformField,
496-
) -> bool:
497-
"""Only include field if one of conditions match, or no conditions are set."""
498-
if platform_field.conditions is None or component is None:
499-
return True
500-
return any(
501-
all(component.get(key) == value for key, value in condition.items())
502-
for condition in platform_field.conditions
503-
)
504-
505-
user_data = component
531+
user_data = deepcopy(component)
506532
if user_data is not None and user_input is not None:
507533
user_data |= user_input
508534
sections: dict[str | None, None] = {
509535
field_details.section: None for field_details in data_schema_fields.values()
510536
}
511537
data_schema: dict[Any, Any] = {}
538+
all_data_element_options: set[Any] = set()
512539
for schema_section in sections:
513540
data_schema_element = {
514541
vol.Required(field_name, default=field_details.default)
@@ -521,23 +548,25 @@ def _check_conditions(
521548
for field_name, field_details in data_schema_fields.items()
522549
if field_details.section == schema_section
523550
and (not field_details.exclude_from_reconfig or not reconfig)
524-
and _check_conditions(field_details)
551+
and _check_conditions(field_details, user_data)
525552
}
553+
data_element_options = set(data_schema_element)
554+
all_data_element_options |= data_element_options
555+
if schema_section is None:
556+
data_schema.update(data_schema_element)
557+
continue
526558
collapsed = (
527-
bool(set(data_schema_element) - set(user_data)) # type: ignore[arg-type]
559+
bool(data_element_options - set(user_data)) # type: ignore[arg-type]
528560
if user_data is not None
529561
else True
530562
)
531-
if schema_section is None:
532-
data_schema.update(data_schema_element)
533-
continue
534563
data_schema[vol.Optional(schema_section)] = section(
535564
vol.Schema(data_schema_element), SectionConfig({"collapsed": collapsed})
536565
)
537566

538567
# Reset all fields from the component not in the schema
539568
if component:
540-
filtered_fields = set(data_schema_fields) - set(data_schema)
569+
filtered_fields = set(data_schema_fields) - all_data_element_options
541570
for field in filtered_fields:
542571
if field in component:
543572
del component[field]
@@ -1066,7 +1095,7 @@ class MQTTSubentryFlowHandler(ConfigSubentryFlow):
10661095

10671096
@callback
10681097
def update_component_fields(
1069-
self, data_schema: vol.Schema, user_input: dict[str, Any]
1098+
self, data_schema: vol.Schema, merged_user_input: dict[str, Any]
10701099
) -> None:
10711100
"""Update the componment fields."""
10721101
if TYPE_CHECKING:
@@ -1076,10 +1105,10 @@ def update_component_fields(
10761105
for field in [
10771106
form_field
10781107
for form_field in data_schema.schema
1079-
if form_field in component_data and form_field not in user_input
1108+
if form_field in component_data and form_field not in merged_user_input
10801109
]:
10811110
component_data.pop(field)
1082-
component_data.update(user_input)
1111+
component_data.update(merged_user_input)
10831112

10841113
@callback
10851114
def generate_names(self) -> tuple[str, str]:
@@ -1097,7 +1126,7 @@ def generate_names(self) -> tuple[str, str]:
10971126

10981127
@callback
10991128
def add_suggested_values_from_component_data_to_schema(
1100-
self, data_schema: vol.Schema, data_schema_fields: dict[str, PlatformField]
1129+
self, data_schema: vol.Schema
11011130
) -> vol.Schema:
11021131
"""Add suggestions from component data to data schema."""
11031132
if TYPE_CHECKING:
@@ -1123,14 +1152,15 @@ def _apply_suggested_value(
11231152
if not isinstance(value, section):
11241153
continue
11251154
data_section_schema = value.schema.schema
1126-
value.schema = vol.Schema(
1155+
new_schema = vol.Schema(
11271156
{
11281157
_apply_suggested_value(
11291158
section_field, component.get(section_field)
11301159
): section_field_selector
11311160
for section_field, section_field_selector in data_section_schema.items()
11321161
}
11331162
)
1163+
value.schema = new_schema
11341164

11351165
return vol.Schema(schema)
11361166

@@ -1182,16 +1212,16 @@ async def async_step_entity(
11821212
data_schema_fields = COMMON_ENTITY_FIELDS
11831213
entity_name_label: str = ""
11841214
platform_label: str = ""
1215+
component: dict[str, Any] | None = None
11851216
if reconfig := (self._component_id is not None):
1186-
name: str | None = self._subentry_data["components"][
1187-
self._component_id
1188-
].get(CONF_NAME)
1217+
component = self._subentry_data["components"][self._component_id]
1218+
name: str | None = component.get(CONF_NAME)
11891219
platform_label = f"{self._subentry_data['components'][self._component_id][CONF_PLATFORM]} "
11901220
entity_name_label = f" ({name})" if name is not None else ""
11911221
data_schema = data_schema_from_fields(data_schema_fields, reconfig=reconfig)
11921222
if user_input is not None:
11931223
merged_user_input = validate_user_input(
1194-
user_input, data_schema_fields, errors
1224+
user_input, data_schema_fields, errors, component
11951225
)
11961226
if not errors:
11971227
if self._component_id is None:
@@ -1202,7 +1232,7 @@ async def async_step_entity(
12021232
data_schema = self.add_suggested_values_to_schema(data_schema, user_input)
12031233
elif self.source == SOURCE_RECONFIGURE and self._component_id is not None:
12041234
data_schema = self.add_suggested_values_from_component_data_to_schema(
1205-
data_schema, data_schema_fields
1235+
data_schema
12061236
)
12071237
device_name = self._subentry_data[CONF_DEVICE][CONF_NAME]
12081238
return self.async_show_form(
@@ -1291,6 +1321,7 @@ async def async_step_entity_platform_config(
12911321
user_input,
12921322
data_schema_fields,
12931323
errors,
1324+
component,
12941325
ENTITY_CONFIG_VALIDATOR[platform],
12951326
)
12961327
if not errors:
@@ -1300,7 +1331,7 @@ async def async_step_entity_platform_config(
13001331
data_schema = self.add_suggested_values_to_schema(data_schema, user_input)
13011332
else:
13021333
data_schema = self.add_suggested_values_from_component_data_to_schema(
1303-
data_schema, data_schema_fields
1334+
data_schema
13041335
)
13051336

13061337
device_name, full_entity_name = self.generate_names()
@@ -1340,8 +1371,8 @@ async def async_step_mqtt_platform_config(
13401371
user_input,
13411372
data_schema_fields,
13421373
errors,
1374+
component,
13431375
ENTITY_CONFIG_VALIDATOR[platform],
1344-
self._subentry_data["components"][self._component_id],
13451376
)
13461377
if not errors:
13471378
self.update_component_fields(data_schema, merged_user_input)
@@ -1353,7 +1384,7 @@ async def async_step_mqtt_platform_config(
13531384
data_schema = self.add_suggested_values_to_schema(data_schema, user_input)
13541385
else:
13551386
data_schema = self.add_suggested_values_from_component_data_to_schema(
1356-
data_schema, data_schema_fields
1387+
data_schema
13571388
)
13581389
device_name, full_entity_name = self.generate_names()
13591390
return self.async_show_form(

homeassistant/components/mqtt/strings.json

+3-2
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@
250250
},
251251
"sections": {
252252
"advanced_settings": {
253-
"name": "Advanced options",
253+
"name": "Advanced settings",
254254
"data": {
255255
"expire_after": "Expire after"
256256
},
@@ -274,7 +274,8 @@
274274
"invalid_uom": "The unit of measurement is not allowed with the selected device class, please use the correct device class, or pick a valid unit of measurement from the list",
275275
"invalid_url": "Invalid URL",
276276
"options_not_allowed_with_state_class_or_uom": "The 'Options' setting is not allowed when state class or unit of measurement are used",
277-
"options_device_class_enum": "The 'Options' setting must be used with the Enumeration device class'"
277+
"options_device_class_enum": "The 'Options' setting must be used with the Enumeration device class'. If you continue, the existing options will be reset",
278+
"options_with_enum_device_class": "Configure options for the enumeration sensor"
278279
}
279280
}
280281
},

homeassistant/components/mqtt/util.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
from pathlib import Path
1111
import tempfile
12-
from typing import Any
12+
from typing import TYPE_CHECKING, Any
1313

1414
import voluptuous as vol
1515

@@ -435,23 +435,30 @@ def learn_more_url(platform: str) -> str:
435435

436436
@callback
437437
def validate_sensor_state_and_device_class_config(
438-
config: ConfigType, errors: dict[str, str] | None = None
438+
config: ConfigType,
439+
errors: dict[str, str] | None = None,
440+
reset_fields: list[str] | None = None,
439441
) -> ConfigType:
440442
"""Validate the sensor options, state and device class config."""
441443
if (
442444
CONF_LAST_RESET_VALUE_TEMPLATE in config
443445
and (state_class := config.get(CONF_STATE_CLASS)) != SensorStateClass.TOTAL
444446
):
445-
raise vol.Invalid(
446-
f"The option `{CONF_LAST_RESET_VALUE_TEMPLATE}` cannot be used "
447-
f"together with state class `{state_class}`"
448-
)
447+
if errors is None:
448+
raise vol.Invalid(
449+
f"The option `{CONF_LAST_RESET_VALUE_TEMPLATE}` cannot be used "
450+
f"together with state class `{state_class}`"
451+
)
452+
if reset_fields is not None:
453+
reset_fields.append(CONF_LAST_RESET_VALUE_TEMPLATE)
449454

450455
# Only allow `options` to be set for `enum` sensors
451456
# to limit the possible sensor values
452457
if (options := config.get(CONF_OPTIONS)) is not None:
453458
if not options:
454-
raise vol.Invalid("An empty options list is not allowed")
459+
if reset_fields is None:
460+
raise vol.Invalid("An empty options list is not allowed")
461+
reset_fields.append(CONF_OPTIONS)
455462
if config.get(CONF_STATE_CLASS) or config.get(CONF_UNIT_OF_MEASUREMENT):
456463
if errors is None:
457464
raise vol.Invalid(
@@ -467,11 +474,22 @@ def validate_sensor_state_and_device_class_config(
467474
f"together with device class `{SensorDeviceClass.ENUM}`, "
468475
f"got `{CONF_DEVICE_CLASS}` '{device_class}'"
469476
)
470-
errors[CONF_OPTIONS] = "options_device_class_enum"
477+
if TYPE_CHECKING:
478+
assert reset_fields is not None
479+
errors[CONF_DEVICE_CLASS] = "options_device_class_enum"
480+
reset_fields.append(CONF_OPTIONS)
471481

472-
if (device_class := config.get(CONF_DEVICE_CLASS)) is None or (
473-
unit_of_measurement := config.get(CONF_UNIT_OF_MEASUREMENT)
474-
) is None:
482+
if (
483+
(device_class := config.get(CONF_DEVICE_CLASS)) == SensorDeviceClass.ENUM
484+
and errors is not None
485+
and CONF_OPTIONS not in config
486+
):
487+
errors[CONF_OPTIONS] = "options_with_enum_device_class"
488+
489+
if (
490+
device_class is None
491+
or (unit_of_measurement := config.get(CONF_UNIT_OF_MEASUREMENT)) is None
492+
):
475493
return config
476494

477495
if (

0 commit comments

Comments
 (0)