mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-13 20:11:42 +01:00
Refactor connection class to resolve connect issues (#108)
This commit is contained in:
parent
e03015f8b4
commit
5b99d5c1dd
@ -197,7 +197,10 @@ class APIClient:
|
||||
async def disconnect(self, force: bool = False) -> None:
|
||||
if self._connection is None:
|
||||
return
|
||||
await self._connection.stop(force=force)
|
||||
if force:
|
||||
await self._connection.force_disconnect()
|
||||
else:
|
||||
await self._connection.disconnect()
|
||||
|
||||
def _check_connected(self) -> None:
|
||||
if self._connection is None:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
@ -30,10 +31,14 @@ from .core import (
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
PingFailedAPIError,
|
||||
ProtocolAPIError,
|
||||
ReadFailedAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
ResolveAPIError,
|
||||
SocketAPIError,
|
||||
SocketClosedAPIError,
|
||||
TimeoutAPIError,
|
||||
)
|
||||
from .model import APIVersion
|
||||
from .util import bytes_to_varuint, varuint_to_bytes
|
||||
@ -73,10 +78,11 @@ class APIFrameHelper:
|
||||
self._read_lock = asyncio.Lock()
|
||||
self._ready_event = asyncio.Event()
|
||||
self._proto: Optional[NoiseConnection] = None
|
||||
self._closed_event = asyncio.Event()
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._write_lock:
|
||||
self._writer.close()
|
||||
self._closed_event.set()
|
||||
self._writer.close()
|
||||
|
||||
async def _write_frame_noise(self, frame: bytes) -> None:
|
||||
try:
|
||||
@ -103,6 +109,13 @@ class APIFrameHelper:
|
||||
msg_size = (header[1] << 8) | header[2]
|
||||
frame = await self._reader.readexactly(msg_size)
|
||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||
if (
|
||||
isinstance(err, asyncio.IncompleteReadError)
|
||||
and self._closed_event.is_set()
|
||||
):
|
||||
raise SocketClosedAPIError(
|
||||
f"Socket closed while reading data: {err}"
|
||||
) from err
|
||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||
|
||||
_LOGGER.debug("Received frame %s", frame.hex())
|
||||
@ -241,6 +254,13 @@ class APIFrameHelper:
|
||||
raw_msg = await self._reader.readexactly(length_int)
|
||||
return Packet(type=msg_type_int, data=raw_msg)
|
||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||
if (
|
||||
isinstance(err, asyncio.IncompleteReadError)
|
||||
and self._closed_event.is_set()
|
||||
):
|
||||
raise SocketClosedAPIError(
|
||||
f"Socket closed while reading data: {err}"
|
||||
) from err
|
||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||
|
||||
async def read_packet(self) -> Packet:
|
||||
@ -249,79 +269,74 @@ class APIFrameHelper:
|
||||
return await self._read_packet_noise()
|
||||
|
||||
|
||||
class ConnectionState(enum.Enum):
|
||||
# The connection is initialized, but connect() wasn't called yet
|
||||
INITIALIZED = 0
|
||||
# Internal state,
|
||||
SOCKET_OPENED = 1
|
||||
# The connection has been established, data can be exchanged
|
||||
CONNECTED = 1
|
||||
CLOSED = 2
|
||||
|
||||
|
||||
class APIConnection:
|
||||
"""This class represents _one_ connection to a remote native API device.
|
||||
|
||||
An instance of this class may only be used once, for every new connection
|
||||
a new instance should be established.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, params: ConnectionParams, on_stop: Callable[[], Awaitable[None]]
|
||||
):
|
||||
self._params = params
|
||||
self.on_stop = on_stop
|
||||
self._stopped = False
|
||||
self._on_stop_called = False
|
||||
self._socket: Optional[socket.socket] = None
|
||||
self._frame_helper: Optional[APIFrameHelper] = None
|
||||
self._connected = False
|
||||
self._authenticated = False
|
||||
self._socket_connected = False
|
||||
self._state_lock = asyncio.Lock()
|
||||
self._api_version: Optional[APIVersion] = None
|
||||
|
||||
self._connection_state = ConnectionState.INITIALIZED
|
||||
self._is_authenticated = False
|
||||
# Store whether connect() has completed
|
||||
# Used so that on_stop is _not_ called if an error occurs during connect()
|
||||
self._connect_complete = False
|
||||
|
||||
# Message handlers currently subscribed to incoming messages
|
||||
self._message_handlers: List[Callable[[message.Message], None]] = []
|
||||
# The friendly name to show for this connection in the logs
|
||||
self.log_name = params.address
|
||||
self._ping_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
# Handlers currently subscribed to exceptions in the read task
|
||||
self._read_exception_handlers: List[Callable[[Exception], None]] = []
|
||||
|
||||
def _start_ping(self) -> None:
|
||||
async def func() -> None:
|
||||
while True:
|
||||
await asyncio.sleep(self._params.keepalive)
|
||||
self._ping_stop_event = asyncio.Event()
|
||||
|
||||
try:
|
||||
await self.ping()
|
||||
except APIConnectionError:
|
||||
_LOGGER.info("%s: Ping Failed!", self.log_name)
|
||||
await self._on_error()
|
||||
return
|
||||
async def _cleanup(self) -> None:
|
||||
"""Clean up all resources that have been allocated.
|
||||
|
||||
self._ping_task = asyncio.create_task(func())
|
||||
|
||||
async def _close_socket(self) -> None:
|
||||
if not self._socket_connected:
|
||||
return
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self._frame_helper is not None:
|
||||
await self._frame_helper.close()
|
||||
self._frame_helper = None
|
||||
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
if self._ping_task is not None:
|
||||
self._ping_task.cancel()
|
||||
self._ping_task = None
|
||||
self._socket_connected = False
|
||||
self._connected = False
|
||||
self._authenticated = False
|
||||
_LOGGER.debug("%s: Closed socket", self.log_name)
|
||||
|
||||
async def stop(self, force: bool = False) -> None:
|
||||
if self._stopped:
|
||||
return
|
||||
if self._connected and not force:
|
||||
try:
|
||||
await self._disconnect()
|
||||
except APIConnectionError:
|
||||
pass
|
||||
self._stopped = True
|
||||
await self._close_socket()
|
||||
await self.on_stop()
|
||||
if not self._on_stop_called and self._connect_complete:
|
||||
# Ensure on_stop is called
|
||||
asyncio.create_task(self.on_stop())
|
||||
self._on_stop_called = True
|
||||
|
||||
async def _on_error(self) -> None:
|
||||
await self.stop(force=True)
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
async def connect(self) -> None:
|
||||
if self._stopped:
|
||||
raise APIConnectionError(f"Connection is closed for {self.log_name}!")
|
||||
if self._connected:
|
||||
raise APIConnectionError(f"Already connected for {self.log_name}!")
|
||||
# Note: we don't explicitly cancel the ping/read task here
|
||||
# That's because if not written right the ping/read task could cancel
|
||||
# themself, effectively ending execution after _cleanup which may be unexpected
|
||||
self._ping_stop_event.set()
|
||||
|
||||
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
||||
"""Step 1 in connect process: resolve the address."""
|
||||
try:
|
||||
coro = hr.async_resolve_host(
|
||||
self._params.eventloop,
|
||||
@ -329,16 +344,14 @@ class APIConnection:
|
||||
self._params.port,
|
||||
self._params.zeroconf_instance,
|
||||
)
|
||||
addr = await asyncio.wait_for(coro, 30.0)
|
||||
except APIConnectionError as err:
|
||||
await self._on_error()
|
||||
raise err
|
||||
except asyncio.TimeoutError:
|
||||
await self._on_error()
|
||||
return await asyncio.wait_for(coro, 30.0)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise ResolveAPIError(
|
||||
f"Timeout while resolving IP address for {self.log_name}"
|
||||
)
|
||||
) from err
|
||||
|
||||
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
|
||||
"""Step 2 in connect process: connect the socket."""
|
||||
self._socket = socket.socket(
|
||||
family=addr.family, type=addr.type, proto=addr.proto
|
||||
)
|
||||
@ -353,36 +366,38 @@ class APIConnection:
|
||||
addr,
|
||||
)
|
||||
sockaddr = astuple(addr.sockaddr)
|
||||
|
||||
try:
|
||||
coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
|
||||
await asyncio.wait_for(coro2, 30.0)
|
||||
coro = self._params.eventloop.sock_connect(self._socket, sockaddr)
|
||||
await asyncio.wait_for(coro, 30.0)
|
||||
except OSError as err:
|
||||
await self._on_error()
|
||||
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}")
|
||||
except asyncio.TimeoutError:
|
||||
await self._on_error()
|
||||
raise SocketAPIError(f"Timeout while connecting to {sockaddr}")
|
||||
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
|
||||
except asyncio.TimeoutError as err:
|
||||
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
|
||||
|
||||
_LOGGER.debug("%s: Opened socket for", self._params.address)
|
||||
|
||||
async def _connect_init_frame_helper(self) -> None:
|
||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||
reader, writer = await asyncio.open_connection(sock=self._socket)
|
||||
|
||||
self._frame_helper = APIFrameHelper(reader, writer, self._params)
|
||||
self._socket_connected = True
|
||||
await self._frame_helper.perform_handshake()
|
||||
|
||||
try:
|
||||
await self._frame_helper.perform_handshake()
|
||||
except APIConnectionError:
|
||||
await self._on_error()
|
||||
raise
|
||||
self._connection_state = ConnectionState.SOCKET_OPENED
|
||||
|
||||
self._params.eventloop.create_task(self.run_forever())
|
||||
# Create read loop
|
||||
asyncio.create_task(self._read_loop())
|
||||
|
||||
async def _connect_hello(self) -> None:
|
||||
"""Step 4 in connect process: send hello and get api version."""
|
||||
hello = HelloRequest()
|
||||
hello.client_info = self._params.client_info
|
||||
try:
|
||||
resp = await self.send_message_await_response(hello, HelloResponse)
|
||||
except APIConnectionError:
|
||||
await self._on_error()
|
||||
raise
|
||||
except TimeoutAPIError as err:
|
||||
raise TimeoutAPIError("Hello timed out") from err
|
||||
|
||||
_LOGGER.debug(
|
||||
"%s: Successfully connected ('%s' API=%s.%s)",
|
||||
self.log_name,
|
||||
@ -397,62 +412,145 @@ class APIConnection:
|
||||
self.log_name,
|
||||
self._api_version.major,
|
||||
)
|
||||
await self._on_error()
|
||||
raise APIConnectionError("Incompatible API version.")
|
||||
self._connected = True
|
||||
|
||||
self._start_ping()
|
||||
self._connection_state = ConnectionState.CONNECTED
|
||||
|
||||
async def _connect_start_ping(self) -> None:
|
||||
"""Step 5 in connect process: start the ping loop."""
|
||||
|
||||
async def func() -> None:
|
||||
while True:
|
||||
if not self._is_socket_open:
|
||||
return
|
||||
|
||||
# Wait for keepalive seconds, or ping stop event, whichever happens first
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._ping_stop_event.wait(), self._params.keepalive
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
# Re-check connection state
|
||||
if not self._is_socket_open:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._ping()
|
||||
except TimeoutAPIError:
|
||||
_LOGGER.info("%s: Ping timed out!", self.log_name)
|
||||
await self._report_fatal_error(PingFailedAPIError())
|
||||
return
|
||||
except APIConnectionError as err:
|
||||
_LOGGER.info("%s: Ping Failed: %s", self.log_name, err)
|
||||
await self._report_fatal_error(err)
|
||||
return
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.info(
|
||||
"%s: Unexpected error during ping:",
|
||||
self.log_name,
|
||||
exc_info=True,
|
||||
)
|
||||
await self._report_fatal_error(err)
|
||||
return
|
||||
|
||||
asyncio.create_task(func())
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._connection_state != ConnectionState.INITIALIZED:
|
||||
raise ValueError(
|
||||
"Connection can only be used once, connection is not in init state"
|
||||
)
|
||||
|
||||
try:
|
||||
addr = await self._connect_resolve_host()
|
||||
await self._connect_socket_connect(addr)
|
||||
await self._connect_init_frame_helper()
|
||||
await self._connect_hello()
|
||||
await self._connect_start_ping()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Always clean up the connection if an error occured during connect
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
await self._cleanup()
|
||||
raise
|
||||
|
||||
self._connect_complete = True
|
||||
|
||||
async def login(self) -> None:
|
||||
"""Send a login (ConnectRequest) and await the response."""
|
||||
self._check_connected()
|
||||
if self._authenticated:
|
||||
if self._is_authenticated:
|
||||
raise APIConnectionError("Already logged in!")
|
||||
|
||||
connect = ConnectRequest()
|
||||
if self._params.password is not None:
|
||||
connect.password = self._params.password
|
||||
resp = await self.send_message_await_response(connect, ConnectResponse)
|
||||
try:
|
||||
resp = await self.send_message_await_response(connect, ConnectResponse)
|
||||
except TimeoutAPIError as err:
|
||||
# After a timeout for connect the connection can no longer be used
|
||||
# We don't know what state the device may be in after ConnectRequest
|
||||
# was already sent
|
||||
await self._report_fatal_error(err)
|
||||
raise
|
||||
|
||||
if resp.invalid_password:
|
||||
raise InvalidAuthAPIError("Invalid password!")
|
||||
|
||||
self._authenticated = True
|
||||
self._is_authenticated = True
|
||||
|
||||
def _check_connected(self) -> None:
|
||||
if not self._connected:
|
||||
if self._connection_state != ConnectionState.CONNECTED:
|
||||
raise APIConnectionError("Must be connected!")
|
||||
|
||||
@property
|
||||
def _is_socket_open(self) -> bool:
|
||||
return self._connection_state in (
|
||||
ConnectionState.SOCKET_OPENED,
|
||||
ConnectionState.CONNECTED,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connected
|
||||
return self._connection_state == ConnectionState.CONNECTED
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return self._authenticated
|
||||
return self.is_connected and self._is_authenticated
|
||||
|
||||
async def send_message(self, msg: message.Message) -> None:
|
||||
if not self._socket_connected:
|
||||
raise APIConnectionError("Socket is not connected")
|
||||
"""Send a protobuf message to the remote."""
|
||||
if not self._is_socket_open:
|
||||
raise APIConnectionError("Connection isn't established yet")
|
||||
|
||||
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
||||
if isinstance(msg, klass):
|
||||
break
|
||||
else:
|
||||
raise ValueError
|
||||
raise ValueError(f"Message type id not found for type {type(msg)}")
|
||||
|
||||
encoded = msg.SerializeToString()
|
||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||
# pylint: disable=undefined-loop-variable
|
||||
assert self._frame_helper is not None
|
||||
await self._frame_helper.write_packet(
|
||||
Packet(
|
||||
type=message_type,
|
||||
data=encoded,
|
||||
|
||||
try:
|
||||
assert self._frame_helper is not None
|
||||
# pylint: disable=undefined-loop-variable
|
||||
await self._frame_helper.write_packet(
|
||||
Packet(
|
||||
type=message_type,
|
||||
data=encoded,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
# If writing packet fails, we don't know what state the frames
|
||||
# are in anymore and we have to close the connection
|
||||
await self._report_fatal_error(err)
|
||||
|
||||
async def send_message_callback_response(
|
||||
self, send_msg: message.Message, on_message: Callable[[Any], None]
|
||||
) -> None:
|
||||
"""Send a message to the remote and register the given message handler."""
|
||||
self._message_handlers.append(on_message)
|
||||
await self.send_message(send_msg)
|
||||
|
||||
@ -463,6 +561,15 @@ class APIConnection:
|
||||
do_stop: Callable[[Any], bool],
|
||||
timeout: float = 10.0,
|
||||
) -> List[Any]:
|
||||
"""Send a message to the remote and build up a list response.
|
||||
|
||||
:param send_msg: The message (request) to send.
|
||||
:param do_append: Predicate to check if a received message is part of the response.
|
||||
:param do_stop: Predicate to check if a received message is the stop response.
|
||||
:param timeout: The maximum amount of time to wait for the stop response.
|
||||
|
||||
:raises TimeoutAPIError: if a timeout occured
|
||||
"""
|
||||
fut = self._params.eventloop.create_future()
|
||||
responses = []
|
||||
|
||||
@ -476,7 +583,10 @@ class APIConnection:
|
||||
|
||||
def on_read_exception(exc: Exception) -> None:
|
||||
if not fut.done():
|
||||
fut.set_exception(exc)
|
||||
# Wrap error so that caller gets right stacktrace
|
||||
new_exc = ReadFailedAPIError("Read failed")
|
||||
new_exc.__cause__ = exc
|
||||
fut.set_exception(new_exc)
|
||||
|
||||
self._message_handlers.append(on_message)
|
||||
self._read_exception_handlers.append(on_read_exception)
|
||||
@ -484,10 +594,10 @@ class APIConnection:
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(fut, timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if self._stopped:
|
||||
raise SocketAPIError("Disconnected while waiting for API response!")
|
||||
raise SocketAPIError("Timeout while waiting for API response!")
|
||||
except asyncio.TimeoutError as err:
|
||||
raise TimeoutAPIError(
|
||||
f"Timeout waiting for response for {send_msg}"
|
||||
) from err
|
||||
finally:
|
||||
with suppress(ValueError):
|
||||
self._message_handlers.remove(on_message)
|
||||
@ -510,16 +620,28 @@ class APIConnection:
|
||||
|
||||
return res[0]
|
||||
|
||||
async def _run_once(self) -> None:
|
||||
async def _report_fatal_error(self, err: Exception) -> None:
|
||||
"""Report a fatal error that occured during an operation.
|
||||
|
||||
This should only be called for errors that mean the connection
|
||||
can no longer be used.
|
||||
|
||||
The connection will be closed, all exception handlers notified.
|
||||
This method does not log the error, the call site should do so.
|
||||
"""
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
for handler in self._read_exception_handlers[:]:
|
||||
handler(err)
|
||||
await self._cleanup()
|
||||
|
||||
async def _read_once(self) -> None:
|
||||
assert self._frame_helper is not None
|
||||
pkt = await self._frame_helper.read_packet()
|
||||
|
||||
msg_type = pkt.type
|
||||
raw_msg = pkt.data
|
||||
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
||||
_LOGGER.debug(
|
||||
"%s: Skipping message type %s", self._params.address, msg_type
|
||||
)
|
||||
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
|
||||
return
|
||||
|
||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
||||
@ -527,29 +649,33 @@ class APIConnection:
|
||||
msg.ParseFromString(raw_msg)
|
||||
except Exception as e:
|
||||
raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e
|
||||
_LOGGER.debug(
|
||||
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
|
||||
)
|
||||
_LOGGER.debug("%s: Got message of type %s: %s", self.log_name, type(msg), msg)
|
||||
for msg_handler in self._message_handlers[:]:
|
||||
msg_handler(msg)
|
||||
await self._handle_internal_messages(msg)
|
||||
|
||||
async def run_forever(self) -> None:
|
||||
async def _read_loop(self) -> None:
|
||||
while True:
|
||||
if self._frame_helper is None:
|
||||
# Socket closed
|
||||
if not self._is_socket_open:
|
||||
# Socket closed but task isn't cancelled yet
|
||||
break
|
||||
try:
|
||||
await self._run_once()
|
||||
await self._read_once()
|
||||
except SocketClosedAPIError as err:
|
||||
# don't log with info, if closed the site that closed the connection should log
|
||||
_LOGGER.debug(
|
||||
"%s: Socket closed, stopping read loop",
|
||||
self.log_name,
|
||||
)
|
||||
await self._report_fatal_error(err)
|
||||
break
|
||||
except APIConnectionError as err:
|
||||
_LOGGER.info(
|
||||
"%s: Error while reading incoming messages: %s",
|
||||
self.log_name,
|
||||
err,
|
||||
)
|
||||
for handler in self._read_exception_handlers[:]:
|
||||
handler(err)
|
||||
await self._on_error()
|
||||
await self._report_fatal_error(err)
|
||||
break
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.warning(
|
||||
@ -558,15 +684,14 @@ class APIConnection:
|
||||
err,
|
||||
exc_info=True,
|
||||
)
|
||||
for handler in self._read_exception_handlers[:]:
|
||||
handler(err)
|
||||
await self._on_error()
|
||||
await self._report_fatal_error(err)
|
||||
break
|
||||
|
||||
async def _handle_internal_messages(self, msg: Any) -> None:
|
||||
if isinstance(msg, DisconnectRequest):
|
||||
await self.send_message(DisconnectResponse())
|
||||
await self.stop(force=True)
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
await self._cleanup()
|
||||
elif isinstance(msg, PingRequest):
|
||||
await self.send_message(PingResponse())
|
||||
elif isinstance(msg, GetTimeRequest):
|
||||
@ -574,12 +699,14 @@ class APIConnection:
|
||||
resp.epoch_seconds = int(time.time())
|
||||
await self.send_message(resp)
|
||||
|
||||
async def ping(self) -> None:
|
||||
async def _ping(self) -> None:
|
||||
self._check_connected()
|
||||
await self.send_message_await_response(PingRequest(), PingResponse)
|
||||
|
||||
async def _disconnect(self) -> None:
|
||||
self._check_connected()
|
||||
async def disconnect(self) -> None:
|
||||
if self._connection_state != ConnectionState.CONNECTED:
|
||||
# already disconnected
|
||||
return
|
||||
|
||||
try:
|
||||
await self.send_message_await_response(
|
||||
@ -588,9 +715,12 @@ class APIConnection:
|
||||
except APIConnectionError:
|
||||
pass
|
||||
|
||||
def _check_authenticated(self) -> None:
|
||||
if not self._authenticated:
|
||||
raise APIConnectionError("Must login first!")
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
await self._cleanup()
|
||||
|
||||
async def force_disconnect(self) -> None:
|
||||
self._connection_state = ConnectionState.CLOSED
|
||||
await self._cleanup()
|
||||
|
||||
@property
|
||||
def api_version(self) -> Optional[APIVersion]:
|
||||
|
@ -83,6 +83,10 @@ class SocketAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class SocketClosedAPIError(SocketAPIError):
|
||||
pass
|
||||
|
||||
|
||||
class HandshakeAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
@ -91,6 +95,18 @@ class InvalidEncryptionKeyAPIError(HandshakeAPIError):
|
||||
pass
|
||||
|
||||
|
||||
class PingFailedAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class ReadFailedAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
MESSAGE_TYPE_TO_PROTO = {
|
||||
1: HelloRequest,
|
||||
2: HelloResponse,
|
||||
|
@ -118,13 +118,15 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # type: ignore[misc,name-d
|
||||
|
||||
try:
|
||||
await self._cli.connect(on_stop=self._on_disconnect, login=True)
|
||||
except APIConnectionError as error:
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
level = logging.WARNING if tries == 0 else logging.DEBUG
|
||||
_LOGGER.log(
|
||||
level,
|
||||
"Can't connect to ESPHome API for %s: %s",
|
||||
self._log_name,
|
||||
error,
|
||||
err,
|
||||
# Print stacktrace if unhandled (not APIConnectionError)
|
||||
exc_info=not isinstance(err, APIConnectionError),
|
||||
)
|
||||
await self._start_zc_listen()
|
||||
# Schedule re-connect in event loop in order not to delay HA
|
||||
|
@ -54,8 +54,8 @@ def socket_socket():
|
||||
async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
||||
with patch.object(event_loop, "sock_connect"), patch(
|
||||
"asyncio.open_connection", return_value=(None, None)
|
||||
), patch.object(conn, "run_forever"), patch.object(
|
||||
conn, "_start_ping"
|
||||
), patch.object(conn, "_read_loop"), patch.object(
|
||||
conn, "_connect_start_ping"
|
||||
), patch.object(
|
||||
conn, "send_message_await_response", return_value=HelloResponse()
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user