Split connection process to enable faster reconnects (#576)

This commit is contained in:
J. Nick Koston 2023-10-14 16:03:12 -10:00 committed by GitHub
parent dc367b67bb
commit b7449d4ded
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 267 additions and 139 deletions

View File

@ -105,6 +105,7 @@ from .core import (
APIConnectionError,
BluetoothGATTAPIError,
TimeoutAPIError,
UnhandledAPIConnectionError,
to_human_readable_address,
)
from .host_resolver import ZeroconfInstanceType
@ -297,7 +298,7 @@ class APIClient:
@property
def _log_name(self) -> str:
if self._cached_name is not None:
if self._cached_name is not None and not self.address.endswith(".local"):
return f"{self._cached_name} @ {self.address}"
return self.address
@ -311,6 +312,15 @@ class APIClient:
on_stop: Callable[[bool], Awaitable[None]] | None = None,
login: bool = False,
) -> None:
"""Connect to the device."""
await self.start_connection(on_stop)
await self.finish_connection(login)
async def start_connection(
self,
on_stop: Callable[[bool], Awaitable[None]] | None = None,
) -> None:
"""Start connecting to the device."""
if self._connection is not None:
raise APIConnectionError(f"Already connected to {self._log_name}!")
@ -325,13 +335,30 @@ class APIClient:
)
try:
await self._connection.connect(login=login)
await self._connection.start_connection()
except APIConnectionError:
self._connection = None
raise
except Exception as e:
self._connection = None
raise APIConnectionError(
raise UnhandledAPIConnectionError(
f"Unexpected error while connecting to {self._log_name}: {e}"
) from e
async def finish_connection(
self,
login: bool = False,
) -> None:
"""Finish connecting to the device."""
assert self._connection is not None
try:
await self._connection.finish_connection(login=login)
except APIConnectionError:
self._connection = None
raise
except Exception as e:
self._connection = None
raise UnhandledAPIConnectionError(
f"Unexpected error while connecting to {self._log_name}: {e}"
) from e
@ -352,18 +379,11 @@ class APIClient:
f"Authenticated connection not ready yet for {self._log_name}; "
f"current state is {connection.connection_state}!"
)
if not connection.is_authenticated:
raise APIConnectionError(f"Not authenticated for {self._log_name}!")
async def device_info(self) -> DeviceInfo:
self._check_authenticated()
connection = self._connection
if not connection:
raise APIConnectionError(f"Not connected to {self._log_name}!")
if not connection or not connection.is_connected:
raise APIConnectionError(
f"Connection not ready yet for {self._log_name}; "
f"current state is {connection.connection_state}!"
)
assert connection is not None
resp = await connection.send_message_await_response(
DeviceInfoRequest(), DeviceInfoResponse
)

View File

@ -27,8 +27,7 @@ cdef class APIConnection:
cdef public object _socket
cdef public object _frame_helper
cdef public object api_version
cdef public object _connection_state
cdef object _connect_complete
cdef public object connection_state
cdef dict _message_handlers
cdef public str log_name
cdef set _read_exception_futures
@ -36,14 +35,14 @@ cdef class APIConnection:
cdef object _pong_timer
cdef float _keep_alive_interval
cdef float _keep_alive_timeout
cdef object _connect_task
cdef object _start_connect_task
cdef object _finish_connect_task
cdef object _fatal_exception
cdef bint _expected_disconnect
cdef object _loop
cdef bint _send_pending_ping
cdef public bint is_connected
cdef public bint is_authenticated
cdef bint _is_socket_open
cdef bint _handshake_complete
cdef object _debug_enabled
cpdef send_message(self, object msg)

View File

