Skip to content

Commit b4ff31e

Browse files
committed
Fix restart bugs and harden config write path
- Send SIGUSR1 instead of SIGTERM from the web UI restart handler so the process reloads config in-place instead of exiting. - Re-derive device_types, device_ids and skip_test after reloading cfg on restart so changed DEVICE_TYPE / DEVICE_IDS take effect. - Hold _CONFIG_WRITE_LOCK for the entire read-modify-write transaction to prevent concurrent requests from clobbering each other. - Reject duplicate section names in the order list.
1 parent acaf76b commit b4ff31e

4 files changed

Lines changed: 105 additions & 90 deletions

File tree

src/astrameter/health_service.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ async def _handle_api_config_post(self, request):
189189
)
190190

191191
async def _handle_api_restart(self, request):
192-
"""Acknowledge POST /api/restart and schedule a graceful shutdown via SIGTERM."""
192+
"""Acknowledge POST /api/restart and schedule an in-process restart via SIGUSR1."""
193193
import signal
194194

195195
response = web.Response(
@@ -199,10 +199,10 @@ async def _handle_api_restart(self, request):
199199
content_type="application/json",
200200
)
201201
logger.info("Restart requested via web UI")
202-
# Send SIGTERM to ourselves so the asyncio main loop handles graceful
203-
# shutdown (finally blocks, powermeter/MQTT cleanup) before the
204-
# external supervisor (Docker, systemd, HA addon) restarts the process.
205-
threading.Timer(0.5, lambda: os.kill(os.getpid(), signal.SIGTERM)).start()
202+
# Send SIGUSR1 so the handler in main.py sets restart_requested=True
203+
# before raising KeyboardInterrupt, causing the outer loop to reload
204+
# the config and re-run instead of exiting.
205+
threading.Timer(0.5, lambda: os.kill(os.getpid(), signal.SIGUSR1)).start()
206206
return response
207207

208208
async def _handle_not_found(self, request):

