Skip to content

Commit 4d31bb4

Browse files
authored
Allow any reset during startup, not just RESET_SOFTWARE (#694)
* Allow any reset during startup, not just `RESET_SOFTWARE` * Attempt to reset multiple times * Reduce reset timeout * Deal with spurious resets better * Update unit tests to account for five retries on connect * Add some tests * Bump default Python version to 3.11.14 * Drop default Python version * Drop `PRE_COMMIT_CACHE_PATH`
1 parent 55bdc1c commit 4d31bb4

File tree

7 files changed

+120
-21
lines changed

7 files changed

+120
-21
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ jobs:
1010
with:
1111
CODE_FOLDER: bellows
1212
CACHE_VERSION: 2
13-
PYTHON_VERSION_DEFAULT: 3.11.0
14-
PRE_COMMIT_CACHE_PATH: ~/.cache/pre-commit
1513
MINIMUM_COVERAGE_PERCENTAGE: 99
1614
secrets:
1715
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

bellows/ash.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ def rstack_frame_received(self, frame: RStackFrame) -> None:
558558

559559
self._tx_seq = 0
560560
self._rx_seq = 0
561+
self._cancel_pending_data_frames(NcpFailure(code=frame.reset_code))
561562
self._change_ack_timeout(T_RX_ACK_INIT)
562563
self._ezsp_protocol.reset_received(frame.reset_code)
563564

@@ -582,7 +583,7 @@ def error_frame_received(self, frame: ErrorFrame) -> None:
582583
def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None:
583584
self._ncp_state = NcpState.FAILED
584585
self._cancel_pending_data_frames(NcpFailure(code=reset_code))
585-
self._ezsp_protocol.reset_received(reset_code)
586+
self._ezsp_protocol.error_received(reset_code)
586587

587588
def _write_frame(
588589
self,

bellows/ezsp/__init__.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
from . import v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v16, v17
2828

29+
RESET_ATTEMPTS = 5
30+
2931
EZSP_LATEST = v17.EZSPv17.VERSION
3032
LOGGER = logging.getLogger(__name__)
3133
MTOR_MIN_INTERVAL = 60
@@ -130,12 +132,24 @@ async def connect(self, *, use_thread: bool = True) -> None:
130132
assert self._gw is None
131133
self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread)
132134

133-
try:
135+
for attempt in range(RESET_ATTEMPTS):
134136
self._protocol = v4.EZSPv4(self.handle_callback, self._gw)
135-
await self.startup_reset()
136-
except Exception:
137-
await self.disconnect()
138-
raise
137+
138+
try:
139+
await self.startup_reset()
140+
break
141+
except Exception as exc:
142+
if attempt + 1 < RESET_ATTEMPTS:
143+
LOGGER.debug(
144+
"EZSP startup/reset failed, retrying (%d/%d): %r",
145+
attempt + 1,
146+
RESET_ATTEMPTS,
147+
exc,
148+
)
149+
continue
150+
151+
await self.disconnect()
152+
raise
139153

140154
async def reset(self):
141155
LOGGER.debug("Resetting EZSP")

bellows/uart.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import bellows.types as t
1111

1212
LOGGER = logging.getLogger(__name__)
13-
RESET_TIMEOUT = 5
13+
RESET_TIMEOUT = 3
1414

1515

1616
class Gateway(zigpy.serial.SerialProtocol):
@@ -33,21 +33,22 @@ def data_received(self, data):
3333

3434
def reset_received(self, code: t.NcpResetCode) -> None:
3535
"""Reset acknowledgement frame receive handler"""
36-
# not a reset we've requested. Signal api reset
37-
if code is not t.NcpResetCode.RESET_SOFTWARE:
38-
self._api.enter_failed_state(code)
39-
return
36+
LOGGER.debug("Received reset: %r", code)
4037

4138
if self._reset_future and not self._reset_future.done():
4239
self._reset_future.set_result(True)
4340
elif self._startup_reset_future and not self._startup_reset_future.done():
4441
self._startup_reset_future.set_result(True)
4542
else:
43+
self._api.enter_failed_state(code)
4644
LOGGER.warning("Received an unexpected reset: %r", code)
4745

4846
def error_received(self, code: t.NcpResetCode) -> None:
4947
"""Error frame receive handler."""
50-
self._api.enter_failed_state(code)
48+
if self._reset_future is not None or self._startup_reset_future is not None:
49+
LOGGER.debug("Ignoring spurious error during reset: %r", code)
50+
else:
51+
self._api.enter_failed_state(code)
5152

5253
async def wait_for_startup_reset(self) -> None:
5354
"""Wait for the first reset frame on startup."""

tests/test_ash.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,43 @@ async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None:
605605
await host.send_data(b"ncp NAKing until failure")
606606

607607

608+
async def test_rstack_cancels_pending_frames() -> None:
609+
"""Test that RSTACK frame cancels pending data frames."""
610+
host_ezsp = MagicMock()
611+
ncp_ezsp = MagicMock()
612+
613+
host = ash.AshProtocol(host_ezsp)
614+
ncp = AshNcpProtocol(ncp_ezsp)
615+
616+
host_transport = FakeTransport(ncp)
617+
ncp_transport = FakeTransport(host)
618+
619+
host.connection_made(host_transport)
620+
ncp.connection_made(ncp_transport)
621+
622+
# Pause the NCP transport so ACKs can't be sent back, creating a pending frame
623+
ncp_transport.paused = True
624+
625+
# Start sending data without awaiting - this will create a pending frame
626+
send_task = asyncio.create_task(host.send_data(b"test data"))
627+
628+
# Give task time to start and create the pending frame
629+
await asyncio.sleep(0.1)
630+
631+
# Verify we have a pending frame
632+
assert len(host._pending_data_frames) == 1
633+
634+
# Trigger RSTACK frame to cancel the pending frame
635+
rstack = ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_POWER_ON)
636+
host.rstack_frame_received(rstack)
637+
638+
# Verify task was cancelled with NcpFailure containing the reset code
639+
with pytest.raises(ash.NcpFailure) as exc_info:
640+
await send_task
641+
642+
assert exc_info.value.code == t.NcpResetCode.RESET_POWER_ON
643+
644+
608645
def test_ncp_failure_comparison() -> None:
609646
exc1 = ash.NcpFailure(code=t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT)
610647
exc2 = ash.NcpFailure(code=t.NcpResetCode.RESET_POWER_ON)

