From cd5ad769f0d254af8393956434d46c9c4eaafa84 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 25 Nov 2023 14:17:24 -0600 Subject: [PATCH] Refactor reading varuints to significant simplify plaintext frame helper (#718) --- aioesphomeapi/_frame_helper/base.pxd | 8 +++ aioesphomeapi/_frame_helper/base.py | 15 ++++ aioesphomeapi/_frame_helper/plain_text.pxd | 19 +---- aioesphomeapi/_frame_helper/plain_text.py | 84 +++------------------- tests/test__frame_helper.py | 14 ---- 5 files changed, 34 insertions(+), 106 deletions(-) diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index dce21f1..e70acab 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -25,6 +25,14 @@ cdef class APIFrameHelper: @cython.locals(original_pos="unsigned int", new_pos="unsigned int") cdef bytes _read(self, int length) + @cython.locals( + result="unsigned int", + bitpos="unsigned int", + val="unsigned char", + current_pos="unsigned int" + ) + cdef int _read_varuint(self) + @cython.locals(bytes_data=bytes) cdef void _add_to_buffer(self, object data) diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index 9adb521..194bf1c 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -119,6 +119,21 @@ class APIFrameHelper: assert self._buffer is not None, "Buffer should be set" return self._buffer[original_pos:new_pos] + def _read_varuint(self) -> _int: + """Read a varuint from the buffer or -1 if the buffer runs out of bytes.""" + if TYPE_CHECKING: + assert self._buffer is not None, "Buffer should be set" + result = 0 + bitpos = 0 + while self._buffer_len > self._pos: + val = self._buffer[self._pos] + self._pos += 1 + result |= (val & 0x7F) << bitpos + if (val & 0x80) == 0: + return result + bitpos += 7 + return -1 + async def perform_handshake(self, timeout: float) -> None: """Perform the handshake with the server.""" handshake_handle = self._loop.call_at( diff --git a/aioesphomeapi/_frame_helper/plain_text.pxd b/aioesphomeapi/_frame_helper/plain_text.pxd index f1c1e6a..382b3b8 100644 --- a/aioesphomeapi/_frame_helper/plain_text.pxd +++ b/aioesphomeapi/_frame_helper/plain_text.pxd @@ -4,30 +4,15 @@ from ..connection cimport APIConnection from .base cimport APIFrameHelper -cdef bint TYPE_CHECKING -cdef object bytes_to_varuint, varuint_to_bytes +cdef object varuint_to_bytes cpdef _varuint_to_bytes(cython.int value) -@cython.locals(result=cython.int, bitpos=cython.int, val=cython.int) -cpdef _bytes_to_varuint(cython.bytes value) - cdef class APIPlaintextFrameHelper(APIFrameHelper): - @cython.locals( - msg_type=bytes, - length=bytes, - init_bytes=bytes, - add_length=bytes, - end_of_frame_pos=cython.uint, - length_int=cython.uint, - preamble="unsigned char", - length_high="unsigned char", - maybe_msg_type="unsigned char" - ) cpdef data_received(self, object data) - cdef void _error_on_incorrect_preamble(self, object preamble) + cdef void _error_on_incorrect_preamble(self, int preamble) @cython.locals( type_="unsigned int", diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index 3d873c2..b250bf5 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -2,13 +2,11 @@ from __future__ import annotations import asyncio from functools import lru_cache -from typing import TYPE_CHECKING from ..core import ProtocolAPIError, RequiresEncryptionAPIError from .base import APIFrameHelper _int = int -_bytes = bytes def _varuint_to_bytes(value: _int) -> bytes: @@ -32,22 +30,6 @@ _cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes) varuint_to_bytes = _cached_varuint_to_bytes -def _bytes_to_varuint(value: _bytes) -> _int | None: - """Convert bytes to a varuint.""" - result = 0 - bitpos = 0 - for val in value: - result |= (val & 0x7F) << bitpos - if (val & 0x80) == 0: - return result - bitpos += 7 - return None - - -_cached_bytes_to_varuint = lru_cache(maxsize=1024)(_bytes_to_varuint) -bytes_to_varuint = _cached_bytes_to_varuint - - class APIPlaintextFrameHelper(APIFrameHelper): """Frame helper for plaintext API connections.""" @@ -81,65 +63,17 @@ class APIPlaintextFrameHelper(APIFrameHelper): ) -> None: self._add_to_buffer(data) while self._buffer_len: - # Read preamble, which should always 0x00 - # Also try to get the length and msg type - # to avoid multiple calls to _read self._pos = 0 - if (init_bytes := self._read(3)) is None: - return - msg_type_int: int | None = None - length_int = 0 - preamble = init_bytes[0] - length_high = init_bytes[1] - maybe_msg_type = init_bytes[2] - if preamble != 0x00: + # Read preamble, which should always 0x00 + if (preamble := self._read_varuint()) != 0x00: self._error_on_incorrect_preamble(preamble) return + if (length := self._read_varuint()) == -1: + return + if (msg_type := self._read_varuint()) == -1: + return - if length_high & 0x80 != 0x80: - # Length is only 1 byte - # - # This is the most common case needing a single byte for - # length and type which means we avoid 2 calls to _read - length_int = length_high - if maybe_msg_type & 0x80 != 0x80: - # Message type is also only 1 byte - msg_type_int = maybe_msg_type - else: - # Message type is longer than 1 byte - msg_type = init_bytes[2:3] - else: - # Length is longer than 1 byte - length = init_bytes[1:3] - # If the message is long, we need to read the rest of the length - while length[-1] & 0x80 == 0x80: - if (add_length := self._read(1)) is None: - return - length += add_length - length_int = bytes_to_varuint(length) or 0 - # Since the length is longer than 1 byte we do not have the - # message type yet. - if (msg_type_byte := self._read(1)) is None: - return - msg_type = msg_type_byte - if msg_type[-1] & 0x80 != 0x80: - # Message type is only 1 byte - msg_type_int = msg_type[0] - - # If the we do not have the message type yet because the message - # length was so long it did not fit into the first byte we need - # to read the (rest) of the message type - if msg_type_int is None: - while msg_type[-1] & 0x80 == 0x80: - if (add_msg_type := self._read(1)) is None: - return - msg_type += add_msg_type - msg_type_int = bytes_to_varuint(msg_type) - - if TYPE_CHECKING: - assert msg_type_int is not None - - if length_int == 0: + if length == 0: packet_data = b"" else: # The packet data is not yet available, wait for more data @@ -147,12 +81,12 @@ class APIPlaintextFrameHelper(APIFrameHelper): # been called yet the buffer will not be cleared and the next # call to data_received will continue processing the packet # at the start of the frame. - if (maybe_packet_data := self._read(length_int)) is None: + if (maybe_packet_data := self._read(length)) is None: return packet_data = maybe_packet_data self._remove_from_buffer() - self._connection.process_packet(msg_type_int, packet_data) + self._connection.process_packet(msg_type, packet_data) # If we have more data, continue processing def _error_on_incorrect_preamble(self, preamble: _int) -> None: diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 38d5235..875e9ca 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -12,10 +12,6 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped] from aioesphomeapi import APIConnection from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND -from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint -from aioesphomeapi._frame_helper.plain_text import ( - _cached_bytes_to_varuint as cached_bytes_to_varuint, -) from aioesphomeapi._frame_helper.plain_text import ( _cached_varuint_to_bytes as cached_varuint_to_bytes, ) @@ -459,16 +455,6 @@ def test_varuint_to_bytes(val, encoded): assert cached_varuint_to_bytes(val) == encoded -@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES) -def test_bytes_to_varuint(val, encoded): - assert bytes_to_varuint(encoded) == val - assert cached_bytes_to_varuint(encoded) == val - - -def test_bytes_to_varuint_invalid(): - assert bytes_to_varuint(b"\xFF") is None - - @pytest.mark.asyncio async def test_noise_frame_helper_handshake_failure(): """Test the noise frame helper handshake failure."""