From 000ff14ac0b32d20c9747a928ac86713e902641a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 21 Oct 2023 17:46:45 -1000 Subject: [PATCH] Make reconnect logic state machine switches check locks (#597) --- aioesphomeapi/reconnect_logic.py | 56 ++++++++--- tests/test_reconnect_logic.py | 161 ++++++++++++++++++++++++++++++- 2 files changed, 198 insertions(+), 19 deletions(-) diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index 55595f8..a34c251 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -90,6 +90,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): self._filter_alias: str | None = None # Flag to check if the device is connected self._connection_state = ReconnectLogicState.DISCONNECTED + self._accept_zeroconf_records = True self._connected_lock = asyncio.Lock() self._is_stopped = True self._zc_listening = False @@ -118,8 +119,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # Run disconnect hook await self._on_disconnect_cb(expected_disconnect) - async with self._connected_lock: - self._connection_state = ReconnectLogicState.DISCONNECTED + await self._async_set_connection_state(ReconnectLogicState.DISCONNECTED) wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0 # If we expected the disconnect we need @@ -128,6 +128,29 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # before its about to reboot in the event we are too fast. self._schedule_connect(wait) + async def _async_set_connection_state(self, state: ReconnectLogicState) -> None: + """Set the connection state.""" + async with self._connected_lock: + self._async_set_connection_state_while_locked(state) + + def _async_set_connection_state_while_locked( + self, state: ReconnectLogicState + ) -> None: + """Set the connection state while holding the lock.""" + assert self._connected_lock.locked(), "connected_lock must be locked" + self._async_set_connection_state_without_lock(state) + + def _async_set_connection_state_without_lock( + self, state: ReconnectLogicState + ) -> None: + """Set the connection state without holding the lock. + + This should only be used for setting the state to DISCONNECTED + when the state is CONNECTING. + """ + self._connection_state = state + self._accept_zeroconf_records = state not in NOT_YET_CONNECTED_STATES + def _async_log_connection_error(self, err: Exception) -> None: """Log connection errors.""" # UnhandledAPIConnectionError is a special case in client @@ -155,12 +178,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): async def _try_connect(self) -> bool: """Try connecting to the API client.""" - assert self._connected_lock.locked(), "connected_lock must be locked" - self._connection_state = ReconnectLogicState.CONNECTING + self._async_set_connection_state_while_locked(ReconnectLogicState.CONNECTING) try: await self._cli.start_connection(on_stop=self._on_disconnect) except Exception as err: # pylint: disable=broad-except - self._connection_state = ReconnectLogicState.DISCONNECTED + self._async_set_connection_state_while_locked( + ReconnectLogicState.DISCONNECTED + ) if self._on_connect_error_cb is not None: await self._on_connect_error_cb(err) self._async_log_connection_error(err) @@ -168,11 +192,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): return False _LOGGER.info("Successfully connected to %s", self._log_name) self._stop_zc_listen() - self._connection_state = ReconnectLogicState.HANDSHAKING + self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING) try: await self._cli.finish_connection(login=True) except Exception as err: # pylint: disable=broad-except - self._connection_state = ReconnectLogicState.DISCONNECTED + self._async_set_connection_state_while_locked( + ReconnectLogicState.DISCONNECTED + ) if self._on_connect_error_cb is not None: await self._on_connect_error_cb(err) self._async_log_connection_error(err) @@ -185,7 +211,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): return False self._tries = 0 _LOGGER.info("Successful handshake with %s", self._log_name) - self._connection_state = ReconnectLogicState.READY + self._async_set_connection_state_while_locked(ReconnectLogicState.READY) await self._on_connect_cb() return True @@ -216,7 +242,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): ) self._connect_task.cancel("Scheduling new connect attempt") self._connect_task = None - self._connection_state = ReconnectLogicState.DISCONNECTED + self._async_set_connection_state_without_lock( + ReconnectLogicState.DISCONNECTED + ) self._connect_task = asyncio.create_task( self._connect_once_or_reschedule(), @@ -292,6 +320,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # Cancel again while holding the lock self._cancel_connect("Stopping") self._stop_zc_listen() + self._async_set_connection_state_while_locked( + ReconnectLogicState.DISCONNECTED + ) def _start_zc_listen(self) -> None: """Listen for mDNS records. @@ -325,7 +356,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # Check if already connected, no lock needed for this access and # bail if either the already stopped or we haven't received device info yet if ( - self._connection_state not in NOT_YET_CONNECTED_STATES + not self._accept_zeroconf_records or self._is_stopped or self._filter_alias is None ): @@ -334,10 +365,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): for record_update in records: # We only consider PTR records and match using the alias name new_record = record_update.new - if ( - new_record.type != TYPE_PTR - or new_record.alias != self._filter_alias # type: ignore[attr-defined] - ): + if new_record.type != TYPE_PTR or new_record.alias != self._filter_alias: # type: ignore[attr-defined] continue # Tell connection logic to retry connection attempt now (even before connect timer finishes) diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index eb9162a..a3e11b1 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -1,10 +1,17 @@ -from unittest.mock import MagicMock +import asyncio +from unittest.mock import MagicMock, patch import pytest +from zeroconf import Zeroconf from zeroconf.asyncio import AsyncZeroconf +from aioesphomeapi import APIConnectionError from aioesphomeapi.client import APIClient -from aioesphomeapi.reconnect_logic import ReconnectLogic +from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState + + +def _get_mock_zeroconf() -> MagicMock: + return MagicMock(spec=Zeroconf) @pytest.mark.asyncio @@ -51,7 +58,7 @@ async def test_reconnect_logic_name_from_host_and_set(): client=cli, on_disconnect=on_disconnect, on_connect=on_connect, - zeroconf_instance=MagicMock(spec=AsyncZeroconf), + zeroconf_instance=_get_mock_zeroconf(), name="mydevice", ) assert rl._log_name == "mydevice" @@ -77,7 +84,7 @@ async def test_reconnect_logic_name_from_address(): client=cli, on_disconnect=on_disconnect, on_connect=on_connect, - zeroconf_instance=MagicMock(spec=AsyncZeroconf), + zeroconf_instance=_get_mock_zeroconf(), ) assert rl._log_name == "1.2.3.4" assert cli._log_name == "1.2.3.4" @@ -102,8 +109,152 @@ async def test_reconnect_logic_name_from_name(): client=cli, on_disconnect=on_disconnect, on_connect=on_connect, - zeroconf_instance=MagicMock(spec=AsyncZeroconf), + zeroconf_instance=_get_mock_zeroconf(), name="mydevice", ) assert rl._log_name == "mydevice @ 1.2.3.4" assert cli._log_name == "mydevice @ 1.2.3.4" + + +@pytest.mark.asyncio +async def test_reconnect_logic_state(): + """Test that reconnect logic state changes.""" + on_disconnect_called = [] + on_connect_called = [] + on_connect_fail_called = [] + + class PatchableAPIClient(APIClient): + pass + + cli = PatchableAPIClient( + address="1.2.3.4", + port=6052, + password=None, + ) + + async def on_disconnect(expected_disconnect: bool) -> None: + nonlocal on_disconnect_called + on_disconnect_called.append(expected_disconnect) + + async def on_connect() -> None: + nonlocal on_connect_called + on_connect_called.append(True) + + async def on_connect_fail(connect_exception: Exception) -> None: + nonlocal on_connect_called + on_connect_fail_called.append(connect_exception) + + rl = ReconnectLogic( + client=cli, + on_disconnect=on_disconnect, + on_connect=on_connect, + zeroconf_instance=_get_mock_zeroconf(), + name="mydevice", + on_connect_error=on_connect_fail, + ) + assert rl._log_name == "mydevice @ 1.2.3.4" + assert cli._log_name == "mydevice @ 1.2.3.4" + + with patch.object(cli, "start_connection", side_effect=APIConnectionError): + await rl.start() + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert len(on_disconnect_called) == 0 + assert len(on_connect_called) == 0 + assert len(on_connect_fail_called) == 1 + assert isinstance(on_connect_fail_called[-1], APIConnectionError) + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + + with patch.object(cli, "start_connection"), patch.object( + cli, "finish_connection", side_effect=APIConnectionError + ): + await rl.start() + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert len(on_disconnect_called) == 0 + assert len(on_connect_called) == 0 + assert len(on_connect_fail_called) == 2 + assert isinstance(on_connect_fail_called[-1], APIConnectionError) + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + + with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"): + await rl.start() + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert len(on_disconnect_called) == 0 + assert len(on_connect_called) == 1 + assert len(on_connect_fail_called) == 2 + assert rl._connection_state is ReconnectLogicState.READY + + await rl.stop() + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + + +@pytest.mark.asyncio +async def test_reconnect_retry(): + """Test that reconnect logic retry.""" + on_disconnect_called = [] + on_connect_called = [] + on_connect_fail_called = [] + + class PatchableAPIClient(APIClient): + pass + + cli = PatchableAPIClient( + address="1.2.3.4", + port=6052, + password=None, + ) + + async def on_disconnect(expected_disconnect: bool) -> None: + nonlocal on_disconnect_called + on_disconnect_called.append(expected_disconnect) + + async def on_connect() -> None: + nonlocal on_connect_called + on_connect_called.append(True) + + async def on_connect_fail(connect_exception: Exception) -> None: + nonlocal on_connect_called + on_connect_fail_called.append(connect_exception) + + rl = ReconnectLogic( + client=cli, + on_disconnect=on_disconnect, + on_connect=on_connect, + zeroconf_instance=_get_mock_zeroconf(), + name="mydevice", + on_connect_error=on_connect_fail, + ) + assert rl._log_name == "mydevice @ 1.2.3.4" + assert cli._log_name == "mydevice @ 1.2.3.4" + + with patch.object(cli, "start_connection", side_effect=APIConnectionError): + await rl.start() + await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert len(on_disconnect_called) == 0 + assert len(on_connect_called) == 0 + assert len(on_connect_fail_called) == 1 + assert isinstance(on_connect_fail_called[-1], APIConnectionError) + assert rl._connection_state is ReconnectLogicState.DISCONNECTED + + with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"): + # Should now retry + assert rl._connect_timer is not None + rl._connect_timer._run() + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert len(on_disconnect_called) == 0 + assert len(on_connect_called) == 1 + assert len(on_connect_fail_called) == 1 + assert rl._connection_state is ReconnectLogicState.READY + + await rl.stop() + assert rl._connection_state is ReconnectLogicState.DISCONNECTED