mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-27 04:22:46 +02:00
Optimize the frame helpers by improving cython typing (#691)
This commit is contained in:
parent
33d1d3d8c4
commit
7a57f1fa89
@ -23,7 +23,7 @@ cdef class APIFrameHelper:
|
|||||||
cpdef set_log_name(self, str log_name)
|
cpdef set_log_name(self, str log_name)
|
||||||
|
|
||||||
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
||||||
cdef bytes _read_exactly(self, int length)
|
cdef bytes _read(self, int length)
|
||||||
|
|
||||||
@cython.locals(bytes_data=bytes)
|
@cython.locals(bytes_data=bytes)
|
||||||
cdef _add_to_buffer(self, object data)
|
cdef _add_to_buffer(self, object data)
|
||||||
|
@ -108,7 +108,7 @@ class APIFrameHelper:
|
|||||||
# is blocked and we cannot pull the data out of the buffer fast enough.
|
# is blocked and we cannot pull the data out of the buffer fast enough.
|
||||||
self._buffer = self._buffer[end_of_frame_pos:]
|
self._buffer = self._buffer[end_of_frame_pos:]
|
||||||
|
|
||||||
def _read_exactly(self, length: _int) -> bytes | None:
|
def _read(self, length: _int) -> bytes | None:
|
||||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||||
original_pos = self._pos
|
original_pos = self._pos
|
||||||
new_pos = original_pos + length
|
new_pos = original_pos + length
|
||||||
|
@ -11,10 +11,12 @@ cdef unsigned int NOISE_STATE_HANDSHAKE
|
|||||||
cdef unsigned int NOISE_STATE_READY
|
cdef unsigned int NOISE_STATE_READY
|
||||||
cdef unsigned int NOISE_STATE_CLOSED
|
cdef unsigned int NOISE_STATE_CLOSED
|
||||||
|
|
||||||
|
cdef bytes NOISE_HELLO
|
||||||
|
|
||||||
cdef class APINoiseFrameHelper(APIFrameHelper):
|
cdef class APINoiseFrameHelper(APIFrameHelper):
|
||||||
|
|
||||||
cdef object _noise_psk
|
cdef object _noise_psk
|
||||||
cdef object _expected_name
|
cdef str _expected_name
|
||||||
cdef unsigned int _state
|
cdef unsigned int _state
|
||||||
cdef object _dispatch
|
cdef object _dispatch
|
||||||
cdef object _server_name
|
cdef object _server_name
|
||||||
@ -37,12 +39,24 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
)
|
)
|
||||||
cdef _handle_frame(self, bytes frame)
|
cdef _handle_frame(self, bytes frame)
|
||||||
|
|
||||||
|
@cython.locals(
|
||||||
|
chosen_proto=char,
|
||||||
|
server_name_i="unsigned int"
|
||||||
|
)
|
||||||
cdef _handle_hello(self, bytes server_hello)
|
cdef _handle_hello(self, bytes server_hello)
|
||||||
|
|
||||||
cdef _handle_handshake(self, bytes msg)
|
cdef _handle_handshake(self, bytes msg)
|
||||||
|
|
||||||
cdef _handle_closed(self, bytes frame)
|
cdef _handle_closed(self, bytes frame)
|
||||||
|
|
||||||
|
@cython.locals(handshake_frame=bytearray, frame_len="unsigned int")
|
||||||
|
cdef _send_hello_handshake(self)
|
||||||
|
|
||||||
|
cdef _setup_proto(self)
|
||||||
|
|
||||||
|
@cython.locals(psk_bytes=bytes)
|
||||||
|
cdef _decode_noise_psk(self)
|
||||||
|
|
||||||
@cython.locals(
|
@cython.locals(
|
||||||
type_="unsigned int",
|
type_="unsigned int",
|
||||||
data=bytes,
|
data=bytes,
|
||||||
@ -50,7 +64,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
data_len=cython.uint,
|
data_len=cython.uint,
|
||||||
frame=bytes,
|
frame=bytes,
|
||||||
frame_len=cython.uint,
|
frame_len=cython.uint,
|
||||||
type_=object
|
|
||||||
)
|
)
|
||||||
cpdef write_packets(self, list packets, bint debug_enabled)
|
cpdef write_packets(self, list packets, bint debug_enabled)
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
self._add_to_buffer(data)
|
self._add_to_buffer(data)
|
||||||
while self._buffer:
|
while self._buffer:
|
||||||
self._pos = 0
|
self._pos = 0
|
||||||
if (header := self._read_exactly(3)) is None:
|
if (header := self._read(3)) is None:
|
||||||
return
|
return
|
||||||
preamble = header[0]
|
preamble = header[0]
|
||||||
msg_size_high = header[1]
|
msg_size_high = header[1]
|
||||||
@ -150,13 +150,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
frame = self._read_exactly((msg_size_high << 8) | msg_size_low)
|
if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None:
|
||||||
# The complete frame is not yet available, wait for more data
|
# The complete frame is not yet available, wait for more data
|
||||||
# to arrive before continuing, since callback_packet has not
|
# to arrive before continuing, since callback_packet has not
|
||||||
# been called yet the buffer will not be cleared and the next
|
# been called yet the buffer will not be cleared and the next
|
||||||
# call to data_received will continue processing the packet
|
# call to data_received will continue processing the packet
|
||||||
# at the start of the frame.
|
# at the start of the frame.
|
||||||
if frame is None:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# asyncio already runs data_received in a try block
|
# asyncio already runs data_received in a try block
|
||||||
@ -174,11 +173,13 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
|
|
||||||
def _send_hello_handshake(self) -> None:
|
def _send_hello_handshake(self) -> None:
|
||||||
"""Send a ClientHello to the server."""
|
"""Send a ClientHello to the server."""
|
||||||
handshake_frame = b"\x00" + self._proto.write_message()
|
handshake_frame = self._proto.write_message()
|
||||||
frame_len = len(handshake_frame)
|
frame_len = len(handshake_frame) + 1
|
||||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||||
hello_handshake = NOISE_HELLO + header + handshake_frame
|
self._write_bytes(
|
||||||
self._write_bytes(hello_handshake, _LOGGER.isEnabledFor(logging.DEBUG))
|
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
|
||||||
|
_LOGGER.isEnabledFor(logging.DEBUG),
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_hello(self, server_hello: bytes) -> None:
|
def _handle_hello(self, server_hello: bytes) -> None:
|
||||||
"""Perform the handshake with the server."""
|
"""Perform the handshake with the server."""
|
||||||
|
@ -83,9 +83,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
while self._buffer:
|
while self._buffer:
|
||||||
# Read preamble, which should always 0x00
|
# Read preamble, which should always 0x00
|
||||||
# Also try to get the length and msg type
|
# Also try to get the length and msg type
|
||||||
# to avoid multiple calls to _read_exactly
|
# to avoid multiple calls to _read
|
||||||
self._pos = 0
|
self._pos = 0
|
||||||
if (init_bytes := self._read_exactly(3)) is None:
|
if (init_bytes := self._read(3)) is None:
|
||||||
return
|
return
|
||||||
msg_type_int: int | None = None
|
msg_type_int: int | None = None
|
||||||
length_int = 0
|
length_int = 0
|
||||||
@ -100,7 +100,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
# Length is only 1 byte
|
# Length is only 1 byte
|
||||||
#
|
#
|
||||||
# This is the most common case needing a single byte for
|
# This is the most common case needing a single byte for
|
||||||
# length and type which means we avoid 2 calls to _read_exactly
|
# length and type which means we avoid 2 calls to _read
|
||||||
length_int = length_high
|
length_int = length_high
|
||||||
if maybe_msg_type & 0x80 != 0x80:
|
if maybe_msg_type & 0x80 != 0x80:
|
||||||
# Message type is also only 1 byte
|
# Message type is also only 1 byte
|
||||||
@ -113,13 +113,13 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
length = init_bytes[1:3]
|
length = init_bytes[1:3]
|
||||||
# If the message is long, we need to read the rest of the length
|
# If the message is long, we need to read the rest of the length
|
||||||
while length[-1] & 0x80 == 0x80:
|
while length[-1] & 0x80 == 0x80:
|
||||||
if (add_length := self._read_exactly(1)) is None:
|
if (add_length := self._read(1)) is None:
|
||||||
return
|
return
|
||||||
length += add_length
|
length += add_length
|
||||||
length_int = bytes_to_varuint(length) or 0
|
length_int = bytes_to_varuint(length) or 0
|
||||||
# Since the length is longer than 1 byte we do not have the
|
# Since the length is longer than 1 byte we do not have the
|
||||||
# message type yet.
|
# message type yet.
|
||||||
if (msg_type_byte := self._read_exactly(1)) is None:
|
if (msg_type_byte := self._read(1)) is None:
|
||||||
return
|
return
|
||||||
msg_type = msg_type_byte
|
msg_type = msg_type_byte
|
||||||
if msg_type[-1] & 0x80 != 0x80:
|
if msg_type[-1] & 0x80 != 0x80:
|
||||||
@ -131,7 +131,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
# to read the (rest) of the message type
|
# to read the (rest) of the message type
|
||||||
if msg_type_int is None:
|
if msg_type_int is None:
|
||||||
while msg_type[-1] & 0x80 == 0x80:
|
while msg_type[-1] & 0x80 == 0x80:
|
||||||
if (add_msg_type := self._read_exactly(1)) is None:
|
if (add_msg_type := self._read(1)) is None:
|
||||||
return
|
return
|
||||||
msg_type += add_msg_type
|
msg_type += add_msg_type
|
||||||
msg_type_int = bytes_to_varuint(msg_type)
|
msg_type_int = bytes_to_varuint(msg_type)
|
||||||
@ -147,7 +147,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
# been called yet the buffer will not be cleared and the next
|
# been called yet the buffer will not be cleared and the next
|
||||||
# call to data_received will continue processing the packet
|
# call to data_received will continue processing the packet
|
||||||
# at the start of the frame.
|
# at the start of the frame.
|
||||||
if (maybe_packet_data := self._read_exactly(length_int)) is None:
|
if (maybe_packet_data := self._read(length_int)) is None:
|
||||||
return
|
return
|
||||||
packet_data = maybe_packet_data
|
packet_data = maybe_packet_data
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
|||||||
|
|
||||||
from aioesphomeapi import APIConnection
|
from aioesphomeapi import APIConnection
|
||||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
|
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND
|
||||||
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
||||||
from aioesphomeapi._frame_helper.plain_text import (
|
from aioesphomeapi._frame_helper.plain_text import (
|
||||||
_cached_bytes_to_varuint as cached_bytes_to_varuint,
|
_cached_bytes_to_varuint as cached_bytes_to_varuint,
|
||||||
@ -40,6 +40,8 @@ from .conftest import get_mock_connection_params
|
|||||||
|
|
||||||
PREAMBLE = b"\x00"
|
PREAMBLE = b"\x00"
|
||||||
|
|
||||||
|
NOISE_HELLO = b"\x01\x00\x00"
|
||||||
|
|
||||||
|
|
||||||
def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes:
|
def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes:
|
||||||
noise_hello = handshake_pkt[0:3]
|
noise_hello = handshake_pkt[0:3]
|
||||||
|
Loading…
Reference in New Issue
Block a user