mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-29 22:51:25 +01:00
Refactor connection class to resolve connect issues (#108)
This commit is contained in:
parent
e03015f8b4
commit
5b99d5c1dd
@ -197,7 +197,10 @@ class APIClient:
|
|||||||
async def disconnect(self, force: bool = False) -> None:
|
async def disconnect(self, force: bool = False) -> None:
|
||||||
if self._connection is None:
|
if self._connection is None:
|
||||||
return
|
return
|
||||||
await self._connection.stop(force=force)
|
if force:
|
||||||
|
await self._connection.force_disconnect()
|
||||||
|
else:
|
||||||
|
await self._connection.disconnect()
|
||||||
|
|
||||||
def _check_connected(self) -> None:
|
def _check_connected(self) -> None:
|
||||||
if self._connection is None:
|
if self._connection is None:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
@ -30,10 +31,14 @@ from .core import (
|
|||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidAuthAPIError,
|
InvalidAuthAPIError,
|
||||||
InvalidEncryptionKeyAPIError,
|
InvalidEncryptionKeyAPIError,
|
||||||
|
PingFailedAPIError,
|
||||||
ProtocolAPIError,
|
ProtocolAPIError,
|
||||||
|
ReadFailedAPIError,
|
||||||
RequiresEncryptionAPIError,
|
RequiresEncryptionAPIError,
|
||||||
ResolveAPIError,
|
ResolveAPIError,
|
||||||
SocketAPIError,
|
SocketAPIError,
|
||||||
|
SocketClosedAPIError,
|
||||||
|
TimeoutAPIError,
|
||||||
)
|
)
|
||||||
from .model import APIVersion
|
from .model import APIVersion
|
||||||
from .util import bytes_to_varuint, varuint_to_bytes
|
from .util import bytes_to_varuint, varuint_to_bytes
|
||||||
@ -73,9 +78,10 @@ class APIFrameHelper:
|
|||||||
self._read_lock = asyncio.Lock()
|
self._read_lock = asyncio.Lock()
|
||||||
self._ready_event = asyncio.Event()
|
self._ready_event = asyncio.Event()
|
||||||
self._proto: Optional[NoiseConnection] = None
|
self._proto: Optional[NoiseConnection] = None
|
||||||
|
self._closed_event = asyncio.Event()
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
async with self._write_lock:
|
self._closed_event.set()
|
||||||
self._writer.close()
|
self._writer.close()
|
||||||
|
|
||||||
async def _write_frame_noise(self, frame: bytes) -> None:
|
async def _write_frame_noise(self, frame: bytes) -> None:
|
||||||
@ -103,6 +109,13 @@ class APIFrameHelper:
|
|||||||
msg_size = (header[1] << 8) | header[2]
|
msg_size = (header[1] << 8) | header[2]
|
||||||
frame = await self._reader.readexactly(msg_size)
|
frame = await self._reader.readexactly(msg_size)
|
||||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||||
|
if (
|
||||||
|
isinstance(err, asyncio.IncompleteReadError)
|
||||||
|
and self._closed_event.is_set()
|
||||||
|
):
|
||||||
|
raise SocketClosedAPIError(
|
||||||
|
f"Socket closed while reading data: {err}"
|
||||||
|
) from err
|
||||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||||
|
|
||||||
_LOGGER.debug("Received frame %s", frame.hex())
|
_LOGGER.debug("Received frame %s", frame.hex())
|
||||||
@ -241,6 +254,13 @@ class APIFrameHelper:
|
|||||||
raw_msg = await self._reader.readexactly(length_int)
|
raw_msg = await self._reader.readexactly(length_int)
|
||||||
return Packet(type=msg_type_int, data=raw_msg)
|
return Packet(type=msg_type_int, data=raw_msg)
|
||||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||||
|
if (
|
||||||
|
isinstance(err, asyncio.IncompleteReadError)
|
||||||
|
and self._closed_event.is_set()
|
||||||
|
):
|
||||||
|
raise SocketClosedAPIError(
|
||||||
|
f"Socket closed while reading data: {err}"
|
||||||
|
) from err
|
||||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||||
|
|
||||||
async def read_packet(self) -> Packet:
|
async def read_packet(self) -> Packet:
|
||||||
@ -249,79 +269,74 @@ class APIFrameHelper:
|
|||||||
return await self._read_packet_noise()
|
return await self._read_packet_noise()
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionState(enum.Enum):
|
||||||
|
# The connection is initialized, but connect() wasn't called yet
|
||||||
|
INITIALIZED = 0
|
||||||
|
# Internal state,
|
||||||
|
SOCKET_OPENED = 1
|
||||||
|
# The connection has been established, data can be exchanged
|
||||||
|
CONNECTED = 1
|
||||||
|
CLOSED = 2
|
||||||
|
|
||||||
|
|
||||||
class APIConnection:
|
class APIConnection:
|
||||||
|
"""This class represents _one_ connection to a remote native API device.
|
||||||
|
|
||||||
|
An instance of this class may only be used once, for every new connection
|
||||||
|
a new instance should be established.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, params: ConnectionParams, on_stop: Callable[[], Awaitable[None]]
|
self, params: ConnectionParams, on_stop: Callable[[], Awaitable[None]]
|
||||||
):
|
):
|
||||||
self._params = params
|
self._params = params
|
||||||
self.on_stop = on_stop
|
self.on_stop = on_stop
|
||||||
self._stopped = False
|
self._on_stop_called = False
|
||||||
self._socket: Optional[socket.socket] = None
|
self._socket: Optional[socket.socket] = None
|
||||||
self._frame_helper: Optional[APIFrameHelper] = None
|
self._frame_helper: Optional[APIFrameHelper] = None
|
||||||
self._connected = False
|
|
||||||
self._authenticated = False
|
|
||||||
self._socket_connected = False
|
|
||||||
self._state_lock = asyncio.Lock()
|
|
||||||
self._api_version: Optional[APIVersion] = None
|
self._api_version: Optional[APIVersion] = None
|
||||||
|
|
||||||
|
self._connection_state = ConnectionState.INITIALIZED
|
||||||
|
self._is_authenticated = False
|
||||||
|
# Store whether connect() has completed
|
||||||
|
# Used so that on_stop is _not_ called if an error occurs during connect()
|
||||||
|
self._connect_complete = False
|
||||||
|
|
||||||
|
# Message handlers currently subscribed to incoming messages
|
||||||
self._message_handlers: List[Callable[[message.Message], None]] = []
|
self._message_handlers: List[Callable[[message.Message], None]] = []
|
||||||
|
# The friendly name to show for this connection in the logs
|
||||||
self.log_name = params.address
|
self.log_name = params.address
|
||||||
self._ping_task: Optional[asyncio.Task[None]] = None
|
|
||||||
|
# Handlers currently subscribed to exceptions in the read task
|
||||||
self._read_exception_handlers: List[Callable[[Exception], None]] = []
|
self._read_exception_handlers: List[Callable[[Exception], None]] = []
|
||||||
|
|
||||||
def _start_ping(self) -> None:
|
self._ping_stop_event = asyncio.Event()
|
||||||
async def func() -> None:
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(self._params.keepalive)
|
|
||||||
|
|
||||||
try:
|
async def _cleanup(self) -> None:
|
||||||
await self.ping()
|
"""Clean up all resources that have been allocated.
|
||||||
except APIConnectionError:
|
|
||||||
_LOGGER.info("%s: Ping Failed!", self.log_name)
|
|
||||||
await self._on_error()
|
|
||||||
return
|
|
||||||
|
|
||||||
self._ping_task = asyncio.create_task(func())
|
Safe to call multiple times.
|
||||||
|
"""
|
||||||
async def _close_socket(self) -> None:
|
|
||||||
if not self._socket_connected:
|
|
||||||
return
|
|
||||||
if self._frame_helper is not None:
|
if self._frame_helper is not None:
|
||||||
await self._frame_helper.close()
|
await self._frame_helper.close()
|
||||||
self._frame_helper = None
|
self._frame_helper = None
|
||||||
|
|
||||||
if self._socket is not None:
|
if self._socket is not None:
|
||||||
self._socket.close()
|
self._socket.close()
|
||||||
self._socket = None
|
self._socket = None
|
||||||
if self._ping_task is not None:
|
|
||||||
self._ping_task.cancel()
|
|
||||||
self._ping_task = None
|
|
||||||
self._socket_connected = False
|
|
||||||
self._connected = False
|
|
||||||
self._authenticated = False
|
|
||||||
_LOGGER.debug("%s: Closed socket", self.log_name)
|
|
||||||
|
|
||||||
async def stop(self, force: bool = False) -> None:
|
if not self._on_stop_called and self._connect_complete:
|
||||||
if self._stopped:
|
# Ensure on_stop is called
|
||||||
return
|
asyncio.create_task(self.on_stop())
|
||||||
if self._connected and not force:
|
self._on_stop_called = True
|
||||||
try:
|
|
||||||
await self._disconnect()
|
|
||||||
except APIConnectionError:
|
|
||||||
pass
|
|
||||||
self._stopped = True
|
|
||||||
await self._close_socket()
|
|
||||||
await self.on_stop()
|
|
||||||
|
|
||||||
async def _on_error(self) -> None:
|
# Note: we don't explicitly cancel the ping/read task here
|
||||||
await self.stop(force=True)
|
# That's because if not written right the ping/read task could cancel
|
||||||
|
# themself, effectively ending execution after _cleanup which may be unexpected
|
||||||
# pylint: disable=too-many-statements
|
self._ping_stop_event.set()
|
||||||
async def connect(self) -> None:
|
|
||||||
if self._stopped:
|
|
||||||
raise APIConnectionError(f"Connection is closed for {self.log_name}!")
|
|
||||||
if self._connected:
|
|
||||||
raise APIConnectionError(f"Already connected for {self.log_name}!")
|
|
||||||
|
|
||||||
|
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
||||||
|
"""Step 1 in connect process: resolve the address."""
|
||||||
try:
|
try:
|
||||||
coro = hr.async_resolve_host(
|
coro = hr.async_resolve_host(
|
||||||
self._params.eventloop,
|
self._params.eventloop,
|
||||||
@ -329,16 +344,14 @@ class APIConnection:
|
|||||||
self._params.port,
|
self._params.port,
|
||||||
self._params.zeroconf_instance,
|
self._params.zeroconf_instance,
|
||||||
)
|
)
|
||||||
addr = await asyncio.wait_for(coro, 30.0)
|
return await asyncio.wait_for(coro, 30.0)
|
||||||
except APIConnectionError as err:
|
except asyncio.TimeoutError as err:
|
||||||
await self._on_error()
|
|
||||||
raise err
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
await self._on_error()
|
|
||||||
raise ResolveAPIError(
|
raise ResolveAPIError(
|
||||||
f"Timeout while resolving IP address for {self.log_name}"
|
f"Timeout while resolving IP address for {self.log_name}"
|
||||||
)
|
) from err
|
||||||
|
|
||||||
|
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
|
||||||
|
"""Step 2 in connect process: connect the socket."""
|
||||||
self._socket = socket.socket(
|
self._socket = socket.socket(
|
||||||
family=addr.family, type=addr.type, proto=addr.proto
|
family=addr.family, type=addr.type, proto=addr.proto
|
||||||
)
|
)
|
||||||
@ -353,36 +366,38 @@ class APIConnection:
|
|||||||
addr,
|
addr,
|
||||||
)
|
)
|
||||||
sockaddr = astuple(addr.sockaddr)
|
sockaddr = astuple(addr.sockaddr)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
|
coro = self._params.eventloop.sock_connect(self._socket, sockaddr)
|
||||||
await asyncio.wait_for(coro2, 30.0)
|
await asyncio.wait_for(coro, 30.0)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
await self._on_error()
|
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
|
||||||
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}")
|
except asyncio.TimeoutError as err:
|
||||||
except asyncio.TimeoutError:
|
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
|
||||||
await self._on_error()
|
|
||||||
raise SocketAPIError(f"Timeout while connecting to {sockaddr}")
|
|
||||||
|
|
||||||
_LOGGER.debug("%s: Opened socket for", self._params.address)
|
_LOGGER.debug("%s: Opened socket for", self._params.address)
|
||||||
|
|
||||||
|
async def _connect_init_frame_helper(self) -> None:
|
||||||
|
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||||
reader, writer = await asyncio.open_connection(sock=self._socket)
|
reader, writer = await asyncio.open_connection(sock=self._socket)
|
||||||
|
|
||||||
self._frame_helper = APIFrameHelper(reader, writer, self._params)
|
self._frame_helper = APIFrameHelper(reader, writer, self._params)
|
||||||
self._socket_connected = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self._frame_helper.perform_handshake()
|
await self._frame_helper.perform_handshake()
|
||||||
except APIConnectionError:
|
|
||||||
await self._on_error()
|
|
||||||
raise
|
|
||||||
|
|
||||||
self._params.eventloop.create_task(self.run_forever())
|
self._connection_state = ConnectionState.SOCKET_OPENED
|
||||||
|
|
||||||
|
# Create read loop
|
||||||
|
asyncio.create_task(self._read_loop())
|
||||||
|
|
||||||
|
async def _connect_hello(self) -> None:
|
||||||
|
"""Step 4 in connect process: send hello and get api version."""
|
||||||
hello = HelloRequest()
|
hello = HelloRequest()
|
||||||
hello.client_info = self._params.client_info
|
hello.client_info = self._params.client_info
|
||||||
try:
|
try:
|
||||||
resp = await self.send_message_await_response(hello, HelloResponse)
|
resp = await self.send_message_await_response(hello, HelloResponse)
|
||||||
except APIConnectionError:
|
except TimeoutAPIError as err:
|
||||||
await self._on_error()
|
raise TimeoutAPIError("Hello timed out") from err
|
||||||
raise
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Successfully connected ('%s' API=%s.%s)",
|
"%s: Successfully connected ('%s' API=%s.%s)",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
@ -397,62 +412,145 @@ class APIConnection:
|
|||||||
self.log_name,
|
self.log_name,
|
||||||
self._api_version.major,
|
self._api_version.major,
|
||||||
)
|
)
|
||||||
await self._on_error()
|
|
||||||
raise APIConnectionError("Incompatible API version.")
|
raise APIConnectionError("Incompatible API version.")
|
||||||
self._connected = True
|
|
||||||
|
|
||||||
self._start_ping()
|
self._connection_state = ConnectionState.CONNECTED
|
||||||
|
|
||||||
|
async def _connect_start_ping(self) -> None:
|
||||||
|
"""Step 5 in connect process: start the ping loop."""
|
||||||
|
|
||||||
|
async def func() -> None:
|
||||||
|
while True:
|
||||||
|
if not self._is_socket_open:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Wait for keepalive seconds, or ping stop event, whichever happens first
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._ping_stop_event.wait(), self._params.keepalive
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Re-check connection state
|
||||||
|
if not self._is_socket_open:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._ping()
|
||||||
|
except TimeoutAPIError:
|
||||||
|
_LOGGER.info("%s: Ping timed out!", self.log_name)
|
||||||
|
await self._report_fatal_error(PingFailedAPIError())
|
||||||
|
return
|
||||||
|
except APIConnectionError as err:
|
||||||
|
_LOGGER.info("%s: Ping Failed: %s", self.log_name, err)
|
||||||
|
await self._report_fatal_error(err)
|
||||||
|
return
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
_LOGGER.info(
|
||||||
|
"%s: Unexpected error during ping:",
|
||||||
|
self.log_name,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
await self._report_fatal_error(err)
|
||||||
|
return
|
||||||
|
|
||||||
|
asyncio.create_task(func())
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
if self._connection_state != ConnectionState.INITIALIZED:
|
||||||
|
raise ValueError(
|
||||||
|
"Connection can only be used once, connection is not in init state"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
addr = await self._connect_resolve_host()
|
||||||
|
await self._connect_socket_connect(addr)
|
||||||
|
await self._connect_init_frame_helper()
|
||||||
|
await self._connect_hello()
|
||||||
|
await self._connect_start_ping()
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
# Always clean up the connection if an error occured during connect
|
||||||
|
self._connection_state = ConnectionState.CLOSED
|
||||||
|
await self._cleanup()
|
||||||
|
raise
|
||||||
|
|
||||||
|
self._connect_complete = True
|
||||||
|
|
||||||
async def login(self) -> None:
|
async def login(self) -> None:
|
||||||
|
"""Send a login (ConnectRequest) and await the response."""
|
||||||
self._check_connected()
|
self._check_connected()
|
||||||
if self._authenticated:
|
if self._is_authenticated:
|
||||||
raise APIConnectionError("Already logged in!")
|
raise APIConnectionError("Already logged in!")
|
||||||
|
|
||||||
connect = ConnectRequest()
|
connect = ConnectRequest()
|
||||||
if self._params.password is not None:
|
if self._params.password is not None:
|
||||||
connect.password = self._params.password
|
connect.password = self._params.password
|
||||||
|
try:
|
||||||
resp = await self.send_message_await_response(connect, ConnectResponse)
|
resp = await self.send_message_await_response(connect, ConnectResponse)
|
||||||
|
except TimeoutAPIError as err:
|
||||||
|
# After a timeout for connect the connection can no longer be used
|
||||||
|
# We don't know what state the device may be in after ConnectRequest
|
||||||
|
# was already sent
|
||||||
|
await self._report_fatal_error(err)
|
||||||
|
raise
|
||||||
|
|
||||||
if resp.invalid_password:
|
if resp.invalid_password:
|
||||||
raise InvalidAuthAPIError("Invalid password!")
|
raise InvalidAuthAPIError("Invalid password!")
|
||||||
|
|
||||||
self._authenticated = True
|
self._is_authenticated = True
|
||||||
|
|
||||||
def _check_connected(self) -> None:
|
def _check_connected(self) -> None:
|
||||||
if not self._connected:
|
if self._connection_state != ConnectionState.CONNECTED:
|
||||||
raise APIConnectionError("Must be connected!")
|
raise APIConnectionError("Must be connected!")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _is_socket_open(self) -> bool:
|
||||||
|
return self._connection_state in (
|
||||||
|
ConnectionState.SOCKET_OPENED,
|
||||||
|
ConnectionState.CONNECTED,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self._connected
|
return self._connection_state == ConnectionState.CONNECTED
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_authenticated(self) -> bool:
|
def is_authenticated(self) -> bool:
|
||||||
return self._authenticated
|
return self.is_connected and self._is_authenticated
|
||||||
|
|
||||||
async def send_message(self, msg: message.Message) -> None:
|
async def send_message(self, msg: message.Message) -> None:
|
||||||
if not self._socket_connected:
|
"""Send a protobuf message to the remote."""
|
||||||
raise APIConnectionError("Socket is not connected")
|
if not self._is_socket_open:
|
||||||
|
raise APIConnectionError("Connection isn't established yet")
|
||||||
|
|
||||||
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
||||||
if isinstance(msg, klass):
|
if isinstance(msg, klass):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError(f"Message type id not found for type {type(msg)}")
|
||||||
|
|
||||||
encoded = msg.SerializeToString()
|
encoded = msg.SerializeToString()
|
||||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||||
# pylint: disable=undefined-loop-variable
|
|
||||||
|
try:
|
||||||
assert self._frame_helper is not None
|
assert self._frame_helper is not None
|
||||||
|
# pylint: disable=undefined-loop-variable
|
||||||
await self._frame_helper.write_packet(
|
await self._frame_helper.write_packet(
|
||||||
Packet(
|
Packet(
|
||||||
type=message_type,
|
type=message_type,
|
||||||
data=encoded,
|
data=encoded,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
# If writing packet fails, we don't know what state the frames
|
||||||
|
# are in anymore and we have to close the connection
|
||||||
|
await self._report_fatal_error(err)
|
||||||
|
|
||||||
async def send_message_callback_response(
|
async def send_message_callback_response(
|
||||||
self, send_msg: message.Message, on_message: Callable[[Any], None]
|
self, send_msg: message.Message, on_message: Callable[[Any], None]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Send a message to the remote and register the given message handler."""
|
||||||
self._message_handlers.append(on_message)
|
self._message_handlers.append(on_message)
|
||||||
await self.send_message(send_msg)
|
await self.send_message(send_msg)
|
||||||
|
|
||||||
@ -463,6 +561,15 @@ class APIConnection:
|
|||||||
do_stop: Callable[[Any], bool],
|
do_stop: Callable[[Any], bool],
|
||||||
timeout: float = 10.0,
|
timeout: float = 10.0,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
|
"""Send a message to the remote and build up a list response.
|
||||||
|
|
||||||
|
:param send_msg: The message (request) to send.
|
||||||
|
:param do_append: Predicate to check if a received message is part of the response.
|
||||||
|
:param do_stop: Predicate to check if a received message is the stop response.
|
||||||
|
:param timeout: The maximum amount of time to wait for the stop response.
|
||||||
|
|
||||||
|
:raises TimeoutAPIError: if a timeout occured
|
||||||
|
"""
|
||||||
fut = self._params.eventloop.create_future()
|
fut = self._params.eventloop.create_future()
|
||||||
responses = []
|
responses = []
|
||||||
|
|
||||||
@ -476,7 +583,10 @@ class APIConnection:
|
|||||||
|
|
||||||
def on_read_exception(exc: Exception) -> None:
|
def on_read_exception(exc: Exception) -> None:
|
||||||
if not fut.done():
|
if not fut.done():
|
||||||
fut.set_exception(exc)
|
# Wrap error so that caller gets right stacktrace
|
||||||
|
new_exc = ReadFailedAPIError("Read failed")
|
||||||
|
new_exc.__cause__ = exc
|
||||||
|
fut.set_exception(new_exc)
|
||||||
|
|
||||||
self._message_handlers.append(on_message)
|
self._message_handlers.append(on_message)
|
||||||
self._read_exception_handlers.append(on_read_exception)
|
self._read_exception_handlers.append(on_read_exception)
|
||||||
@ -484,10 +594,10 @@ class APIConnection:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(fut, timeout)
|
await asyncio.wait_for(fut, timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError as err:
|
||||||
if self._stopped:
|
raise TimeoutAPIError(
|
||||||
raise SocketAPIError("Disconnected while waiting for API response!")
|
f"Timeout waiting for response for {send_msg}"
|
||||||
raise SocketAPIError("Timeout while waiting for API response!")
|
) from err
|
||||||
finally:
|
finally:
|
||||||
with suppress(ValueError):
|
with suppress(ValueError):
|
||||||
self._message_handlers.remove(on_message)
|
self._message_handlers.remove(on_message)
|
||||||
@ -510,16 +620,28 @@ class APIConnection:
|
|||||||
|
|
||||||
return res[0]
|
return res[0]
|
||||||
|
|
||||||
async def _run_once(self) -> None:
|
async def _report_fatal_error(self, err: Exception) -> None:
|
||||||
|
"""Report a fatal error that occured during an operation.
|
||||||
|
|
||||||
|
This should only be called for errors that mean the connection
|
||||||
|
can no longer be used.
|
||||||
|
|
||||||
|
The connection will be closed, all exception handlers notified.
|
||||||
|
This method does not log the error, the call site should do so.
|
||||||
|
"""
|
||||||
|
self._connection_state = ConnectionState.CLOSED
|
||||||
|
for handler in self._read_exception_handlers[:]:
|
||||||
|
handler(err)
|
||||||
|
await self._cleanup()
|
||||||
|
|
||||||
|
async def _read_once(self) -> None:
|
||||||
assert self._frame_helper is not None
|
assert self._frame_helper is not None
|
||||||
pkt = await self._frame_helper.read_packet()
|
pkt = await self._frame_helper.read_packet()
|
||||||
|
|
||||||
msg_type = pkt.type
|
msg_type = pkt.type
|
||||||
raw_msg = pkt.data
|
raw_msg = pkt.data
|
||||||
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
|
||||||
"%s: Skipping message type %s", self._params.address, msg_type
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
||||||
@ -527,29 +649,33 @@ class APIConnection:
|
|||||||
msg.ParseFromString(raw_msg)
|
msg.ParseFromString(raw_msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e
|
raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e
|
||||||
_LOGGER.debug(
|
_LOGGER.debug("%s: Got message of type %s: %s", self.log_name, type(msg), msg)
|
||||||
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
|
|
||||||
)
|
|
||||||
for msg_handler in self._message_handlers[:]:
|
for msg_handler in self._message_handlers[:]:
|
||||||
msg_handler(msg)
|
msg_handler(msg)
|
||||||
await self._handle_internal_messages(msg)
|
await self._handle_internal_messages(msg)
|
||||||
|
|
||||||
async def run_forever(self) -> None:
|
async def _read_loop(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
if self._frame_helper is None:
|
if not self._is_socket_open:
|
||||||
# Socket closed
|
# Socket closed but task isn't cancelled yet
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
await self._run_once()
|
await self._read_once()
|
||||||
|
except SocketClosedAPIError as err:
|
||||||
|
# don't log with info, if closed the site that closed the connection should log
|
||||||
|
_LOGGER.debug(
|
||||||
|
"%s: Socket closed, stopping read loop",
|
||||||
|
self.log_name,
|
||||||
|
)
|
||||||
|
await self._report_fatal_error(err)
|
||||||
|
break
|
||||||
except APIConnectionError as err:
|
except APIConnectionError as err:
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"%s: Error while reading incoming messages: %s",
|
"%s: Error while reading incoming messages: %s",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
for handler in self._read_exception_handlers[:]:
|
await self._report_fatal_error(err)
|
||||||
handler(err)
|
|
||||||
await self._on_error()
|
|
||||||
break
|
break
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
@ -558,15 +684,14 @@ class APIConnection:
|
|||||||
err,
|
err,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
for handler in self._read_exception_handlers[:]:
|
await self._report_fatal_error(err)
|
||||||
handler(err)
|
|
||||||
await self._on_error()
|
|
||||||
break
|
break
|
||||||
|
|
||||||
async def _handle_internal_messages(self, msg: Any) -> None:
|
async def _handle_internal_messages(self, msg: Any) -> None:
|
||||||
if isinstance(msg, DisconnectRequest):
|
if isinstance(msg, DisconnectRequest):
|
||||||
await self.send_message(DisconnectResponse())
|
await self.send_message(DisconnectResponse())
|
||||||
await self.stop(force=True)
|
self._connection_state = ConnectionState.CLOSED
|
||||||
|
await self._cleanup()
|
||||||
elif isinstance(msg, PingRequest):
|
elif isinstance(msg, PingRequest):
|
||||||
await self.send_message(PingResponse())
|
await self.send_message(PingResponse())
|
||||||
elif isinstance(msg, GetTimeRequest):
|
elif isinstance(msg, GetTimeRequest):
|
||||||
@ -574,12 +699,14 @@ class APIConnection:
|
|||||||
resp.epoch_seconds = int(time.time())
|
resp.epoch_seconds = int(time.time())
|
||||||
await self.send_message(resp)
|
await self.send_message(resp)
|
||||||
|
|
||||||
async def ping(self) -> None:
|
async def _ping(self) -> None:
|
||||||
self._check_connected()
|
self._check_connected()
|
||||||
await self.send_message_await_response(PingRequest(), PingResponse)
|
await self.send_message_await_response(PingRequest(), PingResponse)
|
||||||
|
|
||||||
async def _disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
self._check_connected()
|
if self._connection_state != ConnectionState.CONNECTED:
|
||||||
|
# already disconnected
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.send_message_await_response(
|
await self.send_message_await_response(
|
||||||
@ -588,9 +715,12 @@ class APIConnection:
|
|||||||
except APIConnectionError:
|
except APIConnectionError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _check_authenticated(self) -> None:
|
self._connection_state = ConnectionState.CLOSED
|
||||||
if not self._authenticated:
|
await self._cleanup()
|
||||||
raise APIConnectionError("Must login first!")
|
|
||||||
|
async def force_disconnect(self) -> None:
|
||||||
|
self._connection_state = ConnectionState.CLOSED
|
||||||
|
await self._cleanup()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def api_version(self) -> Optional[APIVersion]:
|
def api_version(self) -> Optional[APIVersion]:
|
||||||
|
@ -83,6 +83,10 @@ class SocketAPIError(APIConnectionError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SocketClosedAPIError(SocketAPIError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HandshakeAPIError(APIConnectionError):
|
class HandshakeAPIError(APIConnectionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -91,6 +95,18 @@ class InvalidEncryptionKeyAPIError(HandshakeAPIError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PingFailedAPIError(APIConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TimeoutAPIError(APIConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFailedAPIError(APIConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
MESSAGE_TYPE_TO_PROTO = {
|
MESSAGE_TYPE_TO_PROTO = {
|
||||||
1: HelloRequest,
|
1: HelloRequest,
|
||||||
2: HelloResponse,
|
2: HelloResponse,
|
||||||
|
@ -118,13 +118,15 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): # type: ignore[misc,name-d
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self._cli.connect(on_stop=self._on_disconnect, login=True)
|
await self._cli.connect(on_stop=self._on_disconnect, login=True)
|
||||||
except APIConnectionError as error:
|
except Exception as err: # pylint: disable=broad-except
|
||||||
level = logging.WARNING if tries == 0 else logging.DEBUG
|
level = logging.WARNING if tries == 0 else logging.DEBUG
|
||||||
_LOGGER.log(
|
_LOGGER.log(
|
||||||
level,
|
level,
|
||||||
"Can't connect to ESPHome API for %s: %s",
|
"Can't connect to ESPHome API for %s: %s",
|
||||||
self._log_name,
|
self._log_name,
|
||||||
error,
|
err,
|
||||||
|
# Print stacktrace if unhandled (not APIConnectionError)
|
||||||
|
exc_info=not isinstance(err, APIConnectionError),
|
||||||
)
|
)
|
||||||
await self._start_zc_listen()
|
await self._start_zc_listen()
|
||||||
# Schedule re-connect in event loop in order not to delay HA
|
# Schedule re-connect in event loop in order not to delay HA
|
||||||
|
@ -54,8 +54,8 @@ def socket_socket():
|
|||||||
async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
||||||
with patch.object(event_loop, "sock_connect"), patch(
|
with patch.object(event_loop, "sock_connect"), patch(
|
||||||
"asyncio.open_connection", return_value=(None, None)
|
"asyncio.open_connection", return_value=(None, None)
|
||||||
), patch.object(conn, "run_forever"), patch.object(
|
), patch.object(conn, "_read_loop"), patch.object(
|
||||||
conn, "_start_ping"
|
conn, "_connect_start_ping"
|
||||||
), patch.object(
|
), patch.object(
|
||||||
conn, "send_message_await_response", return_value=HelloResponse()
|
conn, "send_message_await_response", return_value=HelloResponse()
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user