2019-04-07 19:03:26 +02:00
|
|
|
import asyncio
|
2021-10-04 12:12:43 +02:00
|
|
|
import enum
|
2019-04-07 19:03:26 +02:00
|
|
|
import logging
|
|
|
|
import socket
|
|
|
|
import time
|
2021-09-14 12:44:52 +02:00
|
|
|
from contextlib import suppress
|
2021-06-30 17:00:22 +02:00
|
|
|
from dataclasses import astuple, dataclass
|
2022-05-18 05:39:03 +02:00
|
|
|
from typing import Any, Callable, Coroutine, List, Optional
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2022-08-22 05:26:53 +02:00
|
|
|
import async_timeout
|
2019-04-07 19:03:26 +02:00
|
|
|
from google.protobuf import message
|
|
|
|
|
2021-07-12 20:09:17 +02:00
|
|
|
import aioesphomeapi.host_resolver as hr
|
|
|
|
|
2021-10-13 10:05:08 +02:00
|
|
|
from ._frame_helper import (
|
|
|
|
APIFrameHelper,
|
|
|
|
APINoiseFrameHelper,
|
|
|
|
APIPlaintextFrameHelper,
|
|
|
|
Packet,
|
|
|
|
)
|
2021-06-30 17:00:22 +02:00
|
|
|
from .api_pb2 import ( # type: ignore
|
2021-06-18 17:57:02 +02:00
|
|
|
ConnectRequest,
|
|
|
|
ConnectResponse,
|
|
|
|
DisconnectRequest,
|
|
|
|
DisconnectResponse,
|
|
|
|
GetTimeRequest,
|
|
|
|
GetTimeResponse,
|
|
|
|
HelloRequest,
|
|
|
|
HelloResponse,
|
|
|
|
PingRequest,
|
|
|
|
PingResponse,
|
|
|
|
)
|
2021-09-14 12:44:52 +02:00
|
|
|
from .core import (
|
|
|
|
MESSAGE_TYPE_TO_PROTO,
|
|
|
|
APIConnectionError,
|
2022-01-20 12:03:36 +01:00
|
|
|
BadNameAPIError,
|
2021-09-14 12:44:52 +02:00
|
|
|
InvalidAuthAPIError,
|
2021-10-04 12:12:43 +02:00
|
|
|
PingFailedAPIError,
|
2021-09-14 12:44:52 +02:00
|
|
|
ProtocolAPIError,
|
2021-10-04 12:12:43 +02:00
|
|
|
ReadFailedAPIError,
|
2021-09-14 12:44:52 +02:00
|
|
|
ResolveAPIError,
|
|
|
|
SocketAPIError,
|
2021-10-04 12:12:43 +02:00
|
|
|
SocketClosedAPIError,
|
|
|
|
TimeoutAPIError,
|
2021-09-14 12:44:52 +02:00
|
|
|
)
|
2021-06-30 17:00:22 +02:00
|
|
|
from .model import APIVersion
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
2022-10-03 00:32:26 +02:00
|
|
|
BUFFER_SIZE = 1024 * 1024 # Set buffer limit to 1MB
|
|
|
|
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-06-29 15:36:14 +02:00
|
|
|
@dataclass
|
2019-04-07 19:03:26 +02:00
|
|
|
class ConnectionParams:
|
2021-06-29 15:36:14 +02:00
|
|
|
address: str
|
|
|
|
port: int
|
|
|
|
password: Optional[str]
|
|
|
|
client_info: str
|
|
|
|
keepalive: float
|
2021-07-12 20:09:17 +02:00
|
|
|
zeroconf_instance: hr.ZeroconfInstanceType
|
2021-09-08 23:12:07 +02:00
|
|
|
noise_psk: Optional[str]
|
2022-01-20 12:03:36 +01:00
|
|
|
expected_name: Optional[str]
|
2021-09-08 23:12:07 +02:00
|
|
|
|
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
class ConnectionState(enum.Enum):
|
|
|
|
# The connection is initialized, but connect() wasn't called yet
|
|
|
|
INITIALIZED = 0
|
|
|
|
# Internal state,
|
|
|
|
SOCKET_OPENED = 1
|
|
|
|
# The connection has been established, data can be exchanged
|
|
|
|
CONNECTED = 1
|
|
|
|
CLOSED = 2
|
|
|
|
|
|
|
|
|
2019-04-07 19:03:26 +02:00
|
|
|
class APIConnection:
|
2021-10-04 12:12:43 +02:00
|
|
|
"""This class represents _one_ connection to a remote native API device.
|
|
|
|
|
|
|
|
An instance of this class may only be used once, for every new connection
|
|
|
|
a new instance should be established.
|
|
|
|
"""
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
def __init__(
|
2022-05-18 05:39:03 +02:00
|
|
|
self, params: ConnectionParams, on_stop: Callable[[], Coroutine[Any, Any, None]]
|
2021-06-18 17:57:02 +02:00
|
|
|
):
|
2019-04-07 19:03:26 +02:00
|
|
|
self._params = params
|
|
|
|
self.on_stop = on_stop
|
2021-10-04 12:12:43 +02:00
|
|
|
self._on_stop_called = False
|
2021-06-18 17:57:02 +02:00
|
|
|
self._socket: Optional[socket.socket] = None
|
2021-09-08 23:12:07 +02:00
|
|
|
self._frame_helper: Optional[APIFrameHelper] = None
|
2021-06-18 17:57:02 +02:00
|
|
|
self._api_version: Optional[APIVersion] = None
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
self._connection_state = ConnectionState.INITIALIZED
|
|
|
|
self._is_authenticated = False
|
|
|
|
# Store whether connect() has completed
|
|
|
|
# Used so that on_stop is _not_ called if an error occurs during connect()
|
|
|
|
self._connect_complete = False
|
|
|
|
|
|
|
|
# Message handlers currently subscribed to incoming messages
|
2021-06-18 17:57:02 +02:00
|
|
|
self._message_handlers: List[Callable[[message.Message], None]] = []
|
2021-10-04 12:12:43 +02:00
|
|
|
# The friendly name to show for this connection in the logs
|
2021-06-30 17:10:30 +02:00
|
|
|
self.log_name = params.address
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
# Handlers currently subscribed to exceptions in the read task
|
|
|
|
self._read_exception_handlers: List[Callable[[Exception], None]] = []
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
self._ping_stop_event = asyncio.Event()
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2022-09-29 23:59:40 +02:00
|
|
|
self._to_process: asyncio.Queue[message.Message] = asyncio.Queue()
|
|
|
|
|
2022-10-06 22:37:47 +02:00
|
|
|
self._process_task: Optional[asyncio.Task[None]] = None
|
|
|
|
|
2022-10-24 14:11:16 +02:00
|
|
|
self._connect_lock: asyncio.Lock = asyncio.Lock()
|
|
|
|
self._cleanup_task: Optional[asyncio.Task[None]] = None
|
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
async def _cleanup(self) -> None:
|
|
|
|
"""Clean up all resources that have been allocated.
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
Safe to call multiple times.
|
|
|
|
"""
|
2022-10-24 14:11:16 +02:00
|
|
|
|
|
|
|
async def _do_cleanup() -> None:
|
|
|
|
async with self._connect_lock:
|
|
|
|
if self._frame_helper is not None:
|
|
|
|
await self._frame_helper.close()
|
|
|
|
self._frame_helper = None
|
|
|
|
|
|
|
|
if self._process_task is not None:
|
|
|
|
self._process_task.cancel()
|
|
|
|
with suppress(asyncio.CancelledError):
|
|
|
|
await self._process_task
|
|
|
|
self._process_task = None
|
|
|
|
|
|
|
|
if self._socket is not None:
|
|
|
|
self._socket.close()
|
|
|
|
self._socket = None
|
|
|
|
|
|
|
|
if not self._on_stop_called and self._connect_complete:
|
|
|
|
# Ensure on_stop is called
|
|
|
|
asyncio.create_task(self.on_stop())
|
|
|
|
self._on_stop_called = True
|
|
|
|
|
|
|
|
# Note: we don't explicitly cancel the ping/read task here
|
|
|
|
# That's because if not written right the ping/read task could cancel
|
|
|
|
# themself, effectively ending execution after _cleanup which may be unexpected
|
|
|
|
self._ping_stop_event.set()
|
|
|
|
|
|
|
|
if not self._cleanup_task or not self._cleanup_task.done():
|
|
|
|
self._cleanup_task = asyncio.create_task(_do_cleanup())
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
async def _connect_resolve_host(self) -> hr.AddrInfo:
|
|
|
|
"""Step 1 in connect process: resolve the address."""
|
2019-04-07 19:03:26 +02:00
|
|
|
try:
|
2021-07-12 20:09:17 +02:00
|
|
|
coro = hr.async_resolve_host(
|
2021-06-18 17:57:02 +02:00
|
|
|
self._params.address,
|
|
|
|
self._params.port,
|
|
|
|
self._params.zeroconf_instance,
|
|
|
|
)
|
2022-08-22 05:26:53 +02:00
|
|
|
async with async_timeout.timeout(30.0):
|
|
|
|
return await coro
|
2021-10-04 12:12:43 +02:00
|
|
|
except asyncio.TimeoutError as err:
|
2021-09-14 12:44:52 +02:00
|
|
|
raise ResolveAPIError(
|
2021-06-30 17:10:30 +02:00
|
|
|
f"Timeout while resolving IP address for {self.log_name}"
|
2021-10-04 12:12:43 +02:00
|
|
|
) from err
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
async def _connect_socket_connect(self, addr: hr.AddrInfo) -> None:
|
|
|
|
"""Step 2 in connect process: connect the socket."""
|
2021-06-30 17:00:22 +02:00
|
|
|
self._socket = socket.socket(
|
|
|
|
family=addr.family, type=addr.type, proto=addr.proto
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
self._socket.setblocking(False)
|
|
|
|
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
2022-10-03 00:32:26 +02:00
|
|
|
# Try to reduce the pressure on esphome device as it measures
|
|
|
|
# ram in bytes and we measure ram in megabytes.
|
|
|
|
try:
|
|
|
|
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
|
|
|
|
except OSError as err:
|
|
|
|
_LOGGER.warning(
|
|
|
|
"%s: Failed to set socket receive buffer size: %s",
|
|
|
|
self.log_name,
|
|
|
|
err,
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.debug(
|
|
|
|
"%s: Connecting to %s:%s (%s)",
|
2021-06-30 17:10:30 +02:00
|
|
|
self.log_name,
|
2021-06-18 17:57:02 +02:00
|
|
|
self._params.address,
|
|
|
|
self._params.port,
|
2021-06-30 17:00:22 +02:00
|
|
|
addr,
|
2021-06-18 17:57:02 +02:00
|
|
|
)
|
2021-06-30 17:00:22 +02:00
|
|
|
sockaddr = astuple(addr.sockaddr)
|
2021-10-04 12:12:43 +02:00
|
|
|
|
2019-04-07 19:03:26 +02:00
|
|
|
try:
|
2021-10-13 10:15:30 +02:00
|
|
|
coro = asyncio.get_event_loop().sock_connect(self._socket, sockaddr)
|
2022-08-22 05:26:53 +02:00
|
|
|
async with async_timeout.timeout(30.0):
|
|
|
|
await coro
|
2019-04-07 19:03:26 +02:00
|
|
|
except OSError as err:
|
2021-10-04 12:12:43 +02:00
|
|
|
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
|
|
|
|
except asyncio.TimeoutError as err:
|
|
|
|
raise SocketAPIError(f"Timeout while connecting to {sockaddr}") from err
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2022-02-09 16:29:50 +01:00
|
|
|
_LOGGER.debug("%s: Opened socket", self._params.address)
|
2021-10-04 12:12:43 +02:00
|
|
|
|
|
|
|
async def _connect_init_frame_helper(self) -> None:
|
|
|
|
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
2022-09-29 23:59:40 +02:00
|
|
|
reader, writer = await asyncio.open_connection(
|
2022-10-03 00:32:26 +02:00
|
|
|
sock=self._socket, limit=BUFFER_SIZE
|
2022-09-29 23:59:40 +02:00
|
|
|
) # Set buffer limit to 1MB
|
2021-10-04 12:12:43 +02:00
|
|
|
|
2021-10-13 10:05:08 +02:00
|
|
|
if self._params.noise_psk is None:
|
|
|
|
self._frame_helper = APIPlaintextFrameHelper(reader, writer)
|
|
|
|
else:
|
|
|
|
fh = self._frame_helper = APINoiseFrameHelper(
|
|
|
|
reader, writer, self._params.noise_psk
|
|
|
|
)
|
2022-01-20 12:03:36 +01:00
|
|
|
await fh.perform_handshake(self._params.expected_name)
|
2021-09-08 23:12:07 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
self._connection_state = ConnectionState.SOCKET_OPENED
|
2021-09-08 23:12:07 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
# Create read loop
|
|
|
|
asyncio.create_task(self._read_loop())
|
2022-09-29 23:59:40 +02:00
|
|
|
# Create process loop
|
2022-10-06 22:37:47 +02:00
|
|
|
self._process_task = asyncio.create_task(self._process_loop())
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
async def _connect_hello(self) -> None:
|
|
|
|
"""Step 4 in connect process: send hello and get api version."""
|
2021-06-18 17:57:02 +02:00
|
|
|
hello = HelloRequest()
|
2019-04-07 19:03:26 +02:00
|
|
|
hello.client_info = self._params.client_info
|
2022-09-29 22:12:49 +02:00
|
|
|
hello.api_version_major = 1
|
|
|
|
hello.api_version_minor = 7
|
2019-04-07 19:03:26 +02:00
|
|
|
try:
|
2021-06-18 17:57:02 +02:00
|
|
|
resp = await self.send_message_await_response(hello, HelloResponse)
|
2021-10-04 12:12:43 +02:00
|
|
|
except TimeoutAPIError as err:
|
|
|
|
raise TimeoutAPIError("Hello timed out") from err
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.debug(
|
|
|
|
"%s: Successfully connected ('%s' API=%s.%s)",
|
2021-06-30 17:10:30 +02:00
|
|
|
self.log_name,
|
2021-06-18 17:57:02 +02:00
|
|
|
resp.server_info,
|
|
|
|
resp.api_version_major,
|
|
|
|
resp.api_version_minor,
|
|
|
|
)
|
|
|
|
self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
|
2019-04-07 19:03:26 +02:00
|
|
|
if self._api_version.major > 2:
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.error(
|
|
|
|
"%s: Incompatible version %s! Closing connection",
|
2021-06-30 17:10:30 +02:00
|
|
|
self.log_name,
|
2021-06-18 17:57:02 +02:00
|
|
|
self._api_version.major,
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
raise APIConnectionError("Incompatible API version.")
|
|
|
|
|
2022-01-20 12:03:36 +01:00
|
|
|
if (
|
|
|
|
self._params.expected_name is not None
|
|
|
|
and resp.name != ""
|
|
|
|
and resp.name != self._params.expected_name
|
|
|
|
):
|
|
|
|
raise BadNameAPIError(
|
|
|
|
f"Server sent a different name '{resp.name}'", resp.name
|
|
|
|
)
|
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
self._connection_state = ConnectionState.CONNECTED
|
|
|
|
|
|
|
|
async def _connect_start_ping(self) -> None:
|
|
|
|
"""Step 5 in connect process: start the ping loop."""
|
|
|
|
|
|
|
|
async def func() -> None:
|
|
|
|
while True:
|
|
|
|
if not self._is_socket_open:
|
|
|
|
return
|
|
|
|
|
|
|
|
# Wait for keepalive seconds, or ping stop event, whichever happens first
|
|
|
|
try:
|
2022-08-22 05:26:53 +02:00
|
|
|
async with async_timeout.timeout(self._params.keepalive):
|
|
|
|
await self._ping_stop_event.wait()
|
2021-10-04 12:12:43 +02:00
|
|
|
except asyncio.TimeoutError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# Re-check connection state
|
|
|
|
if not self._is_socket_open:
|
2022-01-08 13:57:56 +01:00
|
|
|
return # type: ignore[unreachable]
|
2021-10-04 12:12:43 +02:00
|
|
|
|
|
|
|
try:
|
|
|
|
await self._ping()
|
|
|
|
except TimeoutAPIError:
|
|
|
|
_LOGGER.info("%s: Ping timed out!", self.log_name)
|
|
|
|
await self._report_fatal_error(PingFailedAPIError())
|
|
|
|
return
|
|
|
|
except APIConnectionError as err:
|
|
|
|
_LOGGER.info("%s: Ping Failed: %s", self.log_name, err)
|
|
|
|
await self._report_fatal_error(err)
|
|
|
|
return
|
|
|
|
except Exception as err: # pylint: disable=broad-except
|
|
|
|
_LOGGER.info(
|
|
|
|
"%s: Unexpected error during ping:",
|
|
|
|
self.log_name,
|
|
|
|
exc_info=True,
|
|
|
|
)
|
|
|
|
await self._report_fatal_error(err)
|
|
|
|
return
|
|
|
|
|
|
|
|
asyncio.create_task(func())
|
|
|
|
|
2021-10-21 19:20:05 +02:00
|
|
|
async def connect(self, *, login: bool) -> None:
|
2021-10-04 12:12:43 +02:00
|
|
|
if self._connection_state != ConnectionState.INITIALIZED:
|
|
|
|
raise ValueError(
|
|
|
|
"Connection can only be used once, connection is not in init state"
|
|
|
|
)
|
|
|
|
|
2022-02-09 16:29:50 +01:00
|
|
|
async def _do_connect() -> None:
|
2021-10-04 12:12:43 +02:00
|
|
|
addr = await self._connect_resolve_host()
|
|
|
|
await self._connect_socket_connect(addr)
|
|
|
|
await self._connect_init_frame_helper()
|
|
|
|
await self._connect_hello()
|
|
|
|
await self._connect_start_ping()
|
2021-10-21 19:20:05 +02:00
|
|
|
if login:
|
|
|
|
await self.login()
|
2022-02-09 16:29:50 +01:00
|
|
|
|
2022-10-24 14:11:16 +02:00
|
|
|
# A connection lock must be created to avoid potential issues where
|
|
|
|
# connect has succeeded but not yet returned, followed by a disconnect.
|
|
|
|
# See esphome/aioesphomeapi#258 for more information
|
|
|
|
async with self._connect_lock:
|
|
|
|
try:
|
|
|
|
# Allow 2 minutes for connect; this is only as a last measure
|
|
|
|
# to protect from issues if some part of the connect process mistakenly
|
|
|
|
# does not have a timeout
|
|
|
|
async with async_timeout.timeout(120.0):
|
|
|
|
await _do_connect()
|
|
|
|
except Exception: # pylint: disable=broad-except
|
|
|
|
# Always clean up the connection if an error occured during connect
|
|
|
|
self._connection_state = ConnectionState.CLOSED
|
|
|
|
await self._cleanup()
|
|
|
|
raise
|
|
|
|
|
|
|
|
self._connect_complete = True
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
async def login(self) -> None:
|
2021-10-04 12:12:43 +02:00
|
|
|
"""Send a login (ConnectRequest) and await the response."""
|
2019-04-07 19:03:26 +02:00
|
|
|
self._check_connected()
|
2021-10-04 12:12:43 +02:00
|
|
|
if self._is_authenticated:
|
2019-04-07 19:03:26 +02:00
|
|
|
raise APIConnectionError("Already logged in!")
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
connect = ConnectRequest()
|
2019-04-07 19:03:26 +02:00
|
|
|
if self._params.password is not None:
|
|
|
|
connect.password = self._params.password
|
2021-10-04 12:12:43 +02:00
|
|
|
try:
|
|
|
|
resp = await self.send_message_await_response(connect, ConnectResponse)
|
|
|
|
except TimeoutAPIError as err:
|
|
|
|
# After a timeout for connect the connection can no longer be used
|
|
|
|
# We don't know what state the device may be in after ConnectRequest
|
|
|
|
# was already sent
|
|
|
|
await self._report_fatal_error(err)
|
|
|
|
raise
|
|
|
|
|
2019-04-07 19:03:26 +02:00
|
|
|
if resp.invalid_password:
|
2021-09-14 12:44:52 +02:00
|
|
|
raise InvalidAuthAPIError("Invalid password!")
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
self._is_authenticated = True
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
def _check_connected(self) -> None:
|
2021-10-04 12:12:43 +02:00
|
|
|
if self._connection_state != ConnectionState.CONNECTED:
|
2019-04-07 19:03:26 +02:00
|
|
|
raise APIConnectionError("Must be connected!")
|
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
@property
|
|
|
|
def _is_socket_open(self) -> bool:
|
|
|
|
return self._connection_state in (
|
|
|
|
ConnectionState.SOCKET_OPENED,
|
|
|
|
ConnectionState.CONNECTED,
|
|
|
|
)
|
|
|
|
|
2019-04-07 19:03:26 +02:00
|
|
|
@property
|
|
|
|
def is_connected(self) -> bool:
|
2021-10-04 12:12:43 +02:00
|
|
|
return self._connection_state == ConnectionState.CONNECTED
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
@property
|
|
|
|
def is_authenticated(self) -> bool:
|
2021-10-04 12:12:43 +02:00
|
|
|
return self.is_connected and self._is_authenticated
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-09-08 23:12:07 +02:00
|
|
|
async def send_message(self, msg: message.Message) -> None:
|
2021-10-04 12:12:43 +02:00
|
|
|
"""Send a protobuf message to the remote."""
|
|
|
|
if not self._is_socket_open:
|
|
|
|
raise APIConnectionError("Connection isn't established yet")
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
|
|
|
if isinstance(msg, klass):
|
|
|
|
break
|
|
|
|
else:
|
2021-10-04 12:12:43 +02:00
|
|
|
raise ValueError(f"Message type id not found for type {type(msg)}")
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
encoded = msg.SerializeToString()
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
2021-10-04 12:12:43 +02:00
|
|
|
|
|
|
|
try:
|
|
|
|
assert self._frame_helper is not None
|
|
|
|
# pylint: disable=undefined-loop-variable
|
|
|
|
await self._frame_helper.write_packet(
|
|
|
|
Packet(
|
|
|
|
type=message_type,
|
|
|
|
data=encoded,
|
|
|
|
)
|
2021-09-08 23:12:07 +02:00
|
|
|
)
|
2022-01-04 20:30:22 +01:00
|
|
|
except SocketAPIError as err: # pylint: disable=broad-except
|
2021-10-04 12:12:43 +02:00
|
|
|
# If writing packet fails, we don't know what state the frames
|
|
|
|
# are in anymore and we have to close the connection
|
|
|
|
await self._report_fatal_error(err)
|
2021-10-04 12:30:03 +02:00
|
|
|
raise
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2022-10-30 23:38:24 +01:00
|
|
|
def add_message_callback(
|
|
|
|
self, on_message: Callable[[Any], None]
|
|
|
|
) -> Callable[[], None]:
|
|
|
|
"""Add a message callback."""
|
|
|
|
self._message_handlers.append(on_message)
|
|
|
|
|
|
|
|
def unsub() -> None:
|
|
|
|
self._message_handlers.remove(on_message)
|
|
|
|
|
|
|
|
return unsub
|
|
|
|
|
2022-09-28 18:50:37 +02:00
|
|
|
def remove_message_callback(self, on_message: Callable[[Any], None]) -> None:
|
|
|
|
"""Remove a message callback."""
|
|
|
|
self._message_handlers.remove(on_message)
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
async def send_message_callback_response(
|
|
|
|
self, send_msg: message.Message, on_message: Callable[[Any], None]
|
|
|
|
) -> None:
|
2021-10-04 12:12:43 +02:00
|
|
|
"""Send a message to the remote and register the given message handler."""
|
2019-04-07 19:03:26 +02:00
|
|
|
self._message_handlers.append(on_message)
|
|
|
|
await self.send_message(send_msg)
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
async def send_message_await_response_complex(
|
|
|
|
self,
|
|
|
|
send_msg: message.Message,
|
2022-10-30 23:38:24 +01:00
|
|
|
do_append: Callable[[message.Message], bool],
|
|
|
|
do_stop: Callable[[message.Message], bool],
|
2021-09-08 23:12:07 +02:00
|
|
|
timeout: float = 10.0,
|
2022-10-30 23:38:24 +01:00
|
|
|
) -> List[message.Message]:
|
2021-10-04 12:12:43 +02:00
|
|
|
"""Send a message to the remote and build up a list response.
|
|
|
|
|
|
|
|
:param send_msg: The message (request) to send.
|
|
|
|
:param do_append: Predicate to check if a received message is part of the response.
|
|
|
|
:param do_stop: Predicate to check if a received message is the stop response.
|
|
|
|
:param timeout: The maximum amount of time to wait for the stop response.
|
|
|
|
|
|
|
|
:raises TimeoutAPIError: if a timeout occured
|
|
|
|
"""
|
2021-10-13 10:15:30 +02:00
|
|
|
fut = asyncio.get_event_loop().create_future()
|
2019-04-07 19:03:26 +02:00
|
|
|
responses = []
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
def on_message(resp: message.Message) -> None:
|
2019-04-07 19:03:26 +02:00
|
|
|
if fut.done():
|
|
|
|
return
|
|
|
|
if do_append(resp):
|
|
|
|
responses.append(resp)
|
|
|
|
if do_stop(resp):
|
|
|
|
fut.set_result(responses)
|
|
|
|
|
2021-09-14 12:44:52 +02:00
|
|
|
def on_read_exception(exc: Exception) -> None:
|
|
|
|
if not fut.done():
|
2021-10-05 10:56:35 +02:00
|
|
|
new_exc = exc
|
|
|
|
if not isinstance(exc, APIConnectionError):
|
|
|
|
new_exc = ReadFailedAPIError("Read failed")
|
|
|
|
new_exc.__cause__ = exc
|
2021-10-04 12:12:43 +02:00
|
|
|
fut.set_exception(new_exc)
|
2021-09-14 12:44:52 +02:00
|
|
|
|
2019-04-07 19:03:26 +02:00
|
|
|
self._message_handlers.append(on_message)
|
2021-09-14 12:44:52 +02:00
|
|
|
self._read_exception_handlers.append(on_read_exception)
|
2019-04-07 19:03:26 +02:00
|
|
|
await self.send_message(send_msg)
|
|
|
|
|
|
|
|
try:
|
2022-08-22 05:26:53 +02:00
|
|
|
async with async_timeout.timeout(timeout):
|
|
|
|
await fut
|
2021-10-04 12:12:43 +02:00
|
|
|
except asyncio.TimeoutError as err:
|
|
|
|
raise TimeoutAPIError(
|
2022-11-06 20:32:32 +01:00
|
|
|
f"Timeout waiting for response for {type(send_msg)} after {timeout}s"
|
2021-10-04 12:12:43 +02:00
|
|
|
) from err
|
2021-09-14 12:44:52 +02:00
|
|
|
finally:
|
|
|
|
with suppress(ValueError):
|
|
|
|
self._message_handlers.remove(on_message)
|
|
|
|
with suppress(ValueError):
|
|
|
|
self._read_exception_handlers.remove(on_read_exception)
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
return responses
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
async def send_message_await_response(
|
2021-09-08 23:12:07 +02:00
|
|
|
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
|
2021-06-18 17:57:02 +02:00
|
|
|
) -> Any:
|
|
|
|
def is_response(msg: message.Message) -> bool:
|
2019-04-07 19:03:26 +02:00
|
|
|
return isinstance(msg, response_type)
|
|
|
|
|
|
|
|
res = await self.send_message_await_response_complex(
|
2021-06-18 17:57:02 +02:00
|
|
|
send_msg, is_response, is_response, timeout=timeout
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
if len(res) != 1:
|
2021-09-19 19:08:18 +02:00
|
|
|
raise APIConnectionError(f"Expected one result, got {len(res)}")
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
return res[0]
|
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
async def _report_fatal_error(self, err: Exception) -> None:
|
|
|
|
"""Report a fatal error that occured during an operation.
|
|
|
|
|
|
|
|
This should only be called for errors that mean the connection
|
|
|
|
can no longer be used.
|
|
|
|
|
|
|
|
The connection will be closed, all exception handlers notified.
|
|
|
|
This method does not log the error, the call site should do so.
|
|
|
|
"""
|
|
|
|
self._connection_state = ConnectionState.CLOSED
|
|
|
|
for handler in self._read_exception_handlers[:]:
|
|
|
|
handler(err)
|
|
|
|
await self._cleanup()
|
|
|
|
|
|
|
|
async def _read_once(self) -> None:
|
2021-09-08 23:12:07 +02:00
|
|
|
assert self._frame_helper is not None
|
|
|
|
pkt = await self._frame_helper.read_packet()
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-09-08 23:12:07 +02:00
|
|
|
msg_type = pkt.type
|
|
|
|
raw_msg = pkt.data
|
2019-04-07 19:03:26 +02:00
|
|
|
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
2021-10-04 12:12:43 +02:00
|
|
|
_LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type)
|
2019-04-07 19:03:26 +02:00
|
|
|
return
|
|
|
|
|
|
|
|
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
|
|
|
try:
|
|
|
|
msg.ParseFromString(raw_msg)
|
|
|
|
except Exception as e:
|
2021-09-14 12:44:52 +02:00
|
|
|
raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e
|
2021-10-04 12:12:43 +02:00
|
|
|
_LOGGER.debug("%s: Got message of type %s: %s", self.log_name, type(msg), msg)
|
2022-09-29 23:59:40 +02:00
|
|
|
self._to_process.put_nowait(msg)
|
|
|
|
|
|
|
|
async def _process_loop(self) -> None:
|
|
|
|
while True:
|
|
|
|
if not self._is_socket_open:
|
|
|
|
# Socket closed but task isn't cancelled yet
|
|
|
|
break
|
|
|
|
|
2022-11-15 08:26:11 +01:00
|
|
|
try:
|
|
|
|
msg = await self._to_process.get()
|
|
|
|
except RuntimeError:
|
|
|
|
break
|
2022-09-29 23:59:40 +02:00
|
|
|
|
|
|
|
for handler in self._message_handlers[:]:
|
|
|
|
handler(msg)
|
|
|
|
await self._handle_internal_messages(msg)
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
async def _read_loop(self) -> None:
|
2019-04-07 19:03:26 +02:00
|
|
|
while True:
|
2021-10-04 12:12:43 +02:00
|
|
|
if not self._is_socket_open:
|
|
|
|
# Socket closed but task isn't cancelled yet
|
2021-10-01 11:25:20 +02:00
|
|
|
break
|
2019-04-07 19:03:26 +02:00
|
|
|
try:
|
2021-10-04 12:12:43 +02:00
|
|
|
await self._read_once()
|
|
|
|
except SocketClosedAPIError as err:
|
|
|
|
# don't log with info, if closed the site that closed the connection should log
|
|
|
|
_LOGGER.debug(
|
|
|
|
"%s: Socket closed, stopping read loop",
|
|
|
|
self.log_name,
|
|
|
|
)
|
|
|
|
await self._report_fatal_error(err)
|
|
|
|
break
|
2019-04-07 19:03:26 +02:00
|
|
|
except APIConnectionError as err:
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.info(
|
|
|
|
"%s: Error while reading incoming messages: %s",
|
2021-06-30 17:10:30 +02:00
|
|
|
self.log_name,
|
2021-06-18 17:57:02 +02:00
|
|
|
err,
|
|
|
|
)
|
2021-10-04 12:12:43 +02:00
|
|
|
await self._report_fatal_error(err)
|
2019-04-07 19:03:26 +02:00
|
|
|
break
|
2020-07-14 20:00:12 +02:00
|
|
|
except Exception as err: # pylint: disable=broad-except
|
2021-09-14 12:44:52 +02:00
|
|
|
_LOGGER.warning(
|
2021-06-18 17:57:02 +02:00
|
|
|
"%s: Unexpected error while reading incoming messages: %s",
|
2021-06-30 17:10:30 +02:00
|
|
|
self.log_name,
|
2021-06-18 17:57:02 +02:00
|
|
|
err,
|
2021-09-08 23:12:07 +02:00
|
|
|
exc_info=True,
|
2021-06-18 17:57:02 +02:00
|
|
|
)
|
2021-10-04 12:12:43 +02:00
|
|
|
await self._report_fatal_error(err)
|
2019-04-07 19:03:26 +02:00
|
|
|
break
|
|
|
|
|
|
|
|
async def _handle_internal_messages(self, msg: Any) -> None:
|
2021-06-18 17:57:02 +02:00
|
|
|
if isinstance(msg, DisconnectRequest):
|
|
|
|
await self.send_message(DisconnectResponse())
|
2021-10-04 12:12:43 +02:00
|
|
|
self._connection_state = ConnectionState.CLOSED
|
|
|
|
await self._cleanup()
|
2021-06-18 17:57:02 +02:00
|
|
|
elif isinstance(msg, PingRequest):
|
|
|
|
await self.send_message(PingResponse())
|
|
|
|
elif isinstance(msg, GetTimeRequest):
|
|
|
|
resp = GetTimeResponse()
|
2019-04-07 19:03:26 +02:00
|
|
|
resp.epoch_seconds = int(time.time())
|
|
|
|
await self.send_message(resp)
|
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
async def _ping(self) -> None:
|
2019-04-07 19:03:26 +02:00
|
|
|
self._check_connected()
|
2021-06-18 17:57:02 +02:00
|
|
|
await self.send_message_await_response(PingRequest(), PingResponse)
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
async def disconnect(self) -> None:
|
|
|
|
if self._connection_state != ConnectionState.CONNECTED:
|
|
|
|
# already disconnected
|
|
|
|
return
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
try:
|
2021-06-18 17:57:02 +02:00
|
|
|
await self.send_message_await_response(
|
|
|
|
DisconnectRequest(), DisconnectResponse
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
except APIConnectionError:
|
|
|
|
pass
|
|
|
|
|
2021-10-04 12:12:43 +02:00
|
|
|
self._connection_state = ConnectionState.CLOSED
|
|
|
|
await self._cleanup()
|
|
|
|
|
|
|
|
async def force_disconnect(self) -> None:
|
|
|
|
self._connection_state = ConnectionState.CLOSED
|
|
|
|
await self._cleanup()
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
@property
|
|
|
|
def api_version(self) -> Optional[APIVersion]:
|
|
|
|
return self._api_version
|