Refactor reading varuints to significant simplify plaintext frame helper (#718)

This commit is contained in:
J. Nick Koston 2023-11-25 14:17:24 -06:00 committed by GitHub
parent cf3ada3deb
commit cd5ad769f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 106 deletions

View File

@ -25,6 +25,14 @@ cdef class APIFrameHelper:
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
cdef bytes _read(self, int length)
@cython.locals(
result="unsigned int",
bitpos="unsigned int",
val="unsigned char",
current_pos="unsigned int"
)
cdef int _read_varuint(self)
@cython.locals(bytes_data=bytes)
cdef void _add_to_buffer(self, object data)

View File

@ -119,6 +119,21 @@ class APIFrameHelper:
assert self._buffer is not None, "Buffer should be set"
return self._buffer[original_pos:new_pos]
def _read_varuint(self) -> _int:
"""Read a varuint from the buffer or -1 if the buffer runs out of bytes."""
if TYPE_CHECKING:
assert self._buffer is not None, "Buffer should be set"
result = 0
bitpos = 0
while self._buffer_len > self._pos:
val = self._buffer[self._pos]
self._pos += 1
result |= (val & 0x7F) << bitpos
if (val & 0x80) == 0:
return result
bitpos += 7
return -1
async def perform_handshake(self, timeout: float) -> None:
"""Perform the handshake with the server."""
handshake_handle = self._loop.call_at(

View File

@ -4,30 +4,15 @@ from ..connection cimport APIConnection
from .base cimport APIFrameHelper
cdef bint TYPE_CHECKING
cdef object bytes_to_varuint, varuint_to_bytes
cdef object varuint_to_bytes
cpdef _varuint_to_bytes(cython.int value)
@cython.locals(result=cython.int, bitpos=cython.int, val=cython.int)
cpdef _bytes_to_varuint(cython.bytes value)
cdef class APIPlaintextFrameHelper(APIFrameHelper):
@cython.locals(
msg_type=bytes,
length=bytes,
init_bytes=bytes,
add_length=bytes,
end_of_frame_pos=cython.uint,
length_int=cython.uint,
preamble="unsigned char",
length_high="unsigned char",
maybe_msg_type="unsigned char"
)
cpdef data_received(self, object data)
cdef void _error_on_incorrect_preamble(self, object preamble)
cdef void _error_on_incorrect_preamble(self, int preamble)
@cython.locals(
type_="unsigned int",

View File

@ -2,13 +2,11 @@ from __future__ import annotations
import asyncio
from functools import lru_cache
from typing import TYPE_CHECKING
from ..core import ProtocolAPIError, RequiresEncryptionAPIError
from .base import APIFrameHelper
_int = int
_bytes = bytes
def _varuint_to_bytes(value: _int) -> bytes:
@ -32,22 +30,6 @@ _cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes)
varuint_to_bytes = _cached_varuint_to_bytes
def _bytes_to_varuint(value: _bytes) -> _int | None:
"""Convert bytes to a varuint."""
result = 0
bitpos = 0
for val in value:
result |= (val & 0x7F) << bitpos
if (val & 0x80) == 0:
return result
bitpos += 7
return None
_cached_bytes_to_varuint = lru_cache(maxsize=1024)(_bytes_to_varuint)
bytes_to_varuint = _cached_bytes_to_varuint
class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
@ -81,65 +63,17 @@ class APIPlaintextFrameHelper(APIFrameHelper):
) -> None:
self._add_to_buffer(data)
while self._buffer_len:
# Read preamble, which should always 0x00
# Also try to get the length and msg type
# to avoid multiple calls to _read
self._pos = 0
if (init_bytes := self._read(3)) is None:
return
msg_type_int: int | None = None
length_int = 0
preamble = init_bytes[0]
length_high = init_bytes[1]
maybe_msg_type = init_bytes[2]
if preamble != 0x00:
# Read preamble, which should always 0x00
if (preamble := self._read_varuint()) != 0x00:
self._error_on_incorrect_preamble(preamble)
return
if (length := self._read_varuint()) == -1:
return
if (msg_type := self._read_varuint()) == -1:
return
if length_high & 0x80 != 0x80:
# Length is only 1 byte
#
# This is the most common case needing a single byte for
# length and type which means we avoid 2 calls to _read
length_int = length_high
if maybe_msg_type & 0x80 != 0x80:
# Message type is also only 1 byte
msg_type_int = maybe_msg_type
else:
# Message type is longer than 1 byte
msg_type = init_bytes[2:3]
else:
# Length is longer than 1 byte
length = init_bytes[1:3]
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
if (add_length := self._read(1)) is None:
return
length += add_length
length_int = bytes_to_varuint(length) or 0
# Since the length is longer than 1 byte we do not have the
# message type yet.
if (msg_type_byte := self._read(1)) is None:
return
msg_type = msg_type_byte
if msg_type[-1] & 0x80 != 0x80:
# Message type is only 1 byte
msg_type_int = msg_type[0]
# If the we do not have the message type yet because the message
# length was so long it did not fit into the first byte we need
# to read the (rest) of the message type
if msg_type_int is None:
while msg_type[-1] & 0x80 == 0x80:
if (add_msg_type := self._read(1)) is None:
return
msg_type += add_msg_type
msg_type_int = bytes_to_varuint(msg_type)
if TYPE_CHECKING:
assert msg_type_int is not None
if length_int == 0:
if length == 0:
packet_data = b""
else:
# The packet data is not yet available, wait for more data
@ -147,12 +81,12 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if (maybe_packet_data := self._read(length_int)) is None:
if (maybe_packet_data := self._read(length)) is None:
return
packet_data = maybe_packet_data
self._remove_from_buffer()
self._connection.process_packet(msg_type_int, packet_data)
self._connection.process_packet(msg_type, packet_data)
# If we have more data, continue processing
def _error_on_incorrect_preamble(self, preamble: _int) -> None:

View File

@ -12,10 +12,6 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
from aioesphomeapi import APIConnection
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
from aioesphomeapi._frame_helper.plain_text import (
_cached_bytes_to_varuint as cached_bytes_to_varuint,
)
from aioesphomeapi._frame_helper.plain_text import (
_cached_varuint_to_bytes as cached_varuint_to_bytes,
)
@ -459,16 +455,6 @@ def test_varuint_to_bytes(val, encoded):
assert cached_varuint_to_bytes(val) == encoded
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
def test_bytes_to_varuint(val, encoded):
assert bytes_to_varuint(encoded) == val
assert cached_bytes_to_varuint(encoded) == val
def test_bytes_to_varuint_invalid():
assert bytes_to_varuint(b"\xFF") is None
@pytest.mark.asyncio
async def test_noise_frame_helper_handshake_failure():
"""Test the noise frame helper handshake failure."""