aioesphomeapi/aioesphomeapi/connection.py

392 lines
13 KiB
Python

import asyncio
import logging
import socket
import time
from typing import Any, Awaitable, Callable, List, Optional, cast
import attr
import zeroconf
from google.protobuf import message
from aioesphomeapi.api_pb2 import ( # type: ignore
ConnectRequest,
ConnectResponse,
DisconnectRequest,
DisconnectResponse,
GetTimeRequest,
GetTimeResponse,
HelloRequest,
HelloResponse,
PingRequest,
PingResponse,
)
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from aioesphomeapi.model import APIVersion
from aioesphomeapi.util import _bytes_to_varuint, _varuint_to_bytes, resolve_ip_address
_LOGGER = logging.getLogger(__name__)
@attr.s
class ConnectionParams:
eventloop = attr.ib(type=asyncio.events.AbstractEventLoop)
address = attr.ib(type=str)
port = attr.ib(type=int)
password = attr.ib(type=Optional[str])
client_info = attr.ib(type=str)
keepalive = attr.ib(type=float)
zeroconf_instance = attr.ib(type=Optional[zeroconf.Zeroconf])
class APIConnection:
def __init__(
self, params: ConnectionParams, on_stop: Callable[[], Awaitable[None]]
):
self._params = params
self.on_stop = on_stop
self._stopped = False
self._socket: Optional[socket.socket] = None
self._socket_reader: Optional[asyncio.StreamReader] = None
self._socket_writer: Optional[asyncio.StreamWriter] = None
self._write_lock = asyncio.Lock()
self._connected = False
self._authenticated = False
self._socket_connected = False
self._state_lock = asyncio.Lock()
self._api_version: Optional[APIVersion] = None
self._message_handlers: List[Callable[[message.Message], None]] = []
def _start_ping(self) -> None:
async def func() -> None:
while self._connected:
await asyncio.sleep(self._params.keepalive)
if not self._connected:
return
try:
await self.ping()
except APIConnectionError:
_LOGGER.info("%s: Ping Failed!", self._params.address)
await self._on_error()
return
self._params.eventloop.create_task(func())
async def _close_socket(self) -> None:
if not self._socket_connected:
return
async with self._write_lock:
if self._socket_writer is not None:
self._socket_writer.close()
self._socket_writer = None
self._socket_reader = None
if self._socket is not None:
self._socket.close()
self._socket_connected = False
self._connected = False
self._authenticated = False
_LOGGER.debug("%s: Closed socket", self._params.address)
async def stop(self, force: bool = False) -> None:
if self._stopped:
return
if self._connected and not force:
try:
await self._disconnect()
except APIConnectionError:
pass
self._stopped = True
await self._close_socket()
await self.on_stop()
async def _on_error(self) -> None:
await self.stop(force=True)
async def connect(self) -> None:
if self._stopped:
raise APIConnectionError("Connection is closed!")
if self._connected:
raise APIConnectionError("Already connected!")
try:
coro = resolve_ip_address(
self._params.eventloop,
self._params.address,
self._params.port,
self._params.zeroconf_instance,
)
sockaddr = await asyncio.wait_for(coro, 30.0)
except APIConnectionError as err:
await self._on_error()
raise err
except asyncio.TimeoutError:
await self._on_error()
raise APIConnectionError("Timeout while resolving IP address")
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.setblocking(False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
self._params.address,
self._params.address,
self._params.port,
sockaddr,
)
try:
coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
await asyncio.wait_for(coro2, 30.0)
except OSError as err:
await self._on_error()
raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err))
except asyncio.TimeoutError:
await self._on_error()
raise APIConnectionError("Timeout while connecting to {}".format(sockaddr))
_LOGGER.debug("%s: Opened socket for", self._params.address)
self._socket_reader, self._socket_writer = await asyncio.open_connection(
sock=self._socket
)
self._socket_connected = True
self._params.eventloop.create_task(self.run_forever())
hello = HelloRequest()
hello.client_info = self._params.client_info
try:
resp = await self.send_message_await_response(hello, HelloResponse)
except APIConnectionError as err:
await self._on_error()
raise err
_LOGGER.debug(
"%s: Successfully connected ('%s' API=%s.%s)",
self._params.address,
resp.server_info,
resp.api_version_major,
resp.api_version_minor,
)
self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
if self._api_version.major > 2:
_LOGGER.error(
"%s: Incompatible version %s! Closing connection",
self._params.address,
self._api_version.major,
)
await self._on_error()
raise APIConnectionError("Incompatible API version.")
self._connected = True
self._start_ping()
async def login(self) -> None:
self._check_connected()
if self._authenticated:
raise APIConnectionError("Already logged in!")
connect = ConnectRequest()
if self._params.password is not None:
connect.password = self._params.password
resp = await self.send_message_await_response(connect, ConnectResponse)
if resp.invalid_password:
raise APIConnectionError("Invalid password!")
self._authenticated = True
def _check_connected(self) -> None:
if not self._connected:
raise APIConnectionError("Must be connected!")
@property
def is_connected(self) -> bool:
return self._connected
@property
def is_authenticated(self) -> bool:
return self._authenticated
async def _write(self, data: bytes) -> None:
# _LOGGER.debug("%s: Write: %s", self._params.address,
# ' '.join('{:02X}'.format(x) for x in data))
if not self._socket_connected:
raise APIConnectionError("Socket is not connected")
try:
async with self._write_lock:
if self._socket_writer is not None:
self._socket_writer.write(data)
await self._socket_writer.drain()
except OSError as err:
await self._on_error()
raise APIConnectionError("Error while writing data: {}".format(err))
async def send_message(self, msg: message.Message) -> None:
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
if isinstance(msg, klass):
break
else:
raise ValueError
encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
req = bytes([0])
req += _varuint_to_bytes(len(encoded))
# pylint: disable=undefined-loop-variable
req += _varuint_to_bytes(message_type)
req += encoded
await self._write(req)
async def send_message_callback_response(
self, send_msg: message.Message, on_message: Callable[[Any], None]
) -> None:
self._message_handlers.append(on_message)
await self.send_message(send_msg)
async def send_message_await_response_complex(
self,
send_msg: message.Message,
do_append: Callable[[Any], bool],
do_stop: Callable[[Any], bool],
timeout: float = 5.0,
) -> List[Any]:
fut = self._params.eventloop.create_future()
responses = []
def on_message(resp: message.Message) -> None:
if fut.done():
return
if do_append(resp):
responses.append(resp)
if do_stop(resp):
fut.set_result(responses)
self._message_handlers.append(on_message)
await self.send_message(send_msg)
try:
await asyncio.wait_for(fut, timeout)
except asyncio.TimeoutError:
if self._stopped:
raise APIConnectionError("Disconnected while waiting for API response!")
raise APIConnectionError("Timeout while waiting for API response!")
try:
self._message_handlers.remove(on_message)
except ValueError:
pass
return responses
async def send_message_await_response(
self, send_msg: message.Message, response_type: Any, timeout: float = 5.0
) -> Any:
def is_response(msg: message.Message) -> bool:
return isinstance(msg, response_type)
res = await self.send_message_await_response_complex(
send_msg, is_response, is_response, timeout=timeout
)
if len(res) != 1:
raise APIConnectionError("Expected one result, got {}".format(len(res)))
return res[0]
async def _recv(self, amount: int) -> bytes:
if amount == 0:
return bytes()
try:
assert self._socket_reader is not None
ret = await self._socket_reader.readexactly(amount)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
raise APIConnectionError("Error while receiving data: {}".format(err))
return ret
async def _recv_varint(self) -> int:
raw = bytes()
while not raw or raw[-1] & 0x80:
raw += await self._recv(1)
return cast(int, _bytes_to_varuint(raw))
async def _run_once(self) -> None:
preamble = await self._recv(1)
if preamble[0] != 0x00:
raise APIConnectionError("Invalid preamble")
length = await self._recv_varint()
msg_type = await self._recv_varint()
raw_msg = await self._recv(length)
if msg_type not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug(
"%s: Skipping message type %s", self._params.address, msg_type
)
return
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
try:
msg.ParseFromString(raw_msg)
except Exception as e:
raise APIConnectionError("Invalid protobuf message: {}".format(e))
_LOGGER.debug(
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
)
for msg_handler in self._message_handlers[:]:
msg_handler(msg)
await self._handle_internal_messages(msg)
async def run_forever(self) -> None:
while True:
try:
await self._run_once()
except APIConnectionError as err:
_LOGGER.info(
"%s: Error while reading incoming messages: %s",
self._params.address,
err,
)
await self._on_error()
break
except Exception as err: # pylint: disable=broad-except
_LOGGER.info(
"%s: Unexpected error while reading incoming messages: %s",
self._params.address,
err,
)
await self._on_error()
break
async def _handle_internal_messages(self, msg: Any) -> None:
if isinstance(msg, DisconnectRequest):
await self.send_message(DisconnectResponse())
await self.stop(force=True)
elif isinstance(msg, PingRequest):
await self.send_message(PingResponse())
elif isinstance(msg, GetTimeRequest):
resp = GetTimeResponse()
resp.epoch_seconds = int(time.time())
await self.send_message(resp)
async def ping(self) -> None:
self._check_connected()
await self.send_message_await_response(PingRequest(), PingResponse)
async def _disconnect(self) -> None:
self._check_connected()
try:
await self.send_message_await_response(
DisconnectRequest(), DisconnectResponse
)
except APIConnectionError:
pass
def _check_authenticated(self) -> None:
if not self._authenticated:
raise APIConnectionError("Must login first!")
@property
def api_version(self) -> Optional[APIVersion]:
return self._api_version