mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-23 12:15:13 +01:00
Ensure we hold strong references to tasks (#382)
This commit is contained in:
parent
f99db3577c
commit
0656b65ca1
@ -96,6 +96,7 @@ class APIConnection:
|
||||
) -> None:
|
||||
self._params = params
|
||||
self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop
|
||||
self._on_stop_task: Optional[asyncio.Task[None]] = None
|
||||
self._socket: Optional[socket.socket] = None
|
||||
self._frame_helper: Optional[APIFrameHelper] = None
|
||||
self._api_version: Optional[APIVersion] = None
|
||||
@ -117,6 +118,7 @@ class APIConnection:
|
||||
self._ping_stop_event = asyncio.Event()
|
||||
|
||||
self._connect_task: Optional[asyncio.Task[None]] = None
|
||||
self._keep_alive_task: Optional[asyncio.Task[None]] = None
|
||||
self._fatal_exception: Optional[Exception] = None
|
||||
self._expected_disconnect = False
|
||||
|
||||
@ -142,6 +144,10 @@ class APIConnection:
|
||||
self._connect_task.cancel()
|
||||
self._connect_task = None
|
||||
|
||||
if self._keep_alive_task is not None:
|
||||
self._keep_alive_task.cancel()
|
||||
self._keep_alive_task = None
|
||||
|
||||
if self._frame_helper is not None:
|
||||
self._frame_helper.close()
|
||||
self._frame_helper = None
|
||||
@ -151,8 +157,19 @@ class APIConnection:
|
||||
self._socket = None
|
||||
|
||||
if self.on_stop and self._connect_complete:
|
||||
|
||||
def _remove_on_stop_task(_fut: asyncio.Future[None]) -> None:
|
||||
"""Remove the stop task.
|
||||
|
||||
We need to do this because the asyncio does not hold
|
||||
a strong reference to the task, so it can be garbage
|
||||
collected unexpectedly.
|
||||
"""
|
||||
self._on_stop_task = None
|
||||
|
||||
# Ensure on_stop is called only once
|
||||
asyncio.create_task(self.on_stop())
|
||||
self._on_stop_task = asyncio.create_task(self.on_stop())
|
||||
self._on_stop_task.add_done_callback(_remove_on_stop_task)
|
||||
self.on_stop = None
|
||||
|
||||
# Note: we don't explicitly cancel the ping/read task here
|
||||
@ -318,7 +335,7 @@ class APIConnection:
|
||||
self._report_fatal_error(err)
|
||||
return
|
||||
|
||||
asyncio.create_task(_keep_alive_loop())
|
||||
self._keep_alive_task = asyncio.create_task(_keep_alive_loop())
|
||||
|
||||
async def connect(self, *, login: bool) -> None:
|
||||
if self._connection_state != ConnectionState.INITIALIZED:
|
||||
|
@ -59,6 +59,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
self._wait_task_lock = asyncio.Lock()
|
||||
# Event for tracking when logic should stop
|
||||
self._stop_event = asyncio.Event()
|
||||
self._stop_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
@property
|
||||
def _is_stopped(self) -> bool:
|
||||
@ -200,7 +201,17 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
await self._stop_zc_listen()
|
||||
|
||||
def stop_callback(self) -> None:
|
||||
asyncio.create_task(self.stop())
|
||||
def _remove_stop_task(_fut: asyncio.Future[None]) -> None:
|
||||
"""Remove the stop task from the reconnect loop.
|
||||
|
||||
We need to do this because the asyncio does not hold
|
||||
a strong reference to the task, so it can be garbage
|
||||
collected unexpectedly.
|
||||
"""
|
||||
self._stop_task = None
|
||||
|
||||
self._stop_task = asyncio.create_task(self.stop())
|
||||
self._stop_task.add_done_callback(_remove_stop_task)
|
||||
|
||||
async def _start_zc_listen(self) -> None:
|
||||
"""Listen for mDNS records.
|
||||
|
Loading…
Reference in New Issue
Block a user