Skip to content

Commit b310ba3

Browse files
committed
Add validation on unit of measurement and show SelectSelector if we know the device class
1 parent 8484073 commit b310ba3

File tree

5 files changed

+92
-8
lines changed

5 files changed

+92
-8
lines changed

homeassistant/components/mqtt/config_flow.py

+38-7
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from homeassistant.components.hassio import AddonError, AddonManager, AddonState
3030
from homeassistant.components.sensor import (
3131
CONF_STATE_CLASS,
32+
DEVICE_CLASS_UNITS,
3233
SensorDeviceClass,
3334
SensorStateClass,
3435
)
@@ -271,10 +272,28 @@
271272
class PlatformField:
272273
"""Stores a platform config field schema, required flag and validator."""
273274

274-
selector: Selector
275+
selector: Selector[Any] | Callable[..., Selector[Any]]
275276
required: bool
276277
validator: Callable[..., Any]
277278
error: str | None = None
279+
custom_filtering: bool = False
280+
281+
282+
@callback
283+
def unit_of_measurement_selector(user_data: dict[str, Any | None]) -> Selector:
284+
"""Return a context based unit of measurement selector."""
285+
if (
286+
user_data is None
287+
or (device_class := user_data.get(CONF_DEVICE_CLASS)) is None
288+
or device_class not in DEVICE_CLASS_UNITS
289+
):
290+
return TEXT_SELECTOR
291+
return SelectSelector(
292+
SelectSelectorConfig(
293+
options=[str(uom) for uom in DEVICE_CLASS_UNITS[device_class]],
294+
custom_value=True,
295+
)
296+
)
278297

279298

