diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index 59f05c1..29559cb 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -2,6 +2,8 @@ import cython +cdef bint TYPE_CHECKING + cdef class APIFrameHelper: cdef object _loop @@ -10,7 +12,7 @@ cdef class APIFrameHelper: cdef object _transport cdef public object _writer cdef public object _ready_future - cdef bytearray _buffer + cdef bytes _buffer cdef cython.uint _buffer_len cdef cython.uint _pos cdef object _client_info @@ -18,4 +20,9 @@ cdef class APIFrameHelper: cdef object _debug_enabled @cython.locals(original_pos=cython.uint, new_pos=cython.uint) - cdef _read_exactly(self, int length) \ No newline at end of file + cdef bytes _read_exactly(self, int length) + + cdef _add_to_buffer(self, bytes data) + + @cython.locals(end_of_frame_pos=cython.uint) + cdef _remove_from_buffer(self) \ No newline at end of file diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index ca606a8..e893145 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -4,7 +4,7 @@ import asyncio import logging from abc import abstractmethod from functools import partial -from typing import Callable, cast +from typing import TYPE_CHECKING, Callable, cast from ..core import HandshakeAPIError, SocketClosedAPIError @@ -55,7 +55,7 @@ class APIFrameHelper: self._transport: asyncio.Transport | None = None self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None self._ready_future = self._loop.create_future() - self._buffer = bytearray() + self._buffer: bytes | None = None self._buffer_len = 0 self._pos = 0 self._client_info = client_info @@ -66,13 +66,48 @@ class APIFrameHelper: if not self._ready_future.done(): self._ready_future.set_exception(exc) - def _read_exactly(self, length: _int) -> bytearray | None: + def _add_to_buffer(self, data: bytes) -> None: + """Add data to the buffer.""" + if self._buffer_len == 0: + # This is the best case scenario, we don't have to copy the data + # and can just use the buffer directly. This is the most common + # case as well. + self._buffer = data + else: + if TYPE_CHECKING: + assert self._buffer is not None, "Buffer should be set" + # This is the worst case scenario, we have to copy the data + # and can't just use the buffer directly. This is also very + # uncommon since we usually read the entire frame at once. + self._buffer += data + self._buffer_len += len(data) + + def _remove_from_buffer(self) -> None: + """Remove data from the buffer.""" + end_of_frame_pos = self._pos + self._buffer_len -= end_of_frame_pos + if self._buffer_len == 0: + # This is the best case scenario, we can just set the buffer to None + # and don't have to copy the data. This is the most common case as well. + self._buffer = None + return + if TYPE_CHECKING: + assert self._buffer is not None, "Buffer should be set" + # This is the worst case scenario, we have to copy the data + # and can't just use the buffer directly. This should only happen + # when we read multiple frames at once because the event loop + # is blocked and we cannot pull the data out of the buffer fast enough. + self._buffer = self._buffer[end_of_frame_pos:] + + def _read_exactly(self, length: _int) -> bytes | None: """Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" original_pos = self._pos new_pos = original_pos + length if self._buffer_len < new_pos: return None self._pos = new_pos + if TYPE_CHECKING: + assert self._buffer is not None, "Buffer should be set" return self._buffer[original_pos:new_pos] async def perform_handshake(self, timeout: float) -> None: diff --git a/aioesphomeapi/_frame_helper/noise.pxd b/aioesphomeapi/_frame_helper/noise.pxd index 17990b4..0734c9c 100644 --- a/aioesphomeapi/_frame_helper/noise.pxd +++ b/aioesphomeapi/_frame_helper/noise.pxd @@ -3,7 +3,7 @@ import cython from .base cimport APIFrameHelper -cdef object TYPE_CHECKING +cdef bint TYPE_CHECKING cdef class APINoiseFrameHelper(APIFrameHelper): @@ -18,10 +18,15 @@ cdef class APINoiseFrameHelper(APIFrameHelper): cdef bint _is_ready @cython.locals( - header=bytearray, + header=bytes, preamble=cython.uint, msg_size_high=cython.uint, msg_size_low=cython.uint, - end_of_frame_pos=cython.uint, ) - cpdef data_received(self, bytes data) \ No newline at end of file + cpdef data_received(self, bytes data) + + @cython.locals( + type_high=cython.uint, + type_low=cython.uint + ) + cpdef _handle_frame(self, bytes data) diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index 64bd41e..b0e4d34 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -139,8 +139,7 @@ class APINoiseFrameHelper(APIFrameHelper): await super().perform_handshake(timeout) def data_received(self, data: bytes) -> None: - self._buffer += data - self._buffer_len += len(data) + self._add_to_buffer(data) while self._buffer: self._pos = 0 header = self._read_exactly(3) @@ -170,9 +169,7 @@ class APINoiseFrameHelper(APIFrameHelper): except Exception as err: # pylint: disable=broad-except self._handle_error_and_close(err) finally: - end_of_frame_pos = self._pos - del self._buffer[:end_of_frame_pos] - self._buffer_len -= end_of_frame_pos + self._remove_from_buffer() def _send_hello_handshake(self) -> None: """Send a ClientHello to the server.""" @@ -198,7 +195,7 @@ class APINoiseFrameHelper(APIFrameHelper): f"{self._log_name}: Error while writing data: {err}" ) from err - def _handle_hello(self, server_hello: bytearray) -> None: + def _handle_hello(self, server_hello: bytes) -> None: """Perform the handshake with the server.""" if not server_hello: self._handle_error_and_close( @@ -269,7 +266,7 @@ class APINoiseFrameHelper(APIFrameHelper): proto.start_handshake() self._proto = proto - def _handle_handshake(self, msg: bytearray) -> None: + def _handle_handshake(self, msg: bytes) -> None: _LOGGER.debug("Starting handshake...") if msg[0] != 0: explanation = msg[1:].decode() @@ -333,12 +330,12 @@ class APINoiseFrameHelper(APIFrameHelper): f"{self._log_name}: Error while writing data: {err}" ) from err - def _handle_frame(self, frame: bytearray) -> None: + def _handle_frame(self, frame: bytes) -> None: """Handle an incoming frame.""" if TYPE_CHECKING: assert self._decrypt is not None, "Handshake should be complete" try: - msg = self._decrypt(bytes(frame)) + msg = self._decrypt(frame) except InvalidTag as ex: self._handle_error_and_close( ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}") @@ -348,11 +345,11 @@ class APINoiseFrameHelper(APIFrameHelper): # 2 bytes: message type # 2 bytes: message length # N bytes: message data - self._on_pkt((msg[0] << 8) | msg[1], msg[4:]) + type_high = msg[0] + type_low = msg[1] + self._on_pkt((type_high << 8) | type_low, msg[4:]) - def _handle_closed( # pylint: disable=unused-argument - self, frame: bytearray - ) -> None: + def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument """Handle a closed frame.""" self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed")) diff --git a/aioesphomeapi/_frame_helper/plain_text.pxd b/aioesphomeapi/_frame_helper/plain_text.pxd index 656a27f..b0bd3a1 100644 --- a/aioesphomeapi/_frame_helper/plain_text.pxd +++ b/aioesphomeapi/_frame_helper/plain_text.pxd @@ -3,7 +3,7 @@ import cython from .base cimport APIFrameHelper -cdef object TYPE_CHECKING +cdef bint TYPE_CHECKING cdef object WRITE_EXCEPTIONS cdef object bytes_to_varuint, varuint_to_bytes @@ -12,12 +12,12 @@ cdef class APIPlaintextFrameHelper(APIFrameHelper): @cython.locals( msg_type=bytes, length=bytes, - init_bytes=bytearray, - add_length=bytearray, + init_bytes=bytes, + add_length=bytes, end_of_frame_pos=cython.uint, length_int=cython.uint, preamble=cython.uint, length_high=cython.uint, maybe_msg_type=cython.uint ) - cpdef data_received(self, bytes data) \ No newline at end of file + cpdef data_received(self, bytes data) diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index 059bca8..b9b149b 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -39,8 +39,7 @@ class APIPlaintextFrameHelper(APIFrameHelper): ) from err def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches - self._buffer += data - self._buffer_len += len(data) + self._add_to_buffer(data) while self._buffer: # Read preamble, which should always 0x00 # Also try to get the length and msg type @@ -80,10 +79,10 @@ class APIPlaintextFrameHelper(APIFrameHelper): msg_type_int = maybe_msg_type else: # Message type is longer than 1 byte - msg_type = bytes(init_bytes[2:3]) + msg_type = init_bytes[2:3] else: # Length is longer than 1 byte - length = bytes(init_bytes[1:3]) + length = init_bytes[1:3] # If the message is long, we need to read the rest of the length while length[-1] & 0x80 == 0x80: add_length = self._read_exactly(1) @@ -112,18 +111,16 @@ class APIPlaintextFrameHelper(APIFrameHelper): if length_int == 0: packet_data = b"" else: - packet_data_bytearray = self._read_exactly(length_int) + maybe_packet_data = self._read_exactly(length_int) # The packet data is not yet available, wait for more data # to arrive before continuing, since callback_packet has not # been called yet the buffer will not be cleared and the next # call to data_received will continue processing the packet # at the start of the frame. - if packet_data_bytearray is None: + if maybe_packet_data is None: return - packet_data = bytes(packet_data_bytearray) + packet_data = maybe_packet_data - end_of_frame_pos = self._pos - del self._buffer[:end_of_frame_pos] - self._buffer_len -= end_of_frame_pos + self._remove_from_buffer() self._on_pkt(msg_type_int, packet_data) # If we have more data, continue processing