mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-28 04:27:27 +02:00
Refactor reading varuints to significant simplify plaintext frame helper (#718)
This commit is contained in:
parent
cf3ada3deb
commit
cd5ad769f0
@ -25,6 +25,14 @@ cdef class APIFrameHelper:
|
|||||||
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
||||||
cdef bytes _read(self, int length)
|
cdef bytes _read(self, int length)
|
||||||
|
|
||||||
|
@cython.locals(
|
||||||
|
result="unsigned int",
|
||||||
|
bitpos="unsigned int",
|
||||||
|
val="unsigned char",
|
||||||
|
current_pos="unsigned int"
|
||||||
|
)
|
||||||
|
cdef int _read_varuint(self)
|
||||||
|
|
||||||
@cython.locals(bytes_data=bytes)
|
@cython.locals(bytes_data=bytes)
|
||||||
cdef void _add_to_buffer(self, object data)
|
cdef void _add_to_buffer(self, object data)
|
||||||
|
|
||||||
|
@ -119,6 +119,21 @@ class APIFrameHelper:
|
|||||||
assert self._buffer is not None, "Buffer should be set"
|
assert self._buffer is not None, "Buffer should be set"
|
||||||
return self._buffer[original_pos:new_pos]
|
return self._buffer[original_pos:new_pos]
|
||||||
|
|
||||||
|
def _read_varuint(self) -> _int:
|
||||||
|
"""Read a varuint from the buffer or -1 if the buffer runs out of bytes."""
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert self._buffer is not None, "Buffer should be set"
|
||||||
|
result = 0
|
||||||
|
bitpos = 0
|
||||||
|
while self._buffer_len > self._pos:
|
||||||
|
val = self._buffer[self._pos]
|
||||||
|
self._pos += 1
|
||||||
|
result |= (val & 0x7F) << bitpos
|
||||||
|
if (val & 0x80) == 0:
|
||||||
|
return result
|
||||||
|
bitpos += 7
|
||||||
|
return -1
|
||||||
|
|
||||||
async def perform_handshake(self, timeout: float) -> None:
|
async def perform_handshake(self, timeout: float) -> None:
|
||||||
"""Perform the handshake with the server."""
|
"""Perform the handshake with the server."""
|
||||||
handshake_handle = self._loop.call_at(
|
handshake_handle = self._loop.call_at(
|
||||||
|
@ -4,30 +4,15 @@ from ..connection cimport APIConnection
|
|||||||
from .base cimport APIFrameHelper
|
from .base cimport APIFrameHelper
|
||||||
|
|
||||||
|
|
||||||
cdef bint TYPE_CHECKING
|
cdef object varuint_to_bytes
|
||||||
cdef object bytes_to_varuint, varuint_to_bytes
|
|
||||||
|
|
||||||
cpdef _varuint_to_bytes(cython.int value)
|
cpdef _varuint_to_bytes(cython.int value)
|
||||||
|
|
||||||
@cython.locals(result=cython.int, bitpos=cython.int, val=cython.int)
|
|
||||||
cpdef _bytes_to_varuint(cython.bytes value)
|
|
||||||
|
|
||||||
cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
|
|
||||||
@cython.locals(
|
|
||||||
msg_type=bytes,
|
|
||||||
length=bytes,
|
|
||||||
init_bytes=bytes,
|
|
||||||
add_length=bytes,
|
|
||||||
end_of_frame_pos=cython.uint,
|
|
||||||
length_int=cython.uint,
|
|
||||||
preamble="unsigned char",
|
|
||||||
length_high="unsigned char",
|
|
||||||
maybe_msg_type="unsigned char"
|
|
||||||
)
|
|
||||||
cpdef data_received(self, object data)
|
cpdef data_received(self, object data)
|
||||||
|
|
||||||
cdef void _error_on_incorrect_preamble(self, object preamble)
|
cdef void _error_on_incorrect_preamble(self, int preamble)
|
||||||
|
|
||||||
@cython.locals(
|
@cython.locals(
|
||||||
type_="unsigned int",
|
type_="unsigned int",
|
||||||
|
@ -2,13 +2,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from ..core import ProtocolAPIError, RequiresEncryptionAPIError
|
from ..core import ProtocolAPIError, RequiresEncryptionAPIError
|
||||||
from .base import APIFrameHelper
|
from .base import APIFrameHelper
|
||||||
|
|
||||||
_int = int
|
_int = int
|
||||||
_bytes = bytes
|
|
||||||
|
|
||||||
|
|
||||||
def _varuint_to_bytes(value: _int) -> bytes:
|
def _varuint_to_bytes(value: _int) -> bytes:
|
||||||
@ -32,22 +30,6 @@ _cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes)
|
|||||||
varuint_to_bytes = _cached_varuint_to_bytes
|
varuint_to_bytes = _cached_varuint_to_bytes
|
||||||
|
|
||||||
|
|
||||||
def _bytes_to_varuint(value: _bytes) -> _int | None:
|
|
||||||
"""Convert bytes to a varuint."""
|
|
||||||
result = 0
|
|
||||||
bitpos = 0
|
|
||||||
for val in value:
|
|
||||||
result |= (val & 0x7F) << bitpos
|
|
||||||
if (val & 0x80) == 0:
|
|
||||||
return result
|
|
||||||
bitpos += 7
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
_cached_bytes_to_varuint = lru_cache(maxsize=1024)(_bytes_to_varuint)
|
|
||||||
bytes_to_varuint = _cached_bytes_to_varuint
|
|
||||||
|
|
||||||
|
|
||||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
"""Frame helper for plaintext API connections."""
|
"""Frame helper for plaintext API connections."""
|
||||||
|
|
||||||
@ -81,65 +63,17 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._add_to_buffer(data)
|
self._add_to_buffer(data)
|
||||||
while self._buffer_len:
|
while self._buffer_len:
|
||||||
# Read preamble, which should always 0x00
|
|
||||||
# Also try to get the length and msg type
|
|
||||||
# to avoid multiple calls to _read
|
|
||||||
self._pos = 0
|
self._pos = 0
|
||||||
if (init_bytes := self._read(3)) is None:
|
# Read preamble, which should always 0x00
|
||||||
return
|
if (preamble := self._read_varuint()) != 0x00:
|
||||||
msg_type_int: int | None = None
|
|
||||||
length_int = 0
|
|
||||||
preamble = init_bytes[0]
|
|
||||||
length_high = init_bytes[1]
|
|
||||||
maybe_msg_type = init_bytes[2]
|
|
||||||
if preamble != 0x00:
|
|
||||||
self._error_on_incorrect_preamble(preamble)
|
self._error_on_incorrect_preamble(preamble)
|
||||||
return
|
return
|
||||||
|
if (length := self._read_varuint()) == -1:
|
||||||
if length_high & 0x80 != 0x80:
|
|
||||||
# Length is only 1 byte
|
|
||||||
#
|
|
||||||
# This is the most common case needing a single byte for
|
|
||||||
# length and type which means we avoid 2 calls to _read
|
|
||||||
length_int = length_high
|
|
||||||
if maybe_msg_type & 0x80 != 0x80:
|
|
||||||
# Message type is also only 1 byte
|
|
||||||
msg_type_int = maybe_msg_type
|
|
||||||
else:
|
|
||||||
# Message type is longer than 1 byte
|
|
||||||
msg_type = init_bytes[2:3]
|
|
||||||
else:
|
|
||||||
# Length is longer than 1 byte
|
|
||||||
length = init_bytes[1:3]
|
|
||||||
# If the message is long, we need to read the rest of the length
|
|
||||||
while length[-1] & 0x80 == 0x80:
|
|
||||||
if (add_length := self._read(1)) is None:
|
|
||||||
return
|
return
|
||||||
length += add_length
|
if (msg_type := self._read_varuint()) == -1:
|
||||||
length_int = bytes_to_varuint(length) or 0
|
|
||||||
# Since the length is longer than 1 byte we do not have the
|
|
||||||
# message type yet.
|
|
||||||
if (msg_type_byte := self._read(1)) is None:
|
|
||||||
return
|
return
|
||||||
msg_type = msg_type_byte
|
|
||||||
if msg_type[-1] & 0x80 != 0x80:
|
|
||||||
# Message type is only 1 byte
|
|
||||||
msg_type_int = msg_type[0]
|
|
||||||
|
|
||||||
# If the we do not have the message type yet because the message
|
if length == 0:
|
||||||
# length was so long it did not fit into the first byte we need
|
|
||||||
# to read the (rest) of the message type
|
|
||||||
if msg_type_int is None:
|
|
||||||
while msg_type[-1] & 0x80 == 0x80:
|
|
||||||
if (add_msg_type := self._read(1)) is None:
|
|
||||||
return
|
|
||||||
msg_type += add_msg_type
|
|
||||||
msg_type_int = bytes_to_varuint(msg_type)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
assert msg_type_int is not None
|
|
||||||
|
|
||||||
if length_int == 0:
|
|
||||||
packet_data = b""
|
packet_data = b""
|
||||||
else:
|
else:
|
||||||
# The packet data is not yet available, wait for more data
|
# The packet data is not yet available, wait for more data
|
||||||
@ -147,12 +81,12 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
# been called yet the buffer will not be cleared and the next
|
# been called yet the buffer will not be cleared and the next
|
||||||
# call to data_received will continue processing the packet
|
# call to data_received will continue processing the packet
|
||||||
# at the start of the frame.
|
# at the start of the frame.
|
||||||
if (maybe_packet_data := self._read(length_int)) is None:
|
if (maybe_packet_data := self._read(length)) is None:
|
||||||
return
|
return
|
||||||
packet_data = maybe_packet_data
|
packet_data = maybe_packet_data
|
||||||
|
|
||||||
self._remove_from_buffer()
|
self._remove_from_buffer()
|
||||||
self._connection.process_packet(msg_type_int, packet_data)
|
self._connection.process_packet(msg_type, packet_data)
|
||||||
# If we have more data, continue processing
|
# If we have more data, continue processing
|
||||||
|
|
||||||
def _error_on_incorrect_preamble(self, preamble: _int) -> None:
|
def _error_on_incorrect_preamble(self, preamble: _int) -> None:
|
||||||
|
@ -12,10 +12,6 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
|||||||
from aioesphomeapi import APIConnection
|
from aioesphomeapi import APIConnection
|
||||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND
|
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND
|
||||||
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
|
||||||
from aioesphomeapi._frame_helper.plain_text import (
|
|
||||||
_cached_bytes_to_varuint as cached_bytes_to_varuint,
|
|
||||||
)
|
|
||||||
from aioesphomeapi._frame_helper.plain_text import (
|
from aioesphomeapi._frame_helper.plain_text import (
|
||||||
_cached_varuint_to_bytes as cached_varuint_to_bytes,
|
_cached_varuint_to_bytes as cached_varuint_to_bytes,
|
||||||
)
|
)
|
||||||
@ -459,16 +455,6 @@ def test_varuint_to_bytes(val, encoded):
|
|||||||
assert cached_varuint_to_bytes(val) == encoded
|
assert cached_varuint_to_bytes(val) == encoded
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
|
|
||||||
def test_bytes_to_varuint(val, encoded):
|
|
||||||
assert bytes_to_varuint(encoded) == val
|
|
||||||
assert cached_bytes_to_varuint(encoded) == val
|
|
||||||
|
|
||||||
|
|
||||||
def test_bytes_to_varuint_invalid():
|
|
||||||
assert bytes_to_varuint(b"\xFF") is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_noise_frame_helper_handshake_failure():
|
async def test_noise_frame_helper_handshake_failure():
|
||||||
"""Test the noise frame helper handshake failure."""
|
"""Test the noise frame helper handshake failure."""
|
||||||
|
Loading…
Reference in New Issue
Block a user