280299
COMMON_ENTITY_FIELDS = {
@@ -291,7 +310,9 @@ class PlatformField:
291310
Platform.SENSOR.value: {
292311
CONF_DEVICE_CLASS: PlatformField(SENSOR_DEVICE_CLASS_SELECTOR, False, str),
293312
CONF_STATE_CLASS: PlatformField(SENSOR_STATE_CLASS_SELECTOR, False, str),
294-
CONF_UNIT_OF_MEASUREMENT: PlatformField(TEXT_SELECTOR, False, str),
313+
CONF_UNIT_OF_MEASUREMENT: PlatformField(
314+
unit_of_measurement_selector, False, str, custom_filtering=True
315+
),
295316
CONF_SUGGESTED_DISPLAY_PRECISION: PlatformField(
296317
SUGGESTED_DISPLAY_PRECISION_SELECTOR, False, cv.positive_int
297318
),
@@ -413,13 +434,22 @@ def validate_user_input(
413434

414435

415436
@callback
416-
def data_schema_from_fields(data_schema_fields: dict[str, PlatformField]) -> vol.Schema:
417-
"""Generate data schema from platform fields."""
437+
def data_schema_from_fields(
438+
data_schema_fields: dict[str, PlatformField],
439+
component: dict[str, Any] | None = None,
440+
user_input: dict[str, Any] | None = None,
441+
) -> vol.Schema:
442+
"""Generate custom data schema from platform fields."""
443+
user_data = component
444+
if user_data is not None and user_input is not None:
445+
user_data |= user_input
418446
return vol.Schema(
419447
{
420448
vol.Required(field_name)
421449
if field_details.required
422-
else vol.Optional(field_name): field_details.selector
450+
else vol.Optional(field_name): field_details.selector(user_data) # type: ignore[operator]
451+
if field_details.custom_filtering
452+
else field_details.selector
423453
for field_name, field_details in data_schema_fields.items()
424454
}
425455
)
@@ -1141,13 +1171,14 @@ async def async_step_entity_platform_config(
11411171
"""Configure platform entity details."""
11421172
if TYPE_CHECKING:
11431173
assert self._component_id is not None
1144-
platform = self._subentry_data["components"][self._component_id][CONF_PLATFORM]
1174+
component = self._subentry_data["components"][self._component_id]
1175+
platform = component[CONF_PLATFORM]
11451176
if not (data_schema_fields := PLATFORM_ENTITY_FIELDS[platform]):
11461177
return await self.async_step_mqtt_platform_config()
11471178
errors: dict[str, str] = {}
11481179
device_name, full_entity_name = self.generate_names()
11491180

1150-
data_schema = data_schema_from_fields(data_schema_fields)
1181+
data_schema = data_schema_from_fields(data_schema_fields, component, user_input)
11511182
if user_input is not None:
11521183
# Test entity fields against the validator
11531184
self.reset_if_empty(user_input)

homeassistant/components/mqtt/strings.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,10 @@
251251
"invalid_input": "Invalid value",
252252
"invalid_subscribe_topic": "Invalid subscribe topic",
253253
"invalid_template": "Invalid template",
254+
"invalid_uom": "The unit of measurement is not allowed with the selected device class, please use the correct device class, or pick a valid unit of measurement from the list",
254255
"invalid_url": "Invalid URL",
255256
"last_reset_not_with_state_class_total": "The last reset value template option should be used with state class 'Total' only",
256-
"options_not_allowed_with_state_class_or_uom": "The 'Options' setting is not allowed when state class or unit of measurement are used.",
257+
"options_not_allowed_with_state_class_or_uom": "The 'Options' setting is not allowed when state class or unit of measurement are used",
257258
"options_device_class_enum": "The 'Options' setting must be used with the Enumeration device class'"
258259
}
259260
}

homeassistant/components/mqtt/util.py

+19
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from homeassistant.components.sensor import (
1717
CONF_STATE_CLASS,
18+
DEVICE_CLASS_UNITS,
1819
SensorDeviceClass,
1920
SensorStateClass,
2021
)
@@ -426,11 +427,13 @@ def migrate_certificate_file_to_content(file_name_or_auto: str) -> str | None:
426427
return None
427428

428429

430+
@callback
429431
def learn_more_url(platform: str) -> str:
430432
"""Return the URL for the platform specific MQTT documentation."""
431433
return f"https://www.home-assistant.io/integrations/{platform}.mqtt/"
432434

433435

436+
@callback
434437
def validate_sensor_state_and_device_class_config(
435438
config: ConfigType, errors: dict[str, str] | None = None
436439
) -> ConfigType:
@@ -468,4 +471,20 @@ def validate_sensor_state_and_device_class_config(
468471
)
469472
errors[CONF_OPTIONS] = "options_device_class_enum"
470473

474+
if (device_class := config.get(CONF_DEVICE_CLASS)) is None or (
475+
unit_of_measurement := config.get(CONF_UNIT_OF_MEASUREMENT)
476+
) is None:
477+
return config
478+
479+
if (
480+
device_class in DEVICE_CLASS_UNITS
481+
and unit_of_measurement not in DEVICE_CLASS_UNITS[device_class]
482+
):
483+
if errors is None:
484+
raise vol.Invalid(
485+
f"The unit of measurement `{unit_of_measurement}` is not valid "
486+
f"together with device class `{device_class}`"
487+
)
488+
errors[CONF_UNIT_OF_MEASUREMENT] = "invalid_uom"
489+
471490
return config

tests/components/mqtt/test_config_flow.py

+7
Original file line numberDiff line numberDiff line change
@@ -2678,6 +2678,13 @@ async def test_migrate_of_incompatible_config_entry(
26782678
},
26792679
{"options": "options_not_allowed_with_state_class_or_uom"},
26802680
),
2681+
(
2682+
{
2683+
"device_class": "energy",
2684+
"unit_of_measurement": "ppm",
2685+
},
2686+
{"unit_of_measurement": "invalid_uom"},
2687+
),
26812688
),
26822689
{
26832690
"state_topic": "test-topic",

tests/components/mqtt/test_sensor.py

+26
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,32 @@ async def test_invalid_device_class(
870870
assert "expected SensorDeviceClass or one of" in caplog.text
871871

872872

873+
@pytest.mark.parametrize(
874+
"hass_config",
875+
[
876+
{
877+
mqtt.DOMAIN: {
878+
sensor.DOMAIN: {
879+
"name": "test",
880+
"state_topic": "test-topic",
881+
"device_class": "energy",
882+
"unit_of_measurement": "ppm",
883+
}
884+
}
885+
}
886+
],
887+
)
888+
async def test_invalid_unit_of_measurement(
889+
mqtt_mock_entry: MqttMockHAClientGenerator, caplog: pytest.LogCaptureFixture
890+
) -> None:
891+
"""Test device_class with invalid unit of measurement."""
892+
assert await mqtt_mock_entry()
893+
assert (
894+
"The unit of measurement `ppm` is not valid together with device class `energy`"
895+
in caplog.text
896+
)
897+
898+
873899
@pytest.mark.parametrize(
874900
"hass_config",
875901
[

0 commit comments

Comments
 (0)