diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index abf9001e..c49490f9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,8 +10,6 @@ jobs: with: CODE_FOLDER: bellows CACHE_VERSION: 2 - PYTHON_VERSION_DEFAULT: 3.11.0 - PRE_COMMIT_CACHE_PATH: ~/.cache/pre-commit MINIMUM_COVERAGE_PERCENTAGE: 99 secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/bellows/ash.py b/bellows/ash.py index 7ba8cd14..252845e0 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -558,6 +558,7 @@ def rstack_frame_received(self, frame: RStackFrame) -> None: self._tx_seq = 0 self._rx_seq = 0 + self._cancel_pending_data_frames(NcpFailure(code=frame.reset_code)) self._change_ack_timeout(T_RX_ACK_INIT) self._ezsp_protocol.reset_received(frame.reset_code) @@ -582,7 +583,7 @@ def error_frame_received(self, frame: ErrorFrame) -> None: def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None: self._ncp_state = NcpState.FAILED self._cancel_pending_data_frames(NcpFailure(code=reset_code)) - self._ezsp_protocol.reset_received(reset_code) + self._ezsp_protocol.error_received(reset_code) def _write_frame( self, diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 70a1f220..c5158f4a 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -26,6 +26,8 @@ from . import v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v16, v17 +RESET_ATTEMPTS = 5 + EZSP_LATEST = v17.EZSPv17.VERSION LOGGER = logging.getLogger(__name__) MTOR_MIN_INTERVAL = 60 @@ -130,12 +132,24 @@ async def connect(self, *, use_thread: bool = True) -> None: assert self._gw is None self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread) - try: + for attempt in range(RESET_ATTEMPTS): self._protocol = v4.EZSPv4(self.handle_callback, self._gw) - await self.startup_reset() - except Exception: - await self.disconnect() - raise + + try: + await self.startup_reset() + break + except Exception as exc: + if attempt + 1 < RESET_ATTEMPTS: + LOGGER.debug( + "EZSP startup/reset failed, retrying (%d/%d): %r", + attempt + 1, + RESET_ATTEMPTS, + exc, + ) + continue + + await self.disconnect() + raise async def reset(self): LOGGER.debug("Resetting EZSP") diff --git a/bellows/uart.py b/bellows/uart.py index ba377a0f..07066a8f 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -10,7 +10,7 @@ import bellows.types as t LOGGER = logging.getLogger(__name__) -RESET_TIMEOUT = 5 +RESET_TIMEOUT = 3 class Gateway(zigpy.serial.SerialProtocol): @@ -33,21 +33,22 @@ def data_received(self, data): def reset_received(self, code: t.NcpResetCode) -> None: """Reset acknowledgement frame receive handler""" - # not a reset we've requested. Signal api reset - if code is not t.NcpResetCode.RESET_SOFTWARE: - self._api.enter_failed_state(code) - return + LOGGER.debug("Received reset: %r", code) if self._reset_future and not self._reset_future.done(): self._reset_future.set_result(True) elif self._startup_reset_future and not self._startup_reset_future.done(): self._startup_reset_future.set_result(True) else: + self._api.enter_failed_state(code) LOGGER.warning("Received an unexpected reset: %r", code) def error_received(self, code: t.NcpResetCode) -> None: """Error frame receive handler.""" - self._api.enter_failed_state(code) + if self._reset_future is not None or self._startup_reset_future is not None: + LOGGER.debug("Ignoring spurious error during reset: %r", code) + else: + self._api.enter_failed_state(code) async def wait_for_startup_reset(self) -> None: """Wait for the first reset frame on startup.""" diff --git a/tests/test_ash.py b/tests/test_ash.py index af67fa02..8d60c02b 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -605,6 +605,43 @@ async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: await host.send_data(b"ncp NAKing until failure") +async def test_rstack_cancels_pending_frames() -> None: + """Test that RSTACK frame cancels pending data frames.""" + host_ezsp = MagicMock() + ncp_ezsp = MagicMock() + + host = ash.AshProtocol(host_ezsp) + ncp = AshNcpProtocol(ncp_ezsp) + + host_transport = FakeTransport(ncp) + ncp_transport = FakeTransport(host) + + host.connection_made(host_transport) + ncp.connection_made(ncp_transport) + + # Pause the NCP transport so ACKs can't be sent back, creating a pending frame + ncp_transport.paused = True + + # Start sending data without awaiting - this will create a pending frame + send_task = asyncio.create_task(host.send_data(b"test data")) + + # Give task time to start and create the pending frame + await asyncio.sleep(0.1) + + # Verify we have a pending frame + assert len(host._pending_data_frames) == 1 + + # Trigger RSTACK frame to cancel the pending frame + rstack = ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_POWER_ON) + host.rstack_frame_received(rstack) + + # Verify task was cancelled with NcpFailure containing the reset code + with pytest.raises(ash.NcpFailure) as exc_info: + await send_task + + assert exc_info.value.code == t.NcpResetCode.RESET_POWER_ON + + def test_ncp_failure_comparison() -> None: exc1 = ash.NcpFailure(code=t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT) exc2 = ash.NcpFailure(code=t.NcpResetCode.RESET_POWER_ON) diff --git a/tests/test_ezsp.py b/tests/test_ezsp.py index 4ba29f93..e37b3781 100644 --- a/tests/test_ezsp.py +++ b/tests/test_ezsp.py @@ -299,11 +299,33 @@ async def test_ezsp_connect_failure(disconnect_mock, reset_mock, version_mock): await ezsp.connect() assert conn_mock.await_count == 1 - assert reset_mock.await_count == 1 - assert version_mock.await_count == 1 + assert reset_mock.await_count == 5 + assert version_mock.await_count == 5 assert disconnect_mock.call_count == 1 +@pytest.mark.parametrize("failures_before_success", [1, 2, 3, 4]) +@patch.object(EZSP, "disconnect", new_callable=AsyncMock) +async def test_ezsp_connect_retry_success(disconnect_mock, failures_before_success): + """Test connection succeeding after N failures.""" + call_count = 0 + + async def startup_reset_mock(): + nonlocal call_count + call_count += 1 + if call_count <= failures_before_success: + raise RuntimeError(f"Startup failed (attempt {call_count})") + + with patch("bellows.uart.connect"): + ezsp = make_ezsp(version=4) + + with patch.object(ezsp, "startup_reset", side_effect=startup_reset_mock): + await ezsp.connect() + + assert call_count == failures_before_success + 1 + assert disconnect_mock.call_count == 0 + + async def test_ezsp_newer_version(ezsp_f): """Test newer version of ezsp.""" with patch.object( diff --git a/tests/test_uart.py b/tests/test_uart.py index 6908c539..7cc44b27 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -211,9 +211,18 @@ def on_transport_close(): assert len(threads) == 0 -async def test_wait_for_startup_reset(gw): +@pytest.mark.parametrize( + "reset_code", + [ + t.NcpResetCode.RESET_SOFTWARE, + t.NcpResetCode.RESET_POWER_ON, + t.NcpResetCode.RESET_WATCHDOG, + t.NcpResetCode.RESET_EXTERNAL, + ], +) +async def test_wait_for_startup_reset(gw, reset_code): loop = asyncio.get_running_loop() - loop.call_later(0.01, gw.reset_received, t.NcpResetCode.RESET_SOFTWARE) + loop.call_later(0.01, gw.reset_received, reset_code) assert gw._startup_reset_future is None await gw.wait_for_startup_reset() @@ -239,8 +248,25 @@ async def test_callbacks(gw): ] -def test_reset_propagation(gw): - gw.reset_received(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT) +async def test_error_received_during_reset_ignored(gw): + # Set up a reset future to simulate being in the middle of a reset + loop = asyncio.get_running_loop() + gw._reset_future = loop.create_future() + + # Error should be ignored (not trigger failed state) + gw.error_received(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT) + assert gw._api.enter_failed_state.call_count == 0 + + # Clean up + gw._reset_future.cancel() + + +def test_unexpected_reset_triggers_failed_state(gw): + # When no reset is expected, any reset should trigger failed state + assert gw._reset_future is None + assert gw._startup_reset_future is None + + gw.reset_received(t.NcpResetCode.RESET_SOFTWARE) assert gw._api.enter_failed_state.mock_calls == [ - call(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT) + call(t.NcpResetCode.RESET_SOFTWARE) ]