Skip to content

Commit 70910bc

Browse files
authored
Cleanly shut down the serial port on disconnect (#259)
* Cleanly shut down the serial port on disconnect * Send `connection_lost` even if we do not have an open serial connection * Call `super().close()` in `SerialProtocol` * Use `self._transport.write` instead of `send_data` * Let zigpy handle flow control * Bump minimum zigpy version * Fix unit tests * Make `api` an async fixture to grab reference to loop early * Set default pytest-asyncio fixture loop scope * Fix unit test failing due to event loop caching issue in pytest-asyncio * Bring test coverage up
1 parent aa26bbd commit 70910bc

File tree

7 files changed

+87
-66
lines changed

7 files changed

+87
-66
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ license = {text = "GPL-3.0"}
1515
requires-python = ">=3.8"
1616
dependencies = [
1717
"voluptuous",
18-
"zigpy>=0.68.0",
18+
"zigpy>=0.70.0",
1919
'async-timeout; python_version<"3.11"',
2020
]
2121

@@ -47,6 +47,7 @@ ignore_errors = true
4747

4848
[tool.pytest.ini_options]
4949
asyncio_mode = "auto"
50+
asyncio_default_fixture_loop_scope = "function"
5051

5152
[tool.flake8]
5253
exclude = [".venv", ".git", ".tox", "docs", "venv", "bin", "lib", "deps", "build"]

tests/test_api.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,23 @@
2525

2626

2727
@pytest.fixture
28-
def gateway():
28+
async def gateway():
2929
return uart.Gateway(api=None)
3030

3131

3232
@pytest.fixture
33-
def api(gateway, mock_command_rsp):
33+
async def api(gateway, mock_command_rsp):
34+
loop = asyncio.get_running_loop()
35+
3436
async def mock_connect(config, api):
37+
transport = MagicMock()
38+
transport.close = MagicMock(
39+
side_effect=lambda: loop.call_soon(gateway.connection_lost, None)
40+
)
41+
3542
gateway._api = api
36-
gateway.connection_made(MagicMock())
43+
gateway.connection_made(transport)
44+
3745
return gateway
3846

3947
with patch("zigpy_deconz.uart.connect", side_effect=mock_connect):
@@ -178,15 +186,33 @@ async def test_connect(api, mock_command_rsp):
178186
await api.connect()
179187

180188

189+
async def test_connect_failure(api, mock_command_rsp):
190+
transport = None
191+
192+
def mock_version(*args, **kwargs):
193+
nonlocal transport
194+
transport = api._uart._transport
195+
196+
raise asyncio.TimeoutError()
197+
198+
with patch.object(api, "version", side_effect=mock_version):
199+
# We connect but fail to probe
200+
with pytest.raises(asyncio.TimeoutError):
201+
await api.connect()
202+
203+
assert api._uart is None
204+
assert len(transport.close.mock_calls) == 1
205+
206+
181207
async def test_close(api):
182208
await api.connect()
183209

184210
uart = api._uart
185-
uart.close = MagicMock(wraps=uart.close)
211+
uart.disconnect = AsyncMock()
186212

187-
api.close()
213+
await api.disconnect()
188214
assert api._uart is None
189-
assert uart.close.call_count == 1
215+
assert uart.disconnect.call_count == 1
190216

191217

192218
def test_commands():
@@ -898,11 +924,9 @@ async def test_data_poller(api, mock_command_rsp):
898924

899925
# The task is cancelled on close
900926
task = api._data_poller_task
901-
api.close()
927+
await api.disconnect()
902928
assert api._data_poller_task is None
903-
904-
if sys.version_info >= (3, 11):
905-
assert task.cancelling()
929+
assert task.done()
906930

907931

908932
async def test_get_device_state(api, mock_command_rsp):

tests/test_application.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ async def test_connect_failure(app):
187187
with patch.object(application, "Deconz") as api_mock:
188188
api = api_mock.return_value = MagicMock()
189189
api.connect = AsyncMock(side_effect=RuntimeError("Broken"))
190+
api.disconnect = AsyncMock()
190191

191192
app._api = None
192193

@@ -195,16 +196,16 @@ async def test_connect_failure(app):
195196

196197
assert app._api is None
197198
api.connect.assert_called_once()
198-
api.close.assert_called_once()
199+
api.disconnect.assert_called_once()
199200

200201

201202
async def test_disconnect(app):
202-
api_close = app._api.close = MagicMock()
203+
api_disconnect = app._api.disconnect = AsyncMock()
203204

204205
await app.disconnect()
205206

206207
assert app._api is None
207-
assert api_close.call_count == 1
208+
assert api_disconnect.call_count == 1
208209

209210

210211
async def test_disconnect_no_api(app):

tests/test_uart.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from unittest import mock
55

66
import pytest
7-
from zigpy.config import CONF_DEVICE_BAUDRATE, CONF_DEVICE_PATH
7+
from zigpy.config import (
8+
CONF_DEVICE_BAUDRATE,
9+
CONF_DEVICE_FLOW_CONTROL,
10+
CONF_DEVICE_PATH,
11+
)
812
import zigpy.serial
913

1014
from zigpy_deconz import uart
@@ -28,7 +32,12 @@ async def mock_conn(loop, protocol_factory, **kwargs):
2832
monkeypatch.setattr(zigpy.serial, "create_serial_connection", mock_conn)
2933

3034
await uart.connect(
31-
{CONF_DEVICE_PATH: "/dev/null", CONF_DEVICE_BAUDRATE: 115200}, api
35+
{
36+
CONF_DEVICE_PATH: "/dev/null",
37+
CONF_DEVICE_BAUDRATE: 115200,
38+
CONF_DEVICE_FLOW_CONTROL: None,
39+
},
40+
api,
3241
)
3342

3443

zigpy_deconz/api.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
else:
1515
from asyncio import timeout as asyncio_timeout # pragma: no cover
1616

17-
from zigpy.config import CONF_DEVICE_PATH
1817
from zigpy.datastructures import PriorityLock
1918
from zigpy.types import (
2019
APSStatus,
@@ -461,37 +460,37 @@ def protocol_version(self) -> int:
461460

462461
async def connect(self) -> None:
463462
assert self._uart is None
463+
464464
self._uart = await zigpy_deconz.uart.connect(self._config, self)
465465

466-
await self.version()
466+
try:
467+
await self.version()
468+
device_state_rsp = await self.send_command(CommandId.device_state)
469+
except Exception:
470+
await self.disconnect()
471+
self._uart = None
472+
raise
467473

468-
device_state_rsp = await self.send_command(CommandId.device_state)
469474
self._device_state = device_state_rsp["device_state"]
470475

471476
self._data_poller_task = asyncio.create_task(self._data_poller())
472477

473-
def connection_lost(self, exc: Exception) -> None:
478+
def connection_lost(self, exc: Exception | None) -> None:
474479
"""Lost serial connection."""
475-
LOGGER.debug(
476-
"Serial %r connection lost unexpectedly: %r",
477-
self._config[CONF_DEVICE_PATH],
478-
exc,
479-
)
480-
481480
if self._app is not None:
482481
self._app.connection_lost(exc)
483482

484-
def close(self):
485-
self._app = None
486-
483+
async def disconnect(self):
487484
if self._data_poller_task is not None:
488485
self._data_poller_task.cancel()
489486
self._data_poller_task = None
490487

491488
if self._uart is not None:
492-
self._uart.close()
489+
await self._uart.disconnect()
493490
self._uart = None
494491

492+
self._app = None
493+
495494
def _get_command_priority(self, command: Command) -> int:
496495
return {
497496
# The watchdog is fed using `write_parameter` and `get_device_state` so they

zigpy_deconz/uart.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,50 @@
11
"""Uart module."""
22

3+
from __future__ import annotations
4+
35
import asyncio
46
import binascii
57
import logging
6-
from typing import Callable, Dict
8+
from typing import Any, Callable
79

810
import zigpy.config
911
import zigpy.serial
1012

1113
LOGGER = logging.getLogger(__name__)
1214

1315

14-
class Gateway(asyncio.Protocol):
16+
class Gateway(zigpy.serial.SerialProtocol):
1517
END = b"\xC0"
1618
ESC = b"\xDB"
1719
ESC_END = b"\xDC"
1820
ESC_ESC = b"\xDD"
1921

20-
def __init__(self, api, connected_future=None):
22+
def __init__(self, api):
2123
"""Initialize instance of the UART gateway."""
22-
24+
super().__init__()
2325
self._api = api
24-
self._buffer = b""
25-
self._connected_future = connected_future
26-
self._transport = None
2726

28-
def connection_lost(self, exc) -> None:
27+
def connection_lost(self, exc: Exception | None) -> None:
2928
"""Port was closed expectedly or unexpectedly."""
29+
super().connection_lost(exc)
3030

31-
if exc is not None:
32-
LOGGER.warning("Lost connection: %r", exc, exc_info=exc)
33-
34-
self._api.connection_lost(exc)
35-
36-
def connection_made(self, transport):
37-
"""Call this when the uart connection is established."""
38-
39-
LOGGER.debug("Connection made")
40-
self._transport = transport
41-
if self._connected_future and not self._connected_future.done():
42-
self._connected_future.set_result(True)
31+
if self._api is not None:
32+
self._api.connection_lost(exc)
4333

4434
def close(self):
45-
self._transport.close()
35+
super().close()
36+
self._api = None
4637

47-
def send(self, data):
38+
def send(self, data: bytes) -> None:
4839
"""Send data, taking care of escaping and framing."""
49-
LOGGER.debug("Send: %s", binascii.hexlify(data).decode())
5040
checksum = bytes(self._checksum(data))
5141
frame = self._escape(data + checksum)
5242
self._transport.write(self.END + frame + self.END)
5343

54-
def data_received(self, data):
44+
def data_received(self, data: bytes) -> None:
5545
"""Handle data received from the uart."""
56-
self._buffer += data
46+
super().data_received(data)
47+
5748
while self._buffer:
5849
end = self._buffer.find(self.END)
5950
if end < 0:
@@ -121,23 +112,19 @@ def _checksum(self, data):
121112
return bytes(ret)
122113

123114

124-
async def connect(config: Dict[str, any], api: Callable) -> Gateway:
125-
loop = asyncio.get_running_loop()
126-
connected_future = loop.create_future()
127-
protocol = Gateway(api, connected_future)
115+
async def connect(config: dict[str, Any], api: Callable) -> Gateway:
116+
protocol = Gateway(api)
128117

129118
LOGGER.debug("Connecting to %s", config[zigpy.config.CONF_DEVICE_PATH])
130119

131120
_, protocol = await zigpy.serial.create_serial_connection(
132-
loop=loop,
121+
loop=asyncio.get_running_loop(),
133122
protocol_factory=lambda: protocol,
134123
url=config[zigpy.config.CONF_DEVICE_PATH],
135124
baudrate=config[zigpy.config.CONF_DEVICE_BAUDRATE],
136-
xonxoff=False,
125+
flow_control=config[zigpy.config.CONF_DEVICE_FLOW_CONTROL],
137126
)
138127

139-
await connected_future
140-
141-
LOGGER.debug("Connected to %s", config[zigpy.config.CONF_DEVICE_PATH])
128+
await protocol.wait_until_connected()
142129

143130
return protocol

zigpy_deconz/zigbee/application.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def connect(self):
9797
try:
9898
await api.connect()
9999
except Exception:
100-
api.close()
100+
await api.disconnect()
101101
raise
102102

103103
self._api = api
@@ -109,7 +109,7 @@ async def disconnect(self):
109109
self._delayed_neighbor_scan_task = None
110110

111111
if self._api is not None:
112-
self._api.close()
112+
await self._api.disconnect()
113113
self._api = None
114114

115115
async def permit_with_link_key(self, node: t.EUI64, link_key: t.KeyData, time_s=60):

0 commit comments

Comments
 (0)