tests/test_ezsp.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,33 @@ async def test_ezsp_connect_failure(disconnect_mock, reset_mock, version_mock):
299299
await ezsp.connect()
300300

301301
assert conn_mock.await_count == 1
302-
assert reset_mock.await_count == 1
303-
assert version_mock.await_count == 1
302+
assert reset_mock.await_count == 5
303+
assert version_mock.await_count == 5
304304
assert disconnect_mock.call_count == 1
305305

306306

307+
@pytest.mark.parametrize("failures_before_success", [1, 2, 3, 4])
308+
@patch.object(EZSP, "disconnect", new_callable=AsyncMock)
309+
async def test_ezsp_connect_retry_success(disconnect_mock, failures_before_success):
310+
"""Test connection succeeding after N failures."""
311+
call_count = 0
312+
313+
async def startup_reset_mock():
314+
nonlocal call_count
315+
call_count += 1
316+
if call_count <= failures_before_success:
317+
raise RuntimeError(f"Startup failed (attempt {call_count})")
318+
319+
with patch("bellows.uart.connect"):
320+
ezsp = make_ezsp(version=4)
321+
322+
with patch.object(ezsp, "startup_reset", side_effect=startup_reset_mock):
323+
await ezsp.connect()
324+
325+
assert call_count == failures_before_success + 1
326+
assert disconnect_mock.call_count == 0
327+
328+
307329
async def test_ezsp_newer_version(ezsp_f):
308330
"""Test newer version of ezsp."""
309331
with patch.object(

tests/test_uart.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,18 @@ def on_transport_close():
211211
assert len(threads) == 0
212212

213213

214-
async def test_wait_for_startup_reset(gw):
214+
@pytest.mark.parametrize(
215+
"reset_code",
216+
[
217+
t.NcpResetCode.RESET_SOFTWARE,
218+
t.NcpResetCode.RESET_POWER_ON,
219+
t.NcpResetCode.RESET_WATCHDOG,
220+
t.NcpResetCode.RESET_EXTERNAL,
221+
],
222+
)
223+
async def test_wait_for_startup_reset(gw, reset_code):
215224
loop = asyncio.get_running_loop()
216-
loop.call_later(0.01, gw.reset_received, t.NcpResetCode.RESET_SOFTWARE)
225+
loop.call_later(0.01, gw.reset_received, reset_code)
217226

218227
assert gw._startup_reset_future is None
219228
await gw.wait_for_startup_reset()
@@ -239,8 +248,25 @@ async def test_callbacks(gw):
239248
]
240249

241250

242-
def test_reset_propagation(gw):
243-
gw.reset_received(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT)
251+
async def test_error_received_during_reset_ignored(gw):
252+
# Set up a reset future to simulate being in the middle of a reset
253+
loop = asyncio.get_running_loop()
254+
gw._reset_future = loop.create_future()
255+
256+
# Error should be ignored (not trigger failed state)
257+
gw.error_received(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT)
258+
assert gw._api.enter_failed_state.call_count == 0
259+
260+
# Clean up
261+
gw._reset_future.cancel()
262+
263+
264+
def test_unexpected_reset_triggers_failed_state(gw):
265+
# When no reset is expected, any reset should trigger failed state
266+
assert gw._reset_future is None
267+
assert gw._startup_reset_future is None
268+
269+
gw.reset_received(t.NcpResetCode.RESET_SOFTWARE)
244270
assert gw._api.enter_failed_state.mock_calls == [
245-
call(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT)
271+
call(t.NcpResetCode.RESET_SOFTWARE)
246272
]

0 commit comments

Comments
 (0)