Refactor to avoid creating tasks for starting/finishing the connection
This commit is contained in:
parent
41eeabcc08
commit
4d93e694e8
|
@ -98,8 +98,8 @@ cdef class APIConnection:
|
|||
cdef object _pong_timer
|
||||
cdef float _keep_alive_interval
|
||||
cdef float _keep_alive_timeout
|
||||
cdef object _start_connect_task
|
||||
cdef object _finish_connect_task
|
||||
cdef object _start_connect_future
|
||||
cdef object _finish_connect_future
|
||||
cdef public Exception _fatal_exception
|
||||
cdef bint _expected_disconnect
|
||||
cdef object _loop
|
||||
|
|
|
@ -16,6 +16,7 @@ from functools import lru_cache, partial
|
|||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import aiohappyeyeballs
|
||||
from async_interrupt import interrupt
|
||||
from google.protobuf import message
|
||||
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
|
@ -106,6 +107,10 @@ _bytes = bytes
|
|||
_float = float
|
||||
|
||||
|
||||
class ConnectionInterruptedError(Exception):
|
||||
"""An error that is raised when a connection is interrupted."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionParams:
|
||||
addresses: list[str]
|
||||
|
@ -198,8 +203,8 @@ class APIConnection:
|
|||
"_pong_timer",
|
||||
"_keep_alive_interval",
|
||||
"_keep_alive_timeout",
|
||||
"_start_connect_task",
|
||||
"_finish_connect_task",
|
||||
"_start_connect_future",
|
||||
"_finish_connect_future",
|
||||
"_fatal_exception",
|
||||
"_expected_disconnect",
|
||||
"_loop",
|
||||
|
@ -242,8 +247,8 @@ class APIConnection:
|
|||
self._keep_alive_interval = keepalive
|
||||
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO
|
||||
|
||||
self._start_connect_task: asyncio.Task[None] | None = None
|
||||
self._finish_connect_task: asyncio.Task[None] | None = None
|
||||
self._start_connect_future: asyncio.Future[None] | None = None
|
||||
self._finish_connect_future: asyncio.Future[None] | None = None
|
||||
self._fatal_exception: Exception | None = None
|
||||
self._expected_disconnect = False
|
||||
self._send_pending_ping = False
|
||||
|
@ -280,24 +285,14 @@ class APIConnection:
|
|||
new_exc.__cause__ = err
|
||||
fut.set_exception(new_exc)
|
||||
self._read_exception_futures.clear()
|
||||
# If we are being called from do_connect we
|
||||
# need to make sure we don't cancel the task
|
||||
# that called us
|
||||
current_task = asyncio.current_task()
|
||||
|
||||
if (
|
||||
self._start_connect_task is not None
|
||||
and self._start_connect_task is not current_task
|
||||
):
|
||||
self._start_connect_task.cancel("Connection cleanup")
|
||||
self._start_connect_task = None
|
||||
if self._start_connect_future is not None:
|
||||
self._start_connect_future.set_result(None)
|
||||
self._start_connect_future = None
|
||||
|
||||
if (
|
||||
self._finish_connect_task is not None
|
||||
and self._finish_connect_task is not current_task
|
||||
):
|
||||
self._finish_connect_task.cancel("Connection cleanup")
|
||||
self._finish_connect_task = None
|
||||
if self._finish_connect_future is not None:
|
||||
self._finish_connect_future.set_result(None)
|
||||
self._finish_connect_future = None
|
||||
|
||||
if self._frame_helper is not None:
|
||||
self._frame_helper.close()
|
||||
|
@ -605,19 +600,19 @@ class APIConnection:
|
|||
"Connection can only be used once, connection is not in init state"
|
||||
)
|
||||
|
||||
start_connect_task = asyncio.create_task(
|
||||
self._do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
|
||||
)
|
||||
self._start_connect_task = start_connect_task
|
||||
self._start_connect_future = self._loop.create_future()
|
||||
try:
|
||||
await start_connect_task
|
||||
async with interrupt(
|
||||
self._start_connect_future, ConnectionInterruptedError, None
|
||||
):
|
||||
await self._do_connect()
|
||||
except (Exception, CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError as APIConnectionError
|
||||
self._cleanup()
|
||||
raise self._wrap_fatal_connection_exception("starting", ex)
|
||||
finally:
|
||||
self._start_connect_task = None
|
||||
self._start_connect_future = None
|
||||
self._set_connection_state(CONNECTION_STATE_SOCKET_OPENED)
|
||||
|
||||
def _wrap_fatal_connection_exception(
|
||||
|
@ -664,20 +659,19 @@ class APIConnection:
|
|||
raise RuntimeError(
|
||||
"Connection must be in SOCKET_OPENED state to finish connection"
|
||||
)
|
||||
finish_connect_task = asyncio.create_task(
|
||||
self._do_finish_connect(login),
|
||||
name=f"{self.log_name}: aioesphomeapi _do_finish_connect",
|
||||
)
|
||||
self._finish_connect_task = finish_connect_task
|
||||
self._finish_connect_future = self._loop.create_future()
|
||||
try:
|
||||
await self._finish_connect_task
|
||||
async with interrupt(
|
||||
self._finish_connect_future, ConnectionInterruptedError, None
|
||||
):
|
||||
await self._do_finish_connect(login)
|
||||
except (Exception, CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError as APIConnectionError
|
||||
self._cleanup()
|
||||
raise self._wrap_fatal_connection_exception("finishing", ex)
|
||||
finally:
|
||||
self._finish_connect_task = None
|
||||
self._finish_connect_future = None
|
||||
self._set_connection_state(CONNECTION_STATE_CONNECTED)
|
||||
|
||||
def _set_connection_state(self, state: ConnectionState) -> None:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
aiohappyeyeballs>=2.3.0
|
||||
async-interrupt>=1.1.1
|
||||
protobuf>=3.19.0
|
||||
zeroconf>=0.128.4,<1.0
|
||||
chacha20poly1305-reuseable>=0.12.1
|
||||
|
|
Loading…
Reference in New Issue