src/astrameter/main.py

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,65 @@ async def async_main(
390390
logger.exception("Error stopping health check service")
391391

392392

393+
def _resolve_device_config(
394+
cfg: configparser.ConfigParser, args: argparse.Namespace
395+
) -> tuple[list[str], list[str], bool]:
396+
"""Derive device_types, device_ids and skip_test from *cfg* and CLI *args*."""
397+
device_types = (
398+
args.device_types
399+
if args.device_types is not None
400+
else [
401+
dt.strip()
402+
for dt in cfg.get("GENERAL", "DEVICE_TYPE", fallback="shellypro3em").split(
403+
","
404+
)
405+
if dt.strip()
406+
]
407+
)
408+
skip_test = (
409+
args.skip_powermeter_test
410+
if args.skip_powermeter_test is not None
411+
else cfg.getboolean("GENERAL", "SKIP_POWERMETER_TEST", fallback=False)
412+
)
413+
414+
device_ids: list[str] = list(args.device_ids) if args.device_ids is not None else []
415+
if not device_ids:
416+
cfg_device_ids = cfg.get("GENERAL", "DEVICE_IDS", fallback="").strip()
417+
if cfg_device_ids:
418+
device_ids = [
419+
did.strip() for did in cfg_device_ids.split(",") if did.strip()
420+
]
421+
while len(device_ids) < len(device_types):
422+
device_type = device_types[len(device_ids)]
423+
if device_type in ["shellypro3em", "shellyemg3", "shellyproem50"]:
424+
device_ids.append(f"{device_type}-ec4609c439c{len(device_ids) + 1}")
425+
else:
426+
device_ids.append(f"device-{len(device_ids) + 1}")
427+
428+
if "shellypro3em" in device_types:
429+
shellypro3em_index = device_types.index("shellypro3em")
430+
device_types[shellypro3em_index] = "shellypro3em_old"
431+
device_types.append("shellypro3em_new")
432+
device_ids.append(device_ids[shellypro3em_index])
433+
434+
ct_ports = []
435+
for device_type in device_types:
436+
if device_type in ["ct002", "ct003"]:
437+
section = get_ct_section(device_type, cfg)
438+
ct_ports.append(cfg.getint(section, "UDP_PORT", fallback=UDP_PORT))
439+
if len(ct_ports) != len(set(ct_ports)):
440+
raise ValueError(
441+
"Multiple CT002/CT003 devices are configured with the same UDP port. "
442+
"Set UDP_PORT in [CT002]/[CT003] to avoid conflicts."
443+
)
444+
445+
logger.info(f"Device Types: {device_types}")
446+
logger.info(f"Device IDs: {device_ids}")
447+
logger.info(f"Skip Test: {skip_test}")
448+
449+
return device_types, device_ids, skip_test
450+
451+
393452
def main():
394453
parser = argparse.ArgumentParser(description="Power meter device emulator")
395454
parser.add_argument(
@@ -444,61 +503,7 @@ def main():
444503
"Git commit not logged (set GIT_COMMIT_SHA at image build for CI images)"
445504
)
446505

447-
# Load general settings
448-
device_types = (
449-
args.device_types
450-
if args.device_types is not None
451-
else [
452-
dt.strip()
453-
for dt in cfg.get("GENERAL", "DEVICE_TYPE", fallback="shellypro3em").split(
454-
","
455-
)
456-
if dt.strip()
457-
]
458-
)
459-
skip_test = (
460-
args.skip_powermeter_test
461-
if args.skip_powermeter_test is not None
462-
else cfg.getboolean("GENERAL", "SKIP_POWERMETER_TEST", fallback=False)
463-
)
464-
465-
device_ids = args.device_ids if args.device_ids is not None else []
466-
# Load device IDs from config if not provided via CLI
467-
if not device_ids:
468-
cfg_device_ids = cfg.get("GENERAL", "DEVICE_IDS", fallback="").strip()
469-
if cfg_device_ids:
470-
device_ids = [
471-
did.strip() for did in cfg_device_ids.split(",") if did.strip()
472-
]
473-
# Fill missing device IDs with default format
474-
while len(device_ids) < len(device_types):
475-
device_type = device_types[len(device_ids)]
476-
if device_type in ["shellypro3em", "shellyemg3", "shellyproem50"]:
477-
device_ids.append(f"{device_type}-ec4609c439c{len(device_ids) + 1}")
478-
else:
479-
device_ids.append(f"device-{len(device_ids) + 1}")
480-
481-
# For backward compatibility, replace shellypro3em with shellypro3em_old and shellypro3em_new
482-
if "shellypro3em" in device_types:
483-
shellypro3em_index = device_types.index("shellypro3em")
484-
device_types[shellypro3em_index] = "shellypro3em_old"
485-
device_types.append("shellypro3em_new")
486-
device_ids.append(device_ids[shellypro3em_index])
487-
488-
ct_ports = []
489-
for device_type in device_types:
490-
if device_type in ["ct002", "ct003"]:
491-
section = get_ct_section(device_type, cfg)
492-
ct_ports.append(cfg.getint(section, "UDP_PORT", fallback=UDP_PORT))
493-
if len(ct_ports) != len(set(ct_ports)):
494-
raise ValueError(
495-
"Multiple CT002/CT003 devices are configured with the same UDP port. "
496-
"Set UDP_PORT in [CT002]/[CT003] to avoid conflicts."
497-
)
498-
499-
logger.info(f"Device Types: {device_types}")
500-
logger.info(f"Device IDs: {device_ids}")
501-
logger.info(f"Skip Test: {skip_test}")
506+
device_types, device_ids, skip_test = _resolve_device_config(cfg, args)
502507

503508
# Apply command line throttling override if specified
504509
if args.throttle_interval is not None:
@@ -586,6 +591,7 @@ def _restart_handler(signum, frame):
586591
logger.info("Restarting service…")
587592
cfg = configparser.ConfigParser(dict_type=OrderedDict, interpolation=None)
588593
cfg.read(args.config)
594+
device_types, device_ids, skip_test = _resolve_device_config(cfg, args)
589595
except RuntimeError as exc:
590596
logger.error("%s", exc)
591597
exit(1)

src/astrameter/web_config.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ def _validate_config_payload(sections: dict, order: list) -> None:
132132
"""Raise ValueError if any section name, key, or value would corrupt the INI."""
133133
if not isinstance(order, list) or any(not isinstance(s, str) for s in order):
134134
raise ValueError("'order' must be a list of section names")
135+
if len(order) != len(set(order)):
136+
raise ValueError("'order' contains duplicate section names")
135137
for section, pairs in sections.items():
136138
if (
137139
not isinstance(section, str)
@@ -164,37 +166,39 @@ def write_config_from_dict(config_path: str, sections: dict, order: list) -> Non
164166
_validate_config_payload(sections, order)
165167
write_order = list(order) + [s for s in sections if s not in order]
166168

167-
updater = ConfigUpdater()
168-
updater.optionxform = str # type: ignore[assignment] # preserve key case
169-
170-
if os.path.exists(config_path):
171-
updater.read(config_path)
172-
173-
# Update existing sections and add new keys / remove stale keys.
174-
for section_name, new_pairs in sections.items():
175-
if updater.has_section(section_name):
176-
for key in set(updater.options(section_name)) - new_pairs.keys():
177-
updater.remove_option(section_name, key)
178-
else:
179-
updater.add_section(section_name)
180-
for key, value in new_pairs.items():
181-
updater.set(section_name, key, value)
182-
183-
# Remove sections not present in the incoming payload.
184-
for section_name in list(updater.sections()):
185-
if section_name not in sections:
186-
updater.remove_section(section_name)
187-
188-
# Re-order sections to match *write_order* by rebuilding from
189-
# detached copies. Only needed when the order actually differs.
190-
current_order = updater.sections()
191-
desired = [s for s in write_order if s in sections]
192-
if current_order != desired:
193-
detached = {name: updater[name].detach() for name in list(updater.sections())}
194-
for name in desired:
195-
updater.add_section(detached[name])
196-
197169
with _CONFIG_WRITE_LOCK:
170+
updater = ConfigUpdater()
171+
updater.optionxform = str # type: ignore[assignment] # preserve key case
172+
173+
if os.path.exists(config_path):
174+
updater.read(config_path)
175+
176+
# Update existing sections and add new keys / remove stale keys.
177+
for section_name, new_pairs in sections.items():
178+
if updater.has_section(section_name):
179+
for key in set(updater.options(section_name)) - new_pairs.keys():
180+
updater.remove_option(section_name, key)
181+
else:
182+
updater.add_section(section_name)
183+
for key, value in new_pairs.items():
184+
updater.set(section_name, key, value)
185+
186+
# Remove sections not present in the incoming payload.
187+
for section_name in list(updater.sections()):
188+
if section_name not in sections:
189+
updater.remove_section(section_name)
190+
191+
# Re-order sections to match *write_order* by rebuilding from
192+
# detached copies. Only needed when the order actually differs.
193+
current_order = updater.sections()
194+
desired = [s for s in write_order if s in sections]
195+
if current_order != desired:
196+
detached = {
197+
name: updater[name].detach() for name in list(updater.sections())
198+
}
199+
for name in desired:
200+
updater.add_section(detached[name])
201+
198202
_atomic_write_lines(config_path, [str(updater)])
199203

200204

src/astrameter/web_config_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ def test_validate_value_with_newline():
8282
_validate_config_payload({"S": {"k": "a\nb"}}, ["S"])
8383

8484

85+
def test_validate_duplicate_order():
86+
with pytest.raises(ValueError, match="duplicate section names"):
87+
_validate_config_payload({"S": {"k": "v"}}, ["S", "S"])
88+
89+
8590
# ---------- write_config_from_dict — new file ----------
8691

8792

0 commit comments

Comments
 (0)