2019-04-07 19:03:26 +02:00
|
|
|
import asyncio
|
|
|
|
import logging
|
|
|
|
import socket
|
|
|
|
import time
|
2021-06-29 15:36:14 +02:00
|
|
|
from dataclasses import dataclass
|
2021-06-18 17:57:02 +02:00
|
|
|
from typing import Any, Awaitable, Callable, List, Optional, cast
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2020-08-22 05:45:29 +02:00
|
|
|
import zeroconf
|
2019-04-07 19:03:26 +02:00
|
|
|
from google.protobuf import message
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
from aioesphomeapi.api_pb2 import ( # type: ignore
|
|
|
|
ConnectRequest,
|
|
|
|
ConnectResponse,
|
|
|
|
DisconnectRequest,
|
|
|
|
DisconnectResponse,
|
|
|
|
GetTimeRequest,
|
|
|
|
GetTimeResponse,
|
|
|
|
HelloRequest,
|
|
|
|
HelloResponse,
|
|
|
|
PingRequest,
|
|
|
|
PingResponse,
|
|
|
|
)
|
|
|
|
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
2019-04-07 19:03:26 +02:00
|
|
|
from aioesphomeapi.model import APIVersion
|
|
|
|
from aioesphomeapi.util import _bytes_to_varuint, _varuint_to_bytes, resolve_ip_address
|
|
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2021-06-29 15:36:14 +02:00
|
|
|
@dataclass
|
2019-04-07 19:03:26 +02:00
|
|
|
class ConnectionParams:
|
2021-06-29 15:36:14 +02:00
|
|
|
eventloop: asyncio.events.AbstractEventLoop
|
|
|
|
address: str
|
|
|
|
port: int
|
|
|
|
password: Optional[str]
|
|
|
|
client_info: str
|
|
|
|
keepalive: float
|
|
|
|
zeroconf_instance: Optional[zeroconf.Zeroconf]
|
2021-06-18 17:57:02 +02:00
|
|
|
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
class APIConnection:
|
2021-06-18 17:57:02 +02:00
|
|
|
def __init__(
|
|
|
|
self, params: ConnectionParams, on_stop: Callable[[], Awaitable[None]]
|
|
|
|
):
|
2019-04-07 19:03:26 +02:00
|
|
|
self._params = params
|
|
|
|
self.on_stop = on_stop
|
|
|
|
self._stopped = False
|
2021-06-18 17:57:02 +02:00
|
|
|
self._socket: Optional[socket.socket] = None
|
|
|
|
self._socket_reader: Optional[asyncio.StreamReader] = None
|
|
|
|
self._socket_writer: Optional[asyncio.StreamWriter] = None
|
2019-04-07 19:03:26 +02:00
|
|
|
self._write_lock = asyncio.Lock()
|
|
|
|
self._connected = False
|
|
|
|
self._authenticated = False
|
|
|
|
self._socket_connected = False
|
|
|
|
self._state_lock = asyncio.Lock()
|
2021-06-18 17:57:02 +02:00
|
|
|
self._api_version: Optional[APIVersion] = None
|
2019-04-07 19:03:26 +02:00
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
self._message_handlers: List[Callable[[message.Message], None]] = []
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
def _start_ping(self) -> None:
|
|
|
|
async def func() -> None:
|
|
|
|
while self._connected:
|
|
|
|
await asyncio.sleep(self._params.keepalive)
|
|
|
|
|
|
|
|
if not self._connected:
|
|
|
|
return
|
|
|
|
|
|
|
|
try:
|
|
|
|
await self.ping()
|
|
|
|
except APIConnectionError:
|
|
|
|
_LOGGER.info("%s: Ping Failed!", self._params.address)
|
|
|
|
await self._on_error()
|
|
|
|
return
|
|
|
|
|
|
|
|
self._params.eventloop.create_task(func())
|
|
|
|
|
|
|
|
async def _close_socket(self) -> None:
|
|
|
|
if not self._socket_connected:
|
|
|
|
return
|
|
|
|
async with self._write_lock:
|
2021-06-18 17:57:02 +02:00
|
|
|
if self._socket_writer is not None:
|
|
|
|
self._socket_writer.close()
|
2019-04-07 19:03:26 +02:00
|
|
|
self._socket_writer = None
|
|
|
|
self._socket_reader = None
|
|
|
|
if self._socket is not None:
|
|
|
|
self._socket.close()
|
|
|
|
self._socket_connected = False
|
|
|
|
self._connected = False
|
|
|
|
self._authenticated = False
|
|
|
|
_LOGGER.debug("%s: Closed socket", self._params.address)
|
|
|
|
|
|
|
|
async def stop(self, force: bool = False) -> None:
|
|
|
|
if self._stopped:
|
|
|
|
return
|
|
|
|
if self._connected and not force:
|
|
|
|
try:
|
|
|
|
await self._disconnect()
|
|
|
|
except APIConnectionError:
|
|
|
|
pass
|
|
|
|
self._stopped = True
|
|
|
|
await self._close_socket()
|
|
|
|
await self.on_stop()
|
|
|
|
|
|
|
|
async def _on_error(self) -> None:
|
|
|
|
await self.stop(force=True)
|
|
|
|
|
|
|
|
async def connect(self) -> None:
|
|
|
|
if self._stopped:
|
|
|
|
raise APIConnectionError("Connection is closed!")
|
|
|
|
if self._connected:
|
|
|
|
raise APIConnectionError("Already connected!")
|
|
|
|
|
|
|
|
try:
|
2021-06-18 17:57:02 +02:00
|
|
|
coro = resolve_ip_address(
|
|
|
|
self._params.eventloop,
|
|
|
|
self._params.address,
|
|
|
|
self._params.port,
|
|
|
|
self._params.zeroconf_instance,
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
sockaddr = await asyncio.wait_for(coro, 30.0)
|
|
|
|
except APIConnectionError as err:
|
|
|
|
await self._on_error()
|
|
|
|
raise err
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
await self._on_error()
|
|
|
|
raise APIConnectionError("Timeout while resolving IP address")
|
|
|
|
|
|
|
|
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
|
self._socket.setblocking(False)
|
|
|
|
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.debug(
|
|
|
|
"%s: Connecting to %s:%s (%s)",
|
|
|
|
self._params.address,
|
|
|
|
self._params.address,
|
|
|
|
self._params.port,
|
|
|
|
sockaddr,
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
try:
|
2021-06-18 17:57:02 +02:00
|
|
|
coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
|
|
|
|
await asyncio.wait_for(coro2, 30.0)
|
2019-04-07 19:03:26 +02:00
|
|
|
except OSError as err:
|
|
|
|
await self._on_error()
|
2021-06-18 17:57:02 +02:00
|
|
|
raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err))
|
2019-04-07 19:03:26 +02:00
|
|
|
except asyncio.TimeoutError:
|
|
|
|
await self._on_error()
|
2021-06-18 17:57:02 +02:00
|
|
|
raise APIConnectionError("Timeout while connecting to {}".format(sockaddr))
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
_LOGGER.debug("%s: Opened socket for", self._params.address)
|
2021-06-18 17:57:02 +02:00
|
|
|
self._socket_reader, self._socket_writer = await asyncio.open_connection(
|
|
|
|
sock=self._socket
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
self._socket_connected = True
|
|
|
|
self._params.eventloop.create_task(self.run_forever())
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
hello = HelloRequest()
|
2019-04-07 19:03:26 +02:00
|
|
|
hello.client_info = self._params.client_info
|
|
|
|
try:
|
2021-06-18 17:57:02 +02:00
|
|
|
resp = await self.send_message_await_response(hello, HelloResponse)
|
2019-04-07 19:03:26 +02:00
|
|
|
except APIConnectionError as err:
|
|
|
|
await self._on_error()
|
|
|
|
raise err
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.debug(
|
|
|
|
"%s: Successfully connected ('%s' API=%s.%s)",
|
|
|
|
self._params.address,
|
|
|
|
resp.server_info,
|
|
|
|
resp.api_version_major,
|
|
|
|
resp.api_version_minor,
|
|
|
|
)
|
|
|
|
self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
|
2019-04-07 19:03:26 +02:00
|
|
|
if self._api_version.major > 2:
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.error(
|
|
|
|
"%s: Incompatible version %s! Closing connection",
|
|
|
|
self._params.address,
|
|
|
|
self._api_version.major,
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
await self._on_error()
|
|
|
|
raise APIConnectionError("Incompatible API version.")
|
|
|
|
self._connected = True
|
|
|
|
|
|
|
|
self._start_ping()
|
|
|
|
|
|
|
|
async def login(self) -> None:
|
|
|
|
self._check_connected()
|
|
|
|
if self._authenticated:
|
|
|
|
raise APIConnectionError("Already logged in!")
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
connect = ConnectRequest()
|
2019-04-07 19:03:26 +02:00
|
|
|
if self._params.password is not None:
|
|
|
|
connect.password = self._params.password
|
2021-06-18 17:57:02 +02:00
|
|
|
resp = await self.send_message_await_response(connect, ConnectResponse)
|
2019-04-07 19:03:26 +02:00
|
|
|
if resp.invalid_password:
|
|
|
|
raise APIConnectionError("Invalid password!")
|
|
|
|
|
|
|
|
self._authenticated = True
|
|
|
|
|
|
|
|
def _check_connected(self) -> None:
|
|
|
|
if not self._connected:
|
|
|
|
raise APIConnectionError("Must be connected!")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_connected(self) -> bool:
|
|
|
|
return self._connected
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_authenticated(self) -> bool:
|
|
|
|
return self._authenticated
|
|
|
|
|
|
|
|
async def _write(self, data: bytes) -> None:
|
2019-06-17 23:40:23 +02:00
|
|
|
# _LOGGER.debug("%s: Write: %s", self._params.address,
|
|
|
|
# ' '.join('{:02X}'.format(x) for x in data))
|
2019-04-07 19:03:26 +02:00
|
|
|
if not self._socket_connected:
|
|
|
|
raise APIConnectionError("Socket is not connected")
|
|
|
|
try:
|
|
|
|
async with self._write_lock:
|
2021-06-18 17:57:02 +02:00
|
|
|
if self._socket_writer is not None:
|
|
|
|
self._socket_writer.write(data)
|
|
|
|
await self._socket_writer.drain()
|
2019-04-07 19:03:26 +02:00
|
|
|
except OSError as err:
|
|
|
|
await self._on_error()
|
2021-06-18 17:57:02 +02:00
|
|
|
raise APIConnectionError("Error while writing data: {}".format(err))
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
async def send_message(self, msg: message.Message) -> None:
|
|
|
|
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
|
|
|
|
if isinstance(msg, klass):
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise ValueError
|
|
|
|
|
|
|
|
encoded = msg.SerializeToString()
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
2019-04-07 19:03:26 +02:00
|
|
|
req = bytes([0])
|
|
|
|
req += _varuint_to_bytes(len(encoded))
|
2020-07-14 20:00:12 +02:00
|
|
|
# pylint: disable=undefined-loop-variable
|
2019-04-07 19:03:26 +02:00
|
|
|
req += _varuint_to_bytes(message_type)
|
|
|
|
req += encoded
|
|
|
|
await self._write(req)
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
async def send_message_callback_response(
|
|
|
|
self, send_msg: message.Message, on_message: Callable[[Any], None]
|
|
|
|
) -> None:
|
2019-04-07 19:03:26 +02:00
|
|
|
self._message_handlers.append(on_message)
|
|
|
|
await self.send_message(send_msg)
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
async def send_message_await_response_complex(
|
|
|
|
self,
|
|
|
|
send_msg: message.Message,
|
|
|
|
do_append: Callable[[Any], bool],
|
|
|
|
do_stop: Callable[[Any], bool],
|
|
|
|
timeout: float = 5.0,
|
|
|
|
) -> List[Any]:
|
2019-04-07 19:03:26 +02:00
|
|
|
fut = self._params.eventloop.create_future()
|
|
|
|
responses = []
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
def on_message(resp: message.Message) -> None:
|
2019-04-07 19:03:26 +02:00
|
|
|
if fut.done():
|
|
|
|
return
|
|
|
|
if do_append(resp):
|
|
|
|
responses.append(resp)
|
|
|
|
if do_stop(resp):
|
|
|
|
fut.set_result(responses)
|
|
|
|
|
|
|
|
self._message_handlers.append(on_message)
|
|
|
|
await self.send_message(send_msg)
|
|
|
|
|
|
|
|
try:
|
|
|
|
await asyncio.wait_for(fut, timeout)
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
if self._stopped:
|
2021-06-18 17:57:02 +02:00
|
|
|
raise APIConnectionError("Disconnected while waiting for API response!")
|
2019-04-07 19:03:26 +02:00
|
|
|
raise APIConnectionError("Timeout while waiting for API response!")
|
|
|
|
|
|
|
|
try:
|
|
|
|
self._message_handlers.remove(on_message)
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
return responses
|
|
|
|
|
2021-06-18 17:57:02 +02:00
|
|
|
async def send_message_await_response(
|
|
|
|
self, send_msg: message.Message, response_type: Any, timeout: float = 5.0
|
|
|
|
) -> Any:
|
|
|
|
def is_response(msg: message.Message) -> bool:
|
2019-04-07 19:03:26 +02:00
|
|
|
return isinstance(msg, response_type)
|
|
|
|
|
|
|
|
res = await self.send_message_await_response_complex(
|
2021-06-18 17:57:02 +02:00
|
|
|
send_msg, is_response, is_response, timeout=timeout
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
if len(res) != 1:
|
2021-06-18 17:57:02 +02:00
|
|
|
raise APIConnectionError("Expected one result, got {}".format(len(res)))
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
return res[0]
|
|
|
|
|
|
|
|
async def _recv(self, amount: int) -> bytes:
|
|
|
|
if amount == 0:
|
|
|
|
return bytes()
|
|
|
|
|
|
|
|
try:
|
2021-06-18 17:57:02 +02:00
|
|
|
assert self._socket_reader is not None
|
2019-04-07 19:03:26 +02:00
|
|
|
ret = await self._socket_reader.readexactly(amount)
|
|
|
|
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
2021-06-18 17:57:02 +02:00
|
|
|
raise APIConnectionError("Error while receiving data: {}".format(err))
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
async def _recv_varint(self) -> int:
|
|
|
|
raw = bytes()
|
|
|
|
while not raw or raw[-1] & 0x80:
|
|
|
|
raw += await self._recv(1)
|
|
|
|
return cast(int, _bytes_to_varuint(raw))
|
|
|
|
|
|
|
|
async def _run_once(self) -> None:
|
|
|
|
preamble = await self._recv(1)
|
|
|
|
if preamble[0] != 0x00:
|
|
|
|
raise APIConnectionError("Invalid preamble")
|
|
|
|
|
|
|
|
length = await self._recv_varint()
|
|
|
|
msg_type = await self._recv_varint()
|
|
|
|
|
|
|
|
raw_msg = await self._recv(length)
|
|
|
|
if msg_type not in MESSAGE_TYPE_TO_PROTO:
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.debug(
|
|
|
|
"%s: Skipping message type %s", self._params.address, msg_type
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
return
|
|
|
|
|
|
|
|
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
|
|
|
|
try:
|
|
|
|
msg.ParseFromString(raw_msg)
|
|
|
|
except Exception as e:
|
|
|
|
raise APIConnectionError("Invalid protobuf message: {}".format(e))
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.debug(
|
|
|
|
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
for msg_handler in self._message_handlers[:]:
|
|
|
|
msg_handler(msg)
|
|
|
|
await self._handle_internal_messages(msg)
|
|
|
|
|
|
|
|
async def run_forever(self) -> None:
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
await self._run_once()
|
|
|
|
except APIConnectionError as err:
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.info(
|
|
|
|
"%s: Error while reading incoming messages: %s",
|
|
|
|
self._params.address,
|
|
|
|
err,
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
await self._on_error()
|
|
|
|
break
|
2020-07-14 20:00:12 +02:00
|
|
|
except Exception as err: # pylint: disable=broad-except
|
2021-06-18 17:57:02 +02:00
|
|
|
_LOGGER.info(
|
|
|
|
"%s: Unexpected error while reading incoming messages: %s",
|
|
|
|
self._params.address,
|
|
|
|
err,
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
await self._on_error()
|
|
|
|
break
|
|
|
|
|
|
|
|
async def _handle_internal_messages(self, msg: Any) -> None:
|
2021-06-18 17:57:02 +02:00
|
|
|
if isinstance(msg, DisconnectRequest):
|
|
|
|
await self.send_message(DisconnectResponse())
|
2019-04-07 19:03:26 +02:00
|
|
|
await self.stop(force=True)
|
2021-06-18 17:57:02 +02:00
|
|
|
elif isinstance(msg, PingRequest):
|
|
|
|
await self.send_message(PingResponse())
|
|
|
|
elif isinstance(msg, GetTimeRequest):
|
|
|
|
resp = GetTimeResponse()
|
2019-04-07 19:03:26 +02:00
|
|
|
resp.epoch_seconds = int(time.time())
|
|
|
|
await self.send_message(resp)
|
|
|
|
|
|
|
|
async def ping(self) -> None:
|
|
|
|
self._check_connected()
|
2021-06-18 17:57:02 +02:00
|
|
|
await self.send_message_await_response(PingRequest(), PingResponse)
|
2019-04-07 19:03:26 +02:00
|
|
|
|
|
|
|
async def _disconnect(self) -> None:
|
|
|
|
self._check_connected()
|
|
|
|
|
|
|
|
try:
|
2021-06-18 17:57:02 +02:00
|
|
|
await self.send_message_await_response(
|
|
|
|
DisconnectRequest(), DisconnectResponse
|
|
|
|
)
|
2019-04-07 19:03:26 +02:00
|
|
|
except APIConnectionError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def _check_authenticated(self) -> None:
|
|
|
|
if not self._authenticated:
|
|
|
|
raise APIConnectionError("Must login first!")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def api_version(self) -> Optional[APIVersion]:
|
|
|
|
return self._api_version
|