@ -118,16 +118,15 @@ class ConnectionParams:
class ConnectionState(enum.Enum):
# The connection is initialized, but connect() wasn't called yet
INITIALIZED = 0
# Internal state,
# The socket has been opened, but the handshake and login haven't been completed
SOCKET_OPENED = 1
# The connection has been established, data can be exchanged
# The handshake has been completed, messages can be exchanged
HANDSHAKE_COMPLETE = 2
# The connection has been established, authenticated data can be exchanged
CONNECTED = 2
CLOSED = 3
OPEN_STATES = {ConnectionState.SOCKET_OPENED, ConnectionState.CONNECTED}
class APIConnection:
"""This class represents _one_ connection to a remote native API device.
@ -142,8 +141,7 @@ class APIConnection:
"_socket",
"_frame_helper",
"api_version",
"_connection_state",
"_connect_complete",
"connection_state",
"_message_handlers",
"log_name",
"_read_exception_futures",
@ -151,14 +149,14 @@ class APIConnection:
"_pong_timer",
"_keep_alive_interval",
"_keep_alive_timeout",
"_connect_task",
"_start_connect_task",
"_finish_connect_task",
"_fatal_exception",
"_expected_disconnect",
"_loop",
"_send_pending_ping",
"is_connected",
"is_authenticated",
"_is_socket_open",
"_handshake_complete",
"_debug_enabled",
)
@ -177,10 +175,7 @@ class APIConnection:
) = None
self.api_version: APIVersion | None = None
self._connection_state = ConnectionState.INITIALIZED
# Store whether connect() has completed
# Used so that on_stop is _not_ called if an error occurs during connect()
self._connect_complete = False
self.connection_state = ConnectionState.INITIALIZED
# Message handlers currently subscribed to incoming messages
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
@ -195,21 +190,16 @@ class APIConnection:
self._keep_alive_interval = params.keepalive
self._keep_alive_timeout = params.keepalive * KEEP_ALIVE_TIMEOUT_RATIO
self._connect_task: asyncio.Task[None] | None = None
self._start_connect_task: asyncio.Task[None] | None = None
self._finish_connect_task: asyncio.Task[None] | None = None
self._fatal_exception: Exception | None = None
self._expected_disconnect = False
self._send_pending_ping = False
self._loop = asyncio.get_event_loop()
self.is_connected = False
self.is_authenticated = False
self._is_socket_open = False
self._handshake_complete = False
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
@property
def connection_state(self) -> ConnectionState:
"""Return the current connection state."""
return self._connection_state
def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
self.log_name = name
@ -219,6 +209,10 @@ class APIConnection:
Safe to call multiple times.
"""
if self.connection_state == ConnectionState.CLOSED:
return
was_connected = self.is_connected
self._set_connection_state(ConnectionState.CLOSED)
_LOGGER.debug("Cleaning up connection to %s", self.log_name)
for fut in self._read_exception_futures:
if fut.done():
@ -233,9 +227,13 @@ class APIConnection:
# If we are being called from do_connect we
# need to make sure we don't cancel the task
# that called us
if self._connect_task is not None and not in_do_connect.get(False):
self._connect_task.cancel("Connection cleanup")
self._connect_task = None
if self._start_connect_task is not None and not in_do_connect.get(False):
self._start_connect_task.cancel("Connection cleanup")
self._start_connect_task = None
if self._finish_connect_task is not None and not in_do_connect.get(False):
self._finish_connect_task.cancel("Connection cleanup")
self._finish_connect_task = None
if self._frame_helper is not None:
self._frame_helper.close()
@ -251,7 +249,7 @@ class APIConnection:
self._ping_timer.cancel()
self._ping_timer = None
if self.on_stop and self._connect_complete:
if self.on_stop and was_connected:
# Ensure on_stop is called only once
self._on_stop_task = asyncio.create_task(
self.on_stop(self._expected_disconnect),
@ -346,7 +344,8 @@ class APIConnection:
sock=self._socket,
)
else:
noise_psk = self._params.noise_psk
# Ensure noise_psk is a string and not an EStr
noise_psk = str(self._params.noise_psk)
assert noise_psk is not None
_, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APINoiseFrameHelper(
@ -361,13 +360,13 @@ class APIConnection:
)
self._frame_helper = fh
self._set_connection_state(ConnectionState.SOCKET_OPENED)
try:
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
except asyncio.TimeoutError as err:
raise TimeoutAPIError("Handshake timed out") from err
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
async def _connect_hello(self) -> None:
"""Step 4 in connect process: send hello and get api version."""
@ -419,7 +418,7 @@ class APIConnection:
def _async_send_keep_alive(self) -> None:
"""Send a keep alive message."""
if not self._is_socket_open:
if not self.is_connected:
return
loop = self._loop
@ -461,7 +460,7 @@ class APIConnection:
def _async_pong_not_received(self) -> None:
"""Ping not received."""
if not self._is_socket_open:
if not self.is_connected:
return
_LOGGER.debug(
"%s: Ping response not received after %s seconds",
@ -474,63 +473,105 @@ class APIConnection:
)
)
async def _do_connect(self, login: bool) -> None:
async def _do_connect(self) -> None:
"""Do the actual connect process."""
in_do_connect.set(True)
addr = await self._connect_resolve_host()
await self._connect_socket_connect(addr)
await self._connect_init_frame_helper()
await self._connect_hello()
if login:
await self.login(check_connected=False)
self._async_schedule_keep_alive(self._loop.time())
async def connect(self, *, login: bool) -> None:
if self._connection_state != ConnectionState.INITIALIZED:
async def start_connection(self) -> None:
"""Start the connection process.
This part of the process establishes the socket connection but
does not initialize the frame helper or send the hello message.
"""
if self.connection_state != ConnectionState.INITIALIZED:
raise ValueError(
"Connection can only be used once, connection is not in init state"
)
self._connect_task = asyncio.create_task(
self._do_connect(login), name=f"{self.log_name}: aioesphomeapi do_connect"
start_connect_task = asyncio.create_task(
self._do_connect(), name=f"{self.log_name}: aioesphomeapi do_connect"
)
self._start_connect_task = start_connect_task
try:
# Allow 2 minutes for connect and setup; this is only as a last measure
# to protect from issues if some part of the connect process mistakenly
# does not have a timeout
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
await self._connect_task
except asyncio.CancelledError:
await start_connect_task
except (Exception, asyncio.CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError
self._set_connection_state(ConnectionState.CLOSED)
self._cleanup()
raise self._fatal_exception or APIConnectionError("Connection cancelled")
except Exception: # pylint: disable=broad-except
# Always clean up the connection if an error occurred during connect
self._set_connection_state(ConnectionState.CLOSED)
self._cleanup()
if isinstance(ex, asyncio.CancelledError):
raise self._fatal_exception or APIConnectionError(
"Connection cancelled"
)
if not start_connect_task.cancelled() and (
task_exc := start_connect_task.exception()
):
raise task_exc
raise
finally:
self._start_connect_task = None
self._set_connection_state(ConnectionState.SOCKET_OPENED)
self._connect_task = None
async def _do_finish_connect(self, login: bool) -> None:
"""Finish the connection process."""
in_do_connect.set(True)
await self._connect_init_frame_helper()
await self._connect_hello()
if login:
await self._login()
self._async_schedule_keep_alive(self._loop.time())
async def finish_connection(self, *, login: bool) -> None:
"""Finish the connection process.
This part of the process initializes the frame helper and sends the hello message
than starts the keep alive process.
"""
if self.connection_state != ConnectionState.SOCKET_OPENED:
raise ValueError(
"Connection must be in SOCKET_OPENED state to finish connection"
)
finish_connect_task = asyncio.create_task(
self._do_finish_connect(login),
name=f"{self.log_name}: aioesphomeapi _do_finish_connect",
)
self._finish_connect_task = finish_connect_task
try:
# Allow 2 minutes for connect and setup; this is only as a last measure
# to protect from issues if some part of the connect process mistakenly
# does not have a timeout
async with asyncio_timeout(CONNECT_AND_SETUP_TIMEOUT):
await self._finish_connect_task
except (Exception, asyncio.CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError
self._cleanup()
if isinstance(ex, asyncio.CancelledError):
raise self._fatal_exception or APIConnectionError(
"Connection cancelled"
)
if not finish_connect_task.cancelled() and (
task_exc := finish_connect_task.exception()
):
raise task_exc
raise
finally:
self._finish_connect_task = None
self._set_connection_state(ConnectionState.CONNECTED)
self._connect_complete = True
def _set_connection_state(self, state: ConnectionState) -> None:
"""Set the connection state and log the change."""
self._connection_state = state
self.connection_state = state
self.is_connected = state == ConnectionState.CONNECTED
self._is_socket_open = state in OPEN_STATES
self._handshake_complete = state == ConnectionState.HANDSHAKE_COMPLETE
async def login(self, check_connected: bool = True) -> None:
async def _login(self) -> None:
"""Send a login (ConnectRequest) and await the response."""
if check_connected and self.is_connected:
# On first connect, we don't want to check if we're connected
# because we don't set the connection state until after login
# is complete
raise APIConnectionError("Must be connected!")
if self.is_authenticated:
raise APIConnectionError("Already logged in!")
connect = ConnectRequest()
if self._params.password is not None:
connect.password = self._params.password
@ -549,18 +590,16 @@ class APIConnection:
if resp.invalid_password:
raise InvalidAuthAPIError("Invalid password!")
self.is_authenticated = True
def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote."""
if not self._is_socket_open:
if not self._handshake_complete:
if in_do_connect.get(False):
# If we are in the do_connect task, we can't raise an error
# because it would obscure the original exception (ie encrypt error).
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
return
raise ConnectionNotEstablishedAPIError(
f"Connection isn't established yet ({self._connection_state})"
f"Connection isn't established yet ({self.connection_state})"
)
msg_type = type(msg)
@ -731,7 +770,6 @@ class APIConnection:
exc_info=not str(err), # Log the full stack on empty error string
)
self._fatal_exception = err
self._set_connection_state(ConnectionState.CLOSED)
self._cleanup()
def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
@ -793,7 +831,6 @@ class APIConnection:
if msg_type is DisconnectRequest:
self.send_message(DisconnectResponse())
self._set_connection_state(ConnectionState.CLOSED)
self._expected_disconnect = True
self._cleanup()
elif msg_type is PingRequest:
@ -805,11 +842,11 @@ class APIConnection:
async def disconnect(self) -> None:
"""Disconnect from the API."""
if self._connect_task:
if self._finish_connect_task:
# Try to wait for the handshake to finish so we can send
# a disconnect request. If it doesn't finish in time
# we will just close the socket.
_, pending = await asyncio.wait([self._connect_task], timeout=5.0)
_, pending = await asyncio.wait([self._finish_connect_task], timeout=5.0)
if pending:
_LOGGER.debug(
"%s: Connect task didn't finish before disconnect",
@ -817,7 +854,7 @@ class APIConnection:
)
self._expected_disconnect = True
if self._is_socket_open and self._frame_helper:
if self._handshake_complete:
# We still want to send a disconnect request even
# if the hello phase isn't finished to ensure we
# the esp will clean up the connection as soon
@ -831,13 +868,12 @@ class APIConnection:
"%s: Failed to send disconnect request: %s", self.log_name, err
)
self._set_connection_state(ConnectionState.CLOSED)
self._cleanup()
async def force_disconnect(self) -> None:
"""Forcefully disconnect from the API."""
self._expected_disconnect = True
if self._is_socket_open and self._frame_helper:
if self._handshake_complete:
# Still try to tell the esp to disconnect gracefully
# but don't wait for it to finish
try:
@ -849,5 +885,4 @@ class APIConnection:
err,
)
self._set_connection_state(ConnectionState.CLOSED)
self._cleanup()

View File

@ -217,6 +217,10 @@ class ReadFailedAPIError(APIConnectionError):
pass
class UnhandledAPIConnectionError(APIConnectionError):
pass
def to_human_readable_address(address: int) -> str:
"""Convert a MAC address to a human readable format."""
return ":".join(TWO_CHAR.findall(f"{address:012X}"))

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import logging
from collections.abc import Awaitable
from enum import Enum
from typing import Callable
import zeroconf
@ -13,6 +14,7 @@ from .core import (
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
RequiresEncryptionAPIError,
UnhandledAPIConnectionError,
)
_LOGGER = logging.getLogger(__name__)
@ -22,6 +24,26 @@ MAXIMUM_BACKOFF_TRIES = 100
TYPE_PTR = 12
class ReconnectLogicState(Enum):
CONNECTING = 0
HANDSHAKING = 1
READY = 2
DISCONNECTED = 3
NOT_YET_CONNECTED_STATES = {
ReconnectLogicState.DISCONNECTED,
ReconnectLogicState.CONNECTING,
}
AUTH_EXCEPTIONS = (
RequiresEncryptionAPIError,
InvalidEncryptionKeyAPIError,
InvalidAuthAPIError,
)
class ReconnectLogic(zeroconf.RecordUpdateListener):
"""Reconnectiong logic handler for ESPHome config entries.
@ -50,14 +72,23 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
"""
self.loop = asyncio.get_event_loop()
self._cli = client
self.name = name
self.name: str | None
if client.address.endswith(".local"):
self.name = client.address[:-6]
self._log_name = self.name
elif name:
self.name = name
self._log_name = f"{self.name} @ {self._cli.address}"
else:
self.name = None
self._log_name = client.address
self._on_connect_cb = on_connect
self._on_disconnect_cb = on_disconnect
self._on_connect_error_cb = on_connect_error
self._zc = zeroconf_instance
self._filter_alias: str | None = None
# Flag to check if the device is connected
self._connected = False
self._connection_state = ReconnectLogicState.DISCONNECTED
self._connected_lock = asyncio.Lock()
self._is_stopped = True
self._zc_listening = False
@ -68,12 +99,6 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._connect_timer: asyncio.TimerHandle | None = None
self._stop_task: asyncio.Task[None] | None = None
@property
def _log_name(self) -> str:
if self.name is not None:
return f"{self.name} @ {self._cli.address}"
return self._cli.address
async def _on_disconnect(self, expected_disconnect: bool) -> None:
"""Log and issue callbacks when disconnecting."""
if self._is_stopped:
@ -93,7 +118,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
await self._on_disconnect_cb(expected_disconnect)
async with self._connected_lock:
self._connected = False
self._connection_state = ReconnectLogicState.DISCONNECTED
wait = EXPECTED_DISCONNECT_COOLDOWN if expected_disconnect else 0
# If we expected the disconnect we need
@ -102,41 +127,64 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
# before its about to reboot in the event we are too fast.
self._schedule_connect(wait)
def _async_log_connection_error(self, err: Exception) -> None:
"""Log connection errors."""
# UnhandledAPIConnectionError is a special case in client
# for when the connection raises an exception that is not
# handled by the client. This is usually a bug in the connection
# code and should be logged as an error.
is_handled_exception = not isinstance(
err, UnhandledAPIConnectionError
) and isinstance(err, APIConnectionError)
if not is_handled_exception:
level = logging.ERROR
elif self._tries == 0:
level = logging.WARNING
else:
level = logging.DEBUG
_LOGGER.log(
level,
"Can't connect to ESPHome API for %s: %s (%s)",
self._log_name,
err,
type(err).__name__,
# Print stacktrace if unhandled
exc_info=not is_handled_exception,
)
async def _try_connect(self) -> bool:
"""Try connecting to the API client."""
assert self._connected_lock.locked(), "connected_lock must be locked"
self._connection_state = ReconnectLogicState.CONNECTING
try:
await self._cli.connect(on_stop=self._on_disconnect, login=True)
await self._cli.start_connection(on_stop=self._on_disconnect)
except Exception as err: # pylint: disable=broad-except
self._connection_state = ReconnectLogicState.DISCONNECTED
if self._on_connect_error_cb is not None:
await self._on_connect_error_cb(err)
level = logging.WARNING if self._tries == 0 else logging.DEBUG
_LOGGER.log(
level,
"Can't connect to ESPHome API for %s: %s (%s)",
self._log_name,
err,
type(err).__name__,
# Print stacktrace if unhandled (not APIConnectionError)
exc_info=not isinstance(err, APIConnectionError),
)
if isinstance(
err,
(
RequiresEncryptionAPIError,
InvalidEncryptionKeyAPIError,
InvalidAuthAPIError,
),
):
self._async_log_connection_error(err)
self._tries += 1
return False
_LOGGER.info("Successfully connected to %s", self._log_name)
self._stop_zc_listen()
self._connection_state = ReconnectLogicState.HANDSHAKING
try:
await self._cli.finish_connection(login=True)
except Exception as err: # pylint: disable=broad-except
self._connection_state = ReconnectLogicState.DISCONNECTED
if self._on_connect_error_cb is not None:
await self._on_connect_error_cb(err)
self._async_log_connection_error(err)
if isinstance(err, AUTH_EXCEPTIONS):
# If we get an encryption or password error,
# backoff for the maximum amount of time
self._tries = MAXIMUM_BACKOFF_TRIES
else:
self._tries += 1
return False
_LOGGER.info("Successfully connected to %s", self._log_name)
self._connected = True
self._tries = 0
_LOGGER.info("Successful handshake with %s", self._log_name)
self._connection_state = ReconnectLogicState.READY
await self._on_connect_cb()
return True
@ -156,6 +204,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
Must only be called from _schedule_connect.
"""
if self._connect_task:
if self._connection_state != ReconnectLogicState.CONNECTING:
# Connection state is far enough along that we should
# not restart the connect task
return
_LOGGER.debug(
"%s: Cancelling existing connect task, to try again now!",
self._log_name,
)
self._connect_task.cancel("Scheduling new connect attempt")
self._connect_task = None
self._connection_state = ReconnectLogicState.DISCONNECTED
self._connect_task = asyncio.create_task(
self._connect_once_or_reschedule(),
name=f"{self._log_name}: aioesphomeapi connect",
@ -178,8 +239,10 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
_LOGGER.debug("Trying to connect to %s", self._log_name)
async with self._connected_lock:
_LOGGER.debug("Connected lock acquired for %s", self._log_name)
self._stop_zc_listen()
if self._connected or self._is_stopped:
if (
self._connection_state != ReconnectLogicState.DISCONNECTED
or self._is_stopped
):
return
if await self._try_connect():
return
@ -195,22 +258,21 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._start_zc_listen()
self._schedule_connect(wait_time)
def _remove_stop_task(self, _fut: asyncio.Future[None]) -> None:
"""Remove the stop task from the connect loop.
We need to do this because the asyncio does not hold
a strong reference to the task, so it can be garbage
collected unexpectedly.
"""
self._stop_task = None
def stop_callback(self) -> None:
"""Stop the connect logic."""
def _remove_stop_task(_fut: asyncio.Future[None]) -> None:
"""Remove the stop task from the connect loop.
We need to do this because the asyncio does not hold
a strong reference to the task, so it can be garbage
collected unexpectedly.
"""
self._stop_task = None
self._stop_task = asyncio.create_task(
self.stop(),
name=f"{self._log_name}: aioesphomeapi reconnect_logic stop_callback",
)
self._stop_task.add_done_callback(_remove_stop_task)
self._stop_task.add_done_callback(self._remove_stop_task)
async def start(self) -> None:
"""Start the connecting logic background task."""
@ -218,7 +280,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._cli.set_cached_name_if_unset(self.name)
async with self._connected_lock:
self._is_stopped = False
if self._connected:
if self._connection_state != ReconnectLogicState.DISCONNECTED:
return
self._tries = 0
self._schedule_connect(0.0)
@ -261,10 +323,13 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
This is a mDNS record from the device and could mean it just woke up.
"""
# Check if already connected, no lock needed for this access and
# bail if either the already stopped or we haven't received device info yet
if self._connected or self._is_stopped or self._filter_alias is None:
if (
self._connection_state not in NOT_YET_CONNECTED_STATES
or self._is_stopped
or self._filter_alias is None
):
return
for record_update in records:

View File

@ -54,7 +54,6 @@ def auth_client():
)
with patch.object(client, "_connection") as conn:
conn.is_connected = True
conn.is_authenticated = True
yield client

View File

@ -12,6 +12,12 @@ from aioesphomeapi.core import RequiresEncryptionAPIError
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
async def connect(conn: APIConnection, login: bool = True):
"""Wrapper for connection logic to do both parts."""
await conn.start_connection()
await conn.finish_connection(login=login)
@pytest.fixture
def connection_params() -> ConnectionParams:
return ConnectionParams(
@ -81,7 +87,7 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop):
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(conn.connect(login=False))
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
protocol.data_received(
bytes.fromhex(
@ -114,7 +120,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
conn._socket = MagicMock()
await conn._connect_init_frame_helper()
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
conn._connection_state = ConnectionState.CONNECTED
conn.connection_state = ConnectionState.CONNECTED
with pytest.raises(RequiresEncryptionAPIError):
task = asyncio.create_task(conn._connect_hello())
@ -149,7 +155,7 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(conn.connect(login=False))
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
protocol.data_received(