import asyncio import base64 import logging import socket import time from contextlib import suppress from dataclasses import astuple, dataclass from typing import Any, Awaitable, Callable, List, Optional from google.protobuf import message from noise.connection import NoiseConnection # type: ignore import aioesphomeapi.host_resolver as hr from .api_pb2 import ( # type: ignore ConnectRequest, ConnectResponse, DisconnectRequest, DisconnectResponse, GetTimeRequest, GetTimeResponse, HelloRequest, HelloResponse, PingRequest, PingResponse, ) from .core import ( MESSAGE_TYPE_TO_PROTO, APIConnectionError, HandshakeAPIError, InvalidAuthAPIError, InvalidEncryptionKeyAPIError, ProtocolAPIError, RequiresEncryptionAPIError, ResolveAPIError, SocketAPIError, ) from .model import APIVersion from .util import bytes_to_varuint, varuint_to_bytes _LOGGER = logging.getLogger(__name__) @dataclass class ConnectionParams: eventloop: asyncio.events.AbstractEventLoop address: str port: int password: Optional[str] client_info: str keepalive: float zeroconf_instance: hr.ZeroconfInstanceType noise_psk: Optional[str] @dataclass class Packet: type: int data: bytes class APIFrameHelper: def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, params: ConnectionParams, ): self._reader = reader self._writer = writer self._params = params self._write_lock = asyncio.Lock() self._read_lock = asyncio.Lock() self._ready_event = asyncio.Event() self._proto: Optional[NoiseConnection] = None async def close(self) -> None: async with self._write_lock: self._writer.close() async def _write_frame_noise(self, frame: bytes) -> None: try: async with self._write_lock: _LOGGER.debug("Sending frame %s", frame.hex()) header = bytes( [ 0x01, (len(frame) >> 8) & 0xFF, len(frame) & 0xFF, ] ) self._writer.write(header + frame) await self._writer.drain() except OSError as err: raise SocketAPIError(f"Error while writing data: {err}") from err async def _read_frame_noise(self) -> bytes: try: async with self._read_lock: header = await self._reader.readexactly(3) if header[0] != 0x01: raise ProtocolAPIError(f"Marker byte invalid: {header[0]}") msg_size = (header[1] << 8) | header[2] frame = await self._reader.readexactly(msg_size) except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: raise SocketAPIError(f"Error while reading data: {err}") from err _LOGGER.debug("Received frame %s", frame.hex()) return frame async def perform_handshake(self) -> None: if self._params.noise_psk is None: return await self._write_frame_noise(b"") # ClientHello prologue = b"NoiseAPIInit" + b"\x00\x00" server_hello = await self._read_frame_noise() # ServerHello if not server_hello: raise HandshakeAPIError("ServerHello is empty") chosen_proto = server_hello[0] if chosen_proto != 0x01: raise HandshakeAPIError( f"Unknown protocol selected by client {chosen_proto}" ) self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256") self._proto.set_as_initiator() try: noise_psk_bytes = base64.b64decode(self._params.noise_psk) except ValueError: raise InvalidEncryptionKeyAPIError( f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value" ) if len(noise_psk_bytes) != 32: raise InvalidEncryptionKeyAPIError( f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data" ) self._proto.set_psks(noise_psk_bytes) self._proto.set_prologue(prologue) self._proto.start_handshake() _LOGGER.debug("Starting handshake...") do_write = True while not self._proto.handshake_finished: if do_write: msg = self._proto.write_message() await self._write_frame_noise(b"\x00" + msg) else: msg = await self._read_frame_noise() if not msg: raise HandshakeAPIError("Handshake message too short") if msg[0] != 0: explanation = msg[1:].decode() if explanation == "Handshake MAC failure": raise InvalidEncryptionKeyAPIError("Invalid encryption key") raise HandshakeAPIError(f"Handshake failure: {explanation}") self._proto.read_message(msg[1:]) do_write = not do_write _LOGGER.debug("Handshake complete!") self._ready_event.set() async def _write_packet_noise(self, packet: Packet) -> None: await self._ready_event.wait() padding = 0 data = ( bytes( [ (packet.type >> 8) & 0xFF, (packet.type >> 0) & 0xFF, (len(packet.data) >> 8) & 0xFF, (len(packet.data) >> 0) & 0xFF, ] ) + packet.data + b"\x00" * padding ) assert self._proto is not None frame = self._proto.encrypt(data) await self._write_frame_noise(frame) async def _write_packet_plaintext(self, packet: Packet) -> None: data = b"\0" data += varuint_to_bytes(len(packet.data)) data += varuint_to_bytes(packet.type) data += packet.data try: async with self._write_lock: _LOGGER.debug("Sending frame %s", data.hex()) self._writer.write(data) await self._writer.drain() except OSError as err: raise SocketAPIError(f"Error while writing data: {err}") from err async def write_packet(self, packet: Packet) -> None: if self._params.noise_psk is None: await self._write_packet_plaintext(packet) else: await self._write_packet_noise(packet) async def _read_packet_noise(self) -> Packet: await self._ready_event.wait() frame = await self._read_frame_noise() assert self._proto is not None msg = self._proto.decrypt(frame) if len(msg) < 4: raise ProtocolAPIError(f"Bad packet frame: {msg}") pkt_type = (msg[0] << 8) | msg[1] data_len = (msg[2] << 8) | msg[3] if data_len + 4 > len(msg): raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") data = msg[4 : 4 + data_len] return Packet(type=pkt_type, data=data) async def _read_packet_plaintext(self) -> Packet: async with self._read_lock: try: preamble = await self._reader.readexactly(1) if preamble[0] != 0x00: if preamble[0] == 0x01: raise RequiresEncryptionAPIError( "Connection requires encryption" ) raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}") length = b"" while not length or (length[-1] & 0x80) == 0x80: length += await self._reader.readexactly(1) length_int = bytes_to_varuint(length) assert length_int is not None msg_type = b"" while not msg_type or (msg_type[-1] & 0x80) == 0x80: msg_type += await self._reader.readexactly(1) msg_type_int = bytes_to_varuint(msg_type) assert msg_type_int is not None raw_msg = b"" if length_int != 0: raw_msg = await self._reader.readexactly(length_int) return Packet(type=msg_type_int, data=raw_msg) except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: raise SocketAPIError(f"Error while reading data: {err}") from err async def read_packet(self) -> Packet: if self._params.noise_psk is None: return await self._read_packet_plaintext() return await self._read_packet_noise() class APIConnection: def __init__( self, params: ConnectionParams, on_stop: Callable[[], Awaitable[None]] ): self._params = params self.on_stop = on_stop self._stopped = False self._socket: Optional[socket.socket] = None self._frame_helper: Optional[APIFrameHelper] = None self._connected = False self._authenticated = False self._socket_connected = False self._state_lock = asyncio.Lock() self._api_version: Optional[APIVersion] = None self._message_handlers: List[Callable[[message.Message], None]] = [] self.log_name = params.address self._ping_task: Optional[asyncio.Task[None]] = None self._read_exception_handlers: List[Callable[[Exception], None]] = [] def _start_ping(self) -> None: async def func() -> None: while True: await asyncio.sleep(self._params.keepalive) try: await self.ping() except APIConnectionError: _LOGGER.info("%s: Ping Failed!", self.log_name) await self._on_error() return self._ping_task = asyncio.create_task(func()) async def _close_socket(self) -> None: if not self._socket_connected: return if self._frame_helper is not None: await self._frame_helper.close() self._frame_helper = None if self._socket is not None: self._socket.close() self._socket = None if self._ping_task is not None: self._ping_task.cancel() self._ping_task = None self._socket_connected = False self._connected = False self._authenticated = False _LOGGER.debug("%s: Closed socket", self.log_name) async def stop(self, force: bool = False) -> None: if self._stopped: return if self._connected and not force: try: await self._disconnect() except APIConnectionError: pass self._stopped = True await self._close_socket() await self.on_stop() async def _on_error(self) -> None: await self.stop(force=True) # pylint: disable=too-many-statements async def connect(self) -> None: if self._stopped: raise APIConnectionError(f"Connection is closed for {self.log_name}!") if self._connected: raise APIConnectionError(f"Already connected for {self.log_name}!") try: coro = hr.async_resolve_host( self._params.eventloop, self._params.address, self._params.port, self._params.zeroconf_instance, ) addr = await asyncio.wait_for(coro, 30.0) except APIConnectionError as err: await self._on_error() raise err except asyncio.TimeoutError: await self._on_error() raise ResolveAPIError( f"Timeout while resolving IP address for {self.log_name}" ) self._socket = socket.socket( family=addr.family, type=addr.type, proto=addr.proto ) self._socket.setblocking(False) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) _LOGGER.debug( "%s: Connecting to %s:%s (%s)", self.log_name, self._params.address, self._params.port, addr, ) sockaddr = astuple(addr.sockaddr) try: coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr) await asyncio.wait_for(coro2, 30.0) except OSError as err: await self._on_error() raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") except asyncio.TimeoutError: await self._on_error() raise SocketAPIError(f"Timeout while connecting to {sockaddr}") _LOGGER.debug("%s: Opened socket for", self._params.address) reader, writer = await asyncio.open_connection(sock=self._socket) self._frame_helper = APIFrameHelper(reader, writer, self._params) self._socket_connected = True try: await self._frame_helper.perform_handshake() except APIConnectionError: await self._on_error() raise self._params.eventloop.create_task(self.run_forever()) hello = HelloRequest() hello.client_info = self._params.client_info try: resp = await self.send_message_await_response(hello, HelloResponse) except APIConnectionError: await self._on_error() raise _LOGGER.debug( "%s: Successfully connected ('%s' API=%s.%s)", self.log_name, resp.server_info, resp.api_version_major, resp.api_version_minor, ) self._api_version = APIVersion(resp.api_version_major, resp.api_version_minor) if self._api_version.major > 2: _LOGGER.error( "%s: Incompatible version %s! Closing connection", self.log_name, self._api_version.major, ) await self._on_error() raise APIConnectionError("Incompatible API version.") self._connected = True self._start_ping() async def login(self) -> None: self._check_connected() if self._authenticated: raise APIConnectionError("Already logged in!") connect = ConnectRequest() if self._params.password is not None: connect.password = self._params.password resp = await self.send_message_await_response(connect, ConnectResponse) if resp.invalid_password: raise InvalidAuthAPIError("Invalid password!") self._authenticated = True def _check_connected(self) -> None: if not self._connected: raise APIConnectionError("Must be connected!") @property def is_connected(self) -> bool: return self._connected @property def is_authenticated(self) -> bool: return self._authenticated async def send_message(self, msg: message.Message) -> None: if not self._socket_connected: raise APIConnectionError("Socket is not connected") for message_type, klass in MESSAGE_TYPE_TO_PROTO.items(): if isinstance(msg, klass): break else: raise ValueError encoded = msg.SerializeToString() _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) # pylint: disable=undefined-loop-variable assert self._frame_helper is not None await self._frame_helper.write_packet( Packet( type=message_type, data=encoded, ) ) async def send_message_callback_response( self, send_msg: message.Message, on_message: Callable[[Any], None] ) -> None: self._message_handlers.append(on_message) await self.send_message(send_msg) async def send_message_await_response_complex( self, send_msg: message.Message, do_append: Callable[[Any], bool], do_stop: Callable[[Any], bool], timeout: float = 10.0, ) -> List[Any]: fut = self._params.eventloop.create_future() responses = [] def on_message(resp: message.Message) -> None: if fut.done(): return if do_append(resp): responses.append(resp) if do_stop(resp): fut.set_result(responses) def on_read_exception(exc: Exception) -> None: if not fut.done(): fut.set_exception(exc) self._message_handlers.append(on_message) self._read_exception_handlers.append(on_read_exception) await self.send_message(send_msg) try: await asyncio.wait_for(fut, timeout) except asyncio.TimeoutError: if self._stopped: raise SocketAPIError("Disconnected while waiting for API response!") raise SocketAPIError("Timeout while waiting for API response!") finally: with suppress(ValueError): self._message_handlers.remove(on_message) with suppress(ValueError): self._read_exception_handlers.remove(on_read_exception) return responses async def send_message_await_response( self, send_msg: message.Message, response_type: Any, timeout: float = 10.0 ) -> Any: def is_response(msg: message.Message) -> bool: return isinstance(msg, response_type) res = await self.send_message_await_response_complex( send_msg, is_response, is_response, timeout=timeout ) if len(res) != 1: raise APIConnectionError(f"Expected one result, got {len(res)}") return res[0] async def _run_once(self) -> None: assert self._frame_helper is not None pkt = await self._frame_helper.read_packet() msg_type = pkt.type raw_msg = pkt.data if msg_type not in MESSAGE_TYPE_TO_PROTO: _LOGGER.debug( "%s: Skipping message type %s", self._params.address, msg_type ) return msg = MESSAGE_TYPE_TO_PROTO[msg_type]() try: msg.ParseFromString(raw_msg) except Exception as e: raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e _LOGGER.debug( "%s: Got message of type %s: %s", self._params.address, type(msg), msg ) for msg_handler in self._message_handlers[:]: msg_handler(msg) await self._handle_internal_messages(msg) async def run_forever(self) -> None: while True: if self._frame_helper is None: # Socket closed break try: await self._run_once() except APIConnectionError as err: _LOGGER.info( "%s: Error while reading incoming messages: %s", self.log_name, err, ) for handler in self._read_exception_handlers[:]: handler(err) await self._on_error() break except Exception as err: # pylint: disable=broad-except _LOGGER.warning( "%s: Unexpected error while reading incoming messages: %s", self.log_name, err, exc_info=True, ) for handler in self._read_exception_handlers[:]: handler(err) await self._on_error() break async def _handle_internal_messages(self, msg: Any) -> None: if isinstance(msg, DisconnectRequest): await self.send_message(DisconnectResponse()) await self.stop(force=True) elif isinstance(msg, PingRequest): await self.send_message(PingResponse()) elif isinstance(msg, GetTimeRequest): resp = GetTimeResponse() resp.epoch_seconds = int(time.time()) await self.send_message(resp) async def ping(self) -> None: self._check_connected() await self.send_message_await_response(PingRequest(), PingResponse) async def _disconnect(self) -> None: self._check_connected() try: await self.send_message_await_response( DisconnectRequest(), DisconnectResponse ) except APIConnectionError: pass def _check_authenticated(self) -> None: if not self._authenticated: raise APIConnectionError("Must login first!") @property def api_version(self) -> Optional[APIVersion]: return self._api_version