Split connection process to enable faster reconnects (#576)
This commit is contained in:
parent
dc367b67bb
commit
b7449d4ded
|
@ -105,6 +105,7 @@ from .core import (
|
|||
APIConnectionError,
|
||||
BluetoothGATTAPIError,
|
||||
TimeoutAPIError,
|
||||
UnhandledAPIConnectionError,
|
||||
to_human_readable_address,
|
||||
)
|
||||
from .host_resolver import ZeroconfInstanceType
|
||||
|
@ -297,7 +298,7 @@ class APIClient:
|
|||
|
||||
@property
|
||||
def _log_name(self) -> str:
|
||||
if self._cached_name is not None:
|
||||
if self._cached_name is not None and not self.address.endswith(".local"):
|
||||
return f"{self._cached_name} @ {self.address}"
|
||||
return self.address
|
||||
|
||||
|
@ -311,6 +312,15 @@ class APIClient:
|
|||
on_stop: Callable[[bool], Awaitable[None]] | None = None,
|
||||
login: bool = False,
|
||||
) -> None:
|
||||
"""Connect to the device."""
|
||||
await self.start_connection(on_stop)
|
||||
await self.finish_connection(login)
|
||||
|
||||
async def start_connection(
|
||||
self,
|
||||
on_stop: Callable[[bool], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
"""Start connecting to the device."""
|
||||
if self._connection is not None:
|
||||
raise APIConnectionError(f"Already connected to {self._log_name}!")
|
||||
|
||||
|
@ -325,13 +335,30 @@ class APIClient:
|
|||
)
|
||||
|
||||
try:
|
||||
await self._connection.connect(login=login)
|
||||
await self._connection.start_connection()
|
||||
except APIConnectionError:
|
||||
self._connection = None
|
||||
raise
|
||||
except Exception as e:
|
||||
self._connection = None
|
||||
raise APIConnectionError(
|
||||
raise UnhandledAPIConnectionError(
|
||||
f"Unexpected error while connecting to {self._log_name}: {e}"
|
||||
) from e
|
||||
|
||||
async def finish_connection(
|
||||
self,
|
||||
login: bool = False,
|
||||
) -> None:
|
||||
"""Finish connecting to the device."""
|
||||
assert self._connection is not None
|
||||
try:
|
||||
await self._connection.finish_connection(login=login)
|
||||
except APIConnectionError:
|
||||
self._connection = None
|
||||
raise
|
||||
except Exception as e:
|
||||
self._connection = None
|
||||
raise UnhandledAPIConnectionError(
|
||||
f"Unexpected error while connecting to {self._log_name}: {e}"
|
||||
) from e
|
||||
|
||||
|
@ -352,18 +379,11 @@ class APIClient:
|
|||
f"Authenticated connection not ready yet for {self._log_name}; "
|
||||
f"current state is {connection.connection_state}!"
|
||||
)
|
||||
if not connection.is_authenticated:
|
||||
raise APIConnectionError(f"Not authenticated for {self._log_name}!")
|
||||
|
||||
async def device_info(self) -> DeviceInfo:
|
||||
self._check_authenticated()
|
||||
connection = self._connection
|
||||
if not connection:
|
||||
raise APIConnectionError(f"Not connected to {self._log_name}!")
|
||||
if not connection or not connection.is_connected:
|
||||
raise APIConnectionError(
|
||||
f"Connection not ready yet for {self._log_name}; "
|
||||
f"current state is {connection.connection_state}!"
|
||||
)
|
||||
assert connection is not None
|
||||
resp = await connection.send_message_await_response(
|
||||
DeviceInfoRequest(), DeviceInfoResponse
|
||||
)
|
||||
|
|
|
@ -27,8 +27,7 @@ cdef class APIConnection:
|
|||
cdef public object _socket
|
||||
cdef public object _frame_helper
|
||||
cdef public object api_version
|
||||
cdef public object _connection_state
|
||||
cdef object _connect_complete
|
||||
cdef public object connection_state
|
||||
cdef dict _message_handlers
|
||||
cdef public str log_name
|
||||
cdef set _read_exception_futures
|
||||
|
@ -36,14 +35,14 @@ cdef class APIConnection:
|
|||
cdef object _pong_timer
|
||||
cdef float _keep_alive_interval
|
||||
cdef float _keep_alive_timeout
|
||||
cdef object _connect_task
|
||||
cdef object _start_connect_task
|
||||
cdef object _finish_connect_task
|
||||
cdef object _fatal_exception
|
||||
cdef bint _expected_disconnect
|
||||
cdef object _loop
|
||||
cdef bint _send_pending_ping
|
||||
cdef public bint is_connected
|
||||
cdef public bint is_authenticated
|
||||
cdef bint _is_socket_open
|
||||
cdef bint _handshake_complete
|
||||
cdef object _debug_enabled
|
||||
|
||||
cpdef send_message(self, object msg)
|
||||
|
|
|
@ -118,16 +118,15 @@ class ConnectionParams:
|
|||
class ConnectionState(enum.Enum):
|
||||
# The connection is initialized, but connect() wasn't called yet
|
||||
INITIALIZED = 0
|
||||
# Internal state,
|
||||
# The socket has been opened, but the handshake and login haven't been completed
|
||||
SOCKET_OPENED = 1
|
||||
# The connection has been established, data can be exchanged
|
||||
# The handshake has been completed, messages can be exchanged
|
||||
HANDSHAKE_COMPLETE = 2
|
||||
# The connection has been established, authenticated data can be exchanged
|
||||
CONNECTED = 2
|
||||
CLOSED = 3
|
||||
|
||||
|
||||
OPEN_STATES = {ConnectionState.SOCKET_OPENED, ConnectionState.CONNECTED}
|
||||
|
||||
|
||||
class APIConnection:
|
||||
"""This class represents _one_ connection to a remote native API device.
|
||||
|
||||
|
@ -142,8 +141,7 @@ class APIConnection:
|
|||
"_socket",
|
||||
"_frame_helper",
|
||||
"api_version",
|
||||
"_connection_state",
|
||||
"_connect_complete",
|
||||
"connection_state",
|
||||
"_message_handlers",
|
||||
"log_name",
|
||||
"_read_exception_futures",
|
||||
|
@ -151,14 +149,14 @@ class APIConnection:
|
|||
"_pong_timer",
|
||||
"_keep_alive_interval",
|
||||
"_keep_alive_timeout",
|
||||
"_connect_task",
|
||||
"_start_connect_task",
|
||||
"_finish_connect_task",
|
||||
"_fatal_exception",
|
||||
"_expected_disconnect",
|
||||
"_loop",
|
||||
"_send_pending_ping",
|
||||
"is_connected",
|
||||
"is_authenticated",
|
||||
"_is_socket_open",
|
||||
"_handshake_complete",
|
||||
"_debug_enabled",
|
||||
)
|
||||
|
||||
|
@ -177,10 +175,7 @@ class APIConnection:
|
|||
) = None
|
||||
self.api_version: APIVersion | None = None
|
||||
|
||||
self._connection_state = ConnectionState.INITIALIZED
|
||||
# Store whether connect() has completed
|
||||
# Used so that on_stop is _not_ called if an error occurs during connect()
|
||||
self._connect_complete = False
|
||||
self.connection_state = ConnectionState.INITIALIZED
|
||||
|
||||
# Message handlers currently subscribed to incoming messages
|
||||
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
|
||||
|
@ -195,21 +190,16 @@ class APIConnection:
|
|||
self._keep_alive_interval = params.keepalive
|
||||
self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO
|
||||
|
||||
self._connect_task: asyncio.Task[None] | None = None
|
||||
self._start_connect_task: asyncio.Task[None] | None = None
|
||||
self._finish_connect_task: asyncio.Task[None] | None = None
|
||||
self._fatal_exception: Exception | None = None
|
||||
self._expected_disconnect = False
|
||||
self._send_pending_ping = False
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self.is_connected = False
|
||||
self.is_authenticated = False
|
||||
self._is_socket_open = False
|
||||
self._handshake_complete = False
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
|
||||
@property
|
||||
def connection_state(self) -> ConnectionState:
|
||||
"""Return the current connection state."""
|
||||
return self._connection_state
|
||||
|
||||
def set_log_name(self, name: str) -> None:
|
||||
"""Set the friendly log name for this connection."""
|
||||
self.log_name = name
|
||||
|
@ -219,6 +209,10 @@ class APIConnection:
|
|||
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self.connection_state == ConnectionState.CLOSED:
|
||||
return
|
||||
was_connected = self.is_connected
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
|
||||
for fut in self._read_exception_futures:
|
||||
if fut.done():
|
||||
|
@ -233,9 +227,13 @@ class APIConnection:
|
|||
# If we are being called from do_connect we
|
||||
# need to make sure we don't cancel the task
|
||||
# that called us
|
||||
if self._connect_task is not None and not in_do_connect.get(False):
|
||||
self._connect_task.cancel("Connection cleanup")
|
||||
self._connect_task = None
|
||||
if self._start_connect_task is not None and not in_do_connect.get(False):
|
||||
self._start_connect_task.cancel("Connection cleanup")
|
||||
self._start_connect_task = None
|
||||
|
||||
if self._finish_connect_task is not None and not in_do_connect.get(False):
|
||||
self._finish_connect_task.cancel("Connection cleanup")
|
||||
self._finish_connect_task = None
|
||||
|
||||
if self._frame_helper is not None:
|
||||
self._frame_helper.close()
|
||||
|
@ -251,7 +249,7 @@ class APIConnection:
|
|||
self._ping_timer.cancel()
|
||||
self._ping_timer = None
|
||||
|
||||
if self.on_stop and self._connect_complete:
|
||||
if self.on_stop and was_connected:
|
||||
# Ensure on_stop is called only once
|
||||
self._on_stop_task = asyncio.create_task(
|
||||
self.on_stop(self._expected_disconnect),
|
||||
|
@ -346,7 +344,8 @@ class APIConnection:
|
|||
sock=self._socket,
|
||||
)
|
||||
else:
|
||||
noise_psk = self._params.noise_psk
|
||||
# Ensure noise_psk is a string and not an EStr
|
||||
noise_psk = str(self._params.noise_psk)
|
||||
assert noise_psk is not None
|
||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
||||
lambda: APINoiseFrameHelper(
|
||||
|
@ -361,13 +360,13 @@ class APIConnection:
|
|||
)
|
||||
|
||||
self._frame_helper = fh
|
||||
self._set_connection_state(ConnectionState.SOCKET_OPENED)
|
||||
try:
|
||||
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise TimeoutAPIError("Handshake timed out") from err
|
||||
except OSError as err:
|
||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
|
||||
|
||||
async def _connect_hello(self) -> None:
|
||||
"""Step 4 in connect process: send hello and get api version."""
|
||||
|
@ -419,7 +418,7 @@ class APIConnection:
|
|||
|
||||
def _async_send_keep_alive(self) -> None:
|
||||
"""Send a keep alive message."""
|
||||
if not self._is_socket_open:
|
||||
if not self.is_connected:
|
||||
return
|
||||
|
||||
loop = self._loop
|
||||
|
@ -461,7 +460,7 @@ class APIConnection:
|
|||
|
||||
def _async_pong_not_received(self) -> None:
|
||||
"""Ping not received."""
|
||||
if not self._is_socket_open:
|
||||
if not self.is_connected:
|
||||
return
|
||||
_LOGGER.debug(
|
||||
"%s: Ping response not received after %s seconds",
|
||||
|
@ -474,63 +473,105 @@ class APIConnection:
|
|||
)
|
||||
)
|
||||
|
||||
async def _do_connect(self, login: bool) -> None:
|
||||
async def _do_connect(self) -> None:
|
||||
"""Do the actual connect process."""
|
||||
in_do_connect.set(True)
|
||||
addr = await self._connect_resolve_host()
|
||||
await self._connect_socket_connect(addr)
|
||||
await self._connect_init_frame_helper()
|
||||
await self._connect_hello()
|
||||
if login:
|
||||
await self.login(check_connected=False)
|
||||
self._async_schedule_keep_alive(self._loop.time())
|
||||
|
||||
async def connect(self, *, login: bool) -> None:
|
||||
if self._connection_state != ConnectionState.INITIALIZED:
|
||||
async def start_connection(self) -> None:
|
||||
"""Start the connection process.
|
||||
|
||||
This part of the process establishes the socket connection but
|
||||
does not initialize the frame helper or send the hello message.
|
||||
"""
|
||||
if self.connection_state != ConnectionState.INITIALIZED:
|
||||
raise ValueError(
|
||||
"Connection can only be used once, connection is not in init state"
|
||||
)
|
||||
self._connect_task = asyncio.create_task(
|
||||
self._do_connect(login), name=f"{self.log_name}: aioesphomeapi do_connect"
|
||||
|
||||
start_connect_task = asyncio.create_task(
|
||||
self._do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
|
||||
)
|
||||
self._start_connect_task = start_connect_task
|
||||
try:
|
||||
# Allow 2 minutes for connect and setup; this is only as a last measure
|
||||
# to protect from issues if some part of the connect process mistakenly
|
||||
# does not have a timeout
|
||||
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
|
||||
await self._connect_task
|
||||
except asyncio.CancelledError:
|
||||
await start_connect_task
|
||||
except (Exception, asyncio.CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._cleanup()
|
||||
raise self._fatal_exception or APIConnectionError("Connection cancelled")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Always clean up the connection if an error occurred during connect
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._cleanup()
|
||||
if isinstance(ex, asyncio.CancelledError):
|
||||
raise self._fatal_exception or APIConnectionError(
|
||||
"Connection cancelled"
|
||||
)
|
||||
if not start_connect_task.cancelled() and (
|
||||
task_exc := start_connect_task.exception()
|
||||
):
|
||||
raise task_exc
|
||||
raise
|
||||
finally:
|
||||
self._start_connect_task = None
|
||||
self._set_connection_state(ConnectionState.SOCKET_OPENED)
|
||||
|
||||
self._connect_task = None
|
||||
async def _do_finish_connect(self, login: bool) -> None:
|
||||
"""Finish the connection process."""
|
||||
in_do_connect.set(True)
|
||||
await self._connect_init_frame_helper()
|
||||
await self._connect_hello()
|
||||
if login:
|
||||
await self._login()
|
||||
self._async_schedule_keep_alive(self._loop.time())
|
||||
|
||||
async def finish_connection(self, *, login: bool) -> None:
|
||||
"""Finish the connection process.
|
||||
|
||||
This part of the process initializes the frame helper and sends the hello message
|
||||
than starts the keep alive process.
|
||||
"""
|
||||
if self.connection_state != ConnectionState.SOCKET_OPENED:
|
||||
raise ValueError(
|
||||
"Connection must be in SOCKET_OPENED state to finish connection"
|
||||
)
|
||||
finish_connect_task = asyncio.create_task(
|
||||
self._do_finish_connect(login),
|
||||
name=f"{self.log_name}: aioesphomeapi _do_finish_connect",
|
||||
)
|
||||
self._finish_connect_task = finish_connect_task
|
||||
try:
|
||||
# Allow 2 minutes for connect and setup; this is only as a last measure
|
||||
# to protect from issues if some part of the connect process mistakenly
|
||||
# does not have a timeout
|
||||
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
|
||||
await self._finish_connect_task
|
||||
except (Exception, asyncio.CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError
|
||||
self._cleanup()
|
||||
if isinstance(ex, asyncio.CancelledError):
|
||||
raise self._fatal_exception or APIConnectionError(
|
||||
"Connection cancelled"
|
||||
)
|
||||
if not finish_connect_task.cancelled() and (
|
||||
task_exc := finish_connect_task.exception()
|
||||
):
|
||||
raise task_exc
|
||||
raise
|
||||
finally:
|
||||
self._finish_connect_task = None
|
||||
self._set_connection_state(ConnectionState.CONNECTED)
|
||||
self._connect_complete = True
|
||||
|
||||
def _set_connection_state(self, state: ConnectionState) -> None:
|
||||
"""Set the connection state and log the change."""
|
||||
self._connection_state = state
|
||||
self.connection_state = state
|
||||
self.is_connected = state == ConnectionState.CONNECTED
|
||||
self._is_socket_open = state in OPEN_STATES
|
||||
self._handshake_complete = state == ConnectionState.HANDSHAKE_COMPLETE
|
||||
|
||||
async def login(self, check_connected: bool = True) -> None:
|
||||
async def _login(self) -> None:
|
||||
"""Send a login (ConnectRequest) and await the response."""
|
||||
if check_connected and self.is_connected:
|
||||
# On first connect, we don't want to check if we're connected
|
||||
# because we don't set the connection state until after login
|
||||
# is complete
|
||||
raise APIConnectionError("Must be connected!")
|
||||
if self.is_authenticated:
|
||||
raise APIConnectionError("Already logged in!")
|
||||
|
||||
connect = ConnectRequest()
|
||||
if self._params.password is not None:
|
||||
connect.password = self._params.password
|
||||
|
@ -549,18 +590,16 @@ class APIConnection:
|
|||
if resp.invalid_password:
|
||||
raise InvalidAuthAPIError("Invalid password!")
|
||||
|
||||
self.is_authenticated = True
|
||||
|
||||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a protobuf message to the remote."""
|
||||
if not self._is_socket_open:
|
||||
if not self._handshake_complete:
|
||||
if in_do_connect.get(False):
|
||||
# If we are in the do_connect task, we can't raise an error
|
||||
# because it would obscure the original exception (ie encrypt error).
|
||||
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
|
||||
return
|
||||
raise ConnectionNotEstablishedAPIError(
|
||||
f"Connection isn't established yet ({self._connection_state})"
|
||||
f"Connection isn't established yet ({self.connection_state})"
|
||||
)
|
||||
|
||||
msg_type = type(msg)
|
||||
|
@ -731,7 +770,6 @@ class APIConnection:
|
|||
exc_info=not str(err), # Log the full stack on empty error string
|
||||
)
|
||||
self._fatal_exception = err
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._cleanup()
|
||||
|
||||
def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
||||
|
@ -793,7 +831,6 @@ class APIConnection:
|
|||
|
||||
if msg_type is DisconnectRequest:
|
||||
self.send_message(DisconnectResponse())
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._expected_disconnect = True
|
||||
self._cleanup()
|
||||
elif msg_type is PingRequest:
|
||||
|
@ -805,11 +842,11 @@ class APIConnection:
|
|||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the API."""
|
||||
if self._connect_task:
|
||||
if self._finish_connect_task:
|
||||
# Try to wait for the handshake to finish so we can send
|
||||
# a disconnect request. If it doesn't finish in time
|
||||
# we will just close the socket.
|
||||
_, pending = await asyncio.wait([self._connect_task], timeout=5.0)
|
||||
_, pending = await asyncio.wait([self._finish_connect_task], timeout=5.0)
|
||||
if pending:
|
||||
_LOGGER.debug(
|
||||
"%s: Connect task didn't finish before disconnect",
|
||||
|
@ -817,7 +854,7 @@ class APIConnection:
|
|||
)
|
||||
|
||||
self._expected_disconnect = True
|
||||
if self._is_socket_open and self._frame_helper:
|
||||
if self._handshake_complete:
|
||||
# We still want to send a disconnect request even
|
||||
# if the hello phase isn't finished to ensure we
|
||||
# the esp will clean up the connection as soon
|
||||
|
@ -831,13 +868,12 @@ class APIConnection:
|
|||
"%s: Failed to send disconnect request: %s", self.log_name, err
|
||||
)
|
||||
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._cleanup()
|
||||
|
||||
async def force_disconnect(self) -> None:
|
||||
"""Forcefully disconnect from the API."""
|
||||
self._expected_disconnect = True
|
||||
if self._is_socket_open and self._frame_helper:
|
||||
if self._handshake_complete:
|
||||
# Still try to tell the esp to disconnect gracefully
|
||||
# but don't wait for it to finish
|
||||
try:
|
||||
|
@ -849,5 +885,4 @@ class APIConnection:
|
|||
err,
|
||||
)
|
||||
|
||||
self._set_connection_state(ConnectionState.CLOSED)
|
||||
self._cleanup()
|
||||
|
|
|
@ -217,6 +217,10 @@ class ReadFailedAPIError(APIConnectionError):
|
|||
pass
|
||||
|
||||
|
||||
class UnhandledAPIConnectionError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
def to_human_readable_address(address: int) -> str:
|
||||
"""Convert a MAC address to a human readable format."""
|
||||
return ":".join(TWO_CHAR.findall(f"{address:012X}"))
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import zeroconf
|
||||
|
@ -13,6 +14,7 @@ from .core import (
|
|||
InvalidAuthAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
UnhandledAPIConnectionError,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -22,6 +24,26 @@ MAXIMUM_BACKOFF_TRIES = 100
|
|||
TYPE_PTR = 12
|
||||
|
||||
|
||||
class ReconnectLogicState(Enum):
|
||||
CONNECTING = 0
|
||||
HANDSHAKING = 1
|
||||
READY = 2
|
||||
DISCONNECTED = 3
|
||||
|
||||
|
||||
NOT_YET_CONNECTED_STATES = {
|
||||
ReconnectLogicState.DISCONNECTED,
|
||||
ReconnectLogicState.CONNECTING,
|
||||
}
|
||||
|
||||
|
||||
AUTH_EXCEPTIONS = (
|
||||
RequiresEncryptionAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
InvalidAuthAPIError,
|
||||
)
|
||||
|
||||
|
||||
class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
"""Reconnectiong logic handler for ESPHome config entries.
|
||||
|
||||
|
@ -50,14 +72,23 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
"""
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self._cli = client
|
||||
self.name = name
|
||||
self.name: str | None
|
||||
if client.address.endswith(".local"):
|
||||
self.name = client.address[:-6]
|
||||
self._log_name = self.name
|
||||
elif name:
|
||||
self.name = name
|
||||
self._log_name = f"{self.name} @ {self._cli.address}"
|
||||
else:
|
||||
self.name = None
|
||||
self._log_name = client.address
|
||||
self._on_connect_cb = on_connect
|
||||
self._on_disconnect_cb = on_disconnect
|
||||
self._on_connect_error_cb = on_connect_error
|
||||
self._zc = zeroconf_instance
|
||||
self._filter_alias: str | None = None
|
||||
# Flag to check if the device is connected
|
||||
self._connected = False
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
self._connected_lock = asyncio.Lock()
|
||||
self._is_stopped = True
|
||||
self._zc_listening = False
|
||||
|
@ -68,12 +99,6 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
self._connect_timer: asyncio.TimerHandle | None = None
|
||||
self._stop_task: asyncio.Task[None] | None = None
|
||||
|
||||
@property
|
||||
def _log_name(self) -> str:
|
||||
if self.name is not None:
|
||||
return f"{self.name} @ {self._cli.address}"
|
||||
return self._cli.address
|
||||
|
||||
async def _on_disconnect(self, expected_disconnect: bool) -> None:
|
||||
"""Log and issue callbacks when disconnecting."""
|
||||
if self._is_stopped:
|
||||
|
@ -93,7 +118,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
await self._on_disconnect_cb(expected_disconnect)
|
||||
|
||||
async with self._connected_lock:
|
||||
self._connected = False
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
|
||||
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
|
||||
# If we expected the disconnect we need
|
||||
|
@ -102,41 +127,64 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
# before its about to reboot in the event we are too fast.
|
||||
self._schedule_connect(wait)
|
||||
|
||||
def _async_log_connection_error(self, err: Exception) -> None:
|
||||
"""Log connection errors."""
|
||||
# UnhandledAPIConnectionError is a special case in client
|
||||
# for when the connection raises an exception that is not
|
||||
# handled by the client. This is usually a bug in the connection
|
||||
# code and should be logged as an error.
|
||||
is_handled_exception = not isinstance(
|
||||
err, UnhandledAPIConnectionError
|
||||
) and isinstance(err, APIConnectionError)
|
||||
if not is_handled_exception:
|
||||
level = logging.ERROR
|
||||
elif self._tries == 0:
|
||||
level = logging.WARNING
|
||||
else:
|
||||
level = logging.DEBUG
|
||||
_LOGGER.log(
|
||||
level,
|
||||
"Can't connect to ESPHome API for %s: %s (%s)",
|
||||
self._log_name,
|
||||
err,
|
||||
type(err).__name__,
|
||||
# Print stacktrace if unhandled
|
||||
exc_info=not is_handled_exception,
|
||||
)
|
||||
|
||||
async def _try_connect(self) -> bool:
|
||||
"""Try connecting to the API client."""
|
||||
assert self._connected_lock.locked(), "connected_lock must be locked"
|
||||
self._connection_state = ReconnectLogicState.CONNECTING
|
||||
try:
|
||||
await self._cli.connect(on_stop=self._on_disconnect, login=True)
|
||||
await self._cli.start_connection(on_stop=self._on_disconnect)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
if self._on_connect_error_cb is not None:
|
||||
await self._on_connect_error_cb(err)
|
||||
level = logging.WARNING if self._tries == 0 else logging.DEBUG
|
||||
_LOGGER.log(
|
||||
level,
|
||||
"Can't connect to ESPHome API for %s: %s (%s)",
|
||||
self._log_name,
|
||||
err,
|
||||
type(err).__name__,
|
||||
# Print stacktrace if unhandled (not APIConnectionError)
|
||||
exc_info=not isinstance(err, APIConnectionError),
|
||||
)
|
||||
if isinstance(
|
||||
err,
|
||||
(
|
||||
RequiresEncryptionAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
InvalidAuthAPIError,
|
||||
),
|
||||
):
|
||||
self._async_log_connection_error(err)
|
||||
self._tries += 1
|
||||
return False
|
||||
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||
self._stop_zc_listen()
|
||||
self._connection_state = ReconnectLogicState.HANDSHAKING
|
||||
try:
|
||||
await self._cli.finish_connection(login=True)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
if self._on_connect_error_cb is not None:
|
||||
await self._on_connect_error_cb(err)
|
||||
self._async_log_connection_error(err)
|
||||
if isinstance(err, AUTH_EXCEPTIONS):
|
||||
# If we get an encryption or password error,
|
||||
# backoff for the maximum amount of time
|
||||
self._tries = MAXIMUM_BACKOFF_TRIES
|
||||
else:
|
||||
self._tries += 1
|
||||
return False
|
||||
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||
self._connected = True
|
||||
self._tries = 0
|
||||
_LOGGER.info("Successful handshake with %s", self._log_name)
|
||||
self._connection_state = ReconnectLogicState.READY
|
||||
await self._on_connect_cb()
|
||||
return True
|
||||
|
||||
|
@ -156,6 +204,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
|
||||
Must only be called from _schedule_connect.
|
||||
"""
|
||||
if self._connect_task:
|
||||
if self._connection_state != ReconnectLogicState.CONNECTING:
|
||||
# Connection state is far enough along that we should
|
||||
# not restart the connect task
|
||||
return
|
||||
_LOGGER.debug(
|
||||
"%s: Cancelling existing connect task, to try again now!",
|
||||
self._log_name,
|
||||
)
|
||||
self._connect_task.cancel("Scheduling new connect attempt")
|
||||
self._connect_task = None
|
||||
self._connection_state = ReconnectLogicState.DISCONNECTED
|
||||
|
||||
self._connect_task = asyncio.create_task(
|
||||
self._connect_once_or_reschedule(),
|
||||
name=f"{self._log_name}: aioesphomeapi connect",
|
||||
|
@ -178,8 +239,10 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
_LOGGER.debug("Trying to connect to %s", self._log_name)
|
||||
async with self._connected_lock:
|
||||
_LOGGER.debug("Connected lock acquired for %s", self._log_name)
|
||||
self._stop_zc_listen()
|
||||
if self._connected or self._is_stopped:
|
||||
if (
|
||||
self._connection_state != ReconnectLogicState.DISCONNECTED
|
||||
or self._is_stopped
|
||||
):
|
||||
return
|
||||
if await self._try_connect():
|
||||
return
|
||||
|
@ -195,22 +258,21 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
self._start_zc_listen()
|
||||
self._schedule_connect(wait_time)
|
||||
|
||||
def _remove_stop_task(self, _fut: asyncio.Future[None]) -> None:
|
||||
"""Remove the stop task from the connect loop.
|
||||
We need to do this because the asyncio does not hold
|
||||
a strong reference to the task, so it can be garbage
|
||||
collected unexpectedly.
|
||||
"""
|
||||
self._stop_task = None
|
||||
|
||||
def stop_callback(self) -> None:
|
||||
"""Stop the connect logic."""
|
||||
|
||||
def _remove_stop_task(_fut: asyncio.Future[None]) -> None:
|
||||
"""Remove the stop task from the connect loop.
|
||||
We need to do this because the asyncio does not hold
|
||||
a strong reference to the task, so it can be garbage
|
||||
collected unexpectedly.
|
||||
"""
|
||||
self._stop_task = None
|
||||
|
||||
self._stop_task = asyncio.create_task(
|
||||
self.stop(),
|
||||
name=f"{self._log_name}: aioesphomeapi reconnect_logic stop_callback",
|
||||
)
|
||||
self._stop_task.add_done_callback(_remove_stop_task)
|
||||
self._stop_task.add_done_callback(self._remove_stop_task)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the connecting logic background task."""
|
||||
|
@ -218,7 +280,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
self._cli.set_cached_name_if_unset(self.name)
|
||||
async with self._connected_lock:
|
||||
self._is_stopped = False
|
||||
if self._connected:
|
||||
if self._connection_state != ReconnectLogicState.DISCONNECTED:
|
||||
return
|
||||
self._tries = 0
|
||||
self._schedule_connect(0.0)
|
||||
|
@ -261,10 +323,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
|||
|
||||
This is a mDNS record from the device and could mean it just woke up.
|
||||
"""
|
||||
|
||||
# Check if already connected, no lock needed for this access and
|
||||
# bail if either the already stopped or we haven't received device info yet
|
||||
if self._connected or self._is_stopped or self._filter_alias is None:
|
||||
if (
|
||||
self._connection_state not in NOT_YET_CONNECTED_STATES
|
||||
or self._is_stopped
|
||||
or self._filter_alias is None
|
||||
):
|
||||
return
|
||||
|
||||
for record_update in records:
|
||||
|
|
|
@ -54,7 +54,6 @@ def auth_client():
|
|||
)
|
||||
with patch.object(client, "_connection") as conn:
|
||||
conn.is_connected = True
|
||||
conn.is_authenticated = True
|
||||
yield client
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,12 @@ from aioesphomeapi.core import RequiresEncryptionAPIError
|
|||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||
|
||||
|
||||
async def connect(conn: APIConnection, login: bool = True):
|
||||
"""Wrapper for connection logic to do both parts."""
|
||||
await conn.start_connection()
|
||||
await conn.finish_connection(login=login)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_params() -> ConnectionParams:
|
||||
return ConnectionParams(
|
||||
|
@ -81,7 +87,7 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
|||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(conn.connect(login=False))
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
protocol.data_received(
|
||||
bytes.fromhex(
|
||||
|
@ -114,7 +120,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
|
|||
conn._socket = MagicMock()
|
||||
await conn._connect_init_frame_helper()
|
||||
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
|
||||
conn._connection_state = ConnectionState.CONNECTED
|
||||
conn.connection_state = ConnectionState.CONNECTED
|
||||
|
||||
with pytest.raises(RequiresEncryptionAPIError):
|
||||
task = asyncio.create_task(conn._connect_hello())
|
||||
|
@ -149,7 +155,7 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
|
|||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(conn.connect(login=False))
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
protocol.data_received(
|
||||
|
|
Loading…
Reference in New Issue