diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index fca5191..bb59e9d 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -106,13 +106,21 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): async def _on_disconnect(self, expected_disconnect: bool) -> None: """Log and issue callbacks when disconnecting.""" - if self._is_stopped: - return # This can happen often depending on WiFi signal strength. # So therefore all these connection warnings are logged # as infos. The "unavailable" logic will still trigger so the # user knows if the device is not connected. - disconnect_type = "expected" if expected_disconnect else "unexpected" + if expected_disconnect: + # If we expected the disconnect we need + # to cooldown before connecting in case the remote + # is rebooting so we don't establish a connection right + # before its about to reboot in the event we are too fast. + disconnect_type = "expected" + wait = EXPECTED_DISCONNECT_COOLDOWN + else: + disconnect_type = "unexpected" + wait = 0 + _LOGGER.info( "Processing %s disconnect from ESPHome API for %s", disconnect_type, @@ -120,21 +128,14 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): ) # Run disconnect hook - await self._on_disconnect_cb(expected_disconnect) - - await self._async_set_connection_state(ReconnectLogicState.DISCONNECTED) - - wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0 - # If we expected the disconnect we need - # to cooldown before connecting in case the remote - # is rebooting so we don't establish a connection right - # before its about to reboot in the event we are too fast. - self._schedule_connect(wait) - - async def _async_set_connection_state(self, state: ReconnectLogicState) -> None: - """Set the connection state.""" async with self._connected_lock: - self._async_set_connection_state_while_locked(state) + self._async_set_connection_state_while_locked( + ReconnectLogicState.DISCONNECTED + ) + await self._on_disconnect_cb(expected_disconnect) + + if not self._is_stopped: + self._schedule_connect(wait) def _async_set_connection_state_while_locked( self, state: ReconnectLogicState diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index b145c32..97610f5 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -18,10 +18,17 @@ from zeroconf.asyncio import AsyncZeroconf from zeroconf.const import _CLASS_IN, _TYPE_A, _TYPE_PTR from aioesphomeapi import APIConnectionError +from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper from aioesphomeapi.client import APIClient +from aioesphomeapi.connection import APIConnection from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState -from .common import get_mock_zeroconf +from .common import ( + get_mock_async_zeroconf, + get_mock_zeroconf, + send_plaintext_connect_response, + send_plaintext_hello, +) logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG) @@ -443,3 +450,77 @@ async def test_reconnect_logic_stop_callback_waits_for_handshake(): await asyncio.sleep(0) assert rl._is_stopped is True assert rl._connection_state is ReconnectLogicState.DISCONNECTED + + +@pytest.mark.asyncio +async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventLoop): + """Test the disconnect callback fires with expected_disconnect=False.""" + loop = asyncio.get_event_loop() + protocol: APIPlaintextFrameHelper | None = None + transport = MagicMock() + connected = asyncio.Event() + + class PatchableAPIClient(APIClient): + pass + + async_zeroconf = get_mock_async_zeroconf() + + cli = PatchableAPIClient( + address="1.2.3.4", + port=6052, + password=None, + noise_psk=None, + expected_name="fake", + zeroconf_instance=async_zeroconf.zeroconf, + ) + + def _create_mock_transport_protocol(create_func, **kwargs): + nonlocal protocol + protocol = create_func() + protocol.connection_made(transport) + connected.set() + return transport, protocol + + connected = asyncio.Event() + on_disconnect_calls = [] + + async def on_disconnect(expected_disconnect: bool) -> None: + on_disconnect_calls.append(expected_disconnect) + + async def on_connect() -> None: + connected.set() + + logic = ReconnectLogic( + client=cli, + on_connect=on_connect, + on_disconnect=on_disconnect, + zeroconf_instance=async_zeroconf, + name="fake", + ) + + with patch.object(event_loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ): + await logic.start() + await connected.wait() + protocol = cli._connection._frame_helper + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + await connected.wait() + + assert cli._connection.is_connected is True + await asyncio.sleep(0) + + with patch.object(event_loop, "sock_connect"), patch.object( + loop, "create_connection", side_effect=_create_mock_transport_protocol + ) as mock_create_connection: + protocol.eof_received() + # Wait for the task to run + await asyncio.sleep(0) + # Ensure we try to reconnect immediately + # since its an unexpected disconnect + assert mock_create_connection.call_count == 0 + + assert len(on_disconnect_calls) == 1 + assert on_disconnect_calls[0] is False + await logic.stop()