From 910e3496d1c612f18b6b36e1a47ed7bbca8dc428 Mon Sep 17 00:00:00 2001 From: sphings79 <43515272+sphings79@users.noreply.github.com> Date: Mon, 4 May 2026 21:18:55 +0200 Subject: [PATCH 1/6] Enhance async_setup_entry with detailed comments Refactor setup entry process and improve error handling. Signed-off-by: sphings79 <43515272+sphings79@users.noreply.github.com> --- custom_components/marstek_modbus/__init__.py | 97 +++++++++----------- 1 file changed, 42 insertions(+), 55 deletions(-) diff --git a/custom_components/marstek_modbus/__init__.py b/custom_components/marstek_modbus/__init__.py index 027a487..655ad4d 100644 --- a/custom_components/marstek_modbus/__init__.py +++ b/custom_components/marstek_modbus/__init__.py @@ -9,6 +9,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryNotReady from .const import DOMAIN from .coordinator import MarstekCoordinator @@ -23,23 +24,11 @@ "button", "number", "binary_sensor", -] +] async def async_setup(hass: HomeAssistant, config: dict) -> bool: - """ - General setup of the integration. - - This is called once when Home Assistant starts. - It does not perform any configuration and always returns True. - - Args: - hass: Home Assistant instance. - config: Configuration dict. - - Returns: - True always. - """ + """General setup – called once when Home Assistant starts.""" return True @@ -47,83 +36,81 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """ Set up a config entry. - Initializes the coordinator for this entry and stores it in hass.data. - Forwards setup to platforms (e.g., sensor, select) used by this integration. - - Args: - hass: Home Assistant instance. - entry: ConfigEntry to setup. - - Returns: - True if setup successful, False otherwise. + Order of operations: + 1. Load register YAML for the configured device version. + 2. Connect to the Modbus gateway (raises ConfigEntryNotReady on failure). + 3. Forward setup to all entity platforms so entities are created and + added to HA (async_added_to_hass is complete for all of them). + 4. Query the entity registry for any user-enabled non-default entities + and add them to the coordinator's polling groups dynamically. + 5. Run the first coordinator refresh so every entity has a value + immediately (tick 1 polls all groups). """ try: - # Migrate legacy device_version tokens in existing config entries to - # the canonical SUPPORTED_VERSIONS strings. This handles older - # installations that used tokens like 'v1/v2' or 'v3'. + # Warn about unsupported device_version strings in existing entries raw_version = (entry.data.get("device_version") or "").strip() if raw_version: normalized = raw_version.lower() - # Consider anything not listed in SUPPORTED_VERSIONS as legacy/unsupported. allowed = {s.lower() for s in SUPPORTED_VERSIONS} if normalized not in allowed: _LOGGER.warning( - "Config entry %s uses unsupported device_version '%s'. Please remove and re-add the device with the correct device version. Supported versions: %s", - entry.entry_id, - raw_version, - ", ".join(SUPPORTED_VERSIONS), + "Config entry %s uses unsupported device_version '%s'. " + "Please remove and re-add the device. Supported: %s", + entry.entry_id, raw_version, ", ".join(SUPPORTED_VERSIONS), ) - # Create the coordinator for data management and attempt an initial - # connection before forwarding platform setup so the client is ready. + coordinator = MarstekCoordinator(hass, entry) hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator - # Load register definitions off the event loop to avoid blocking + # 1 – Load register definitions (blocking I/O runs in executor) try: await coordinator.async_load_registers(entry.data.get("device_version")) except Exception as err: - _LOGGER.warning("Failed loading register definitions for entry %s: %s", entry.entry_id, err) - - # Establish the Modbus connection upfront so the first refresh does not - # lazily reconnect on individual sensor reads, and failure is properly - # tracked from the start. + _LOGGER.warning( + "Failed loading register definitions for entry %s: %s", + entry.entry_id, err, + ) + + # 2 – Connect to Modbus gateway. + # ConfigEntryNotReady is intentionally NOT caught here – it must + # propagate to HA so the built-in retry mechanism (exponential + # backoff: 5 s → 10 s → 30 s → 60 s → …) kicks in automatically. + # This handles temporary failures such as a disconnected LAN cable + # or the device being in the middle of a reboot. await coordinator.async_init() - # Forward setup to all platforms defined in PLATFORMS + # 3 – Create all entity platforms (async_added_to_hass runs here) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) - # Perform first refresh to ensure coordinator has up-to-date data + # 4 – Register user-enabled non-default entities for polling + await coordinator.async_register_enabled_entities() + + # 5 – First refresh: tick 1 polls ALL groups so every entity has + # a value immediately without waiting for slow intervals. + # ConfigEntryNotReady from here also propagates so HA retries. await coordinator.async_config_entry_first_refresh() return True + except ConfigEntryNotReady: + # Re-raise so HA schedules an automatic retry. + # The platforms forwarded in step 3 are cleaned up by HA automatically. + raise except Exception as err: _LOGGER.error("Error setting up entry %s: %s", entry.entry_id, err) return False async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """ - Unload a config entry and its associated platforms. - - Args: - hass: Home Assistant instance. - entry: ConfigEntry to unload. - - Returns: - True if unload successful, False otherwise. - """ + """Unload a config entry and its associated platforms.""" try: - # Unload all platforms for the entry unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) if unload_ok: - # Retrieve the coordinator and close it before removing coordinator = hass.data[DOMAIN][entry.entry_id] await coordinator.async_close() - # Remove coordinator reference from hass data hass.data[DOMAIN].pop(entry.entry_id, None) return unload_ok except Exception as err: _LOGGER.error("Error unloading entry %s: %s", entry.entry_id, err) - return False \ No newline at end of file + return False From 82d8eb6db03bfe23c4c1c2eb5e8345cd2603f135 Mon Sep 17 00:00:00 2001 From: sphings79 <43515272+sphings79@users.noreply.github.com> Date: Mon, 4 May 2026 21:19:54 +0200 Subject: [PATCH 2/6] Refactor MarstekNumber entity initialization Refactor number entity setup and improve comments for clarity. Signed-off-by: sphings79 <43515272+sphings79@users.noreply.github.com> --- custom_components/marstek_modbus/number.py | 124 +++++++-------------- 1 file changed, 42 insertions(+), 82 deletions(-) diff --git a/custom_components/marstek_modbus/number.py b/custom_components/marstek_modbus/number.py index 0a646dd..80171c4 100644 --- a/custom_components/marstek_modbus/number.py +++ b/custom_components/marstek_modbus/number.py @@ -1,7 +1,10 @@ """ Module for creating number entities for Marstek Venus battery devices. Numbers read Modbus registers asynchronously via the coordinator. -All entities are registered through the coordinator to enable centralized polling. + +SCALING NOTE: coordinator.data already contains scaled values. +native_value returns coordinator data as-is. +async_set_native_value converts engineering-unit input back to raw register value. """ import logging @@ -24,142 +27,99 @@ async def async_setup_entry( entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: - """ - Set up number entities when the config entry is loaded. - - This function retrieves the coordinator from hass.data, - creates number entities based on NUMBER_DEFINITIONS, - and registers them with Home Assistant. - - Args: - hass: Home Assistant instance. - entry: Configuration entry. - async_add_entities: Callback to add entities. - """ - # Retrieve the coordinator instance from hass data and add entities + """Set up number entities when the config entry is loaded.""" coordinator = hass.data[DOMAIN][entry.entry_id] - entities = [MarstekNumber(coordinator, definition) for definition in coordinator.NUMBER_DEFINITIONS] - async_add_entities(entities) + entities = [ + MarstekNumber(coordinator, definition) + for definition in coordinator.NUMBER_DEFINITIONS + ] + async_add_entities(entities) class MarstekNumber(CoordinatorEntity, NumberEntity): - """ - Representation of a Modbus number entity for Marstek Venus. - - Number state is read and write asynchronously via - the coordinator communicating with the Modbus device. - """ + """Modbus number entity for Marstek Venus.""" def __init__(self, coordinator: MarstekCoordinator, definition: dict): - """ - Initialize the number entity. - - Args: - coordinator: The data update coordinator instance. - definition: Dictionary containing sensor configuration. - """ super().__init__(coordinator) - # Store the key and definition self._key = definition["key"] - self.definition = definition - - # Assign the entity type to the coordinator mapping + self.definition = definition self.coordinator._entity_types[self._key] = self.entity_type - # Set entity attributes from definition self._attr_unique_id = f"{coordinator.config_entry.entry_id}_{self.definition['key']}" self._attr_has_entity_name = True self._attr_translation_key = definition["key"] - # Internal state variables - self._state = None self._register = definition["register"] - - # Set min, max, and step from definition if provided - self._attr_native_min_value = self.definition.get('min', 0) - self._attr_native_max_value = self.definition.get('max', 100) - self._attr_native_step = self.definition.get('step', 1) + self._attr_native_min_value = definition.get("min", 0) + self._attr_native_max_value = definition.get("max", 100) + self._attr_native_step = definition.get("step", 1) + self._attr_native_unit_of_measurement = definition.get("unit") self._scale = definition.get("scale", 1) - self._unit = definition.get("unit", None) - - # set category if defined in the definition - if "category" in self.definition: - self._attr_entity_category = EntityCategory(self.definition.get("category")) - # Set icon if defined in the button definition - if "icon" in self.definition: - self._attr_icon = self.definition.get("icon") - - # Optional: disable entity by default if specified in the definition + if "category" in definition: + self._attr_entity_category = EntityCategory(definition["category"]) + if "icon" in definition: + self._attr_icon = definition["icon"] if definition.get("enabled_by_default") is False: self._attr_entity_registry_enabled_default = False @property def entity_type(self) -> str: - """ - Return the type of this entity for logging purposes. - This allows the coordinator to show more descriptive messages. - """ return "number" @property def available(self) -> bool: - """ - Return True if the coordinator has successfully fetched data. - Used by Home Assistant to determine entity availability. - """ return self.coordinator.last_update_success @property def native_value(self) -> float | None: """ - Return the current value of the number entity. - Value is obtained from the coordinator's shared data dictionary. + Return the current value. + coordinator.data already contains scaled values – return as-is. """ data = self.coordinator.data if data is None: return None - raw_value = data.get(self._key) - return raw_value * self._scale if raw_value is not None else None + return data.get(self._key) async def async_set_native_value(self, value: float) -> None: """ - Write the given value to the Modbus register via the coordinator. - This updates the number entity in Home Assistant. + Write the given engineering-unit value to the Modbus register. + + Convert back to raw register value (reverse the scale), then + optimistically store the engineering-unit value in coordinator.data + so HA shows the correct state immediately. """ - # Convert the float value to an integer for Modbus - raw_value = int(value / self._scale) - - # Optimistically update the coordinator data so HA shows the new state immediately - self.coordinator.data[self._key] = raw_value + scale = self._scale if self._scale else 1 + raw_value = int(round(value / scale)) + + # Optimistic update: store the SCALED (engineering-unit) value so + # native_value returns the correct number without waiting for next poll + self.coordinator.data[self._key] = value self.async_write_ha_state() - # Write the value using the coordinator's async_write_value method success = await self.coordinator.async_write_value( register=self._register, value=raw_value, key=self._key, - scale=self._scale, - unit=self._unit, + scale=scale, + unit=self.definition.get("unit"), entity_type=self.entity_type, ) - - # Only refresh if write failed to get actual device state + if not success: - _LOGGER.debug("Write failed for %s, refreshing to get actual state", self._key) - await self.coordinator.async_read_value(self.definition, self._key, track_failure=False) + _LOGGER.debug( + "Write failed for %s, refreshing to get actual state", self._key + ) + await self.coordinator.async_request_refresh() @property def device_info(self) -> dict: - """ - Return device information for Home Assistant's device registry. - Includes identifiers, name, manufacturer, model, and entry type. - """ return { "identifiers": {(DOMAIN, self.coordinator.config_entry.entry_id)}, "name": self.coordinator.config_entry.title, "manufacturer": MANUFACTURER, "model": MODEL, "entry_type": "service", - } \ No newline at end of file + } From 1b59eee58a52392c7b037e925bca67637e5dceb6 Mon Sep 17 00:00:00 2001 From: sphings79 <43515272+sphings79@users.noreply.github.com> Date: Mon, 4 May 2026 21:20:23 +0200 Subject: [PATCH 3/6] Refactor Marstek sensor scaling and calculations Refactor Marstek sensor classes to ensure scaling is handled correctly by the coordinator. Adjusted methods to prevent double scaling and improved handling of sensor attributes and calculations. Signed-off-by: sphings79 <43515272+sphings79@users.noreply.github.com> --- custom_components/marstek_modbus/sensor.py | 236 +++++---------------- 1 file changed, 56 insertions(+), 180 deletions(-) diff --git a/custom_components/marstek_modbus/sensor.py b/custom_components/marstek_modbus/sensor.py index 1c075f9..301c9a1 100644 --- a/custom_components/marstek_modbus/sensor.py +++ b/custom_components/marstek_modbus/sensor.py @@ -3,6 +3,11 @@ All sensors now derive their values from the shared coordinator data. No separate async_update needed; coordinator handles polling. + +SCALING NOTE: The coordinator applies scale via extract_typed_value() before +storing values in coordinator.data. Sensor entities must NOT apply scale again. +The definition's "scale" and "precision" fields are only used for display hints, +not for re-scaling already-scaled coordinator data. """ import logging @@ -10,7 +15,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.components.sensor import SensorEntity from homeassistant.core import HomeAssistant -from homeassistant.helpers.entity import Entity, EntityCategory +from homeassistant.helpers.entity import EntityCategory from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity @@ -28,7 +33,6 @@ async def async_setup_entry( """Set up all Marstek sensors from definitions.""" coordinator = hass.data[DOMAIN][entry.entry_id] - # Create sensor entities from coordinator-provided definitions entities = [] sensor_groups = ( (MarstekSensor, coordinator.SENSOR_DEFINITIONS), @@ -39,7 +43,6 @@ async def async_setup_entry( for entity_cls, definitions in sensor_groups: entities.extend(entity_cls(coordinator, definition) for definition in definitions) - # Add all entities to Home Assistant async_add_entities(entities) @@ -49,24 +52,18 @@ class MarstekSensor(CoordinatorEntity, SensorEntity): def __init__(self, coordinator: MarstekCoordinator, definition: dict): super().__init__(coordinator) - # Store the key and definition self._key = definition["key"] - self.definition = definition - - # Assign the entity type to the coordinator mapping + self.definition = definition self.coordinator._entity_types[self._key] = self.entity_type - # Set entity attributes from definition self._attr_unique_id = f"{coordinator.config_entry.entry_id}_{self.definition['key']}" self._attr_has_entity_name = True self._attr_translation_key = definition["key"] - # Set basic attributes from definition self._attr_native_unit_of_measurement = definition.get("unit") self._attr_device_class = definition.get("device_class") self._attr_state_class = definition.get("state_class") - # Optional: entity category and icon if "category" in definition: self._attr_entity_category = EntityCategory(definition["category"]) if "icon" in definition: @@ -74,99 +71,67 @@ def __init__(self, coordinator: MarstekCoordinator, definition: dict): if definition.get("enabled_by_default") is False: self._attr_entity_registry_enabled_default = False - # Optional states mapping for int → label conversion self.states = definition.get("states") @property def entity_type(self) -> str: - """ - Return the type of this entity for logging purposes. - This allows the coordinator to show more descriptive messages. - """ return "sensor" @property def available(self) -> bool: - """Return True if coordinator has valid data for this sensor.""" - # Consider the sensor available when coordinator has provided a value - # for this key. This avoids sensors remaining 'unknown' when the - # coordinator had transient update failures but still supplies data. data = getattr(self.coordinator, "data", None) return isinstance(data, dict) and self._key in data @property def native_value(self): - """Return the value from coordinator data with scaling and states applied.""" + """ + Return the sensor value from coordinator data. + + IMPORTANT: coordinator.data already contains SCALED values. + Do NOT multiply by scale here – that would double-scale everything. + Only apply: states mapping, precision rounding, ems_version special case. + """ if self._key not in self.coordinator.data: return None + value = self.coordinator.data[self._key] - # Special handling for schedule data type: the sensor state should - # represent whether the schedule is enabled (boolean). The raw - # register list is exposed in attributes under `raw` and all decoding - # / interpretation is performed in `extra_state_attributes`. + # Schedule sensor: return boolean enabled state if self.definition.get("data_type") == "schedule": data = getattr(self.coordinator, "data", {}) or {} - # Prefer decoded attrs if coordinator provided them, otherwise - # attempt to decode from the raw register list. attrs = data.get(f"{self._key}_attrs") or {} enabled = None - if isinstance(attrs, dict) and "enabled" in attrs: try: enabled = bool(int(attrs.get("enabled") or 0)) except Exception: enabled = bool(attrs.get("enabled")) else: - # Try to decode from raw registers stored at data[self._key] raw = data.get(self._key) if isinstance(raw, (list, tuple)) and len(raw) >= 5: try: enabled = bool(int(raw[4])) except Exception: enabled = bool(raw[4]) - - # If we couldn't determine enabled state, return None (unknown) - if enabled is None: - return None - return enabled if isinstance(value, (int, float)): - # Special-case: EMS version is encoded as an integer where - # values with 4 digits encode a decimal in the last digit - # (e.g. 1573 -> 157.3), while 3-digit values are whole numbers - # (e.g. 158 -> 158). Handle that before applying generic scale. + # Special case: EMS version encoding if self._key == "ems_version": try: iv = int(value) - except Exception: - iv = None - - if iv is not None: if iv >= 1000: - # interpret last digit as decimal (tenths) value = round(iv / 10.0, 1) else: value = int(iv) - # return early after mapping; skip generic scaling if isinstance(value, float) and value.is_integer(): value = int(value) - # apply states mapping below - else: - # fall back to generic handling if conversion fails + except Exception: pass else: - # Apply scaling/offset and round according to precision. - scale = self.definition.get("scale", 1) - offset = self.definition.get("offset", 0) + # Coordinator already applied scale. Only round to display precision. precision = int(self.definition.get("precision", 0) or 0) - - value = float(value) * scale + offset - value = round(value, precision) - - # If the rounded value has no fractional component, return int - # so Home Assistant does not render an unnecessary trailing .0. + value = round(float(value), precision) if isinstance(value, float) and value.is_integer(): value = int(value) @@ -177,14 +142,12 @@ def native_value(self): @property def suggested_display_precision(self) -> int | None: - """Suggest display precision based on definition, but only if not a string or mapped state.""" if self.states: return None return self.definition.get("precision") @property def suggested_display_unit(self) -> str | None: - """Suggest display unit based on definition, but only if not a string or mapped state.""" if self.states: return None return self.definition.get("unit") @@ -201,13 +164,8 @@ def device_info(self) -> dict: @property def extra_state_attributes(self) -> dict: - """Return attributes for packed schedule sensors from coordinator data.""" data = self.coordinator.data or {} attrs = data.get(f"{self._key}_attrs") or {} - # For schedule types, enrich attributes with human-readable fields. - # If `_attrs` is not present but the coordinator stored the raw - # 5-register list in `data[key]`, decode that here so we don't - # duplicate decoding in the coordinator. if self.definition.get("data_type") == "schedule": if not isinstance(attrs, dict) or not attrs: raw = data.get(self._key) @@ -227,10 +185,6 @@ def extra_state_attributes(self) -> dict: def _fmt_time(t): try: t = int(t) - # Heuristic: device encodes times as HHMM (e.g. 200 -> 02:00, - # 610 -> 06:10) when the low two digits are < 60 and the - # value is within 0..2359. Otherwise treat value as - # minutes-since-midnight. if 0 <= t <= 2359 and (t % 100) < 60: hh = t // 100 mm = t % 100 @@ -241,45 +195,27 @@ def _fmt_time(t): except Exception: return t - # Debug logging for raw schedule data from coordinator - _LOGGER.warning( - "Raw schedule data for %s: value=%s attrs=%s", - self._key, - data.get(self._key), - attrs, - ) - days = attrs.get("days") try: dmask = int(days) if days is not None else 0 except Exception: dmask = 0 - # Bits are encoded with Monday at bit 0 (device ordering), but - # display should start with Sunday. Compute set using Monday-first - # mapping, then reorder to Sunday-first for presentation. weekday_names_mon = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] selected_mon = [weekday_names_mon[i] for i in range(7) if (dmask >> i) & 1] display_order = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"] selected = [d for d in display_order if d in selected_mon] - # Build a minimal enriched dict — do not duplicate raw fields. - enriched = {} - enriched["days_list"] = selected - enriched["start_time"] = _fmt_time(attrs.get("start")) - enriched["end_time"] = _fmt_time(attrs.get("end")) - - # Interpret mode into a human-friendly type and a separate watt attribute. - # NOTE: device uses signed mode where -1 == self consumption and - # signed values represent magnitude. Empirically the device - # uses negative -> charge and positive -> discharge (inverse - # of earlier assumption), so map accordingly. + enriched = { + "days_list": selected, + "start_time": _fmt_time(attrs.get("start")), + "end_time": _fmt_time(attrs.get("end")), + } + mode_raw = attrs.get("mode") mode = None power = None try: - if mode_raw is None: - mode = None - else: + if mode_raw is not None: m = int(mode_raw) if m == -1: mode = "self consumption" @@ -290,8 +226,7 @@ def _fmt_time(t): mode = "discharge" power = m except Exception: - mode = None - power = None + pass enriched["mode"] = mode enriched["power"] = power @@ -305,31 +240,25 @@ class MarstekCalculatedSensor(CoordinatorEntity, SensorEntity): """ Base class for calculated sensors that depend on multiple coordinator keys. - Handles registration of dependency keys and provides update handling. + SCALING NOTE: dependency values are read directly from coordinator.data, + which already contains scaled values. Do NOT multiply by scale again. """ def __init__(self, coordinator: MarstekCoordinator, definition: dict): - """Initialize the calculated sensor and register dependencies.""" super().__init__(coordinator) - # Store the key and definition self._key = definition["key"] self.definition = definition - - # Assign the entity type to the coordinator mapping self.coordinator._entity_types[self._key] = self.entity_type - # Set entity attributes from definition self._attr_unique_id = f"{coordinator.config_entry.entry_id}_{self.definition['key']}" self._attr_has_entity_name = True self._attr_translation_key = definition["key"] - # Set basic attributes from definition self._attr_native_unit_of_measurement = definition.get("unit") self._attr_device_class = definition.get("device_class") self._attr_state_class = definition.get("state_class") - # Optional: entity category and icon if "category" in definition: self._attr_entity_category = EntityCategory(definition["category"]) if "icon" in definition: @@ -337,41 +266,15 @@ def __init__(self, coordinator: MarstekCoordinator, definition: dict): if definition.get("enabled_by_default") is False: self._attr_entity_registry_enabled_default = False - # Register dependency keys in coordinator and set scales - for alias, dep_key in self.get_dependency_keys().items(): - if not dep_key: - continue - - self.coordinator._entity_types[dep_key] = "sensor" - - # Combine all definitions for iteration using coordinator-provided lists - if not hasattr(self, "_all_definitions"): - self._all_definitions = ( - self.coordinator.SENSOR_DEFINITIONS + self.coordinator.BINARY_SENSOR_DEFINITIONS - ) - all_definitions = self._all_definitions - - # Get scale from all definitions or fallback to current sensor dependency_defs - scale = next((d.get("scale", 1) for d in all_definitions if d.get("key") == dep_key), None) - scale = scale or self.definition.get("dependency_defs", {}).get(alias, 1) - - self.coordinator._scales[dep_key] = scale - def get_dependency_keys(self): - """Return the keys this sensor depends on.""" return self.definition.get("dependency_keys", {}) @property def entity_type(self) -> str: - """ - Return the type of this entity for logging purposes. - This allows the coordinator to show more descriptive messages. - """ return "sensor" @property def device_info(self) -> dict: - """Return device info so sensor is linked to the integration/device.""" return { "identifiers": {(DOMAIN, self.coordinator.config_entry.entry_id)}, "name": self.coordinator.config_entry.title, @@ -381,43 +284,40 @@ def device_info(self) -> dict: } def _handle_coordinator_update(self) -> None: - """ - Handle coordinator update by recalculating the sensor value. - - Calls the subclass's calculate_value method and updates state. - """ if not getattr(self.coordinator, "last_update_success", False): self._attr_native_value = None self.async_write_ha_state() return data = self.coordinator.data if isinstance(self.coordinator.data, dict) else {} - self._calculate(data) self.async_write_ha_state() def _calculate(self, data: dict) -> None: """ - Centralized method to check dependencies, log missing values, - calculate value, and update native_value attribute. + Check dependencies and calculate sensor value. + + Values in coordinator.data are already scaled – use them directly + without multiplying by scale again. """ dependency_keys = self.get_dependency_keys() dep_values = {} missing = [] - # dependency_keys is a dict alias -> actual key for alias, actual_key in dependency_keys.items(): val = data.get(actual_key) - scale = self.coordinator._scales.get(actual_key, 1) if val is None: missing.append(alias) else: - dep_values[alias] = float(val) * scale + # coordinator.data already contains scaled values – no re-scaling + dep_values[alias] = float(val) if missing: _LOGGER.warning( "%s missing required value(s): %s. Current data: %s. Cannot calculate value.", - self._key, ", ".join(missing), {k: data.get(v) for k, v in dependency_keys.items()}, + self._key, + ", ".join(missing), + {k: data.get(v) for k, v in dependency_keys.items()}, ) self._attr_native_value = None return @@ -426,55 +326,38 @@ def _calculate(self, data: dict) -> None: value = self.calculate_value(dep_values) _LOGGER.debug( "Calculated value for %s: %s (input values: %s)", - self._key, - value, - dep_values + self._key, value, dep_values, ) self._attr_native_value = value except Exception as ex: - _LOGGER.warning( - "Error calculating value for sensor %s: %s", self._key, ex - ) + _LOGGER.warning("Error calculating value for sensor %s: %s", self._key, ex) self._attr_native_value = None def calculate_value(self, dep_values: dict): - """ - Calculate the sensor value from scaled dependency values. - - Must be implemented by subclasses. - """ raise NotImplementedError class MarstekStoredEnergySensor(MarstekCalculatedSensor): - """ - Sensor calculating stored battery energy (kWh). + """Stored battery energy = SOC% × capacity (kWh).""" - Uses SOC (%) and battery total energy (kWh) from coordinator data. - """ def calculate_value(self, dep_values: dict): - """Calculate stored energy based on SOC and capacity dynamically.""" soc = dep_values.get("soc") capacity = dep_values.get("capacity") - stored_energy = round((soc / 100) * capacity, 2) - self._attr_native_value = stored_energy - return stored_energy + if soc is None or capacity in (None, 0): + return None + return round((soc / 100) * capacity, 2) class MarstekEfficiencySensor(MarstekCalculatedSensor): - """ - Calculate either Round Trip Efficiency (RTE) or Actual Conversion Efficiency. + """Round-trip or conversion efficiency sensor.""" - Mode is determined by 'mode' in the sensor definition: - - "round_trip": uses charge / discharge energy - - "conversion": uses battery_power / ac_power - """ def calculate_value(self, dep_values: dict): mode = self.definition.get("mode", "round_trip") + if mode == "round_trip": charge = dep_values.get("charge") discharge = dep_values.get("discharge") - if charge in (None, 0): + if not charge: return None efficiency = (discharge / charge) * 100 @@ -484,32 +367,25 @@ def calculate_value(self, dep_values: dict): if battery_power is None or ac_power is None: return None if battery_power > 0: - if ac_power == 0: - return None - efficiency = abs(battery_power) / abs(ac_power) * 100 + efficiency = abs(battery_power) / abs(ac_power) * 100 if ac_power else None else: - if battery_power == 0: - return None - efficiency = abs(ac_power) / abs(battery_power) * 100 + efficiency = abs(ac_power) / abs(battery_power) * 100 if battery_power else 0.0 + if efficiency is None: + return None else: _LOGGER.warning("%s unknown efficiency mode '%s'", self._key, mode) return None - efficiency_rounded = round(min(efficiency, 100.0), 1) - self._attr_native_value = efficiency_rounded - return efficiency_rounded + return round(min(efficiency, 100.0), 1) class MarstekBatteryCycleSensor(MarstekCalculatedSensor): - """Calculate estimated battery cycles from discharge energy and capacity.""" + """Estimated battery cycles = total discharge ÷ capacity.""" def calculate_value(self, dep_values: dict): discharge = dep_values.get("discharge") capacity = dep_values.get("capacity") - if discharge is None or capacity in (None, 0): + if discharge is None or not capacity: return None - - cycles = round(discharge / capacity, 2) - self._attr_native_value = cycles - return cycles \ No newline at end of file + return round(discharge / capacity, 2) From cb5dc39754e71e40b06d2c70af89239f7c37f323 Mon Sep 17 00:00:00 2001 From: sphings79 <43515272+sphings79@users.noreply.github.com> Date: Mon, 4 May 2026 21:21:28 +0200 Subject: [PATCH 4/6] Refactor Modbus client to use tmodbus library Refactor Modbus client to use tmodbus, adding async connection and read/write methods. Maintain backward compatibility with existing read helpers. Signed-off-by: sphings79 <43515272+sphings79@users.noreply.github.com> --- .../marstek_modbus/helpers/modbus_client.py | 913 ++++++++---------- 1 file changed, 390 insertions(+), 523 deletions(-) diff --git a/custom_components/marstek_modbus/helpers/modbus_client.py b/custom_components/marstek_modbus/helpers/modbus_client.py index d513b73..2c0776d 100644 --- a/custom_components/marstek_modbus/helpers/modbus_client.py +++ b/custom_components/marstek_modbus/helpers/modbus_client.py @@ -1,572 +1,439 @@ """ -Helper module for Modbus TCP communication using pymodbus. -Provides an abstraction for reading and writing registers from -a Marstek Venus battery system asynchronously. +modbus_client.py – tmodbus wrapper for marstek_modbus. + +Provides: + - create_client / disconnect_client (connection lifecycle) + - batch_read() (block-optimised multi-register read) + - extract_typed_value() (decode raw register cache → Python value) + - write_register / write_registers (single + multi write) + +Individual per-type read helpers (read_uint16 etc.) are kept for +backward-compatibility in case other files call them. + +Drop into: custom_components/marstek_modbus/helpers/modbus_client.py + (or wherever the integration currently places this file) """ -from pymodbus.client.tcp import AsyncModbusTcpClient -import asyncio -import socket -from typing import Optional +from __future__ import annotations import logging +import struct +from typing import Any -from ..const import DEFAULT_MESSAGE_WAIT_MS, DEFAULT_UNIT_ID +from tmodbus import create_async_tcp_client +from tmodbus.client import AsyncModbusClient +from tmodbus.exceptions import TModbusError, ModbusConnectionError _LOGGER = logging.getLogger(__name__) +# ── Block-read tuning ──────────────────────────────────────────────────────── +MAX_BLOCK_SIZE: int = 64 # max registers per single Modbus FC03 request + # (Modbus spec allows up to 125; 64 is a safe + # conservative value for RS-485 gateways) +MAX_GAP: int = 15 # bridge gaps of up to 15 registers between + # requested addresses – dramatically reduces the + # number of round-trips for sparse register maps + # (e.g. cell voltages at 34018-34033 + 34003 → 1 + # request instead of 2) +# ──────────────────────────────────────────────────────────────────────────── + + +# --------------------------------------------------------------------------- +# Connection +# --------------------------------------------------------------------------- + +async def create_client( + host: str, port: int, unit_id: int, timeout: float = 5.0 +) -> AsyncModbusClient: + """Create and connect a tmodbus TCP client.""" + client = create_async_tcp_client( + host, + port, + unit_id=unit_id, + timeout=timeout, + connect_timeout=timeout, + auto_reconnect=True, + wait_between_requests=0.05, # 50 ms – safe for Elfin EW11 / RS485 gateways + ) + await client.connect() + _LOGGER.debug("tmodbus connected to %s:%s unit_id=%s", host, port, unit_id) + return client + + +async def disconnect_client(client: AsyncModbusClient | None) -> None: + """Gracefully disconnect.""" + if client is not None: + try: + await client.disconnect() + except Exception as exc: # noqa: BLE001 + _LOGGER.debug("Error while disconnecting: %s", exc) + -class MarstekModbusClient: +# --------------------------------------------------------------------------- +# Block builder +# --------------------------------------------------------------------------- + +def _build_blocks( + addresses: list[int], + max_gap: int = MAX_GAP, + max_size: int = MAX_BLOCK_SIZE, + bad_gaps: set[tuple[int, int]] | None = None, +) -> list[tuple[int, int]]: """ - Wrapper for pymodbus AsyncModbusTcpClient with helper methods - for async reading/writing and interpreting common data types. + Group a list of register addresses into (start, count) read blocks. + + Rules: + - Gaps ≤ max_gap between addresses are bridged (those registers are + read but their values are simply available in the cache for free). + - A block is split once it would exceed max_size registers. + - Gaps listed in bad_gaps are never bridged regardless of size – they + caused a block failure on a previous poll and the device rejected them. + + Returns a list of (start_address, count) tuples, sorted by address. + """ + if not addresses: + return [] + + sorted_addrs = sorted(set(addresses)) + blocks: list[tuple[int, int]] = [] + block_start = sorted_addrs[0] + prev_addr = sorted_addrs[0] + + for addr in sorted_addrs[1:]: + gap = addr - prev_addr + new_size = addr - block_start + 1 + gap_is_bad = bad_gaps is not None and (prev_addr, addr) in bad_gaps + if gap <= max_gap and new_size <= max_size and not gap_is_bad: + prev_addr = addr + else: + blocks.append((block_start, prev_addr - block_start + 1)) + block_start = addr + prev_addr = addr + + blocks.append((block_start, prev_addr - block_start + 1)) + return blocks + + +# --------------------------------------------------------------------------- +# Batch read +# --------------------------------------------------------------------------- + +async def batch_read( + client: AsyncModbusClient, + addresses: list[int], + max_gap: int = MAX_GAP, + max_size: int = MAX_BLOCK_SIZE, + bad_gaps: set[tuple[int, int]] | None = None, + good_gaps: set[tuple[int, int]] | None = None, +) -> dict[int, int]: """ + Read all *addresses* using the fewest possible Modbus requests. - def __init__(self, host: str, port: int, message_wait_ms: int = DEFAULT_MESSAGE_WAIT_MS, timeout: int = 3, unit_id: int = DEFAULT_UNIT_ID): - """ - Initialize Modbus client with host, port, message wait time, timeout, and unit ID. - - Args: - host (str): IP address or hostname of Modbus server. - port (int): TCP port number. - message_wait_ms (int): Delay in ms between Modbus messages. - timeout (int): Connection timeout in seconds (default 3 for faster failure). - unit_id (int): Modbus Unit ID (slave ID), default is 1. - """ - self.host = host - self.port = port - self.timeout = timeout + Returns {register_address: raw_uint16_value} for every address that + was read (including gap-bridging registers). - # Normalize and guard message_wait_ms so it is never None - self.message_wait_ms = int(message_wait_ms) if message_wait_ms is not None else DEFAULT_MESSAGE_WAIT_MS + bad_gaps: mutable set of (left, right) gap pairs that caused block + failures in a previous poll. _build_blocks avoids bridging + these so the device is never asked about unsupported ranges. - # Precompute seconds sleep to avoid repeated float(None) errors - try: - self.message_wait_sec = max(0.0, float(self.message_wait_ms) / 1000.0) - except (TypeError, ValueError): - self.message_wait_sec = float(DEFAULT_MESSAGE_WAIT_MS) / 1000.0 - - # Create pymodbus async TCP client instance - self.client = AsyncModbusTcpClient( - host=host, - port=port, - timeout=timeout, - ) + good_gaps: mutable set of (left, right) gap pairs that were bridged + successfully at least once. A gap already in good_gaps is + NEVER added to bad_gaps – this protects block optimisations + from being destroyed by a temporary TCP outage (where every + block fails, not just the ones with unsupported registers). - # set message wait on client if supported - try: - self.client.message_wait_milliseconds = self.message_wait_ms - except AttributeError: - pass + Both sets are mutated in place; pass coordinator._bad_gaps / + coordinator._good_gaps to enable self-healing gap avoidance. - # Normalize and guard unit_id so it is never None - try: - self.unit_id = int(unit_id) - except (TypeError, ValueError): - self.unit_id = DEFAULT_UNIT_ID + Raises ConnectionError / OSError on Modbus failure. + """ + if not addresses: + return {} - # Lock to serialize outgoing Modbus requests to avoid transaction id collisions - self._request_lock = asyncio.Lock() + blocks = _build_blocks(addresses, max_gap, max_size, bad_gaps) - async def async_connect(self) -> bool: - """ - Connect asynchronously to the Modbus TCP server. + _LOGGER.debug( + "batch_read: %d unique addresses → %d block(s): %s", + len(set(addresses)), + len(blocks), + [(s, s + c - 1) for s, c in blocks], + ) - Returns: - bool: True if connection succeeded, False otherwise. - """ - # Always create a fresh client instance to avoid reusing internal - # buffers/state that may be left in an inconsistent state after - # network interruptions. This reduces "extra data" / parse errors - # and stale transaction id problems. + cache: dict[int, int] = {} + needed = set(addresses) # only addresses we actually need (not gap fillers) + + for start, count in blocks: try: - # Close and discard any existing client first - if self.client: - try: - result = self.client.close() - if asyncio.iscoroutine(result): - await result - except Exception: - pass - - # Create a new client instance - self.client = AsyncModbusTcpClient( - host=self.host, - port=self.port, - timeout=self.timeout, + regs = await client.read_holding_registers( + start_address=start, quantity=count ) - # restore configured properties where supported - try: - self.client.message_wait_milliseconds = self.message_wait_ms - except Exception: - pass - - connected = await self.client.connect() - - if connected: - # Small settle time so the device has time to flush and be ready - await asyncio.sleep(max(0.2, self.message_wait_sec)) - # Enable TCP keepalive so the OS probes dead connections quickly - # rather than waiting hours for the default kernel timeout. - try: - transport = getattr(self.client, "transport", None) - if transport is not None: - sock = transport.get_extra_info("socket") - if sock is not None: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if hasattr(socket, "TCP_KEEPIDLE"): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60) - if hasattr(socket, "TCP_KEEPINTVL"): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10) - if hasattr(socket, "TCP_KEEPCNT"): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) - _LOGGER.debug("TCP keepalive enabled on Modbus socket") - except Exception as ke: - _LOGGER.debug("Could not set TCP keepalive: %s", ke) - _LOGGER.info( - "Connected to Modbus server at %s:%s with unit %s", - self.host, - self.port, - self.unit_id, + for i, val in enumerate(regs): + cache[start + i] = val + + # Block succeeded – record every gap bridged within it as good. + # This protects these gaps from being poisoned as bad during a + # future TCP outage where all blocks fail simultaneously. + if good_gaps is not None: + block_needed = sorted( + addr for addr in needed if start <= addr < start + count ) - else: - _LOGGER.warning( - "Failed to connect to Modbus server at %s:%s with unit %s", - self.host, - self.port, - self.unit_id, + for i in range(len(block_needed) - 1): + left, right = block_needed[i], block_needed[i + 1] + if right - left > 1 and (left, right) not in good_gaps: + good_gaps.add((left, right)) + _LOGGER.debug( + "Confirmed good gap %d → %d (gap=%d): " + "protected from bad-gap poisoning on TCP outages.", + left, right, right - left, + ) + + except Exception as exc: # noqa: BLE001 + # One block failed (timeout, illegal address, connection drop…). + # Log and skip – do NOT abort the whole poll. + # The missing registers will simply stay absent from the cache. + _LOGGER.debug( + "Block %d–%d failed (%s: %s) – skipping", + start, start + count - 1, type(exc).__name__, exc, + ) + + # Record the gaps between consecutive needed addresses inside + # this failed block as bad – UNLESS the gap was already confirmed + # good (i.e. succeeded before). A confirmed-good gap failing now + # means a TCP outage, not an unsupported register range. + if bad_gaps is not None: + block_needed = sorted( + addr for addr in needed if start <= addr < start + count ) + for i in range(len(block_needed) - 1): + left, right = block_needed[i], block_needed[i + 1] + if right - left > 1: + already_good = good_gaps is not None and (left, right) in good_gaps + already_bad = (left, right) in bad_gaps + if not already_good and not already_bad: + bad_gaps.add((left, right)) + _LOGGER.info( + "Recorded bad gap %d → %d (gap=%d): " + "will not bridge this range in future requests.", + left, right, right - left, + ) - return bool(connected) - except Exception as e: - _LOGGER.exception("Exception while connecting to Modbus server: %s", e) - return False + # Try each needed address in this block individually as fallback. + # Use quantity=1 so a single bad register can't poison the rest. + for addr in range(start, start + count): + if addr not in needed: + continue + try: + single = await client.read_holding_registers( + start_address=addr, quantity=1 + ) + cache[addr] = single[0] + except Exception: # noqa: BLE001 + pass # register truly not available – leave absent - async def async_close(self) -> None: - """ - Close the Modbus TCP connection safely (sync or async) - and reset client reference. - """ - if not self.client: - return + return cache - try: - result = self.client.close() - if asyncio.iscoroutine(result): - await result - _LOGGER.debug("Modbus client closed successfully") - except Exception as e: - _LOGGER.debug("Error closing Modbus client: %s", e) - finally: - # Ensure client reference is cleared so future connect creates fresh instance - self.client = None - - async def async_reconnect(self) -> bool: - """Reconnect to the Modbus TCP server by closing and re-opening the connection.""" - async with self._request_lock: - _LOGGER.info("Reconnecting to Modbus server at %s:%s", self.host, self.port) - - try: - try: - await self.async_close() - except Exception as e: - _LOGGER.debug("Error closing Modbus client during reconnect: %s", e) - try: - connected = await self.async_connect() - except Exception as e: - _LOGGER.warning( - "Exception while reconnecting to Modbus server at %s:%s: %s", - self.host, - self.port, - e, - ) - return False +# --------------------------------------------------------------------------- +# Value extraction +# --------------------------------------------------------------------------- - if connected: - _LOGGER.info("Reconnected to Modbus server at %s:%s", self.host, self.port) - else: - _LOGGER.warning("Reconnect failed to Modbus server at %s:%s", self.host, self.port) +def extract_typed_value( + cache: dict[int, int], + address: int, + data_type: str, + count: int, + scale: float, +) -> int | float | str | None: + """ + Extract and scale a value from a batch_read cache dict. - return connected - except Exception as e: - _LOGGER.warning("Unhandled exception during reconnect: %s", e) - return False + Returns None if any required address is missing from the cache. - async def async_read_register( - self, - register: int, - data_type: str = "uint16", - count: Optional[int] = None, - bit_index: Optional[int] = None, - sensor_key: Optional[str] = None, - max_retries: int = 3, - retry_delay: float = 0.1, - ): - """ - Robustly read registers and interpret the data asynchronously with retries. - - Args: - register (int): Register address to read from. - data_type (str): Data type for interpretation, e.g. 'int16', 'int32', 'char', 'bit'. - count (Optional[int]): Number of registers to read (default depends on data_type). - bit_index (Optional[int]): Bit position for 'bit' data type (0-15). - sensor_key (Optional[str]): Sensor key for logging. - max_retries (int): Maximum number of read attempts. - retry_delay (float): Delay in seconds between retries. - - Returns: - int, str, bool, or None: Interpreted value or None on error. - """ + data_type: uint16 | int16 | uint32 | int32 | char + count: number of registers (1 for 16-bit, 2 for 32-bit, N for char) + scale: multiply raw integer result (not applied to char) + """ + # Check all required addresses are present + needed = range(address, address + count) + if any(a not in cache for a in needed): + return None - if count is None: - count = 2 if data_type in ["int32", "uint32"] else 1 + raw = cache[address] - if not (0 <= register <= 0xFFFF): - _LOGGER.error( - "Invalid register address: %d (0x%04X). Must be 0-65535.", - register, - register, - ) - return None + if data_type == "uint16": + value: int | float | str = raw - if not (1 <= count <= 125): # Modbus spec limit - _LOGGER.error( - "Invalid register count: %d. Must be between 1 and 125.", - count, - ) - return None + elif data_type == "int16": + value = struct.unpack(">h", struct.pack(">H", raw))[0] - attempt = 0 - while attempt < max_retries: - # Guard against client being None (closed during unload) - client_connected = False - try: - client_connected = bool(self.client and getattr(self.client, "connected", False)) - except Exception: - client_connected = False - - if not client_connected: - _LOGGER.warning( - "Modbus client not connected, attempting reconnect before register %d (0x%04X)", - register, - register, - ) - connected = await self.async_connect() - if not connected: - _LOGGER.error( - "Reconnect failed, skipping register %d (0x%04X)", - register, - register, - ) - return None - - try: - result = None - # Serialize Modbus requests to avoid overlapping frames and transaction id mismatches - async with self._request_lock: - try: - # Try multiple kwarg names for different pymodbus versions - read_method = getattr(self.client, "read_holding_registers") - for unit_kw in ("device_id", "unit", "slave"): - try: - result = await read_method(address=register, count=count, **{unit_kw: self.unit_id}) - break - except TypeError: - result = None - continue - finally: - # Short spacing after each request to give the device time - try: - await asyncio.sleep(self.message_wait_sec) - except asyncio.CancelledError: - raise - - if result is None: - _LOGGER.error( - "No response object returned for register %d (0x%04X) on attempt %d", - register, - register, - attempt + 1, - ) - elif getattr(result, "isError", lambda: False)(): - _LOGGER.error( - "Modbus read error at register %d (0x%04X) on attempt %d", - register, - register, - attempt + 1, - ) - elif not hasattr(result, "registers") or result.registers is None or len(result.registers) < count: - _LOGGER.warning( - "Incomplete data received at register %d (0x%04X) on attempt %d: expected %d registers, got %s", - register, - register, - attempt + 1, - count, - len(result.registers) if result.registers else 0, - ) - else: - regs = result.registers - _LOGGER.debug( - "Requesting register %d (0x%04X) from '%s' for sensor '%s' (type: %s, count: %s)", - register, - register, - self.host, - sensor_key or 'unknown', - data_type, - count, - ) - _LOGGER.debug("Received data from '%s' for register %d (0x%04X): %s", self.host, register, register, regs) - - if data_type == "int16": - val = regs[0] - return val - 0x10000 if val >= 0x8000 else val - - elif data_type == "uint16": - return regs[0] - - elif data_type == "int32": - if len(regs) < 2: - _LOGGER.warning( - "Expected 2 registers for int32 at register %d (0x%04X), got %s", - register, - register, - len(regs), - ) - return None - val = (regs[0] << 16) | regs[1] - return val - 0x100000000 if val >= 0x80000000 else val - - elif data_type == "uint32": - if len(regs) < 2: - _LOGGER.warning( - "Expected 2 registers for uint32 at register %d (0x%04X), got %s", - register, - register, - len(regs), - ) - return None - return (regs[0] << 16) | regs[1] - - elif data_type == "char": - byte_array = bytearray() - for reg in regs: - byte_array.append((reg >> 8) & 0xFF) - byte_array.append(reg & 0xFF) - return byte_array.decode("ascii", errors="ignore").rstrip('\x00') - - elif data_type == "schedule": - # Return a decoded dict for schedule blocks. - # 5 registers: days, start, end, mode (int16 signed), enabled - if len(regs) < 5: - _LOGGER.warning( - "Expected 5 registers for schedule at %d (0x%04X), got %s", - register, - register, - len(regs), - ) - return None - mode_raw = int(regs[3]) - mode_signed = mode_raw - 0x10000 if mode_raw >= 0x8000 else mode_raw - return { - "days": int(regs[0]), - "start": int(regs[1]), - "end": int(regs[2]), - "mode": mode_signed, - "enabled": int(regs[4]), - } - - elif data_type == "bit": - if bit_index is None or not (0 <= bit_index < 16): - raise ValueError("bit_index must be between 0 and 15 for bit data_type") - reg_val = regs[0] - return bool((reg_val >> bit_index) & 1) - - else: - raise ValueError(f"Unsupported data_type: {data_type}") - - except asyncio.CancelledError: - # Allow cancellation to propagate during Home Assistant shutdown - raise - except Exception as e: - # If the underlying cause is a CancelledError (pymodbus wraps it), - # propagate it so shutdown is not logged as an error. - cause = getattr(e, "__cause__", None) - if isinstance(cause, asyncio.CancelledError): - raise cause - - _LOGGER.exception( - "Exception during Modbus read at register %d (0x%04X) on attempt %d: %s", - register, - register, - attempt + 1, - e, - ) + elif data_type == "uint32": + value = (cache[address] << 16) | cache[address + 1] - attempt += 1 - if attempt < max_retries: - await asyncio.sleep(retry_delay) + elif data_type == "int32": + unsigned = (cache[address] << 16) | cache[address + 1] + value = struct.unpack(">i", struct.pack(">I", unsigned))[0] - _LOGGER.error( - "Failed to read register %d (0x%04X) after %d attempts", - register, - register, - max_retries, + elif data_type == "char": + raw_bytes = b"".join( + struct.pack(">H", cache[address + i]) for i in range(count) ) + return raw_bytes.split(b"\x00")[0].decode("ascii", errors="replace").strip() + + else: + _LOGGER.warning("Unknown data_type '%s' at address %s", data_type, address) return None - async def async_write_register( - self, - register: int, - value: int, - max_retries: int = 3, - retry_delay: float = 0.2, - ) -> bool: - """ - Write a single value to a Modbus holding register asynchronously with retries. + # Apply scale + if scale == 1.0 or scale == 1: + return int(value) + return round(float(value) * scale, 6) - Args: - register (int): Register address to write to. - value (int): Value to write. - max_retries (int): Maximum number of write attempts. - retry_delay (float): Delay in seconds between retries. - Returns: - bool: True if write was successful, False otherwise. - """ - # Input validation - if not (0 <= register <= 0xFFFF): - _LOGGER.error( - "Invalid register address for write: %d (0x%04X). Must be 0-65535.", - register, - register, - ) - return False +# --------------------------------------------------------------------------- +# Backward-compatible individual read helpers +# (kept so that any existing callers outside coordinator.py still work) +# --------------------------------------------------------------------------- + +async def read_registers( + client: AsyncModbusClient, address: int, count: int +) -> list[int]: + """Read *count* holding registers; return raw list[int].""" + cache = await batch_read(client, list(range(address, address + count)), + max_gap=0, max_size=MAX_BLOCK_SIZE) + return [cache[address + i] for i in range(count)] + + +async def read_uint16(client: AsyncModbusClient, address: int) -> int: + regs = await read_registers(client, address, 1) + return regs[0] + + +async def read_int16(client: AsyncModbusClient, address: int) -> int: + regs = await read_registers(client, address, 1) + return struct.unpack(">h", struct.pack(">H", regs[0]))[0] - # Expect caller to supply an already validated/converted 16-bit unsigned value. - if not isinstance(value, int): - _LOGGER.error("Invalid value type for write: %s. Must be int.", type(value)) - return False - if not (0 <= value <= 0xFFFF): - _LOGGER.error( - "Invalid value for write: %d. Must be 0-65535.", - value, +async def read_uint32(client: AsyncModbusClient, address: int) -> int: + regs = await read_registers(client, address, 2) + return (regs[0] << 16) | regs[1] + + +async def read_int32(client: AsyncModbusClient, address: int) -> int: + regs = await read_registers(client, address, 2) + unsigned = (regs[0] << 16) | regs[1] + return struct.unpack(">i", struct.pack(">I", unsigned))[0] + + +async def read_string(client: AsyncModbusClient, address: int, num_registers: int) -> str: + regs = await read_registers(client, address, num_registers) + raw_bytes = b"".join(struct.pack(">H", r) for r in regs) + return raw_bytes.split(b"\x00")[0].decode("ascii", errors="replace").strip() + + +# --------------------------------------------------------------------------- +# Write helpers +# --------------------------------------------------------------------------- + +async def write_register( + client: AsyncModbusClient, address: int, value: int +) -> None: + """Write a single holding register (FC 06).""" + try: + await client.write_single_register(address, value) + _LOGGER.debug("write_register %s = %s", address, value) + except ModbusConnectionError as exc: + raise ConnectionError( + f"Modbus connection error writing reg {address}: {exc}" + ) from exc + except TModbusError as exc: + raise OSError(f"Modbus error writing reg {address} = {value}: {exc}") from exc + + +async def write_registers( + client: AsyncModbusClient, address: int, values: list[int] +) -> None: + """Write multiple consecutive holding registers (FC 16).""" + try: + await client.write_multiple_registers(address, values) + _LOGGER.debug( + "write_registers %s…%s = %s", + address, address + len(values) - 1, values, + ) + except ModbusConnectionError as exc: + raise ConnectionError( + f"Modbus connection error writing regs at {address}: {exc}" + ) from exc + except TModbusError as exc: + raise OSError(f"Modbus error writing regs at {address}: {exc}") from exc + + +# --------------------------------------------------------------------------- +# MarstekModbusClient – Compatibility wrapper for config_flow.py +# --------------------------------------------------------------------------- + +class MarstekModbusClient: + """ + High-level Modbus client wrapper used by config_flow.py. + + Provides an async connect / close / read interface on top of tmodbus + so that config_flow.py does not need to know about tmodbus internals. + """ + + def __init__( + self, + host: str, + port: int, + *, + message_wait_ms: int | None = None, + timeout: float = 5.0, + unit_id: int = 1, + ) -> None: + self.host = host + self.port = port + self.timeout = timeout + self.unit_id = unit_id + self._wait_s = (message_wait_ms or 80) / 1000.0 + self._client: AsyncModbusClient | None = None + + async def async_connect(self) -> bool: + """Open the Modbus TCP connection. Returns True on success.""" + try: + self._client = await create_client( + self.host, self.port, self.unit_id, timeout=self.timeout ) + return True + except Exception as exc: # noqa: BLE001 + _LOGGER.debug("MarstekModbusClient.async_connect failed: %s", exc) return False - value_to_send = value - - attempt = 0 - while attempt < max_retries: - # Check client connection - client_connected = False - try: - client_connected = bool( - self.client and getattr(self.client, "connected", False) - ) - except Exception: - client_connected = False - - if not client_connected: - _LOGGER.warning( - "Modbus client not connected, attempting reconnect before write to register %d (0x%04X)", - register, - register, - ) - connected = await self.async_connect() - if not connected: - _LOGGER.error( - "Reconnect failed, skipping write to register %d (0x%04X)", - register, - register, - ) - return False - - # Additional safety check - if self.client is None: - _LOGGER.error("Modbus Client became None unexpectedly") - return False - - try: - _LOGGER.debug( - "Writing to register %d (0x%04X), value=%d (0x%04X), attempt=%d", - register, - register, - value, - value, - attempt + 1, - ) - result = None - async with self._request_lock: - try: - # Try multiple kwarg names for compatibility - for unit_kw in ("device_id", "unit", "slave"): - try: - result = await self.client.write_register( - address=register, value=value, **{unit_kw: self.unit_id} - ) - break - except TypeError: - result = None - continue - finally: - # Spacing after write - try: - await asyncio.sleep(self.message_wait_sec) - except asyncio.CancelledError: - raise - - # Check result - if result is None: - _LOGGER.warning( - "No response from write to register %d (0x%04X) on attempt %d", - register, - register, - attempt + 1, - ) - elif getattr(result, "isError", lambda: False)(): - _LOGGER.warning( - "Modbus write error at register %d (0x%04X) on attempt %d", - register, - register, - attempt + 1, - ) - else: - _LOGGER.debug( - "Write confirmed for register %d (0x%04X), value=%d", - register, - register, - value, - ) - return True - - except asyncio.CancelledError: - # Allow cancellation to propagate during shutdown - raise - - except Exception as e: - # If underlying cause is CancelledError, propagate it - cause = getattr(e, "__cause__", None) - if isinstance(cause, asyncio.CancelledError): - raise cause - - _LOGGER.exception( - "Exception during Modbus write at register %d (0x%04X) on attempt %d: %s", - register, - register, - attempt + 1, - e, - ) + async def async_close(self) -> None: + """Close the Modbus TCP connection.""" + await disconnect_client(self._client) + self._client = None - attempt += 1 - if attempt < max_retries: - await asyncio.sleep(retry_delay) + async def async_read_register( + self, + register: int, + data_type: str, + count: int, + sensor_key: str = "", + scale: float = 1.0, + ) -> int | float | str | None: + """ + Read one register (or a multi-register value) and return the decoded value. - _LOGGER.error( - "Failed to write to register %d (0x%04X) after %d attempts", - register, - register, - max_retries, - ) - return False \ No newline at end of file + Returns None on any Modbus error so callers can treat it as + "no response" without raising exceptions. + """ + if self._client is None: + return None + try: + addresses = list(range(register, register + count)) + cache = await batch_read(self._client, addresses, max_gap=0) + return extract_typed_value(cache, register, data_type, count, scale) + except Exception as exc: # noqa: BLE001 + _LOGGER.debug( + "MarstekModbusClient.async_read_register(%s, %s) failed: %s", + register, sensor_key, exc, + ) + return None From 3f022dcb7517802daa35ec5bbfc3434abead237e Mon Sep 17 00:00:00 2001 From: sphings79 <43515272+sphings79@users.noreply.github.com> Date: Mon, 4 May 2026 21:23:29 +0200 Subject: [PATCH 5/6] Update coordinator.py Signed-off-by: sphings79 <43515272+sphings79@users.noreply.github.com> --- .../marstek_modbus/coordinator.py | 1349 +++++++---------- 1 file changed, 578 insertions(+), 771 deletions(-) diff --git a/custom_components/marstek_modbus/coordinator.py b/custom_components/marstek_modbus/coordinator.py index 81c36dc..df751a5 100644 --- a/custom_components/marstek_modbus/coordinator.py +++ b/custom_components/marstek_modbus/coordinator.py @@ -1,841 +1,648 @@ """ -Handles all sensor polling via Home Assistant DataUpdateCoordinator, -with per-sensor intervals and optional skipping if not due. +coordinator.py – Marstek Venus Modbus DataUpdateCoordinator + (tmodbus + YAML-driven batch block-read) + +Interface expected by __init__.py: + coordinator = MarstekCoordinator(hass, entry) + await coordinator.async_load_registers(version_string) + await coordinator.async_init() + await coordinator.async_config_entry_first_refresh() + await coordinator.async_close() + coordinator._update_scan_intervals(scan_interval_dict) """ +from __future__ import annotations + import asyncio import logging +from dataclasses import dataclass from datetime import timedelta +from pathlib import Path +from typing import Any + +import yaml from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_HOST, CONF_PORT from homeassistant.core import HomeAssistant -from homeassistant.helpers.entity import Entity -from homeassistant.helpers.update_coordinator import DataUpdateCoordinator +from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed + +from .const import DOMAIN, DEFAULT_SCAN_INTERVALS +from .helpers.modbus_client import ( + create_client, + disconnect_client, + batch_read, + extract_typed_value, + write_register as _write_register, + write_registers as _write_registers, +) +from tmodbus.client import AsyncModbusClient +from tmodbus.exceptions import TModbusError -from .const import DEFAULT_SCAN_INTERVALS, SUPPORTED_VERSIONS, DEFAULT_UNIT_ID +_LOGGER = logging.getLogger(__name__) -from .helpers.modbus_client import MarstekModbusClient -from pathlib import Path +_WRITE_LOCK_TIMEOUT = 10.0 -_LOGGER = logging.getLogger(__name__) +# Map SUPPORTED_VERSIONS display strings → YAML filename in registers/ +_VERSION_YAML: dict[str, str] = { + "e v1/v2": "e_v12.yaml", + "e v3": "e_v3.yaml", + "d": "d.yaml", + "a": "a.yaml", +} + +_READABLE_SECTIONS = ( + "SENSOR_DEFINITIONS", + "BINARY_SENSOR_DEFINITIONS", + "SELECT_DEFINITIONS", + "SWITCH_DEFINITIONS", + "NUMBER_DEFINITIONS", +) + +_DEFAULT_COUNT: dict[str, int] = { + "uint16": 1, "int16": 1, "uint32": 2, "int32": 2, "char": 1, +} + +CONF_UNIT_ID = "unit_id" +CONF_DEVICE_VER = "device_version" + + +@dataclass(slots=True) +class ReadableEntry: + key: str; address: int; data_type: str; count: int; scale: float; priority: str + + +def _load_groups(yaml_path: Path) -> dict[str, list[ReadableEntry]]: + with open(yaml_path, encoding="utf-8") as fh: + raw: dict[str, Any] = yaml.safe_load(fh) or {} + + groups: dict[str, list[ReadableEntry]] = { + "high": [], "medium": [], "low": [], "very_low": [] + } + skipped: list[str] = [] + + for section in _READABLE_SECTIONS: + sec = raw.get(section) + if not isinstance(sec, dict): + continue + for key, entry in sec.items(): + if not isinstance(entry, dict) or "register" not in entry: + continue + + # ── Only poll registers that are enabled by default ─────────── + # Registers with enabled_by_default: false may not be supported by + # all device variants and can cause connection timeouts if polled + # unconditionally. Entities that the user explicitly enables in HA + # are added to the polling groups afterwards via + # async_register_enabled_entities() which reads the entity registry. + if not entry.get("enabled_by_default", True): + continue + # ──────────────────────────────────────────────────────────── + + raw_prio = str(entry.get("scan_interval", "")).strip().lower() + raw_prio = raw_prio.replace(" ", "_").replace("-", "_") + + if raw_prio not in groups: + if raw_prio == "": + # No scan_interval defined → use "very_low" as safe default + priority = "very_low" + else: + # Unknown value → skip and warn + skipped.append(f"{key}(unknown:{raw_prio})") + continue + else: + priority = raw_prio + + data_type = str(entry.get("data_type", "uint16")).lower() + if data_type not in _DEFAULT_COUNT: + data_type = "uint16" + count = int(entry.get("count", _DEFAULT_COUNT[data_type])) + if data_type in ("uint16", "int16"): + count = 1 + groups[priority].append(ReadableEntry( + key=key, address=int(entry["register"]), + data_type=data_type, count=count, + scale=float(entry.get("scale", 1.0)), priority=priority, + )) + + if skipped: + _LOGGER.warning("Skipped entries with unknown scan_interval value: %s", skipped) + + total = sum(len(v) for v in groups.values()) + _LOGGER.debug( + "Loaded %d register entries from %s (high=%d medium=%d low=%d very_low=%d)", + total, yaml_path.name, + len(groups["high"]), len(groups["medium"]), + len(groups["low"]), len(groups["very_low"]), + ) + return groups -def get_entity_type(entity) -> str: - """Determine entity type based on its class inheritance.""" - for base in entity.__class__.__mro__: - if issubclass(base, Entity) and base.__name__.endswith("Entity"): - return base.__name__.replace("Entity", "").lower() - return "entity" + +def _addresses_for(entries: list[ReadableEntry]) -> list[int]: + addrs: list[int] = [] + for e in entries: + addrs.extend(range(e.address, e.address + e.count)) + return addrs + + +def _ticks_for(interval_s: float, tick_s: float) -> int: + return max(1, round(interval_s / tick_s)) class MarstekCoordinator(DataUpdateCoordinator): - """Coordinator managing all Marstek Venus Modbus sensors.""" - - def __init__(self, hass: HomeAssistant, entry: ConfigEntry): - """Initialize the coordinator with connection parameters and update interval.""" - self.hass = hass - self.host = entry.data["host"] - self.port = entry.data["port"] - self.message_wait_ms = entry.data.get("message_wait_milliseconds") - self.timeout = entry.data.get("timeout") - self.unit_id = entry.data.get("unit_id", DEFAULT_UNIT_ID) - - # Mapping from sensor key to entity type for logging and processing - self._entity_types: dict[str, str] = {} + """Polls Marstek battery registers using tmodbus block-read.""" - # Store the config entry for potential future use - self.config_entry = entry - - # Scaling factors for sensors, if applicable - self._scales: dict[str, float] = {} - - # Load register/entity definitions for the device version selected in the config entry - # If device_version is missing (older installs), schedule a reauth flow so the user - # can pick the correct device version via a popup in the UI. Use a safe default - # to initialize the coordinator so the integration does not crash while waiting - # for the user to respond. - # Placeholder definitions — actual register definitions are loaded - # asynchronously to avoid blocking the event loop during __init__. - self.SENSOR_DEFINITIONS = [] - self.BINARY_SENSOR_DEFINITIONS = [] - self.SELECT_DEFINITIONS = [] - self.SWITCH_DEFINITIONS = [] - self.NUMBER_DEFINITIONS = [] - self.BUTTON_DEFINITIONS = [] - self.EFFICIENCY_SENSOR_DEFINITIONS = [] - self.STORED_ENERGY_SENSOR_DEFINITIONS = [] - self.CYCLE_SENSOR_DEFINITIONS = [] - - # Combine all sensor definitions for polling - self._all_definitions = [] - - # Initialize Modbus client for communication - self.client = MarstekModbusClient( - self.host, - self.port, - message_wait_ms=self.message_wait_ms, - timeout=self.timeout, - unit_id=self.unit_id, - ) + def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + self.host = entry.data.get(CONF_HOST, "") + self.port = int(entry.data.get(CONF_PORT, 502)) + self.unit_id = int(entry.data.get(CONF_UNIT_ID, 1)) + + self._entry = entry + self._client: AsyncModbusClient | None = None + self._lock = asyncio.Lock() + self._tick = 0 + + self._groups: dict[str, list[ReadableEntry]] = { + "high": [], "medium": [], "low": [], "very_low": [] + } + self._raw_yaml: dict[str, Any] = {} + + # Bad-gap tracking: when a batch_read block fails, the specific gaps + # between consecutive needed addresses inside that block are recorded + # here. _build_blocks will avoid bridging these gaps in future requests + # so the device is never asked to read unsupported register ranges again. + # The YAML registers themselves are NEVER removed – only the bridging + # of gaps between them is suppressed. + self._bad_gaps: set[tuple[int, int]] = set() + # Flag set by async_register_enabled_entities() to ensure the very_low + # group is polled once immediately after user-enabled entities are added, + # regardless of when tick 1 fired relative to entity registration. + self._needs_very_low_poll: bool = False + # Gaps that have been successfully bridged at least once. + # These are protected from bad_gaps poisoning during TCP outages: + # if a connection drop causes ALL blocks to fail, we do not want + # previously working gaps to be recorded as bad. + self._good_gaps: set[tuple[int, int]] = set() + + + # Registries used by entity classes to store metadata on the coordinator. + self._entity_types: dict[str, str] = {} + self._scales: dict[str, float] = {} + self._dependencies: dict = {} + self._precision: dict = {} + self._unit: dict = {} + self._device_class: dict = {} + self._state_class: dict = {} + self._enabled_by_default: dict = {} + self._icon: dict = {} + self._category: dict = {} + + # Merge defaults with any saved options + opts = {**DEFAULT_SCAN_INTERVALS, **entry.options} + high = opts.get("high", DEFAULT_SCAN_INTERVALS["high"]) + self._interval_high = high + self._medium_every = _ticks_for(opts.get("medium", DEFAULT_SCAN_INTERVALS["medium"]), high) + self._low_every = _ticks_for(opts.get("low", DEFAULT_SCAN_INTERVALS["low"]), high) + self._very_low_every = _ticks_for(opts.get("very_low", DEFAULT_SCAN_INTERVALS["very_low"]), high) - # Data storage for sensor values and timestamps of last updates - self.data: dict = {} - self._last_update_times: dict = {} - # Timestamps of last successful writes per key (for post-write read suppression) - self._last_write_times: dict = {} - # Timestamps when a read was last started per key (for stale-read detection) - self._read_start_times: dict = {} - - # Connection throttling to prevent endless retry attempts after repeated failures - self._consecutive_failures = 0 - self._max_consecutive_failures = 5 - self._connection_suspended = False - self._suspension_reset_time = None - - self._consecutive_timeout_cycles = 0 - self._max_consecutive_timeout_cycles = 3 - self._timeout_ratio_reconnect_threshold = 0.5 - - # Connection health tracking for diagnostics - self._last_successful_read = None - self._connection_established_at = None - - # Per-register failure tracking for exponential backoff. - # Counts consecutive failed reads per key; resets to 0 on first success. - # Effective poll interval = base_interval * 2^min(failures, 6), capped at 3600s. - self._register_failures: dict[str, int] = {} - # Tracks last *attempt* time (success or failure) for backoff interval calculation. - self._last_attempt_times: dict = {} - - # Prepare scan intervals (from config_entry.options or default) - options = entry.options or {} - self._update_scan_intervals(options) - - # Initialize the base DataUpdateCoordinator with the calculated interval super().__init__( - hass, - _LOGGER, - name="MarstekCoordinator", - update_interval=self.update_interval, + hass, _LOGGER, name=DOMAIN, + update_interval=timedelta(seconds=high), ) - - _LOGGER.debug("Coordinator initialized with update_interval: %s", self.update_interval) - - def _update_scan_intervals(self, options: dict): - """Update scan intervals from config options and compute update_interval (lowest interval always used).""" - old_intervals = getattr(self, "scan_intervals", {}).copy() if hasattr(self, "scan_intervals") else {} - self.scan_intervals = DEFAULT_SCAN_INTERVALS.copy() - - for key in DEFAULT_SCAN_INTERVALS: - if key in options: - try: - self.scan_intervals[key] = int(options[key]) - except Exception: - _LOGGER.warning("Invalid scan interval for %s: %s", key, options[key]) - - # Compute minimum interval for coordinator - min_interval = min(self.scan_intervals.values()) if self.scan_intervals else 30 - self.update_interval = timedelta(seconds=min_interval) - - # Update DataUpdateCoordinator's update_interval if coordinator is already initialized - if hasattr(self, "_listeners") and self._listeners is not None: - # update_interval is a property in DataUpdateCoordinator - try: - super(MarstekCoordinator, self.__class__).update_interval.fset(self, self.update_interval) - _LOGGER.debug( - "Coordinator update_interval changed dynamically to %s due to options change", - self.update_interval, - ) - except Exception as e: - _LOGGER.warning("Failed to update coordinator update_interval: %s", e) - _LOGGER.debug( - "Scan intervals updated. Old: %s, New: %s, Coordinator update_interval: %s", - old_intervals, - self.scan_intervals, - self.update_interval, - ) + # ── Lifecycle ───────────────────────────────────────────────────────── + async def async_load_registers(self, version_string: str | None) -> None: + """Load register YAML for the given device version string.""" + if not version_string: + _LOGGER.warning("No device_version set – skipping register load") + return - def register_entity_type(self, key: str, entity_type: str): - """Register the entity type for a given sensor key. - For calculated sensors with dependencies, ensure all dependency keys are registered. - """ - self._entity_types[key] = entity_type - - # Register all dependency keys with entity type and scale - definition = next((d for d in self.SENSOR_DEFINITIONS if d.get("key") == key), None) - if definition and "dependency_keys" in definition: - for dep_alias, dep_key in definition["dependency_keys"].items(): - if dep_key not in self._entity_types: - # Use the same entity type as the parent sensor - self._entity_types[dep_key] = entity_type - - # Retrieve scale from the dependency sensor definition - dep_def = next((d for d in self.SENSOR_DEFINITIONS if d.get("key") == dep_key), None) - if dep_def: - scale = dep_def.get("scale") - if scale is not None: - self._scales[dep_key] = scale - - def get_connection_diagnostics(self) -> dict: - """Return diagnostic information about the connection.""" - from homeassistant.util.dt import utcnow - now = utcnow() - - diagnostics = { - "host": self.host, - "port": self.port, - "consecutive_failures": self._consecutive_failures, - "connection_suspended": self._connection_suspended, - "last_successful_read": self._last_successful_read.isoformat() if self._last_successful_read else None, - "connection_established_at": self._connection_established_at.isoformat() if self._connection_established_at else None, - } - - if self._connection_suspended and self._suspension_reset_time: - diagnostics["suspension_expires_in_seconds"] = (self._suspension_reset_time - now).total_seconds() - - return diagnostics - - async def async_init(self): - """Asynchronously initialize the Modbus connection.""" - from homeassistant.util.dt import utcnow - connected = await self.client.async_connect() - if not connected: - _LOGGER.error("Failed to connect to Modbus device at %s:%d", self.host, self.port) - else: - self._connection_established_at = utcnow() - _LOGGER.info("Successfully connected to Modbus device at %s:%d", self.host, self.port) - return connected + yaml_file = _VERSION_YAML.get(version_string.strip().lower()) + if yaml_file is None: + _LOGGER.warning( + "Unknown device_version '%s'. Known: %s", + version_string, list(_VERSION_YAML), + ) + return + yaml_path = Path(__file__).parent / "registers" / yaml_file + if not yaml_path.exists(): + _LOGGER.error("Register file not found: %s", yaml_path) + return - async def async_load_registers(self, version: str | None = None): - """Load register definitions from YAML (off the event loop) and populate coordinator attributes. + # File I/O must run in the executor – HA forbids blocking calls in the event loop + def _load() -> tuple[dict, dict]: + with open(yaml_path, encoding="utf-8") as fh: + raw = yaml.safe_load(fh) or {} + return _load_groups(yaml_path), raw - This method must be called from async context (and will run the blocking - YAML load in the executor) to avoid performing file I/O inside __init__. - """ - # Determine used version and handle legacy/missing tokens the same way - raw_device_version = (version or "") or "" - if not str(raw_device_version).strip(): - # No device_version configured; use default first supported version - used_version = SUPPORTED_VERSIONS[0] - else: - used_version = raw_device_version + self._groups, self._raw_yaml = await self.hass.async_add_executor_job(_load) + async def async_init(self) -> None: + """Connect to the Modbus gateway. Raises ConfigEntryNotReady on failure.""" try: - data = await self.hass.async_add_executor_job(get_registers, used_version) - self.SENSOR_DEFINITIONS = data.get("SENSOR_DEFINITIONS", []) - self.BINARY_SENSOR_DEFINITIONS = data.get("BINARY_SENSOR_DEFINITIONS", []) - self.SELECT_DEFINITIONS = data.get("SELECT_DEFINITIONS", []) - self.SWITCH_DEFINITIONS = data.get("SWITCH_DEFINITIONS", []) - self.NUMBER_DEFINITIONS = data.get("NUMBER_DEFINITIONS", []) - self.BUTTON_DEFINITIONS = data.get("BUTTON_DEFINITIONS", []) - self.EFFICIENCY_SENSOR_DEFINITIONS = data.get("EFFICIENCY_SENSOR_DEFINITIONS", []) - self.STORED_ENERGY_SENSOR_DEFINITIONS = data.get("STORED_ENERGY_SENSOR_DEFINITIONS", []) - self.CYCLE_SENSOR_DEFINITIONS = data.get("CYCLE_SENSOR_DEFINITIONS", []) - - # Combine into a single list for polling - self._all_definitions = ( - self.SENSOR_DEFINITIONS - + self.BINARY_SENSOR_DEFINITIONS - + self.SELECT_DEFINITIONS - + self.NUMBER_DEFINITIONS - + self.SWITCH_DEFINITIONS + self._client = await create_client( + self.host, self.port, self.unit_id, timeout=5.0 ) - _LOGGER.debug("Loaded register definitions for version '%s' (%d entries)", used_version, len(self._all_definitions)) - except Exception as e: - _LOGGER.warning("Failed to load register definitions for version '%s': %s", used_version, e) - # Keep empty definitions as fallback; platforms will see no entities - self._all_definitions = [] - - async def async_read_value(self, sensor: dict, key: str, track_failure: bool = True): - """Helper to read a single sensor value from Modbus with logging and type checking. - - Args: - sensor: sensor definition dict - key: the sensor key - track_failure: if False, timeouts will not count towards timeout metrics + except (ConnectionError, OSError, TModbusError) as exc: + raise ConfigEntryNotReady( + f"Cannot connect to Marstek at {self.host}:{self.port} – {exc}" + ) from exc + + async def async_close(self) -> None: + """Close the Modbus connection.""" + await disconnect_client(self._client) + self._client = None + + def add_to_polling(self, key: str, definition: dict) -> None: """ - entity_type = self._entity_types.get(key, get_entity_type(sensor)) + Dynamically add a single register definition to the polling groups. - # Determine scale and unit - scale = self._scales.get(key, sensor.get("scale", 1)) - unit = sensor.get("unit", "N/A") + Called for entities that are enabled in HA but have + enabled_by_default: false in their YAML definition. + Silently skips entries that are already polled. + """ + # Skip if already in any group + for group in self._groups.values(): + if any(e.key == key for e in group): + return + + if "register" not in definition: + return + + # Force very_low for all dynamically-registered non-default entities. + # Their YAML scan_interval may be "high" or "medium", but reading + # registers the device does not support causes per-register fallback + # reads of up to 5 s each. At very_low (default 180 s) even a full + # block of 16 unsupported registers (16 × 5 s = 80 s) only blocks + # the coordinator once every three minutes instead of every 10–30 s. + priority = "very_low" + + data_type = str(definition.get("data_type", "uint16")).lower() + if data_type not in _DEFAULT_COUNT: + data_type = "uint16" + count = int(definition.get("count", _DEFAULT_COUNT[data_type])) + if data_type in ("uint16", "int16"): + count = 1 + + entry = ReadableEntry( + key=key, + address=int(definition["register"]), + data_type=data_type, + count=count, + scale=float(definition.get("scale", 1.0)), + priority=priority, + ) + self._groups[priority].append(entry) + _LOGGER.debug( + "Dynamically added '%s' (reg %s) to %s polling group", + key, definition["register"], priority, + ) - # Guard: ensure client exists - if not hasattr(self, "client") or self.client is None: - _LOGGER.error("Modbus client is not available when reading %s '%s'", entity_type, key) - return None + async def async_register_enabled_entities(self) -> None: + """ + Inspect the HA entity registry and add any enabled-but-not-default + entities to the coordinator polling groups. - try: - # 10 second timeout for individual reads to prevent hanging - value = await asyncio.wait_for( - self.client.async_read_register( - register=sensor["register"], - data_type=sensor.get("data_type", "uint16"), - count=sensor.get("count", 1), - sensor_key=key, - ), - timeout=10.0 - ) + Called once from async_setup_entry after all platforms have been + forwarded so that async_added_to_hass has completed for all entities. + """ + from homeassistant.helpers import entity_registry as er - # Accept primitive values and structured types (dict/list) returned - # by specialized data_type handlers (e.g., `schedule` returning a dict). - if isinstance(value, (int, float, bool, str, dict, list)): - _LOGGER.debug( - "Updated %s '%s': register=%d, value=%s, scale=%s, unit=%s", - entity_type, - key, - sensor["register"], - value, - scale, - unit, - ) - return value - _LOGGER.warning( - "Invalid value for %s '%s': %r (type %s)", - entity_type, - key, - value, - type(value).__name__, - ) - return None + ent_reg = er.async_get(self.hass) + entries = er.async_entries_for_config_entry(ent_reg, self._entry.entry_id) - except asyncio.TimeoutError: - if track_failure: - self._timeouts_in_cycle = getattr(self, "_timeouts_in_cycle", 0) + 1 - _LOGGER.warning( - "Timeout reading %s '%s' at register %d from %s:%d - connection may be slow or incorrect", - entity_type, key, sensor["register"], self.client.host, self.client.port - ) - return None - except Exception as e: - _LOGGER.error( - "Error reading %s '%s' at register %d: %s", - entity_type, key, sensor["register"], e, - ) - return None + # Build a flat lookup of all YAML register definitions + all_defs: dict[str, dict] = {} + for section in _READABLE_SECTIONS: + sec = self._raw_yaml.get(section) + if isinstance(sec, dict): + all_defs.update(sec) - async def async_write_value( - self, - register: int, - value: int, - key: str, - scale=None, - unit=None, - entity_type="unknown", - ): - """Write a value to a Modbus register asynchronously and log the operation.""" - # Guard: ensure client exists before attempting write - if not hasattr(self, "client") or self.client is None: - _LOGGER.error("Modbus client is not available when writing %s '%s'", entity_type, key) - return False + added = 0 + for reg_entry in entries: + if reg_entry.disabled: + continue # entity is disabled in HA – skip + + # unique_id format: "{entry_id}_{key}" + prefix = f"{self._entry.entry_id}_" + if not reg_entry.unique_id.startswith(prefix): + continue + key = reg_entry.unique_id[len(prefix):] + + defn = all_defs.get(key) + if defn is None: + continue # not a polled register (e.g. calculated sensor) + if defn.get("enabled_by_default", True): + continue # already in polling groups from _load_groups + + self.add_to_polling(key, defn) + added += 1 _LOGGER.debug( - "Writing to %s '%s': register=%d (0x%04X), value=%s", - entity_type, - key, - register, - register, - value, + "async_register_enabled_entities: added %d non-default enabled entities " + "to very_low polling group", + added, + ) + if added: + # Signal _fetch_tick to include very_low on the next poll so that + # user-enabled entities get an immediate value even if tick 1 fired + # (via select platform's update_before_add=True) before this method + # ran and therefore polled very_low without these entries. + self._needs_very_low_poll = True + + def _update_scan_intervals(self, intervals: dict[str, int]) -> None: + """Update polling intervals at runtime (called from options flow).""" + high = intervals.get("high", self._interval_high) + self._interval_high = high + self._medium_every = _ticks_for(intervals.get("medium", DEFAULT_SCAN_INTERVALS["medium"]), high) + self._low_every = _ticks_for(intervals.get("low", DEFAULT_SCAN_INTERVALS["low"]), high) + self._very_low_every = _ticks_for(intervals.get("very_low", DEFAULT_SCAN_INTERVALS["very_low"]), high) + self.update_interval = timedelta(seconds=high) + _LOGGER.debug( + "Scan intervals updated: high=%ds medium_every=%d low_every=%d very_low_every=%d", + high, self._medium_every, self._low_every, self._very_low_every, ) - # Determine data_type for this key (numbers typically in NUMBER_DEFINITIONS) - data_type = None - try: - defn = next((d for d in self.NUMBER_DEFINITIONS if d.get("key") == key), None) - if not defn: - # fallback to switches/selects if user configured writes elsewhere - defn = next((d for d in self.SWITCH_DEFINITIONS if d.get("key") == key), None) - if defn: - data_type = defn.get("data_type") - except Exception: - data_type = None - - # Default to uint16 when unknown - if not data_type: - data_type = "uint16" + # ── DataUpdateCoordinator ───────────────────────────────────────────── - # Convert/validate value according to data_type - value_to_send = None - if data_type == "int16": - if not isinstance(value, int): - _LOGGER.error("Value for %s '%s' must be int for data_type int16", entity_type, key) - return False - value_to_send = value & 0xFFFF - elif data_type == "uint16": - if not isinstance(value, int) or not (0 <= value <= 0xFFFF): - _LOGGER.error("Value for %s '%s' must be 0..65535 for data_type uint16", entity_type, key) - return False - value_to_send = value - else: - # Not implemented conversion for 32-bit types here - _LOGGER.error("Unsupported data_type '%s' for key '%s' on write", data_type, key) - return False + async def _async_update_data(self) -> dict[str, Any]: + if self._client is None: + _LOGGER.warning("Modbus client is None – attempting reconnect") + await self.async_init() + + self._tick += 1 + _LOGGER.debug("Coordinator poll tick %d", self._tick) try: - import asyncio as _asyncio - try: - success = await _asyncio.wait_for( - self.client.async_write_register(register=register, value=value_to_send), - timeout=10.0, - ) - except _asyncio.TimeoutError: - _LOGGER.error( - "Timeout writing to register 0x%X for %s '%s' - connection may be half-open", - register, - entity_type, - key, + # 120 s gives enough headroom even when many unsupported registers + # each hit the 5 s per-register fallback timeout + # (e.g. 16 cell-voltage registers × 5 s = 80 s). + async with asyncio.timeout(120): + async with self._lock: + return await self._fetch_tick() + except TimeoutError as exc: + raise UpdateFailed("Timeout polling Marstek registers") from exc + except (ConnectionError, OSError, TModbusError) as exc: + raise UpdateFailed(f"Modbus communication error: {exc}") from exc + + async def _fetch_tick(self) -> dict[str, Any]: + client = self._client + assert client is not None + + data: dict[str, Any] = dict(self.data) if self.data else {} + + if self._tick == 1: + # ── Initial full poll ──────────────────────────────────────────── + # Poll all 4 groups once so every entity has an immediate value. + # We deliberately poll group-by-group (not a single merged list) + # so that failures in very_low (user-enabled non-default registers + # that may not exist on this device variant) cannot contaminate + # the good_gaps of the vetted default registers in high/medium/low. + # Each group gets its own batch_read call with its own block layout. + _LOGGER.debug("Tick 1: initial full poll of all groups") + due = ["high", "medium", "low", "very_low"] + else: + due = ["high"] + if self._tick % self._medium_every == 0: + due.append("medium") + if self._tick % self._low_every == 0: + due.append("low") + if self._tick % self._very_low_every == 0: + due.append("very_low") + # If async_register_enabled_entities() added user-enabled entries to + # the very_low group AFTER tick 1 already fired without them, poll + # very_low once immediately so those entities get values right away. + if self._needs_very_low_poll and "very_low" not in due: + due.append("very_low") + self._needs_very_low_poll = False + _LOGGER.debug( + "Tick %d: adding one-shot very_low poll for newly registered " + "user-enabled entities", + self._tick, ) - return False - if success: - _LOGGER.debug( - "Successfully wrote to %s '%s': register=%d (0x%04X), value=%s, scale=%s, unit=%s", - entity_type, - key, - register, - register, - value_to_send, - scale if scale is not None else 1, - unit if unit is not None else "N/A", + _LOGGER.debug("Tick %d polling groups: %s", self._tick, due) + + for priority in due: + entries = self._groups[priority] + if not entries: + continue + try: + cache = await batch_read( + client, + _addresses_for(entries), + bad_gaps=self._bad_gaps, + good_gaps=self._good_gaps, ) - from homeassistant.util.dt import utcnow as _utcnow - self._last_write_times[key] = _utcnow() - return True - else: + except Exception as exc: # noqa: BLE001 _LOGGER.warning( - "Write operation failed for %s '%s': register=%d (0x%04X), value=%s", - entity_type, - key, - register, - register, - value, + "batch_read for %s group failed, skipping: %s", priority, exc ) - return False - - except Exception as e: - _LOGGER.error( - "Failed to write value %s to register 0x%X for %s '%s': %s", - value, - register, - entity_type, - key, - e - ) - return False + continue + for entry in entries: + value = extract_typed_value( + cache, entry.address, entry.data_type, entry.count, entry.scale + ) + if value is not None: + data[entry.key] = value + elif entry.key not in data: + data[entry.key] = None - async def _async_update_data(self): - """Update all sensors asynchronously with per-sensor interval skipping. + data.update(_calculate_derived(data, self._raw_yaml)) + return data - Buttons are excluded as they are not polled. - Sensors disabled in Home Assistant are skipped, except dependencies which are always fetched. - """ - from homeassistant.util.dt import utcnow - from homeassistant.helpers import entity_registry as er - now = utcnow() - updated_data = {} - - # Track if we actually attempted any reads (not just skipped due to intervals) - attempted_reads = 0 - successful_reads = 0 - self._timeouts_in_cycle = 0 - - # Connection throttling: if too many failures, temporarily stop attempting connections - if self._connection_suspended: - if self._suspension_reset_time and now > self._suspension_reset_time: - _LOGGER.info("Connection suspension expired - attempting reconnection") - self._connection_suspended = False - self._consecutive_failures = 0 - - # Force reconnect after suspension - try: - connected = await self.client.async_reconnect() - if connected: - _LOGGER.info("Successfully reconnected after suspension") - else: - _LOGGER.warning("Failed to reconnect after suspension - will retry next cycle") - return self.data or {} - except Exception as exc: - _LOGGER.error("Exception during reconnect: %s", exc) - return self.data or {} - else: - _LOGGER.debug("Connection suspended - skipping update to prevent resource exhaustion") - return self.data or {} - _LOGGER.debug("Coordinator poll tick at %s", now.isoformat()) + # ── Public accessors ────────────────────────────────────────────────── - # Get the entity registry to check for disabled entities - entity_registry = er.async_get(self.hass) + @property + def raw_yaml(self) -> dict[str, Any]: + return self._raw_yaml - # Collect all dependency keys from all definitions - all_definitions_for_deps = ( - self.EFFICIENCY_SENSOR_DEFINITIONS - + self.STORED_ENERGY_SENSOR_DEFINITIONS - + self.CYCLE_SENSOR_DEFINITIONS - ) - dependency_keys_set = { - dep_key - for defn in all_definitions_for_deps - for dep_key in defn.get("dependency_keys", {}).values() - if dep_key - } + # ── YAML section properties (used by entity platform builders) ──────── - # Debug logging - for dep_key in dependency_keys_set: - _LOGGER.debug("Dependency key '%s'", dep_key) - - # Iterate over each sensor definition to poll if due - for sensor in self._all_definitions: - key = sensor["key"] - entity_type = self._entity_types.get(key, get_entity_type(sensor)) - unique_id = f"{self.config_entry.entry_id}_{sensor['key']}" - registry_entry = entity_registry.async_get_entity_id(entity_type, self.config_entry.domain, unique_id) - - # Determine if the entity is disabled in Home Assistant - is_disabled = False - entry = entity_registry.entities.get(registry_entry) if registry_entry else None - if entry: - is_disabled = entry.disabled or entry.disabled_by is not None - - # Check if this key is a dependency key for any sensor - is_dependency = key in dependency_keys_set - - # Skip polling if entity is disabled unless it is a dependency key - if is_disabled: - if is_dependency: - _LOGGER.debug("Fetching disabled dependency key '%s'", key) - else: - _LOGGER.debug("Skipping disabled entity '%s'", sensor.get("name", key)) - continue + def _yaml_section_as_list(self, section: str) -> list[dict]: + """Return a YAML section as list[dict], each with "key" injected.""" + return [ + {"key": k, **v} + for k, v in self._raw_yaml.get(section, {}).items() + if isinstance(v, dict) + ] - # Determine polling interval for this sensor, using self.scan_intervals - interval_name = sensor.get("scan_interval") - interval = None - if interval_name: - interval = self.scan_intervals.get(interval_name) + @property + def SENSOR_DEFINITIONS(self) -> list[dict]: + return self._yaml_section_as_list("SENSOR_DEFINITIONS") - if interval is None: - _LOGGER.warning( - "%s '%s' has no scan_interval defined, skipping this poll", - entity_type, - key, - ) - continue + @property + def BINARY_SENSOR_DEFINITIONS(self) -> list[dict]: + return self._yaml_section_as_list("BINARY_SENSOR_DEFINITIONS") - # Skip read for 3s after a write to avoid reading back stale device state - last_write = self._last_write_times.get(key) - if last_write is not None and (now - last_write).total_seconds() < 3: - _LOGGER.debug("Suppressing read of '%s' after recent write", key) - continue + @property + def SELECT_DEFINITIONS(self) -> list[dict]: + return self._yaml_section_as_list("SELECT_DEFINITIONS") - # Apply per-register exponential backoff based on consecutive failures. - # This prevents hammering dead/removed registers at full poll rate. - failures = self._register_failures.get(key, 0) - backoff = min(2 ** failures, 64) # max 64x base interval - effective_interval = min(interval * backoff, 3600) + @property + def SWITCH_DEFINITIONS(self) -> list[dict]: + return self._yaml_section_as_list("SWITCH_DEFINITIONS") - last_attempt = self._last_attempt_times.get(key) - elapsed = (now - last_attempt).total_seconds() if last_attempt else None + @property + def NUMBER_DEFINITIONS(self) -> list[dict]: + return self._yaml_section_as_list("NUMBER_DEFINITIONS") - if elapsed is not None and elapsed < effective_interval: - _LOGGER.debug( - "Skipping %s '%s', last attempt %.1fs ago (effective interval %ds, failures=%d)", - entity_type, - key, - elapsed, - effective_interval, - failures, - ) - continue + @property + def BUTTON_DEFINITIONS(self) -> list[dict]: + return self._yaml_section_as_list("BUTTON_DEFINITIONS") - # Track that we're attempting a read - attempted_reads += 1 - self._read_start_times[key] = now - - # Attempt to read the sensor value from Modbus using helper function - value = await self.async_read_value(sensor, key) - - if value is not None: - # Special-case: for packed schedule sensors, store both the - # raw 5-register list as the main `data[key]` and the decoded - # dict under `data["_attrs"]` so sensors can expose - # attributes while the state remains the raw registers. - if sensor.get("data_type") == "schedule" and isinstance(value, dict): - try: - days = int(value.get("days") or 0) - except Exception: - days = value.get("days") - try: - start = int(value.get("start") or 0) - except Exception: - start = value.get("start") - try: - end = int(value.get("end") or 0) - except Exception: - end = value.get("end") - try: - enabled = int(value.get("enabled") or 0) - except Exception: - enabled = value.get("enabled") - - # Mode in attrs is signed; convert to unsigned 16-bit for raw register - try: - mode_signed = int(value.get("mode") or 0) - mode_raw = mode_signed & 0xFFFF - except Exception: - mode_raw = value.get("mode") - - raw_regs = [days, start, end, mode_raw, enabled] - - updated_data[key] = raw_regs - try: - updated_data[f"{key}_attrs"] = value - except Exception: - _LOGGER.exception("Failed to populate %s_attrs", key) - - _LOGGER.debug( - "Stored raw schedule for %s: %s and attrs: %s", - key, - raw_regs, - value, - ) - else: - updated_data[key] = value - - self._last_update_times[key] = now - self._last_attempt_times[key] = now - prev_failures = self._register_failures.get(key, 0) - if prev_failures > 0: - _LOGGER.info( - "%s '%s' recovered after %d consecutive failure(s)", - entity_type, key, prev_failures, - ) - self._register_failures[key] = 0 - successful_reads += 1 - else: - # Individual sensor read failed — increment backoff counter - self._last_attempt_times[key] = now - new_failures = self._register_failures.get(key, 0) + 1 - self._register_failures[key] = new_failures - next_backoff = min(2 ** new_failures, 64) - next_interval = min(interval * next_backoff, 3600) - # Log verbosely on first few failures and then only at milestones - if new_failures <= 3 or new_failures % 10 == 0: - _LOGGER.warning( - "Failed to read %s '%s' - value is None " - "(consecutive failures: %d, next poll in %ds)", - entity_type, key, new_failures, next_interval, - ) - else: - _LOGGER.debug( - "Failed to read %s '%s' - value is None (failure #%d)", - entity_type, key, new_failures, - ) - - # Connection retry logic: only track failures if we actually attempted reads - if attempted_reads > 0: - timeout_reads = int(getattr(self, "_timeouts_in_cycle", 0) or 0) - if successful_reads > 0: - # At least some data successfully retrieved - reset failure counter - if self._consecutive_failures > 0: - _LOGGER.info("Connection recovered after %d failures (successful reads: %d/%d)", - self._consecutive_failures, successful_reads, attempted_reads) - self._consecutive_failures = 0 - self._connection_suspended = False - self._last_successful_read = now - - if timeout_reads and (timeout_reads / attempted_reads) >= self._timeout_ratio_reconnect_threshold: - self._consecutive_timeout_cycles += 1 - _LOGGER.warning( - "High timeout rate detected (%d/%d) - consecutive timeout cycles: %d/%d", - timeout_reads, - attempted_reads, - self._consecutive_timeout_cycles, - self._max_consecutive_timeout_cycles, - ) - else: - self._consecutive_timeout_cycles = 0 - - if self._consecutive_timeout_cycles >= self._max_consecutive_timeout_cycles: - try: - _LOGGER.info( - "Attempting reconnect due to repeated timeouts (%d/%d cycles)", - self._consecutive_timeout_cycles, - self._max_consecutive_timeout_cycles, - ) - connected = await self.client.async_reconnect() - if connected: - _LOGGER.info("Successfully reconnected after repeated timeouts") - self._consecutive_timeout_cycles = 0 - self._connection_established_at = now - else: - _LOGGER.warning("Reconnect attempt after repeated timeouts failed") - except Exception as exc: - _LOGGER.error("Exception during reconnect after repeated timeouts: %s", exc) - elif successful_reads == 0: - # We attempted reads but ALL failed - connection issue - self._consecutive_failures += 1 - _LOGGER.warning("All read attempts failed (%d/%d) - consecutive failures: %d/%d", - successful_reads, attempted_reads, - self._consecutive_failures, self._max_consecutive_failures) - - # Try to reconnect immediately on failure (use reconnect helper) - try: - _LOGGER.info("Attempting immediate reconnection after read failures") - connected = await self.client.async_reconnect() - if connected: - _LOGGER.info("Successfully reconnected") - self._consecutive_failures = 0 - self._connection_established_at = now - else: - _LOGGER.warning("Immediate reconnection failed") - except Exception as exc: - _LOGGER.error("Exception during immediate reconnect: %s", exc) - - if self._consecutive_failures >= self._max_consecutive_failures: - # Too many failures - suspend connection attempts for 1 minute - self._connection_suspended = True - self._suspension_reset_time = now + timedelta(minutes=1) - _LOGGER.error( - "Connection suspended after %d consecutive failures. " - "Will retry in 1 minute to prevent resource exhaustion.", - self._consecutive_failures - ) - self._consecutive_timeout_cycles = 0 - else: - _LOGGER.debug("No sensors due for update in this cycle") - - # Defensive check - if self.data is None: - self.data = {} - - # Discard any read result that was overtaken by a write during this cycle. - # If a write completed after the read for a key was started, the read - # observed a pre-write device state and must not overwrite the fresh write. - for _k in list(updated_data.keys()): - _read_start = self._read_start_times.get(_k) - _last_write = self._last_write_times.get(_k) - if _read_start and _last_write and _last_write > _read_start: - _LOGGER.debug( - "Discarding stale read of '%s' — write completed after read started", _k - ) - del updated_data[_k] + @property + def EFFICIENCY_SENSOR_DEFINITIONS(self) -> list[dict]: + return self._yaml_section_as_list("EFFICIENCY_SENSOR_DEFINITIONS") + + @property + def STORED_ENERGY_SENSOR_DEFINITIONS(self) -> list[dict]: + return self._yaml_section_as_list("STORED_ENERGY_SENSOR_DEFINITIONS") - # Update the coordinator's data - self.data.update(updated_data) - return self.data - + @property + def CYCLE_SENSOR_DEFINITIONS(self) -> list[dict]: + return self._yaml_section_as_list("CYCLE_SENSOR_DEFINITIONS") - async def async_close(self): - """Close the Modbus client connection cleanly.""" + # ── Write interface ─────────────────────────────────────────────────── + + async def async_write_register(self, address: int, value: int) -> None: + if self._client is None: + raise UpdateFailed("Cannot write – Modbus client not connected") try: - await self.client.async_close() - _LOGGER.debug("Closed Modbus connection to %s:%d", self.host, self.port) - except Exception as e: - _LOGGER.warning("Error closing Modbus client: %s", e) - - -def get_registers(version: str): - """ - Return a dict with entity/register definitions for the given device version. - - The returned dict contains the keys: - - SENSOR_DEFINITIONS - - BINARY_SENSOR_DEFINITIONS - - SELECT_DEFINITIONS - - SWITCH_DEFINITIONS - - NUMBER_DEFINITIONS - - BUTTON_DEFINITIONS - - EFFICIENCY_SENSOR_DEFINITIONS - - STORED_ENERGY_SENSOR_DEFINITIONS - - CYCLE_SENSOR_DEFINITIONS - - If an unknown version is requested, the function falls back to the v1/v2 - register set (because v1 and v2 share the same registers in this integration). - """ - # Normalize incoming version value and accept legacy tokens. - version_raw = (version or "").strip() - version = version_raw.lower() - _LOGGER.info( - "Version '%s' mapped to '%s'" , - version_raw, - version, - ) - # Accept legacy tokens 'v1/v2' and 'v3' and automatically map them - # to the new tokens used by the integration ('e v1/v2', 'e v3'). - legacy_to_new = { - "v1/v2": "e v1/v2", - "v3": "e v3", - } - if version in legacy_to_new: - mapped = legacy_to_new[version] - _LOGGER.info( - "Mapping legacy device version '%s' to '%s' for backwards compatibility", - version_raw, - mapped, - ) - version = mapped - - # Validate against supported versions (case-insensitive) - allowed = {str(item).lower() for item in SUPPORTED_VERSIONS} - if version not in allowed: - raise ValueError( - "Unsupported or missing device version %r. Supported versions: %s" - % (version_raw, ", ".join(sorted(allowed))) + async with asyncio.timeout(_WRITE_LOCK_TIMEOUT): + async with self._lock: + await _write_register(self._client, address, value) + except TimeoutError as exc: + raise UpdateFailed(f"Timeout acquiring write lock for reg {address}") from exc + except (ConnectionError, OSError, TModbusError) as exc: + raise UpdateFailed(f"Failed writing reg {address}={value}: {exc}") from exc + + async def async_write_registers(self, address: int, values: list[int]) -> None: + if self._client is None: + raise UpdateFailed("Cannot write – Modbus client not connected") + try: + async with asyncio.timeout(_WRITE_LOCK_TIMEOUT): + async with self._lock: + await _write_registers(self._client, address, values) + except TimeoutError as exc: + raise UpdateFailed(f"Timeout acquiring write lock for regs at {address}") from exc + except (ConnectionError, OSError, TModbusError) as exc: + raise UpdateFailed(f"Failed writing regs at {address}: {exc}") from exc + + async def async_write_value( + self, + register: int, + value: int, + key: str = "", + scale: float = 1, + unit: str | None = None, + entity_type: str = "", + ) -> bool: + """ + Write a single raw integer value to a Modbus register. + + This is the unified write entry-point used by all entity platforms + (switch, select, number, button). The caller is responsible for + converting engineering-unit values to raw register integers before + calling this method. + + Returns True on success, False on failure. + """ + try: + await self.async_write_register(register, int(value)) + _LOGGER.debug( + "async_write_value: key=%s reg=%s raw=%s (scale=%s unit=%s type=%s)", + key, register, value, scale, unit, entity_type, + ) + return True + except Exception as exc: # noqa: BLE001 + _LOGGER.warning( + "async_write_value failed: key=%s reg=%s value=%s – %s", + key, register, value, exc, + ) + return False + + async def async_read_value( + self, + definition: dict, + key: str, + track_failure: bool = True, + ) -> None: + """ + Re-read a single register and update coordinator.data in place. + + Used by number entities after a failed write to restore the + actual device state without triggering a full coordinator refresh. + """ + if self._client is None: + return + try: + from .helpers.modbus_client import batch_read, extract_typed_value + + address = int(definition["register"]) + data_type = str(definition.get("data_type", "uint16")).lower() + count_map = {"uint16": 1, "int16": 1, "uint32": 2, "int32": 2, "char": 1} + count = int(definition.get("count", count_map.get(data_type, 1))) + if data_type in ("uint16", "int16"): + count = 1 + scale = float(definition.get("scale", 1.0)) + + cache = await batch_read(self._client, list(range(address, address + count)), max_gap=0) + value = extract_typed_value(cache, address, data_type, count, scale) + if value is not None and isinstance(self.data, dict): + self.data[key] = value + except Exception as exc: # noqa: BLE001 + if track_failure: + _LOGGER.debug("async_read_value failed for %s: %s", key, exc) + + +# --------------------------------------------------------------------------- +# Calculated sensors +# --------------------------------------------------------------------------- + +def _calculate_derived(data: dict[str, Any], raw_yaml: dict[str, Any]) -> dict[str, Any]: + result: dict[str, Any] = {} + + for key, defn in raw_yaml.get("EFFICIENCY_SENSOR_DEFINITIONS", {}).items(): + dep = defn.get("dependency_keys", {}) + mode = defn.get("mode", "round_trip") + if mode == "round_trip": + charge = data.get(dep.get("charge", "")) or 0.0 + discharge = data.get(dep.get("discharge", "")) or 0.0 + result[key] = round(discharge / charge * 100, 1) if charge > 0 else None + elif mode == "conversion": + batt = data.get(dep.get("battery_power", "")) or 0.0 + ac = data.get(dep.get("ac_power", "")) or 0.0 + result[key] = round(abs(ac / batt) * 100, 1) if batt != 0 else 0.0 + _LOGGER.debug( + "Calculated value for %s: %s (input values: %s)", + key, result.get(key), {k: data.get(v) for k, v in dep.items()}, ) - def _normalize_section(section): - """Convert mapping-based sections into the legacy list-of-dicts format.""" - if isinstance(section, dict): - normalized = [] - for key, value in section.items(): - entry = dict(value or {}) - entry.setdefault("key", key) - normalized.append(entry) - return normalized - if isinstance(section, list): - return section - return [] - - # Prefer YAML-based register definitions placed in the `registers/` folder. - # Map version tokens to YAML filenames. - filename_map = { - "e v1/v2": "e_v12.yaml", - "e v3": "e_v3.yaml", - "d": "d.yaml", - "a": "a.yaml", - } + for key, defn in raw_yaml.get("STORED_ENERGY_SENSOR_DEFINITIONS", {}).items(): + dep = defn.get("dependency_keys", {}) + soc = data.get(dep.get("soc", "")) or 0 + capacity = data.get(dep.get("capacity", "")) or 0.0 + result[key] = round(soc / 100 * capacity, 3) if soc and capacity else None + _LOGGER.debug( + "Calculated value for %s: %s (input values: %s)", + key, result.get(key), {k: data.get(v) for k, v in dep.items()}, + ) - yaml_filename = filename_map.get(version) - if yaml_filename: - yaml_path = Path(__file__).parent / "registers" / yaml_filename - if yaml_path.exists(): - try: - import yaml - - with open(yaml_path, "r", encoding="utf-8") as fh: - data = yaml.safe_load(fh) or {} - - return { - "SENSOR_DEFINITIONS": _normalize_section(data.get("SENSOR_DEFINITIONS")), - "BINARY_SENSOR_DEFINITIONS": _normalize_section(data.get("BINARY_SENSOR_DEFINITIONS")), - "SELECT_DEFINITIONS": _normalize_section(data.get("SELECT_DEFINITIONS")), - "SWITCH_DEFINITIONS": _normalize_section(data.get("SWITCH_DEFINITIONS")), - "NUMBER_DEFINITIONS": _normalize_section(data.get("NUMBER_DEFINITIONS")), - "BUTTON_DEFINITIONS": _normalize_section(data.get("BUTTON_DEFINITIONS")), - "EFFICIENCY_SENSOR_DEFINITIONS": _normalize_section( - data.get("EFFICIENCY_SENSOR_DEFINITIONS") - ), - "STORED_ENERGY_SENSOR_DEFINITIONS": _normalize_section( - data.get("STORED_ENERGY_SENSOR_DEFINITIONS") - ), - "CYCLE_SENSOR_DEFINITIONS": _normalize_section( - data.get("CYCLE_SENSOR_DEFINITIONS") - ), - } - except Exception as e: - _LOGGER.warning("Failed to load YAML registers %s: %s", yaml_path, e) + for key, defn in raw_yaml.get("CYCLE_SENSOR_DEFINITIONS", {}).items(): + dep = defn.get("dependency_keys", {}) + discharge = data.get(dep.get("discharge", "")) or 0.0 + capacity = data.get(dep.get("capacity", "")) or 0.0 + result[key] = round(discharge / capacity, 2) if capacity > 0 else None + _LOGGER.debug( + "Calculated value for %s: %s (input values: %s)", + key, result.get(key), {k: data.get(v) for k, v in dep.items()}, + ) + return result From 48099bf4866ea5f6f387163ebdcbe5ba75dac35f Mon Sep 17 00:00:00 2001 From: sphings79 <43515272+sphings79@users.noreply.github.com> Date: Wed, 13 May 2026 22:55:22 +0200 Subject: [PATCH 6/6] Update requirements in manifest.json Signed-off-by: sphings79 <43515272+sphings79@users.noreply.github.com> --- custom_components/marstek_modbus/manifest.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_components/marstek_modbus/manifest.json b/custom_components/marstek_modbus/manifest.json index e3ff4fe..390551a 100644 --- a/custom_components/marstek_modbus/manifest.json +++ b/custom_components/marstek_modbus/manifest.json @@ -4,7 +4,7 @@ "version": "2026.3.4", "config_flow": true, "documentation": "https://github.com/viperrnmc/marstek_venus_modbus", - "requirements": ["pymodbus>=3.9.2"], + "requirements": ["tmodbus"], "codeowners": ["@ViperRNMC"], "iot_class": "local_polling" -} \ No newline at end of file +}