mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Optimize the frame helpers by improving cython typing (#691)
This commit is contained in:
parent
33d1d3d8c4
commit
7a57f1fa89
@ -23,7 +23,7 @@ cdef class APIFrameHelper:
|
||||
cpdef set_log_name(self, str log_name)
|
||||
|
||||
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
||||
cdef bytes _read_exactly(self, int length)
|
||||
cdef bytes _read(self, int length)
|
||||
|
||||
@cython.locals(bytes_data=bytes)
|
||||
cdef _add_to_buffer(self, object data)
|
||||
|
@ -108,7 +108,7 @@ class APIFrameHelper:
|
||||
# is blocked and we cannot pull the data out of the buffer fast enough.
|
||||
self._buffer = self._buffer[end_of_frame_pos:]
|
||||
|
||||
def _read_exactly(self, length: _int) -> bytes | None:
|
||||
def _read(self, length: _int) -> bytes | None:
|
||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||
original_pos = self._pos
|
||||
new_pos = original_pos + length
|
||||
|
@ -11,10 +11,12 @@ cdef unsigned int NOISE_STATE_HANDSHAKE
|
||||
cdef unsigned int NOISE_STATE_READY
|
||||
cdef unsigned int NOISE_STATE_CLOSED
|
||||
|
||||
cdef bytes NOISE_HELLO
|
||||
|
||||
cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
cdef object _noise_psk
|
||||
cdef object _expected_name
|
||||
cdef str _expected_name
|
||||
cdef unsigned int _state
|
||||
cdef object _dispatch
|
||||
cdef object _server_name
|
||||
@ -37,12 +39,24 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
)
|
||||
cdef _handle_frame(self, bytes frame)
|
||||
|
||||
@cython.locals(
|
||||
chosen_proto=char,
|
||||
server_name_i="unsigned int"
|
||||
)
|
||||
cdef _handle_hello(self, bytes server_hello)
|
||||
|
||||
cdef _handle_handshake(self, bytes msg)
|
||||
|
||||
cdef _handle_closed(self, bytes frame)
|
||||
|
||||
@cython.locals(handshake_frame=bytearray, frame_len="unsigned int")
|
||||
cdef _send_hello_handshake(self)
|
||||
|
||||
cdef _setup_proto(self)
|
||||
|
||||
@cython.locals(psk_bytes=bytes)
|
||||
cdef _decode_noise_psk(self)
|
||||
|
||||
@cython.locals(
|
||||
type_="unsigned int",
|
||||
data=bytes,
|
||||
@ -50,7 +64,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||
data_len=cython.uint,
|
||||
frame=bytes,
|
||||
frame_len=cython.uint,
|
||||
type_=object
|
||||
)
|
||||
cpdef write_packets(self, list packets, bint debug_enabled)
|
||||
|
||||
|
@ -138,7 +138,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._add_to_buffer(data)
|
||||
while self._buffer:
|
||||
self._pos = 0
|
||||
if (header := self._read_exactly(3)) is None:
|
||||
if (header := self._read(3)) is None:
|
||||
return
|
||||
preamble = header[0]
|
||||
msg_size_high = header[1]
|
||||
@ -150,13 +150,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
)
|
||||
)
|
||||
return
|
||||
frame = self._read_exactly((msg_size_high << 8) | msg_size_low)
|
||||
# The complete frame is not yet available, wait for more data
|
||||
# to arrive before continuing, since callback_packet has not
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if frame is None:
|
||||
if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None:
|
||||
# The complete frame is not yet available, wait for more data
|
||||
# to arrive before continuing, since callback_packet has not
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
return
|
||||
|
||||
# asyncio already runs data_received in a try block
|
||||
@ -174,11 +173,13 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
def _send_hello_handshake(self) -> None:
|
||||
"""Send a ClientHello to the server."""
|
||||
handshake_frame = b"\x00" + self._proto.write_message()
|
||||
frame_len = len(handshake_frame)
|
||||
handshake_frame = self._proto.write_message()
|
||||
frame_len = len(handshake_frame) + 1
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
hello_handshake = NOISE_HELLO + header + handshake_frame
|
||||
self._write_bytes(hello_handshake, _LOGGER.isEnabledFor(logging.DEBUG))
|
||||
self._write_bytes(
|
||||
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
|
||||
_LOGGER.isEnabledFor(logging.DEBUG),
|
||||
)
|
||||
|
||||
def _handle_hello(self, server_hello: bytes) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
|
@ -83,9 +83,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
while self._buffer:
|
||||
# Read preamble, which should always 0x00
|
||||
# Also try to get the length and msg type
|
||||
# to avoid multiple calls to _read_exactly
|
||||
# to avoid multiple calls to _read
|
||||
self._pos = 0
|
||||
if (init_bytes := self._read_exactly(3)) is None:
|
||||
if (init_bytes := self._read(3)) is None:
|
||||
return
|
||||
msg_type_int: int | None = None
|
||||
length_int = 0
|
||||
@ -100,7 +100,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
# Length is only 1 byte
|
||||
#
|
||||
# This is the most common case needing a single byte for
|
||||
# length and type which means we avoid 2 calls to _read_exactly
|
||||
# length and type which means we avoid 2 calls to _read
|
||||
length_int = length_high
|
||||
if maybe_msg_type & 0x80 != 0x80:
|
||||
# Message type is also only 1 byte
|
||||
@ -113,13 +113,13 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
length = init_bytes[1:3]
|
||||
# If the message is long, we need to read the rest of the length
|
||||
while length[-1] & 0x80 == 0x80:
|
||||
if (add_length := self._read_exactly(1)) is None:
|
||||
if (add_length := self._read(1)) is None:
|
||||
return
|
||||
length += add_length
|
||||
length_int = bytes_to_varuint(length) or 0
|
||||
# Since the length is longer than 1 byte we do not have the
|
||||
# message type yet.
|
||||
if (msg_type_byte := self._read_exactly(1)) is None:
|
||||
if (msg_type_byte := self._read(1)) is None:
|
||||
return
|
||||
msg_type = msg_type_byte
|
||||
if msg_type[-1] & 0x80 != 0x80:
|
||||
@ -131,7 +131,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
# to read the (rest) of the message type
|
||||
if msg_type_int is None:
|
||||
while msg_type[-1] & 0x80 == 0x80:
|
||||
if (add_msg_type := self._read_exactly(1)) is None:
|
||||
if (add_msg_type := self._read(1)) is None:
|
||||
return
|
||||
msg_type += add_msg_type
|
||||
msg_type_int = bytes_to_varuint(msg_type)
|
||||
@ -147,7 +147,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if (maybe_packet_data := self._read_exactly(length_int)) is None:
|
||||
if (maybe_packet_data := self._read(length_int)) is None:
|
||||
return
|
||||
packet_data = maybe_packet_data
|
||||
|
||||
|
@ -11,7 +11,7 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||
|
||||
from aioesphomeapi import APIConnection
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
|
||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND
|
||||
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
||||
from aioesphomeapi._frame_helper.plain_text import (
|
||||
_cached_bytes_to_varuint as cached_bytes_to_varuint,
|
||||
@ -40,6 +40,8 @@ from .conftest import get_mock_connection_params
|
||||
|
||||
PREAMBLE = b"\x00"
|
||||
|
||||
NOISE_HELLO = b"\x01\x00\x00"
|
||||
|
||||
|
||||
def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes:
|
||||
noise_hello = handshake_pkt[0:3]
|
||||
|
Loading…
Reference in New Issue
Block a user