Optimize the frame helpers by improving cython typing (#691)

This commit is contained in:
J. Nick Koston 2023-11-24 12:12:32 -06:00 committed by GitHub
parent 33d1d3d8c4
commit 7a57f1fa89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 40 additions and 24 deletions

View File

@ -23,7 +23,7 @@ cdef class APIFrameHelper:
cpdef set_log_name(self, str log_name) cpdef set_log_name(self, str log_name)
@cython.locals(original_pos="unsigned int", new_pos="unsigned int") @cython.locals(original_pos="unsigned int", new_pos="unsigned int")
cdef bytes _read_exactly(self, int length) cdef bytes _read(self, int length)
@cython.locals(bytes_data=bytes) @cython.locals(bytes_data=bytes)
cdef _add_to_buffer(self, object data) cdef _add_to_buffer(self, object data)

View File

@ -108,7 +108,7 @@ class APIFrameHelper:
# is blocked and we cannot pull the data out of the buffer fast enough. # is blocked and we cannot pull the data out of the buffer fast enough.
self._buffer = self._buffer[end_of_frame_pos:] self._buffer = self._buffer[end_of_frame_pos:]
def _read_exactly(self, length: _int) -> bytes | None: def _read(self, length: _int) -> bytes | None:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" """Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos original_pos = self._pos
new_pos = original_pos + length new_pos = original_pos + length

View File

@ -11,10 +11,12 @@ cdef unsigned int NOISE_STATE_HANDSHAKE
cdef unsigned int NOISE_STATE_READY cdef unsigned int NOISE_STATE_READY
cdef unsigned int NOISE_STATE_CLOSED cdef unsigned int NOISE_STATE_CLOSED
cdef bytes NOISE_HELLO
cdef class APINoiseFrameHelper(APIFrameHelper): cdef class APINoiseFrameHelper(APIFrameHelper):
cdef object _noise_psk cdef object _noise_psk
cdef object _expected_name cdef str _expected_name
cdef unsigned int _state cdef unsigned int _state
cdef object _dispatch cdef object _dispatch
cdef object _server_name cdef object _server_name
@ -37,12 +39,24 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
) )
cdef _handle_frame(self, bytes frame) cdef _handle_frame(self, bytes frame)
@cython.locals(
chosen_proto=char,
server_name_i="unsigned int"
)
cdef _handle_hello(self, bytes server_hello) cdef _handle_hello(self, bytes server_hello)
cdef _handle_handshake(self, bytes msg) cdef _handle_handshake(self, bytes msg)
cdef _handle_closed(self, bytes frame) cdef _handle_closed(self, bytes frame)
@cython.locals(handshake_frame=bytearray, frame_len="unsigned int")
cdef _send_hello_handshake(self)
cdef _setup_proto(self)
@cython.locals(psk_bytes=bytes)
cdef _decode_noise_psk(self)
@cython.locals( @cython.locals(
type_="unsigned int", type_="unsigned int",
data=bytes, data=bytes,
@ -50,7 +64,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
data_len=cython.uint, data_len=cython.uint,
frame=bytes, frame=bytes,
frame_len=cython.uint, frame_len=cython.uint,
type_=object
) )
cpdef write_packets(self, list packets, bint debug_enabled) cpdef write_packets(self, list packets, bint debug_enabled)

View File

@ -138,7 +138,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._add_to_buffer(data) self._add_to_buffer(data)
while self._buffer: while self._buffer:
self._pos = 0 self._pos = 0
if (header := self._read_exactly(3)) is None: if (header := self._read(3)) is None:
return return
preamble = header[0] preamble = header[0]
msg_size_high = header[1] msg_size_high = header[1]
@ -150,13 +150,12 @@ class APINoiseFrameHelper(APIFrameHelper):
) )
) )
return return
frame = self._read_exactly((msg_size_high << 8) | msg_size_low) if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None:
# The complete frame is not yet available, wait for more data # The complete frame is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not # to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next # been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet # call to data_received will continue processing the packet
# at the start of the frame. # at the start of the frame.
if frame is None:
return return
# asyncio already runs data_received in a try block # asyncio already runs data_received in a try block
@ -174,11 +173,13 @@ class APINoiseFrameHelper(APIFrameHelper):
def _send_hello_handshake(self) -> None: def _send_hello_handshake(self) -> None:
"""Send a ClientHello to the server.""" """Send a ClientHello to the server."""
handshake_frame = b"\x00" + self._proto.write_message() handshake_frame = self._proto.write_message()
frame_len = len(handshake_frame) frame_len = len(handshake_frame) + 1
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
hello_handshake = NOISE_HELLO + header + handshake_frame self._write_bytes(
self._write_bytes(hello_handshake, _LOGGER.isEnabledFor(logging.DEBUG)) b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
_LOGGER.isEnabledFor(logging.DEBUG),
)
def _handle_hello(self, server_hello: bytes) -> None: def _handle_hello(self, server_hello: bytes) -> None:
"""Perform the handshake with the server.""" """Perform the handshake with the server."""

