Ensure we hold strong references to tasks (#382)

This commit is contained in:
J. Nick Koston 2023-02-12 19:11:58 -06:00 committed by GitHub
parent f99db3577c
commit 0656b65ca1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 3 deletions

View File

@ -96,6 +96,7 @@ class APIConnection:
) -> None:
self._params = params
self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop
self._on_stop_task: Optional[asyncio.Task[None]] = None
self._socket: Optional[socket.socket] = None
self._frame_helper: Optional[APIFrameHelper] = None
self._api_version: Optional[APIVersion] = None
@ -117,6 +118,7 @@ class APIConnection:
self._ping_stop_event = asyncio.Event()
self._connect_task: Optional[asyncio.Task[None]] = None
self._keep_alive_task: Optional[asyncio.Task[None]] = None
self._fatal_exception: Optional[Exception] = None
self._expected_disconnect = False
@ -142,6 +144,10 @@ class APIConnection:
self._connect_task.cancel()
self._connect_task = None
if self._keep_alive_task is not None:
self._keep_alive_task.cancel()
self._keep_alive_task = None
if self._frame_helper is not None:
self._frame_helper.close()
self._frame_helper = None
@ -151,8 +157,19 @@ class APIConnection:
self._socket = None
if self.on_stop and self._connect_complete:
def _remove_on_stop_task(_fut: asyncio.Future[None]) -> None:
"""Remove the stop task.
We need to do this because the asyncio does not hold
a strong reference to the task, so it can be garbage
collected unexpectedly.
"""
self._on_stop_task = None
# Ensure on_stop is called only once
asyncio.create_task(self.on_stop())
self._on_stop_task = asyncio.create_task(self.on_stop())
self._on_stop_task.add_done_callback(_remove_on_stop_task)
self.on_stop = None
# Note: we don't explicitly cancel the ping/read task here
@ -318,7 +335,7 @@ class APIConnection:
self._report_fatal_error(err)
return
asyncio.create_task(_keep_alive_loop())
self._keep_alive_task = asyncio.create_task(_keep_alive_loop())
async def connect(self, *, login: bool) -> None:
if self._connection_state != ConnectionState.INITIALIZED:

View File

@ -59,6 +59,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._wait_task_lock = asyncio.Lock()
# Event for tracking when logic should stop
self._stop_event = asyncio.Event()
self._stop_task: Optional[asyncio.Task[None]] = None
@property
def _is_stopped(self) -> bool:
@ -200,7 +201,17 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
await self._stop_zc_listen()
def stop_callback(self) -> None:
asyncio.create_task(self.stop())
def _remove_stop_task(_fut: asyncio.Future[None]) -> None:
"""Remove the stop task from the reconnect loop.
We need to do this because the asyncio does not hold
a strong reference to the task, so it can be garbage
collected unexpectedly.
"""
self._stop_task = None
self._stop_task = asyncio.create_task(self.stop())
self._stop_task.add_done_callback(_remove_stop_task)
async def _start_zc_listen(self) -> None:
"""Listen for mDNS records.