Optimize the frame helpers by improving cython typing (#691)

This commit is contained in:
J. Nick Koston 2023-11-24 12:12:32 -06:00 committed by GitHub
parent 33d1d3d8c4
commit 7a57f1fa89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 40 additions and 24 deletions

View File

@ -23,7 +23,7 @@ cdef class APIFrameHelper:
cpdef set_log_name(self, str log_name)
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
cdef bytes _read_exactly(self, int length)
cdef bytes _read(self, int length)
@cython.locals(bytes_data=bytes)
cdef _add_to_buffer(self, object data)

View File

@ -108,7 +108,7 @@ class APIFrameHelper:
# is blocked and we cannot pull the data out of the buffer fast enough.
self._buffer = self._buffer[end_of_frame_pos:]
def _read_exactly(self, length: _int) -> bytes | None:
def _read(self, length: _int) -> bytes | None:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos
new_pos = original_pos + length

View File

@ -11,10 +11,12 @@ cdef unsigned int NOISE_STATE_HANDSHAKE
cdef unsigned int NOISE_STATE_READY
cdef unsigned int NOISE_STATE_CLOSED
cdef bytes NOISE_HELLO
cdef class APINoiseFrameHelper(APIFrameHelper):
cdef object _noise_psk
cdef object _expected_name
cdef str _expected_name
cdef unsigned int _state
cdef object _dispatch
cdef object _server_name
@ -37,12 +39,24 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
)
cdef _handle_frame(self, bytes frame)
@cython.locals(
chosen_proto=char,
server_name_i="unsigned int"
)
cdef _handle_hello(self, bytes server_hello)
cdef _handle_handshake(self, bytes msg)
cdef _handle_closed(self, bytes frame)
@cython.locals(handshake_frame=bytearray, frame_len="unsigned int")
cdef _send_hello_handshake(self)
cdef _setup_proto(self)
@cython.locals(psk_bytes=bytes)
cdef _decode_noise_psk(self)
@cython.locals(
type_="unsigned int",
data=bytes,
@ -50,7 +64,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
data_len=cython.uint,
frame=bytes,
frame_len=cython.uint,
type_=object
)
cpdef write_packets(self, list packets, bint debug_enabled)

View File

@ -138,7 +138,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._add_to_buffer(data)
while self._buffer:
self._pos = 0
if (header := self._read_exactly(3)) is None:
if (header := self._read(3)) is None:
return
preamble = header[0]
msg_size_high = header[1]
@ -150,13 +150,12 @@ class APINoiseFrameHelper(APIFrameHelper):
)
)
return
frame = self._read_exactly((msg_size_high << 8) | msg_size_low)
if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None:
# The complete frame is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if frame is None:
return
# asyncio already runs data_received in a try block
@ -174,11 +173,13 @@ class APINoiseFrameHelper(APIFrameHelper):
def _send_hello_handshake(self) -> None:
"""Send a ClientHello to the server."""
handshake_frame = b"\x00" + self._proto.write_message()
frame_len = len(handshake_frame)
handshake_frame = self._proto.write_message()
frame_len = len(handshake_frame) + 1
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
hello_handshake = NOISE_HELLO + header + handshake_frame
self._write_bytes(hello_handshake, _LOGGER.isEnabledFor(logging.DEBUG))
self._write_bytes(
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
_LOGGER.isEnabledFor(logging.DEBUG),
)
def _handle_hello(self, server_hello: bytes) -> None:
"""Perform the handshake with the server."""

View File

@ -83,9 +83,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
while self._buffer:
# Read preamble, which should always 0x00
# Also try to get the length and msg type
# to avoid multiple calls to _read_exactly
# to avoid multiple calls to _read
self._pos = 0
if (init_bytes := self._read_exactly(3)) is None:
if (init_bytes := self._read(3)) is None:
return
msg_type_int: int | None = None
length_int = 0
@ -100,7 +100,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# Length is only 1 byte
#
# This is the most common case needing a single byte for
# length and type which means we avoid 2 calls to _read_exactly
# length and type which means we avoid 2 calls to _read
length_int = length_high
if maybe_msg_type & 0x80 != 0x80:
# Message type is also only 1 byte
@ -113,13 +113,13 @@ class APIPlaintextFrameHelper(APIFrameHelper):
length = init_bytes[1:3]
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
if (add_length := self._read_exactly(1)) is None:
if (add_length := self._read(1)) is None:
return
length += add_length
length_int = bytes_to_varuint(length) or 0
# Since the length is longer than 1 byte we do not have the
# message type yet.
if (msg_type_byte := self._read_exactly(1)) is None:
if (msg_type_byte := self._read(1)) is None:
return
msg_type = msg_type_byte
if msg_type[-1] & 0x80 != 0x80:
@ -131,7 +131,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# to read the (rest) of the message type
if msg_type_int is None:
while msg_type[-1] & 0x80 == 0x80:
if (add_msg_type := self._read_exactly(1)) is None:
if (add_msg_type := self._read(1)) is None:
return
msg_type += add_msg_type
msg_type_int = bytes_to_varuint(msg_type)
@ -147,7 +147,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
if (maybe_packet_data := self._read_exactly(length_int)) is None:
if (maybe_packet_data := self._read(length_int)) is None:
return
packet_data = maybe_packet_data

View File

@ -11,7 +11,7 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
from aioesphomeapi import APIConnection
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
from aioesphomeapi._frame_helper.plain_text import (
_cached_bytes_to_varuint as cached_bytes_to_varuint,
@ -40,6 +40,8 @@ from .conftest import get_mock_connection_params
PREAMBLE = b"\x00"
NOISE_HELLO = b"\x01\x00\x00"
def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes:
noise_hello = handshake_pkt[0:3]