Refactor to avoid creating tasks for starting/finishing the connection

This commit is contained in:
J. Nick Koston 2024-02-16 16:31:24 -06:00
parent 41eeabcc08
commit 4d93e694e8
No known key found for this signature in database
3 changed files with 30 additions and 35 deletions

View File

@ -98,8 +98,8 @@ cdef class APIConnection:
cdef object _pong_timer
cdef float _keep_alive_interval
cdef float _keep_alive_timeout
cdef object _start_connect_task
cdef object _finish_connect_task
cdef object _start_connect_future
cdef object _finish_connect_future
cdef public Exception _fatal_exception
cdef bint _expected_disconnect
cdef object _loop

View File

@ -16,6 +16,7 @@ from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Callable
import aiohappyeyeballs
from async_interrupt import interrupt
from google.protobuf import message
import aioesphomeapi.host_resolver as hr
@ -106,6 +107,10 @@ _bytes = bytes
_float = float
class ConnectionInterruptedError(Exception):
"""An error that is raised when a connection is interrupted."""
@dataclass
class ConnectionParams:
addresses: list[str]
@ -198,8 +203,8 @@ class APIConnection:
"_pong_timer",
"_keep_alive_interval",
"_keep_alive_timeout",
"_start_connect_task",
"_finish_connect_task",
"_start_connect_future",
"_finish_connect_future",
"_fatal_exception",
"_expected_disconnect",
"_loop",
@ -242,8 +247,8 @@ class APIConnection:
self._keep_alive_interval = keepalive
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO
self._start_connect_task: asyncio.Task[None] | None = None
self._finish_connect_task: asyncio.Task[None] | None = None
self._start_connect_future: asyncio.Future[None] | None = None
self._finish_connect_future: asyncio.Future[None] | None = None
self._fatal_exception: Exception | None = None
self._expected_disconnect = False
self._send_pending_ping = False
@ -280,24 +285,14 @@ class APIConnection:
new_exc.__cause__ = err
fut.set_exception(new_exc)
self._read_exception_futures.clear()
# If we are being called from do_connect we
# need to make sure we don't cancel the task
# that called us
current_task = asyncio.current_task()
if (
self._start_connect_task is not None
and self._start_connect_task is not current_task
):
self._start_connect_task.cancel("Connection cleanup")
self._start_connect_task = None
if self._start_connect_future is not None:
self._start_connect_future.set_result(None)
self._start_connect_future = None
if (
self._finish_connect_task is not None
and self._finish_connect_task is not current_task
):
self._finish_connect_task.cancel("Connection cleanup")
self._finish_connect_task = None
if self._finish_connect_future is not None:
self._finish_connect_future.set_result(None)
self._finish_connect_future = None
if self._frame_helper is not None:
self._frame_helper.close()
@ -605,19 +600,19 @@ class APIConnection:
"Connection can only be used once, connection is not in init state"
)
start_connect_task = asyncio.create_task(
self._do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
)
self._start_connect_task = start_connect_task
self._start_connect_future = self._loop.create_future()
try:
await start_connect_task
async with interrupt(
self._start_connect_future, ConnectionInterruptedError, None
):
await self._do_connect()
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError
self._cleanup()
raise self._wrap_fatal_connection_exception("starting", ex)
finally:
self._start_connect_task = None
self._start_connect_future = None
self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED)
def _wrap_fatal_connection_exception(
@ -664,20 +659,19 @@ class APIConnection:
raise RuntimeError(
"Connection must be in SOCKET_OPENED state to finish connection"
)
finish_connect_task = asyncio.create_task(
self._do_finish_connect(login),
name=f"{self.log_name}: aioesphomeapi _do_finish_connect",
)
self._finish_connect_task = finish_connect_task
self._finish_connect_future = self._loop.create_future()
try:
await self._finish_connect_task
async with interrupt(
self._finish_connect_future, ConnectionInterruptedError, None
):
await self._do_finish_connect(login)
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError
self._cleanup()
raise self._wrap_fatal_connection_exception("finishing", ex)
finally:
self._finish_connect_task = None
self._finish_connect_future = None
self._set_connection_state(CONNECTION_STATE_CONNECTED)
def _set_connection_state(self, state: ConnectionState) -> None:

View File

@ -1,4 +1,5 @@
aiohappyeyeballs>=2.3.0
async-interrupt>=1.1.1
protobuf>=3.19.0
zeroconf>=0.128.4,<1.0
chacha20poly1305-reuseable>=0.12.1