From 9ca228cd1e12cd342fd067b6bbab9d627c4bbb7d Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Wed, 13 Oct 2021 10:05:08 +0200 Subject: [PATCH] Refactor frame_helper into new module (#109) --- aioesphomeapi/_frame_helper.py | 254 +++++++++++++++++++++++++++++++++ aioesphomeapi/connection.py | 232 ++---------------------------- 2 files changed, 267 insertions(+), 219 deletions(-) create mode 100644 aioesphomeapi/_frame_helper.py diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py new file mode 100644 index 0000000..987a6a3 --- /dev/null +++ b/aioesphomeapi/_frame_helper.py @@ -0,0 +1,254 @@ +import asyncio +import base64 +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +from noise.connection import NoiseConnection # type: ignore + +from .core import ( + HandshakeAPIError, + InvalidEncryptionKeyAPIError, + ProtocolAPIError, + RequiresEncryptionAPIError, + SocketAPIError, + SocketClosedAPIError, +) +from .util import bytes_to_varuint, varuint_to_bytes + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class Packet: + type: int + data: bytes + + +class APIFrameHelper(ABC): + @abstractmethod + async def close(self) -> None: + pass + + @abstractmethod + async def write_packet(self, packet: Packet) -> None: + pass + + @abstractmethod + async def read_packet(self) -> Packet: + pass + + +class APIPlaintextFrameHelper(APIFrameHelper): + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + self._reader = reader + self._writer = writer + self._write_lock = asyncio.Lock() + self._read_lock = asyncio.Lock() + self._closed_event = asyncio.Event() + + async def close(self) -> None: + self._closed_event.set() + self._writer.close() + + async def write_packet(self, packet: Packet) -> None: + data = b"\0" + data += varuint_to_bytes(len(packet.data)) + data += varuint_to_bytes(packet.type) + data += packet.data + try: + async with self._write_lock: + _LOGGER.debug("Sending plaintext frame %s", data.hex()) + self._writer.write(data) + await self._writer.drain() + except OSError as err: + raise SocketAPIError(f"Error while writing data: {err}") from err + + async def read_packet(self) -> Packet: + async with self._read_lock: + try: + preamble = await self._reader.readexactly(1) + if preamble[0] != 0x00: + if preamble[0] == 0x01: + raise RequiresEncryptionAPIError( + "Connection requires encryption" + ) + raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}") + + length = b"" + while not length or (length[-1] & 0x80) == 0x80: + length += await self._reader.readexactly(1) + length_int = bytes_to_varuint(length) + assert length_int is not None + msg_type = b"" + while not msg_type or (msg_type[-1] & 0x80) == 0x80: + msg_type += await self._reader.readexactly(1) + msg_type_int = bytes_to_varuint(msg_type) + assert msg_type_int is not None + + raw_msg = b"" + if length_int != 0: + raw_msg = await self._reader.readexactly(length_int) + return Packet(type=msg_type_int, data=raw_msg) + except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: + if ( + isinstance(err, asyncio.IncompleteReadError) + and self._closed_event.is_set() + ): + raise SocketClosedAPIError( + f"Socket closed while reading data: {err}" + ) from err + raise SocketAPIError(f"Error while reading data: {err}") from err + + +def _decode_noise_psk(psk: str) -> bytes: + """Decode the given noise psk from base64 format to raw bytes.""" + try: + psk_bytes = base64.b64decode(psk) + except ValueError: + raise InvalidEncryptionKeyAPIError( + f"Malformed PSK {psk}, expected base64-encoded value" + ) + if len(psk_bytes) != 32: + raise InvalidEncryptionKeyAPIError( + f"Malformed PSK {psk}, expected 32-bytes of base64 data" + ) + return psk_bytes + + +class APINoiseFrameHelper(APIFrameHelper): + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + noise_psk: str, + ): + self._reader = reader + self._writer = writer + self._write_lock = asyncio.Lock() + self._read_lock = asyncio.Lock() + self._ready_event = asyncio.Event() + self._closed_event = asyncio.Event() + self._proto: Optional[NoiseConnection] = None + self._noise_psk = noise_psk + + async def close(self) -> None: + self._closed_event.set() + self._writer.close() + + async def _write_frame(self, frame: bytes) -> None: + try: + async with self._write_lock: + _LOGGER.debug("Sending frame %s", frame.hex()) + header = bytes( + [ + 0x01, + (len(frame) >> 8) & 0xFF, + len(frame) & 0xFF, + ] + ) + self._writer.write(header + frame) + await self._writer.drain() + except OSError as err: + raise SocketAPIError(f"Error while writing data: {err}") from err + + async def _read_frame(self) -> bytes: + try: + async with self._read_lock: + header = await self._reader.readexactly(3) + if header[0] != 0x01: + raise ProtocolAPIError(f"Marker byte invalid: {header[0]}") + msg_size = (header[1] << 8) | header[2] + frame = await self._reader.readexactly(msg_size) + except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: + if ( + isinstance(err, asyncio.IncompleteReadError) + and self._closed_event.is_set() + ): + raise SocketClosedAPIError( + f"Socket closed while reading data: {err}" + ) from err + raise SocketAPIError(f"Error while reading data: {err}") from err + + _LOGGER.debug("Received frame %s", frame.hex()) + return frame + + async def perform_handshake(self) -> None: + await self._write_frame(b"") # ClientHello + prologue = b"NoiseAPIInit" + b"\x00\x00" + server_hello = await self._read_frame() # ServerHello + if not server_hello: + raise HandshakeAPIError("ServerHello is empty") + chosen_proto = server_hello[0] + if chosen_proto != 0x01: + raise HandshakeAPIError( + f"Unknown protocol selected by client {chosen_proto}" + ) + + self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256") + self._proto.set_as_initiator() + self._proto.set_psks(_decode_noise_psk(self._noise_psk)) + self._proto.set_prologue(prologue) + self._proto.start_handshake() + + _LOGGER.debug("Starting handshake...") + do_write = True + while not self._proto.handshake_finished: + if do_write: + msg = self._proto.write_message() + await self._write_frame(b"\x00" + msg) + else: + msg = await self._read_frame() + if not msg: + raise HandshakeAPIError("Handshake message too short") + if msg[0] != 0: + explanation = msg[1:].decode() + if explanation == "Handshake MAC failure": + raise InvalidEncryptionKeyAPIError("Invalid encryption key") + raise HandshakeAPIError(f"Handshake failure: {explanation}") + self._proto.read_message(msg[1:]) + + do_write = not do_write + + _LOGGER.debug("Handshake complete!") + self._ready_event.set() + + async def write_packet(self, packet: Packet) -> None: + # Wait for handshake to complete + await self._ready_event.wait() + padding = 0 + data = ( + bytes( + [ + (packet.type >> 8) & 0xFF, + (packet.type >> 0) & 0xFF, + (len(packet.data) >> 8) & 0xFF, + (len(packet.data) >> 0) & 0xFF, + ] + ) + + packet.data + + b"\x00" * padding + ) + assert self._proto is not None + frame = self._proto.encrypt(data) + await self._write_frame(frame) + + async def read_packet(self) -> Packet: + # Wait for handshake to complete + await self._ready_event.wait() + frame = await self._read_frame() + assert self._proto is not None + msg = self._proto.decrypt(frame) + if len(msg) < 4: + raise ProtocolAPIError(f"Bad packet frame: {msg}") + pkt_type = (msg[0] << 8) | msg[1] + data_len = (msg[2] << 8) | msg[3] + if data_len + 4 > len(msg): + raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") + data = msg[4 : 4 + data_len] + return Packet(type=pkt_type, data=data) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index af5f4c7..12c575a 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -1,5 +1,4 @@ import asyncio -import base64 import enum import logging import socket @@ -9,10 +8,15 @@ from dataclasses import astuple, dataclass from typing import Any, Awaitable, Callable, List, Optional from google.protobuf import message -from noise.connection import NoiseConnection # type: ignore import aioesphomeapi.host_resolver as hr +from ._frame_helper import ( + APIFrameHelper, + APINoiseFrameHelper, + APIPlaintextFrameHelper, + Packet, +) from .api_pb2 import ( # type: ignore ConnectRequest, ConnectResponse, @@ -28,20 +32,16 @@ from .api_pb2 import ( # type: ignore from .core import ( MESSAGE_TYPE_TO_PROTO, APIConnectionError, - HandshakeAPIError, InvalidAuthAPIError, - InvalidEncryptionKeyAPIError, PingFailedAPIError, ProtocolAPIError, ReadFailedAPIError, - RequiresEncryptionAPIError, ResolveAPIError, SocketAPIError, SocketClosedAPIError, TimeoutAPIError, ) from .model import APIVersion -from .util import bytes_to_varuint, varuint_to_bytes _LOGGER = logging.getLogger(__name__) @@ -58,217 +58,6 @@ class ConnectionParams: noise_psk: Optional[str] -@dataclass -class Packet: - type: int - data: bytes - - -class APIFrameHelper: - def __init__( - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - params: ConnectionParams, - ): - self._reader = reader - self._writer = writer - self._params = params - self._write_lock = asyncio.Lock() - self._read_lock = asyncio.Lock() - self._ready_event = asyncio.Event() - self._proto: Optional[NoiseConnection] = None - self._closed_event = asyncio.Event() - - async def close(self) -> None: - self._closed_event.set() - self._writer.close() - - async def _write_frame_noise(self, frame: bytes) -> None: - try: - async with self._write_lock: - _LOGGER.debug("Sending frame %s", frame.hex()) - header = bytes( - [ - 0x01, - (len(frame) >> 8) & 0xFF, - len(frame) & 0xFF, - ] - ) - self._writer.write(header + frame) - await self._writer.drain() - except OSError as err: - raise SocketAPIError(f"Error while writing data: {err}") from err - - async def _read_frame_noise(self) -> bytes: - try: - async with self._read_lock: - header = await self._reader.readexactly(3) - if header[0] != 0x01: - raise ProtocolAPIError(f"Marker byte invalid: {header[0]}") - msg_size = (header[1] << 8) | header[2] - frame = await self._reader.readexactly(msg_size) - except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: - if ( - isinstance(err, asyncio.IncompleteReadError) - and self._closed_event.is_set() - ): - raise SocketClosedAPIError( - f"Socket closed while reading data: {err}" - ) from err - raise SocketAPIError(f"Error while reading data: {err}") from err - - _LOGGER.debug("Received frame %s", frame.hex()) - return frame - - async def perform_handshake(self) -> None: - if self._params.noise_psk is None: - return - await self._write_frame_noise(b"") # ClientHello - prologue = b"NoiseAPIInit" + b"\x00\x00" - server_hello = await self._read_frame_noise() # ServerHello - if not server_hello: - raise HandshakeAPIError("ServerHello is empty") - chosen_proto = server_hello[0] - if chosen_proto != 0x01: - raise HandshakeAPIError( - f"Unknown protocol selected by client {chosen_proto}" - ) - - self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256") - self._proto.set_as_initiator() - - try: - noise_psk_bytes = base64.b64decode(self._params.noise_psk) - except ValueError: - raise InvalidEncryptionKeyAPIError( - f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value" - ) - if len(noise_psk_bytes) != 32: - raise InvalidEncryptionKeyAPIError( - f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data" - ) - - self._proto.set_psks(noise_psk_bytes) - self._proto.set_prologue(prologue) - self._proto.start_handshake() - - _LOGGER.debug("Starting handshake...") - do_write = True - while not self._proto.handshake_finished: - if do_write: - msg = self._proto.write_message() - await self._write_frame_noise(b"\x00" + msg) - else: - msg = await self._read_frame_noise() - if not msg: - raise HandshakeAPIError("Handshake message too short") - if msg[0] != 0: - explanation = msg[1:].decode() - if explanation == "Handshake MAC failure": - raise InvalidEncryptionKeyAPIError("Invalid encryption key") - raise HandshakeAPIError(f"Handshake failure: {explanation}") - self._proto.read_message(msg[1:]) - - do_write = not do_write - - _LOGGER.debug("Handshake complete!") - self._ready_event.set() - - async def _write_packet_noise(self, packet: Packet) -> None: - await self._ready_event.wait() - padding = 0 - data = ( - bytes( - [ - (packet.type >> 8) & 0xFF, - (packet.type >> 0) & 0xFF, - (len(packet.data) >> 8) & 0xFF, - (len(packet.data) >> 0) & 0xFF, - ] - ) - + packet.data - + b"\x00" * padding - ) - assert self._proto is not None - frame = self._proto.encrypt(data) - await self._write_frame_noise(frame) - - async def _write_packet_plaintext(self, packet: Packet) -> None: - data = b"\0" - data += varuint_to_bytes(len(packet.data)) - data += varuint_to_bytes(packet.type) - data += packet.data - try: - async with self._write_lock: - _LOGGER.debug("Sending frame %s", data.hex()) - self._writer.write(data) - await self._writer.drain() - except OSError as err: - raise SocketAPIError(f"Error while writing data: {err}") from err - - async def write_packet(self, packet: Packet) -> None: - if self._params.noise_psk is None: - await self._write_packet_plaintext(packet) - else: - await self._write_packet_noise(packet) - - async def _read_packet_noise(self) -> Packet: - await self._ready_event.wait() - frame = await self._read_frame_noise() - assert self._proto is not None - msg = self._proto.decrypt(frame) - if len(msg) < 4: - raise ProtocolAPIError(f"Bad packet frame: {msg}") - pkt_type = (msg[0] << 8) | msg[1] - data_len = (msg[2] << 8) | msg[3] - if data_len + 4 > len(msg): - raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") - data = msg[4 : 4 + data_len] - return Packet(type=pkt_type, data=data) - - async def _read_packet_plaintext(self) -> Packet: - async with self._read_lock: - try: - preamble = await self._reader.readexactly(1) - if preamble[0] != 0x00: - if preamble[0] == 0x01: - raise RequiresEncryptionAPIError( - "Connection requires encryption" - ) - raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}") - - length = b"" - while not length or (length[-1] & 0x80) == 0x80: - length += await self._reader.readexactly(1) - length_int = bytes_to_varuint(length) - assert length_int is not None - msg_type = b"" - while not msg_type or (msg_type[-1] & 0x80) == 0x80: - msg_type += await self._reader.readexactly(1) - msg_type_int = bytes_to_varuint(msg_type) - assert msg_type_int is not None - - raw_msg = b"" - if length_int != 0: - raw_msg = await self._reader.readexactly(length_int) - return Packet(type=msg_type_int, data=raw_msg) - except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: - if ( - isinstance(err, asyncio.IncompleteReadError) - and self._closed_event.is_set() - ): - raise SocketClosedAPIError( - f"Socket closed while reading data: {err}" - ) from err - raise SocketAPIError(f"Error while reading data: {err}") from err - - async def read_packet(self) -> Packet: - if self._params.noise_psk is None: - return await self._read_packet_plaintext() - return await self._read_packet_noise() - - class ConnectionState(enum.Enum): # The connection is initialized, but connect() wasn't called yet INITIALIZED = 0 @@ -381,8 +170,13 @@ class APIConnection: """Step 3 in connect process: initialize the frame helper and init read loop.""" reader, writer = await asyncio.open_connection(sock=self._socket) - self._frame_helper = APIFrameHelper(reader, writer, self._params) - await self._frame_helper.perform_handshake() + if self._params.noise_psk is None: + self._frame_helper = APIPlaintextFrameHelper(reader, writer) + else: + fh = self._frame_helper = APINoiseFrameHelper( + reader, writer, self._params.noise_psk + ) + await fh.perform_handshake() self._connection_state = ConnectionState.SOCKET_OPENED