From f3f5bd6b55e9fc30c951037e2af84bc7f642fddf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 3 Jul 2023 11:57:04 -0500 Subject: [PATCH] Reduce protocol overhead (#454) --- aioesphomeapi/_frame_helper.py | 178 +++++++++++++++++++++------------ tests/test__frame_helper.py | 13 +++ 2 files changed, 126 insertions(+), 65 deletions(-) diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index 9b64540..c6d37c5 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -3,7 +3,8 @@ import base64 import logging from abc import abstractmethod from enum import Enum -from typing import Any, Callable, Optional, Union, cast +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast import async_timeout from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable @@ -149,55 +150,69 @@ class APIPlaintextFrameHelper(APIFrameHelper): """Perform the handshake.""" await self._connected_event.wait() - def data_received(self, data: bytes) -> None: + def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches self._buffer += data - while len(self._buffer) >= 3: + while self._buffer: # Read preamble, which should always 0x00 # Also try to get the length and msg type - # to avoid multiple calls to readexactly + # to avoid multiple calls to _read_exactly init_bytes = self._init_read(3) - assert init_bytes is not None, "Buffer should have at least 3 bytes" - if init_bytes[0] != 0x00: - if init_bytes[0] == 0x01: + if init_bytes is None: + return + msg_type_int: Optional[int] = None + length_int: Optional[int] = None + preamble, length_high, maybe_msg_type = init_bytes + if preamble != 0x00: + if preamble == 0x01: self._handle_error_and_close( RequiresEncryptionAPIError("Connection requires encryption") ) return self._handle_error_and_close( - ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}") + ProtocolAPIError(f"Invalid preamble {preamble:02x}") ) return - if init_bytes[1] & 0x80 == 0x80: - # Length is longer than 1 byte - length = init_bytes[1:3] - msg_type = b"" + if length_high & 0x80 != 0x80: + # Length is only 1 byte + # + # This is the most common case needing a single byte for + # length and type which means we avoid 2 calls to _read_exactly + length_int = length_high + if maybe_msg_type & 0x80 != 0x80: + # Message type is also only 1 byte + msg_type_int = maybe_msg_type + else: + # Message type is longer than 1 byte + msg_type = bytes(init_bytes[2:3]) else: - # This is the most common case with 99% of messages - # needing a single byte for length and type which means - # we avoid 2 calls to readexactly - length = init_bytes[1:2] - msg_type = init_bytes[2:3] + # Length is longer than 1 byte + length = bytes(init_bytes[1:3]) + # If the message is long, we need to read the rest of the length + while length[-1] & 0x80 == 0x80: + add_length = self._read_exactly(1) + if add_length is None: + return + length += add_length + length_int = bytes_to_varuint(length) + # Since the length is longer than 1 byte we do not have the + # message type yet. + msg_type = b"" - # If the message is long, we need to read the rest of the length - while length[-1] & 0x80 == 0x80: - add_length = self._read_exactly(1) - if add_length is None: - return - length += add_length + # If the we do not have the message type yet because the message + # length was so long it did not fit into the first byte we need + # to read the (rest) of the message type + if msg_type_int is None: + while not msg_type or msg_type[-1] & 0x80 == 0x80: + add_msg_type = self._read_exactly(1) + if add_msg_type is None: + return + msg_type += add_msg_type + msg_type_int = bytes_to_varuint(msg_type) - # If the message length was longer than 1 byte, we need to read the - # message type - while not msg_type or (msg_type[-1] & 0x80) == 0x80: - add_msg_type = self._read_exactly(1) - if add_msg_type is None: - return - msg_type += add_msg_type - - length_int = bytes_to_varuint(bytes(length)) - assert length_int is not None - msg_type_int = bytes_to_varuint(bytes(msg_type)) - assert msg_type_int is not None + if TYPE_CHECKING: + assert length_int is not None + assert msg_type_int is not None if length_int == 0: self._callback_packet(msg_type_int, b"") @@ -205,6 +220,11 @@ class APIPlaintextFrameHelper(APIFrameHelper): continue packet_data = self._read_exactly(length_int) + # The packet data is not yet available, wait for more data + # to arrive before continuing, since callback_packet has not + # been called yet the buffer will not be cleared and the next + # call to data_received will continue processing the packet + # at the start of the frame. if packet_data is None: return @@ -246,6 +266,8 @@ class APINoiseFrameHelper(APIFrameHelper): "_state", "_server_name", "_proto", + "_decrypt", + "_encrypt", ) def __init__( @@ -262,6 +284,8 @@ class APINoiseFrameHelper(APIFrameHelper): self._expected_name = expected_name self._state = NoiseConnectionState.HELLO self._server_name: Optional[str] = None + self._decrypt: Optional[Callable[[bytes], bytes]] = None + self._encrypt: Optional[Callable[[bytes], bytes]] = None self._setup_proto() def _set_ready_future_exception(self, exc: Exception) -> None: @@ -291,12 +315,13 @@ class APINoiseFrameHelper(APIFrameHelper): if _LOGGER.isEnabledFor(logging.DEBUG): _LOGGER.debug("Sending frame: [%s]", frame.hex()) + frame_len = len(frame) try: header = bytes( [ 0x01, - (len(frame) >> 8) & 0xFF, - len(frame) & 0xFF, + (frame_len >> 8) & 0xFF, + frame_len & 0xFF, ] ) self._transport.write(header + frame) @@ -314,17 +339,22 @@ class APINoiseFrameHelper(APIFrameHelper): def data_received(self, data: bytes) -> None: self._buffer += data - while len(self._buffer) >= 3: + while self._buffer: header = self._init_read(3) - assert header is not None, "Buffer should have at least 3 bytes" - if header[0] != 0x01: + if header is None: + return + preamble, msg_size_high, msg_size_low = header + if preamble != 0x01: self._handle_error_and_close( ProtocolAPIError(f"Marker byte invalid: {header[0]}") ) return - msg_size = (header[1] << 8) | header[2] - frame = self._read_exactly(msg_size) - + frame = self._read_exactly((msg_size_high << 8) | msg_size_low) + # The complete frame is not yet available, wait for more data + # to arrive before continuing, since callback_packet has not + # been called yet the buffer will not be cleared and the next + # call to data_received will continue processing the packet + # at the start of the frame. if frame is None: return @@ -415,44 +445,62 @@ class APINoiseFrameHelper(APIFrameHelper): return _LOGGER.debug("Handshake complete") self._state = NoiseConnectionState.READY + noise_protocol = self._proto.noise_protocol + self._decrypt = partial( + noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member + None, + ) + self._encrypt = partial( + noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member + None, + ) self._ready_future.set_result(None) def write_packet(self, type_: int, data: bytes) -> None: """Write a packet to the socket.""" if self._state != NoiseConnectionState.READY: raise HandshakeAPIError("Noise connection is not ready") + if TYPE_CHECKING: + assert self._encrypt is not None, "Handshake should be complete" + data_len = len(data) self._write_frame( - self._proto.encrypt( - ( - bytes( - [ - (type_ >> 8) & 0xFF, - (type_ >> 0) & 0xFF, - (len(data) >> 8) & 0xFF, - (len(data) >> 0) & 0xFF, - ] - ) - + data + self._encrypt( + bytes( + [ + (type_ >> 8) & 0xFF, + (type_ >> 0) & 0xFF, + (data_len >> 8) & 0xFF, + (data_len >> 0) & 0xFF, + ] ) + + data ) ) def _handle_frame(self, frame: bytearray) -> None: """Handle an incoming frame.""" - assert self._proto is not None - msg = self._proto.decrypt(bytes(frame)) - if len(msg) < 4: - self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}")) - return - pkt_type = (msg[0] << 8) | msg[1] - data_len = (msg[2] << 8) | msg[3] - if data_len + 4 > len(msg): + if TYPE_CHECKING: + assert self._decrypt is not None, "Handshake should be complete" + try: + msg = self._decrypt(bytes(frame)) + except InvalidTag as ex: self._handle_error_and_close( - ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") + ProtocolAPIError(f"Bad encryption frame: {ex!r}") ) return - data = msg[4 : 4 + data_len] - self._on_pkt(pkt_type, data) + msg_len = len(msg) + if msg_len < 4: + self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg!r}")) + return + msg_type_high, msg_type_low, data_len_high, data_len_low = msg[:4] + msg_type = (msg_type_high << 8) | msg_type_low + data_len = (data_len_high << 8) | data_len_low + if data_len + 4 != msg_len: + self._handle_error_and_close( + ProtocolAPIError(f"Bad data len: {data_len} vs {msg_len}") + ) + return + self._on_pkt(msg_type, msg[4:]) def _handle_closed( # pylint: disable=unused-argument self, frame: bytearray diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 4eda656..32bcd61 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -43,6 +43,19 @@ PREAMBLE = b"\x00" (b"\x42" * 256), 256, ), + ( + PREAMBLE + varuint_to_bytes(1) + varuint_to_bytes(32768) + b"\x42", + b"\x42", + 32768, + ), + ( + PREAMBLE + + varuint_to_bytes(32768) + + varuint_to_bytes(32768) + + (b"\x42" * 32768), + (b"\x42" * 32768), + 32768, + ), ], ) async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):