diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 6b08bc8..b40742a 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -96,6 +96,7 @@ class APIConnection: ) -> None: self._params = params self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop + self._on_stop_task: Optional[asyncio.Task[None]] = None self._socket: Optional[socket.socket] = None self._frame_helper: Optional[APIFrameHelper] = None self._api_version: Optional[APIVersion] = None @@ -117,6 +118,7 @@ class APIConnection: self._ping_stop_event = asyncio.Event() self._connect_task: Optional[asyncio.Task[None]] = None + self._keep_alive_task: Optional[asyncio.Task[None]] = None self._fatal_exception: Optional[Exception] = None self._expected_disconnect = False @@ -142,6 +144,10 @@ class APIConnection: self._connect_task.cancel() self._connect_task = None + if self._keep_alive_task is not None: + self._keep_alive_task.cancel() + self._keep_alive_task = None + if self._frame_helper is not None: self._frame_helper.close() self._frame_helper = None @@ -151,8 +157,19 @@ class APIConnection: self._socket = None if self.on_stop and self._connect_complete: + + def _remove_on_stop_task(_fut: asyncio.Future[None]) -> None: + """Remove the stop task. + + We need to do this because the asyncio does not hold + a strong reference to the task, so it can be garbage + collected unexpectedly. + """ + self._on_stop_task = None + # Ensure on_stop is called only once - asyncio.create_task(self.on_stop()) + self._on_stop_task = asyncio.create_task(self.on_stop()) + self._on_stop_task.add_done_callback(_remove_on_stop_task) self.on_stop = None # Note: we don't explicitly cancel the ping/read task here @@ -318,7 +335,7 @@ class APIConnection: self._report_fatal_error(err) return - asyncio.create_task(_keep_alive_loop()) + self._keep_alive_task = asyncio.create_task(_keep_alive_loop()) async def connect(self, *, login: bool) -> None: if self._connection_state != ConnectionState.INITIALIZED: diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index f6500da..717631c 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -59,6 +59,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): self._wait_task_lock = asyncio.Lock() # Event for tracking when logic should stop self._stop_event = asyncio.Event() + self._stop_task: Optional[asyncio.Task[None]] = None @property def _is_stopped(self) -> bool: @@ -200,7 +201,17 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): await self._stop_zc_listen() def stop_callback(self) -> None: - asyncio.create_task(self.stop()) + def _remove_stop_task(_fut: asyncio.Future[None]) -> None: + """Remove the stop task from the reconnect loop. + + We need to do this because the asyncio does not hold + a strong reference to the task, so it can be garbage + collected unexpectedly. + """ + self._stop_task = None + + self._stop_task = asyncio.create_task(self.stop()) + self._stop_task.add_done_callback(_remove_stop_task) async def _start_zc_listen(self) -> None: """Listen for mDNS records.