From 6e08933a752a2a2b9a4585ca6aa997e390949e5b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 26 Nov 2023 10:32:16 -0600 Subject: [PATCH] Fix race scheduling reconnect from zeroconf records (#731) --- aioesphomeapi/reconnect_logic.py | 19 +++++++++++++++---- tests/test_reconnect_logic.py | 4 ++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index a27be75..f6cbd4e 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -93,7 +93,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): self._a_name: str | None = None # Flag to check if the device is connected self._connection_state = ReconnectLogicState.DISCONNECTED - self._accept_zeroconf_records = True + self._accept_zeroconf_records: bool = True self._connected_lock = asyncio.Lock() self._is_stopped = True self._zc_listening = False @@ -378,6 +378,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): ) self._zc_listening = False + def _connect_from_zeroconf(self) -> None: + """Connect from zeroconf.""" + self._stop_zc_listen() + self._schedule_connect(0.0) + def async_update_records( self, zc: zeroconf.Zeroconf, # pylint: disable=unused-argument @@ -411,7 +416,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # We can't stop the zeroconf listener here because we are in the middle of # a zeroconf callback which is iterating the listeners. # - # So we schedule a stop for the next event loop iteration. - self.loop.call_soon(self._stop_zc_listen) - self._schedule_connect(0.0) + # So we schedule a stop for the next event loop iteration as well as the + # connect attempt. + # + # If we scheduled the connect attempt immediately, the listener could fire + # again before the connect attempt and we cancel and reschedule the connect + # attempt again. + # + self.loop.call_soon(self._connect_from_zeroconf) + self._accept_zeroconf_records = False return diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index dba1eae..0b605a0 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -426,6 +426,7 @@ async def test_reconnect_zeroconf( assert rl._accept_zeroconf_records is True assert not rl._is_stopped + caplog.clear() with patch.object(cli, "start_connection") as mock_start_connection, patch.object( cli, "finish_connection" ): @@ -436,10 +437,13 @@ async def test_reconnect_zeroconf( assert ( "Triggering connect because of received mDNS record" in caplog.text ) is should_trigger_zeroconf + assert rl._accept_zeroconf_records is not should_trigger_zeroconf assert rl._zc_listening is True # should change after one iteration of the loop await asyncio.sleep(0) assert rl._zc_listening is not should_trigger_zeroconf + # The reconnect is scheduled to run in the next loop iteration + await asyncio.sleep(0) assert mock_start_connection.call_count == int(should_trigger_zeroconf) assert log_text in caplog.text