From d6293d9177d7ed0d133426522a9c12babbc7a938 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Nov 2023 10:31:02 -0600 Subject: [PATCH] Refactor frame helper to allow sending multiple packets at once (#640) --- aioesphomeapi/_frame_helper/base.pxd | 2 + aioesphomeapi/_frame_helper/base.py | 7 +++- aioesphomeapi/_frame_helper/noise.pxd | 6 ++- aioesphomeapi/_frame_helper/noise.py | 48 ++++++++++++++-------- aioesphomeapi/_frame_helper/plain_text.pxd | 7 ++++ aioesphomeapi/_frame_helper/plain_text.py | 24 ++++++++--- aioesphomeapi/connection.pxd | 8 ++-- aioesphomeapi/connection.py | 39 +++++++++--------- tests/test__frame_helper.py | 2 +- 9 files changed, 95 insertions(+), 48 deletions(-) diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index 96a2afb..78ad9c6 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -26,3 +26,5 @@ cdef class APIFrameHelper: @cython.locals(end_of_frame_pos=cython.uint) cdef _remove_from_buffer(self) + + cpdef write_packets(self, list packets) diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index afa0b11..fa94759 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -127,8 +127,11 @@ class APIFrameHelper: handshake_handle.cancel() @abstractmethod - def write_packet(self, type_: int, data: bytes) -> None: - """Write a packet to the socket.""" + def write_packets(self, packets: list[tuple[int, bytes]]) -> None: + """Write a packets to the socket. + + Packets are in the format of tuple[protobuf_type, protobuf_data] + """ def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a new connection.""" diff --git a/aioesphomeapi/_frame_helper/noise.pxd b/aioesphomeapi/_frame_helper/noise.pxd index 87da902..674ac18 100644 --- a/aioesphomeapi/_frame_helper/noise.pxd +++ b/aioesphomeapi/_frame_helper/noise.pxd @@ -32,7 +32,11 @@ cdef class APINoiseFrameHelper(APIFrameHelper): cpdef _handle_frame(self, bytes data) @cython.locals( + type_="unsigned int", + data=bytes, + packet=tuple, data_len=cython.uint, + frame=bytes, frame_len=cython.uint ) - cpdef write_packet(self, cython.uint type_, bytes data) + cpdef write_packets(self, list packets) diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index 33e7a91..79ea85f 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -1,6 +1,6 @@ from __future__ import annotations -import base64 +import binascii import logging from enum import Enum from functools import partial @@ -241,16 +241,16 @@ class APINoiseFrameHelper(APIFrameHelper): psk = self._noise_psk server_name = self._server_name try: - psk_bytes = base64.b64decode(psk) + psk_bytes = binascii.a2b_base64(psk) except ValueError: raise InvalidEncryptionKeyAPIError( - f"{self._log_name}: Malformed PSK {psk}, expected " + f"{self._log_name}: Malformed PSK `{psk}`, expected " "base64-encoded value", server_name, ) if len(psk_bytes) != 32: raise InvalidEncryptionKeyAPIError( - f"{self._log_name}:Malformed PSK {psk}, expected" + f"{self._log_name}:Malformed PSK `{psk}`, expected" f" 32-bytes of base64 data", server_name, ) @@ -304,8 +304,11 @@ class APINoiseFrameHelper(APIFrameHelper): ) self._ready_future.set_result(None) - def write_packet(self, type_: int_, data: bytes) -> None: - """Write a packet to the socket.""" + def write_packets(self, packets: list[tuple[int, bytes]]) -> None: + """Write a packets to the socket. + + Packets are in the format of tuple[protobuf_type, protobuf_data] + """ if not self._is_ready: raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready") @@ -313,19 +316,32 @@ class APINoiseFrameHelper(APIFrameHelper): assert self._encrypt is not None, "Handshake should be complete" assert self._writer is not None, "Writer is not set" - data_len = len(data) - data_header = bytes( - ((type_ >> 8) & 0xFF, type_ & 0xFF, (data_len >> 8) & 0xFF, data_len & 0xFF) - ) - frame = self._encrypt(data_header + data) + out: list[bytes] = [] + debug_enabled = self._debug_enabled() + for packet in packets: + type_: int = packet[0] + data: bytes = packet[1] + data_len = len(data) + data_header = bytes( + ( + (type_ >> 8) & 0xFF, + type_ & 0xFF, + (data_len >> 8) & 0xFF, + data_len & 0xFF, + ) + ) + frame = self._encrypt(data_header + data) - if self._debug_enabled(): - _LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex()) + if debug_enabled is True: + _LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex()) + + frame_len = len(frame) + header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) + out.append(header) + out.append(frame) - frame_len = len(frame) - header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) try: - self._writer(header + frame) + self._writer(b"".join(out)) except WRITE_EXCEPTIONS as err: raise SocketAPIError( f"{self._log_name}: Error while writing data: {err}" diff --git a/aioesphomeapi/_frame_helper/plain_text.pxd b/aioesphomeapi/_frame_helper/plain_text.pxd index 0ef6fa9..164e710 100644 --- a/aioesphomeapi/_frame_helper/plain_text.pxd +++ b/aioesphomeapi/_frame_helper/plain_text.pxd @@ -28,3 +28,10 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper): cpdef data_received(self, bytes data) cpdef _error_on_incorrect_preamble(self, object preamble) + + @cython.locals( + type_="unsigned int", + data=bytes, + packet=tuple + ) + cpdef write_packets(self, list packets) diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index bb31c28..43bdf2c 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -59,20 +59,32 @@ class APIPlaintextFrameHelper(APIFrameHelper): super().connection_made(transport) self._ready_future.set_result(None) - def write_packet(self, type_: int, data: bytes) -> None: - """Write a packet to the socket. + def write_packets(self, packets: list[tuple[int, bytes]]) -> None: + """Write a packets to the socket. + + Packets are in the format of tuple[protobuf_type, protobuf_data] The entire packet must be written in a single call. """ if TYPE_CHECKING: assert self._writer is not None, "Writer should be set" - data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data - if self._debug_enabled(): - _LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex()) + out: list[bytes] = [] + debug_enabled = self._debug_enabled() + for packet in packets: + type_: int = packet[0] + data: bytes = packet[1] + out.append(b"\0") + out.append(varuint_to_bytes(len(data))) + out.append(varuint_to_bytes(type_)) + out.append(data) + if debug_enabled is True: + _LOGGER.debug( + "%s: Sending plaintext frame %s", self._log_name, data.hex() + ) try: - self._writer(data) + self._writer(b"".join(out)) except WRITE_EXCEPTIONS as err: raise SocketAPIError( f"{self._log_name}: Error while writing data: {err}" diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index c301342..e347349 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -1,5 +1,7 @@ import cython +from ._frame_helper.base cimport APIFrameHelper + cdef dict MESSAGE_TYPE_TO_PROTO cdef dict PROTO_TO_MESSAGE_TYPE @@ -47,7 +49,7 @@ cdef class APIConnection: cdef public object on_stop cdef object _on_stop_task cdef public object _socket - cdef public object _frame_helper + cdef public APIFrameHelper _frame_helper cdef public object api_version cdef public object connection_state cdef dict _message_handlers @@ -69,6 +71,8 @@ cdef class APIConnection: cpdef send_message(self, object msg) + cdef send_messages(self, tuple messages) + @cython.locals(handlers=set, handlers_copy=set) cpdef _process_packet(self, object msg_type_proto, object data) @@ -89,5 +93,3 @@ cdef class APIConnection: @cython.locals(handlers=set) cpdef _remove_message_callback(self, object on_message, tuple msg_types) - - cdef _send_messages(self, tuple messages) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 289aa37..7d53fa2 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -613,17 +613,11 @@ class APIConnection: connect.password = self._params.password return connect - def _send_messages(self, messages: tuple[message.Message, ...]) -> None: - """Send a message to the remote. - - Currently this is a wrapper around send_message - but may be changed in the future to batch messages - together. - """ - for msg in messages: - self.send_message(msg) - def send_message(self, msg: message.Message) -> None: + """Send a message to the remote.""" + self.send_messages((msg,)) + + def send_messages(self, msgs: tuple[message.Message, ...]) -> None: """Send a protobuf message to the remote.""" if not self._handshake_complete: if in_do_connect.get(False): @@ -635,23 +629,30 @@ class APIConnection: f"Connection isn't established yet ({self.connection_state})" ) - msg_type = type(msg) - if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None: - raise ValueError(f"Message type id not found for type {msg_type}") + packets: list[tuple[int, bytes]] = [] + debug_enabled = self._debug_enabled() - if self._debug_enabled() is True: - _LOGGER.debug("%s: Sending %s: %s", self.log_name, msg_type.__name__, msg) + for msg in msgs: + msg_type = type(msg) + if (message_type := PROTO_TO_MESSAGE_TYPE.get(msg_type)) is None: + raise ValueError(f"Message type id not found for type {msg_type}") + + if debug_enabled is True: + _LOGGER.debug( + "%s: Sending %s: %s", self.log_name, msg_type.__name__, msg + ) + + packets.append((message_type, msg.SerializeToString())) if TYPE_CHECKING: assert self._frame_helper is not None - encoded = msg.SerializeToString() try: - self._frame_helper.write_packet(message_type, encoded) + self._frame_helper.write_packets(packets) except SocketAPIError as err: # If writing packet fails, we don't know what state the frames # are in anymore and we have to close the connection - _LOGGER.info("%s: Error writing packet: %s", self.log_name, err) + _LOGGER.info("%s: Error writing packets: %s", self.log_name, err) self._report_fatal_error(err) raise @@ -738,7 +739,7 @@ class APIConnection: # Send the message right away to reduce latency. # This is safe because we are not awaiting between # sending the message and registering the handler - self._send_messages(messages) + self.send_messages(messages) loop = self._loop # Unsafe to await between sending the message and registering the handler fut: asyncio.Future[None] = loop.create_future() diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 18dbd35..8217bb7 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -475,7 +475,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): assert not writes await handshake_task - helper.write_packet(1, b"to device") + helper.write_packets([(1, b"to device")]) encrypted_packet = writes.pop() header = encrypted_packet[0:1] assert header == b"\x01"