mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-02-23 02:52:37 +01:00
Refactor reading varuints to significant simplify plaintext frame helper (#718)
This commit is contained in:
parent
cf3ada3deb
commit
cd5ad769f0
@ -25,6 +25,14 @@ cdef class APIFrameHelper:
|
||||
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
||||
cdef bytes _read(self, int length)
|
||||
|
||||
@cython.locals(
|
||||
result="unsigned int",
|
||||
bitpos="unsigned int",
|
||||
val="unsigned char",
|
||||
current_pos="unsigned int"
|
||||
)
|
||||
cdef int _read_varuint(self)
|
||||
|
||||
@cython.locals(bytes_data=bytes)
|
||||
cdef void _add_to_buffer(self, object data)
|
||||
|
||||
|
@ -119,6 +119,21 @@ class APIFrameHelper:
|
||||
assert self._buffer is not None, "Buffer should be set"
|
||||
return self._buffer[original_pos:new_pos]
|
||||
|
||||
def _read_varuint(self) -> _int:
|
||||
"""Read a varuint from the buffer or -1 if the buffer runs out of bytes."""
|
||||
if TYPE_CHECKING:
|
||||
assert self._buffer is not None, "Buffer should be set"
|
||||
result = 0
|
||||
bitpos = 0
|
||||
while self._buffer_len > self._pos:
|
||||
val = self._buffer[self._pos]
|
||||
self._pos += 1
|
||||
result |= (val & 0x7F) << bitpos
|
||||
if (val & 0x80) == 0:
|
||||
return result
|
||||
bitpos += 7
|
||||
return -1
|
||||
|
||||
async def perform_handshake(self, timeout: float) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
handshake_handle = self._loop.call_at(
|
||||
|
@ -4,30 +4,15 @@ from ..connection cimport APIConnection
|
||||
from .base cimport APIFrameHelper
|
||||
|
||||
|
||||
cdef bint TYPE_CHECKING
|
||||
cdef object bytes_to_varuint, varuint_to_bytes
|
||||
cdef object varuint_to_bytes
|
||||
|
||||
cpdef _varuint_to_bytes(cython.int value)
|
||||
|
||||
@cython.locals(result=cython.int, bitpos=cython.int, val=cython.int)
|
||||
cpdef _bytes_to_varuint(cython.bytes value)
|
||||
|
||||
cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
|
||||
@cython.locals(
|
||||
msg_type=bytes,
|
||||
length=bytes,
|
||||
init_bytes=bytes,
|
||||
add_length=bytes,
|
||||
end_of_frame_pos=cython.uint,
|
||||
length_int=cython.uint,
|
||||
preamble="unsigned char",
|
||||
length_high="unsigned char",
|
||||
maybe_msg_type="unsigned char"
|
||||
)
|
||||
cpdef data_received(self, object data)
|
||||
|
||||
cdef void _error_on_incorrect_preamble(self, object preamble)
|
||||
cdef void _error_on_incorrect_preamble(self, int preamble)
|
||||
|
||||
@cython.locals(
|
||||
type_="unsigned int",
|
||||
|
@ -2,13 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..core import ProtocolAPIError, RequiresEncryptionAPIError
|
||||
from .base import APIFrameHelper
|
||||
|
||||
_int = int
|
||||
_bytes = bytes
|
||||
|
||||
|
||||
def _varuint_to_bytes(value: _int) -> bytes:
|
||||
@ -32,22 +30,6 @@ _cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes)
|
||||
varuint_to_bytes = _cached_varuint_to_bytes
|
||||
|
||||
|
||||
def _bytes_to_varuint(value: _bytes) -> _int | None:
|
||||
"""Convert bytes to a varuint."""
|
||||
result = 0
|
||||
bitpos = 0
|
||||
for val in value:
|
||||
result |= (val & 0x7F) << bitpos
|
||||
if (val & 0x80) == 0:
|
||||
return result
|
||||
bitpos += 7
|
||||
return None
|
||||
|
||||
|
||||
_cached_bytes_to_varuint = lru_cache(maxsize=1024)(_bytes_to_varuint)
|
||||
bytes_to_varuint = _cached_bytes_to_varuint
|
||||
|
||||
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for plaintext API connections."""
|
||||
|
||||
@ -81,65 +63,17 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
) -> None:
|
||||
self._add_to_buffer(data)
|
||||
while self._buffer_len:
|
||||
# Read preamble, which should always 0x00
|
||||
# Also try to get the length and msg type
|
||||
# to avoid multiple calls to _read
|
||||
self._pos = 0
|
||||
if (init_bytes := self._read(3)) is None:
|
||||
return
|
||||
msg_type_int: int | None = None
|
||||
length_int = 0
|
||||
preamble = init_bytes[0]
|
||||
length_high = init_bytes[1]
|
||||
maybe_msg_type = init_bytes[2]
|
||||
if preamble != 0x00:
|
||||
# Read preamble, which should always 0x00
|
||||
if (preamble := self._read_varuint()) != 0x00:
|
||||
self._error_on_incorrect_preamble(preamble)
|
||||
return
|
||||
if (length := self._read_varuint()) == -1:
|
||||
return
|
||||
if (msg_type := self._read_varuint()) == -1:
|
||||
return
|
||||
|
||||
if length_high & 0x80 != 0x80:
|
||||
# Length is only 1 byte
|
||||
#
|
||||
# This is the most common case needing a single byte for
|
||||
# length and type which means we avoid 2 calls to _read
|
||||
length_int = length_high
|
||||
if maybe_msg_type & 0x80 != 0x80:
|
||||
# Message type is also only 1 byte
|
||||
msg_type_int = maybe_msg_type
|
||||
else:
|
||||
# Message type is longer than 1 byte
|
||||
msg_type = init_bytes[2:3]
|
||||
else:
|
||||
# Length is longer than 1 byte
|
||||
length = init_bytes[1:3]
|
||||
# If the message is long, we need to read the rest of the length
|
||||
while length[-1] & 0x80 == 0x80:
|
||||
if (add_length := self._read(1)) is None:
|
||||
return
|
||||
length += add_length
|
||||
length_int = bytes_to_varuint(length) or 0
|
||||
# Since the length is longer than 1 byte we do not have the
|
||||
# message type yet.
|
||||
if (msg_type_byte := self._read(1)) is None:
|
||||
return
|
||||
msg_type = msg_type_byte
|
||||
if msg_type[-1] & 0x80 != 0x80:
|
||||
# Message type is only 1 byte
|
||||
msg_type_int = msg_type[0]
|
||||
|
||||
# If the we do not have the message type yet because the message
|
||||
# length was so long it did not fit into the first byte we need
|
||||
# to read the (rest) of the message type
|
||||
if msg_type_int is None:
|
||||
while msg_type[-1] & 0x80 == 0x80:
|
||||
if (add_msg_type := self._read(1)) is None:
|
||||
return
|
||||
msg_type += add_msg_type
|
||||
msg_type_int = bytes_to_varuint(msg_type)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert msg_type_int is not None
|
||||
|
||||
if length_int == 0:
|
||||
if length == 0:
|
||||
packet_data = b""
|
||||
else:
|
||||
# The packet data is not yet available, wait for more data
|
||||
@ -147,12 +81,12 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if (maybe_packet_data := self._read(length_int)) is None:
|
||||
if (maybe_packet_data := self._read(length)) is None:
|
||||
return
|
||||
packet_data = maybe_packet_data
|
||||
|
||||
self._remove_from_buffer()
|
||||
self._connection.process_packet(msg_type_int, packet_data)
|
||||
self._connection.process_packet(msg_type, packet_data)
|
||||
# If we have more data, continue processing
|
||||
|
||||
def _error_on_incorrect_preamble(self, preamble: _int) -> None:
|
||||
|
@ -12,10 +12,6 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||
from aioesphomeapi import APIConnection
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND
|
||||
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
||||
from aioesphomeapi._frame_helper.plain_text import (
|
||||
_cached_bytes_to_varuint as cached_bytes_to_varuint,
|
||||
)
|
||||
from aioesphomeapi._frame_helper.plain_text import (
|
||||
_cached_varuint_to_bytes as cached_varuint_to_bytes,
|
||||
)
|
||||
@ -459,16 +455,6 @@ def test_varuint_to_bytes(val, encoded):
|
||||
assert cached_varuint_to_bytes(val) == encoded
|
||||
|
||||
|
||||
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
|
||||
def test_bytes_to_varuint(val, encoded):
|
||||
assert bytes_to_varuint(encoded) == val
|
||||
assert cached_bytes_to_varuint(encoded) == val
|
||||
|
||||
|
||||
def test_bytes_to_varuint_invalid():
|
||||
assert bytes_to_varuint(b"\xFF") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_frame_helper_handshake_failure():
|
||||
"""Test the noise frame helper handshake failure."""
|
||||
|
Loading…
Reference in New Issue
Block a user