View File

@ -83,9 +83,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
while self._buffer: while self._buffer:
# Read preamble, which should always 0x00 # Read preamble, which should always 0x00
# Also try to get the length and msg type # Also try to get the length and msg type
# to avoid multiple calls to _read_exactly # to avoid multiple calls to _read
self._pos = 0 self._pos = 0
if (init_bytes := self._read_exactly(3)) is None: if (init_bytes := self._read(3)) is None:
return return
msg_type_int: int | None = None msg_type_int: int | None = None
length_int = 0 length_int = 0
@ -100,7 +100,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# Length is only 1 byte # Length is only 1 byte
# #
# This is the most common case needing a single byte for # This is the most common case needing a single byte for
# length and type which means we avoid 2 calls to _read_exactly # length and type which means we avoid 2 calls to _read
length_int = length_high length_int = length_high
if maybe_msg_type & 0x80 != 0x80: if maybe_msg_type & 0x80 != 0x80:
# Message type is also only 1 byte # Message type is also only 1 byte
@ -113,13 +113,13 @@ class APIPlaintextFrameHelper(APIFrameHelper):
length = init_bytes[1:3] length = init_bytes[1:3]
# If the message is long, we need to read the rest of the length # If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80: while length[-1] & 0x80 == 0x80:
if (add_length := self._read_exactly(1)) is None: if (add_length := self._read(1)) is None:
return return
length += add_length length += add_length
length_int = bytes_to_varuint(length) or 0 length_int = bytes_to_varuint(length) or 0
# Since the length is longer than 1 byte we do not have the # Since the length is longer than 1 byte we do not have the
# message type yet. # message type yet.
if (msg_type_byte := self._read_exactly(1)) is None: if (msg_type_byte := self._read(1)) is None:
return return
msg_type = msg_type_byte msg_type = msg_type_byte
if msg_type[-1] & 0x80 != 0x80: if msg_type[-1] & 0x80 != 0x80:
@ -131,7 +131,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# to read the (rest) of the message type # to read the (rest) of the message type
if msg_type_int is None: if msg_type_int is None:
while msg_type[-1] & 0x80 == 0x80: while msg_type[-1] & 0x80 == 0x80:
if (add_msg_type := self._read_exactly(1)) is None: if (add_msg_type := self._read(1)) is None:
return return
msg_type += add_msg_type msg_type += add_msg_type
msg_type_int = bytes_to_varuint(msg_type) msg_type_int = bytes_to_varuint(msg_type)
@ -147,7 +147,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
# been called yet the buffer will not be cleared and the next # been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet # call to data_received will continue processing the packet
# at the start of the frame. # at the start of the frame.
if (maybe_packet_data := self._read_exactly(length_int)) is None: if (maybe_packet_data := self._read(length_int)) is None:
return return
packet_data = maybe_packet_data packet_data = maybe_packet_data

View File

@ -11,7 +11,7 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
from aioesphomeapi import APIConnection from aioesphomeapi import APIConnection
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
from aioesphomeapi._frame_helper.plain_text import ( from aioesphomeapi._frame_helper.plain_text import (
_cached_bytes_to_varuint as cached_bytes_to_varuint, _cached_bytes_to_varuint as cached_bytes_to_varuint,
@ -40,6 +40,8 @@ from .conftest import get_mock_connection_params
PREAMBLE = b"\x00" PREAMBLE = b"\x00"
NOISE_HELLO = b"\x01\x00\x00"
def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes: def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes:
noise_hello = handshake_pkt[0:3] noise_hello = handshake_pkt[0:3]