diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index db554d5..e746509 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -23,7 +23,7 @@ cdef class APIFrameHelper: cpdef set_log_name(self, str log_name) @cython.locals(original_pos="unsigned int", new_pos="unsigned int") - cdef bytes _read_exactly(self, int length) + cdef bytes _read(self, int length) @cython.locals(bytes_data=bytes) cdef _add_to_buffer(self, object data) diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index c405888..9adb521 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -108,7 +108,7 @@ class APIFrameHelper: # is blocked and we cannot pull the data out of the buffer fast enough. self._buffer = self._buffer[end_of_frame_pos:] - def _read_exactly(self, length: _int) -> bytes | None: + def _read(self, length: _int) -> bytes | None: """Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" original_pos = self._pos new_pos = original_pos + length diff --git a/aioesphomeapi/_frame_helper/noise.pxd b/aioesphomeapi/_frame_helper/noise.pxd index 70366e6..5895e09 100644 --- a/aioesphomeapi/_frame_helper/noise.pxd +++ b/aioesphomeapi/_frame_helper/noise.pxd @@ -11,10 +11,12 @@ cdef unsigned int NOISE_STATE_HANDSHAKE cdef unsigned int NOISE_STATE_READY cdef unsigned int NOISE_STATE_CLOSED +cdef bytes NOISE_HELLO + cdef class APINoiseFrameHelper(APIFrameHelper): cdef object _noise_psk - cdef object _expected_name + cdef str _expected_name cdef unsigned int _state cdef object _dispatch cdef object _server_name @@ -37,12 +39,24 @@ cdef class APINoiseFrameHelper(APIFrameHelper): ) cdef _handle_frame(self, bytes frame) + @cython.locals( + chosen_proto=char, + server_name_i="unsigned int" + ) cdef _handle_hello(self, bytes server_hello) cdef _handle_handshake(self, bytes msg) cdef _handle_closed(self, bytes frame) + @cython.locals(handshake_frame=bytearray, frame_len="unsigned int") + cdef _send_hello_handshake(self) + + cdef _setup_proto(self) + + @cython.locals(psk_bytes=bytes) + cdef _decode_noise_psk(self) + @cython.locals( type_="unsigned int", data=bytes, @@ -50,7 +64,6 @@ cdef class APINoiseFrameHelper(APIFrameHelper): data_len=cython.uint, frame=bytes, frame_len=cython.uint, - type_=object ) cpdef write_packets(self, list packets, bint debug_enabled) diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index 3801849..2b8d98e 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -138,7 +138,7 @@ class APINoiseFrameHelper(APIFrameHelper): self._add_to_buffer(data) while self._buffer: self._pos = 0 - if (header := self._read_exactly(3)) is None: + if (header := self._read(3)) is None: return preamble = header[0] msg_size_high = header[1] @@ -150,13 +150,12 @@ class APINoiseFrameHelper(APIFrameHelper): ) ) return - frame = self._read_exactly((msg_size_high << 8) | msg_size_low) - # The complete frame is not yet available, wait for more data - # to arrive before continuing, since callback_packet has not - # been called yet the buffer will not be cleared and the next - # call to data_received will continue processing the packet - # at the start of the frame. - if frame is None: + if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None: + # The complete frame is not yet available, wait for more data + # to arrive before continuing, since callback_packet has not + # been called yet the buffer will not be cleared and the next + # call to data_received will continue processing the packet + # at the start of the frame. return # asyncio already runs data_received in a try block @@ -174,11 +173,13 @@ class APINoiseFrameHelper(APIFrameHelper): def _send_hello_handshake(self) -> None: """Send a ClientHello to the server.""" - handshake_frame = b"\x00" + self._proto.write_message() - frame_len = len(handshake_frame) + handshake_frame = self._proto.write_message() + frame_len = len(handshake_frame) + 1 header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) - hello_handshake = NOISE_HELLO + header + handshake_frame - self._write_bytes(hello_handshake, _LOGGER.isEnabledFor(logging.DEBUG)) + self._write_bytes( + b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)), + _LOGGER.isEnabledFor(logging.DEBUG), + ) def _handle_hello(self, server_hello: bytes) -> None: """Perform the handshake with the server.""" diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index c45857f..b0d9995 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -83,9 +83,9 @@ class APIPlaintextFrameHelper(APIFrameHelper): while self._buffer: # Read preamble, which should always 0x00 # Also try to get the length and msg type - # to avoid multiple calls to _read_exactly + # to avoid multiple calls to _read self._pos = 0 - if (init_bytes := self._read_exactly(3)) is None: + if (init_bytes := self._read(3)) is None: return msg_type_int: int | None = None length_int = 0 @@ -100,7 +100,7 @@ class APIPlaintextFrameHelper(APIFrameHelper): # Length is only 1 byte # # This is the most common case needing a single byte for - # length and type which means we avoid 2 calls to _read_exactly + # length and type which means we avoid 2 calls to _read length_int = length_high if maybe_msg_type & 0x80 != 0x80: # Message type is also only 1 byte @@ -113,13 +113,13 @@ class APIPlaintextFrameHelper(APIFrameHelper): length = init_bytes[1:3] # If the message is long, we need to read the rest of the length while length[-1] & 0x80 == 0x80: - if (add_length := self._read_exactly(1)) is None: + if (add_length := self._read(1)) is None: return length += add_length length_int = bytes_to_varuint(length) or 0 # Since the length is longer than 1 byte we do not have the # message type yet. - if (msg_type_byte := self._read_exactly(1)) is None: + if (msg_type_byte := self._read(1)) is None: return msg_type = msg_type_byte if msg_type[-1] & 0x80 != 0x80: @@ -131,7 +131,7 @@ class APIPlaintextFrameHelper(APIFrameHelper): # to read the (rest) of the message type if msg_type_int is None: while msg_type[-1] & 0x80 == 0x80: - if (add_msg_type := self._read_exactly(1)) is None: + if (add_msg_type := self._read(1)) is None: return msg_type += add_msg_type msg_type_int = bytes_to_varuint(msg_type) @@ -147,7 +147,7 @@ class APIPlaintextFrameHelper(APIFrameHelper): # been called yet the buffer will not be cleared and the next # call to data_received will continue processing the packet # at the start of the frame. - if (maybe_packet_data := self._read_exactly(length_int)) is None: + if (maybe_packet_data := self._read(length_int)) is None: return packet_data = maybe_packet_data diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index c0cbf56..b69a073 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -11,7 +11,7 @@ from noise.connection import NoiseConnection # type: ignore[import-untyped] from aioesphomeapi import APIConnection from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper -from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO +from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint from aioesphomeapi._frame_helper.plain_text import ( _cached_bytes_to_varuint as cached_bytes_to_varuint, @@ -40,6 +40,8 @@ from .conftest import get_mock_connection_params PREAMBLE = b"\x00" +NOISE_HELLO = b"\x01\x00\x00" + def _extract_encrypted_payload_from_handshake(handshake_pkt: bytes) -> bytes: noise_hello = handshake_pkt[0:3]