aioesphomeapi/aioesphomeapi/connection.py

556 lines
19 KiB
Python
Raw Normal View History

2019-04-07 19:03:26 +02:00
import asyncio
2021-09-08 23:12:07 +02:00
import base64
2019-04-07 19:03:26 +02:00
import logging
import socket
import time
from dataclasses import astuple, dataclass
2021-09-08 23:12:07 +02:00
from typing import Any, Awaitable, Callable, List, Optional
2019-04-07 19:03:26 +02:00
from google.protobuf import message
2021-09-08 23:12:07 +02:00
from noise.connection import NoiseConnection # type: ignore
2019-04-07 19:03:26 +02:00
2021-07-12 20:09:17 +02:00
import aioesphomeapi.host_resolver as hr
from .api_pb2 import ( # type: ignore
ConnectRequest,
ConnectResponse,
DisconnectRequest,
DisconnectResponse,
GetTimeRequest,
GetTimeResponse,
HelloRequest,
HelloResponse,
PingRequest,
PingResponse,
)
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .model import APIVersion
from .util import bytes_to_varuint, varuint_to_bytes
2019-04-07 19:03:26 +02:00
_LOGGER = logging.getLogger(__name__)
2021-06-29 15:36:14 +02:00
@dataclass
2019-04-07 19:03:26 +02:00
class ConnectionParams:
2021-06-29 15:36:14 +02:00
eventloop: asyncio.events.AbstractEventLoop
address: str
port: int
password: Optional[str]
client_info: str
keepalive: float
2021-07-12 20:09:17 +02:00
zeroconf_instance: hr.ZeroconfInstanceType
2021-09-08 23:12:07 +02:00
noise_psk: Optional[str]
@property
def noise_psk_bytes(self) -> Optional[bytes]:
if self.noise_psk is None:
return None
return base64.b64decode(self.noise_psk)
@dataclass
class Packet:
type: int
data: bytes
class APIFrameHelper:
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
params: ConnectionParams,
):
self._reader = reader
self._writer = writer
self._params = params
self._write_lock = asyncio.Lock()
self._read_lock = asyncio.Lock()
self._ready_event = asyncio.Event()
self._proto: Optional[NoiseConnection] = None
async def close(self) -> None:
async with self._write_lock:
self._writer.close()
async def _write_frame_noise(self, frame: bytes) -> None:
try:
async with self._write_lock:
_LOGGER.debug("Sending frame %s", frame.hex())
header = bytes(
[
0x01,
(len(frame) >> 8) & 0xFF,
len(frame) & 0xFF,
]
)
self._writer.write(header + frame)
await self._writer.drain()
except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err
async def _read_frame_noise(self) -> bytes:
try:
async with self._read_lock:
header = await self._reader.readexactly(3)
if header[0] != 0x01:
raise APIConnectionError(f"Marker byte invalid: {header[0]}")
msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
raise APIConnectionError(f"Error while reading data: {err}") from err
_LOGGER.debug("Received frame %s", frame.hex())
return frame
async def perform_handshake(self) -> None:
if self._params.noise_psk is None:
return
await self._write_frame_noise(b"") # ClientHello
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame_noise() # ServerHello
if not server_hello:
raise APIConnectionError("ServerHello is empty")
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
raise APIConnectionError(
f"Unknown protocol selected by client {chosen_proto}"
)
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator()
self._proto.set_psks(self._params.noise_psk_bytes)
self._proto.set_prologue(prologue)
self._proto.start_handshake()
_LOGGER.debug("Starting handshake...")
do_write = True
while not self._proto.handshake_finished:
if do_write:
msg = self._proto.write_message()
await self._write_frame_noise(b"\x00" + msg)
else:
msg = await self._read_frame_noise()
if not msg or msg[0] != 0:
raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}")
self._proto.read_message(msg[1:])
do_write = not do_write
_LOGGER.debug("Handshake complete!")
self._ready_event.set()
async def _write_packet_noise(self, packet: Packet) -> None:
await self._ready_event.wait()
padding = 0
data = (
bytes(
[
(packet.type >> 8) & 0xFF,
(packet.type >> 0) & 0xFF,
(len(packet.data) >> 8) & 0xFF,
(len(packet.data) >> 0) & 0xFF,
]
)
+ packet.data
+ b"\x00" * padding
)
assert self._proto is not None
frame = self._proto.encrypt(data)
await self._write_frame_noise(frame)
async def _write_packet_plaintext(self, packet: Packet) -> None:
data = b"\0"
data += varuint_to_bytes(len(packet.data))
data += varuint_to_bytes(packet.type)
data += packet.data
try:
async with self._write_lock:
_LOGGER.debug("Sending frame %s", data.hex())
self._writer.write(data)
await self._writer.drain()
except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err
async def write_packet(self, packet: Packet) -> None:
if self._params.noise_psk is None:
await self._write_packet_plaintext(packet)
else:
await self._write_packet_noise(packet)
async def _read_packet_noise(self) -> Packet:
await self._ready_event.wait()
frame = await self._read_frame_noise()
assert self._proto is not None
msg = self._proto.decrypt(frame)
if len(msg) < 4:
raise APIConnectionError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len]
return Packet(type=pkt_type, data=data)
async def _read_packet_plaintext(self) -> Packet:
async with self._read_lock:
preamble = await self._reader.readexactly(1)
if preamble[0] != 0x00:
raise APIConnectionError("Invalid preamble")
length = b""
while not length or (length[-1] & 0x80) == 0x80:
length += await self._reader.readexactly(1)
length_int = bytes_to_varuint(length)
assert length_int is not None
msg_type = b""
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
msg_type += await self._reader.readexactly(1)
msg_type_int = bytes_to_varuint(msg_type)
assert msg_type_int is not None
raw_msg = b""
if length_int != 0:
raw_msg = await self._reader.readexactly(length_int)
return Packet(type=msg_type_int, data=raw_msg)
async def read_packet(self) -> Packet:
if self._params.noise_psk is None:
return await self._read_packet_plaintext()
return await self._read_packet_noise()
2019-04-07 19:03:26 +02:00
class APIConnection:
def __init__(
self, params: ConnectionParams, on_stop: Callable[[], Awaitable[None]]
):
2019-04-07 19:03:26 +02:00
self._params = params
self.on_stop = on_stop
self._stopped = False
self._socket: Optional[socket.socket] = None
2021-09-08 23:12:07 +02:00
self._frame_helper: Optional[APIFrameHelper] = None
2019-04-07 19:03:26 +02:00
self._connected = False
self._authenticated = False
self._socket_connected = False
self._state_lock = asyncio.Lock()
self._api_version: Optional[APIVersion] = None
2019-04-07 19:03:26 +02:00
self._message_handlers: List[Callable[[message.Message], None]] = []
2021-06-30 17:10:30 +02:00
self.log_name = params.address
2021-09-08 23:12:07 +02:00
self._ping_task: Optional[asyncio.Task[None]] = None
2019-04-07 19:03:26 +02:00
def _start_ping(self) -> None:
async def func() -> None:
2021-09-08 23:12:07 +02:00
while True:
2019-04-07 19:03:26 +02:00
await asyncio.sleep(self._params.keepalive)
try:
await self.ping()
except APIConnectionError:
2021-06-30 17:10:30 +02:00
_LOGGER.info("%s: Ping Failed!", self.log_name)
2019-04-07 19:03:26 +02:00
await self._on_error()
return
2021-09-08 23:12:07 +02:00
self._ping_task = asyncio.create_task(func())
2019-04-07 19:03:26 +02:00
async def _close_socket(self) -> None:
if not self._socket_connected:
return
2021-09-08 23:12:07 +02:00
if self._frame_helper is not None:
await self._frame_helper.close()
self._frame_helper = None
2019-04-07 19:03:26 +02:00
if self._socket is not None:
self._socket.close()
2021-09-08 23:12:07 +02:00
self._socket = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None
2019-04-07 19:03:26 +02:00
self._socket_connected = False
self._connected = False
self._authenticated = False
2021-06-30 17:10:30 +02:00
_LOGGER.debug("%s: Closed socket", self.log_name)
2019-04-07 19:03:26 +02:00
async def stop(self, force: bool = False) -> None:
if self._stopped:
return
if self._connected and not force:
try:
await self._disconnect()
except APIConnectionError:
pass
self._stopped = True
await self._close_socket()
await self.on_stop()
async def _on_error(self) -> None:
await self.stop(force=True)
2021-09-08 23:12:07 +02:00
# pylint: disable=too-many-statements
2019-04-07 19:03:26 +02:00
async def connect(self) -> None:
if self._stopped:
2021-06-30 17:10:30 +02:00
raise APIConnectionError(f"Connection is closed for {self.log_name}!")
2019-04-07 19:03:26 +02:00
if self._connected:
2021-06-30 17:10:30 +02:00
raise APIConnectionError(f"Already connected for {self.log_name}!")
2019-04-07 19:03:26 +02:00
try:
2021-07-12 20:09:17 +02:00
coro = hr.async_resolve_host(
self._params.eventloop,
self._params.address,
self._params.port,
self._params.zeroconf_instance,
)
addr = await asyncio.wait_for(coro, 30.0)
2019-04-07 19:03:26 +02:00
except APIConnectionError as err:
await self._on_error()
raise err
except asyncio.TimeoutError:
await self._on_error()
2021-06-30 17:10:30 +02:00
raise APIConnectionError(
f"Timeout while resolving IP address for {self.log_name}"
)
2019-04-07 19:03:26 +02:00
self._socket = socket.socket(
family=addr.family, type=addr.type, proto=addr.proto
)
2019-04-07 19:03:26 +02:00
self._socket.setblocking(False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
2021-06-30 17:10:30 +02:00
self.log_name,
self._params.address,
self._params.port,
addr,
)
sockaddr = astuple(addr.sockaddr)
2019-04-07 19:03:26 +02:00
try:
coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
await asyncio.wait_for(coro2, 30.0)
2019-04-07 19:03:26 +02:00
except OSError as err:
await self._on_error()
raise APIConnectionError(f"Error connecting to {sockaddr}: {err}")
2019-04-07 19:03:26 +02:00
except asyncio.TimeoutError:
await self._on_error()
raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
2019-04-07 19:03:26 +02:00
_LOGGER.debug("%s: Opened socket for", self._params.address)
2021-09-08 23:12:07 +02:00
reader, writer = await asyncio.open_connection(sock=self._socket)
self._frame_helper = APIFrameHelper(reader, writer, self._params)
2019-04-07 19:03:26 +02:00
self._socket_connected = True
2021-09-08 23:12:07 +02:00
try:
await self._frame_helper.perform_handshake()
except APIConnectionError:
await self._on_error()
raise
2019-04-07 19:03:26 +02:00
self._params.eventloop.create_task(self.run_forever())
hello = HelloRequest()
2019-04-07 19:03:26 +02:00
hello.client_info = self._params.client_info
try:
resp = await self.send_message_await_response(hello, HelloResponse)
2021-09-08 23:12:07 +02:00
except APIConnectionError:
2019-04-07 19:03:26 +02:00
await self._on_error()
2021-09-08 23:12:07 +02:00
raise
_LOGGER.debug(
"%s: Successfully connected ('%s' API=%s.%s)",
2021-06-30 17:10:30 +02:00
self.log_name,
resp.server_info,
resp.api_version_major,
resp.api_version_minor,
)
self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor)
2019-04-07 19:03:26 +02:00
if self._api_version.major > 2:
_LOGGER.error(
"%s: Incompatible version %s! Closing connection",
2021-06-30 17:10:30 +02:00
self.log_name,
self._api_version.major,
)
2019-04-07 19:03:26 +02:00
await self._on_error()
raise APIConnectionError("Incompatible API version.")
self._connected = True
self._start_ping()
async def login(self) -> None:
self._check_connected()
if self._authenticated:
raise APIConnectionError("Already logged in!")
connect = ConnectRequest()
2019-04-07 19:03:26 +02:00
if self._params.password is not None:
connect.password = self._params.password
resp = await self.send_message_await_response(connect, ConnectResponse)
2019-04-07 19:03:26 +02:00
if resp.invalid_password:
raise APIConnectionError("Invalid password!")
self._authenticated = True
def _check_connected(self) -> None:
if not self._connected:
raise APIConnectionError("Must be connected!")
@property
def is_connected(self) -> bool:
return self._connected
@property
def is_authenticated(self) -> bool:
return self._authenticated
2021-09-08 23:12:07 +02:00
async def send_message(self, msg: message.Message) -> None:
2019-04-07 19:03:26 +02:00
if not self._socket_connected:
raise APIConnectionError("Socket is not connected")
for message_type, klass in MESSAGE_TYPE_TO_PROTO.items():
if isinstance(msg, klass):
break
else:
raise ValueError
encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
2020-07-14 20:00:12 +02:00
# pylint: disable=undefined-loop-variable
2021-09-08 23:12:07 +02:00
assert self._frame_helper is not None
await self._frame_helper.write_packet(
Packet(
type=message_type,
data=encoded,
)
)
2019-04-07 19:03:26 +02:00
async def send_message_callback_response(
self, send_msg: message.Message, on_message: Callable[[Any], None]
) -> None:
2019-04-07 19:03:26 +02:00
self._message_handlers.append(on_message)
await self.send_message(send_msg)
async def send_message_await_response_complex(
self,
send_msg: message.Message,
do_append: Callable[[Any], bool],
do_stop: Callable[[Any], bool],
2021-09-08 23:12:07 +02:00
timeout: float = 10.0,
) -> List[Any]:
2019-04-07 19:03:26 +02:00
fut = self._params.eventloop.create_future()
responses = []
def on_message(resp: message.Message) -> None:
2019-04-07 19:03:26 +02:00
if fut.done():
return
if do_append(resp):
responses.append(resp)
if do_stop(resp):
fut.set_result(responses)
self._message_handlers.append(on_message)
await self.send_message(send_msg)
try:
await asyncio.wait_for(fut, timeout)
except asyncio.TimeoutError:
if self._stopped:
raise APIConnectionError("Disconnected while waiting for API response!")
2019-04-07 19:03:26 +02:00
raise APIConnectionError("Timeout while waiting for API response!")
try:
self._message_handlers.remove(on_message)
except ValueError:
pass
return responses
async def send_message_await_response(
2021-09-08 23:12:07 +02:00
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
) -> Any:
def is_response(msg: message.Message) -> bool:
2019-04-07 19:03:26 +02:00
return isinstance(msg, response_type)
res = await self.send_message_await_response_complex(
send_msg, is_response, is_response, timeout=timeout
)
2019-04-07 19:03:26 +02:00
if len(res) != 1:
raise APIConnectionError("Expected one result, got {}".format(len(res)))
2019-04-07 19:03:26 +02:00
return res[0]
async def _run_once(self) -> None:
2021-09-08 23:12:07 +02:00
assert self._frame_helper is not None
pkt = await self._frame_helper.read_packet()
2019-04-07 19:03:26 +02:00
2021-09-08 23:12:07 +02:00
msg_type = pkt.type
raw_msg = pkt.data
2019-04-07 19:03:26 +02:00
if msg_type not in MESSAGE_TYPE_TO_PROTO:
_LOGGER.debug(
"%s: Skipping message type %s", self._params.address, msg_type
)
2019-04-07 19:03:26 +02:00
return
msg = MESSAGE_TYPE_TO_PROTO[msg_type]()
try:
msg.ParseFromString(raw_msg)
except Exception as e:
raise APIConnectionError("Invalid protobuf message: {}".format(e))
_LOGGER.debug(
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
)
2019-04-07 19:03:26 +02:00
for msg_handler in self._message_handlers[:]:
msg_handler(msg)
await self._handle_internal_messages(msg)
async def run_forever(self) -> None:
while True:
try:
await self._run_once()
except APIConnectionError as err:
_LOGGER.info(
"%s: Error while reading incoming messages: %s",
2021-06-30 17:10:30 +02:00
self.log_name,
err,
)
2019-04-07 19:03:26 +02:00
await self._on_error()
break
2020-07-14 20:00:12 +02:00
except Exception as err: # pylint: disable=broad-except
_LOGGER.info(
"%s: Unexpected error while reading incoming messages: %s",
2021-06-30 17:10:30 +02:00
self.log_name,
err,
2021-09-08 23:12:07 +02:00
exc_info=True,
)
2019-04-07 19:03:26 +02:00
await self._on_error()
break
async def _handle_internal_messages(self, msg: Any) -> None:
if isinstance(msg, DisconnectRequest):
await self.send_message(DisconnectResponse())
2019-04-07 19:03:26 +02:00
await self.stop(force=True)
elif isinstance(msg, PingRequest):
await self.send_message(PingResponse())
elif isinstance(msg, GetTimeRequest):
resp = GetTimeResponse()
2019-04-07 19:03:26 +02:00
resp.epoch_seconds = int(time.time())
await self.send_message(resp)
async def ping(self) -> None:
self._check_connected()
await self.send_message_await_response(PingRequest(), PingResponse)
2019-04-07 19:03:26 +02:00
async def _disconnect(self) -> None:
self._check_connected()
try:
await self.send_message_await_response(
DisconnectRequest(), DisconnectResponse
)
2019-04-07 19:03:26 +02:00
except APIConnectionError:
pass
def _check_authenticated(self) -> None:
if not self._authenticated:
raise APIConnectionError("Must login first!")
@property
def api_version(self) -> Optional[APIVersion]:
return self._api_version