diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index ed2040f..c9452c1 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -4,7 +4,7 @@ import logging from abc import abstractmethod from enum import Enum from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, cast import async_timeout from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable @@ -59,6 +59,7 @@ class APIFrameHelper(asyncio.Protocol): "_transport", "_connected_event", "_buffer", + "_buffer_len", "_pos", ) @@ -73,18 +74,14 @@ class APIFrameHelper(asyncio.Protocol): self._transport: Optional[asyncio.Transport] = None self._connected_event = asyncio.Event() self._buffer = bytearray() + self._buffer_len = 0 self._pos = 0 - def _init_read(self, length: int) -> Optional[bytearray]: - """Start reading a packet from the buffer.""" - self._pos = 0 - return self._read_exactly(length) - def _read_exactly(self, length: int) -> Optional[bytearray]: """Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" original_pos = self._pos new_pos = original_pos + length - if len(self._buffer) < new_pos: + if self._buffer_len < new_pos: return None self._pos = new_pos return self._buffer[original_pos:new_pos] @@ -126,11 +123,6 @@ class APIFrameHelper(asyncio.Protocol): class APIPlaintextFrameHelper(APIFrameHelper): """Frame helper for plaintext API connections.""" - def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None: - """Complete reading a packet from the buffer.""" - del self._buffer[: self._pos] - self._on_pkt(type_, data) - def write_packet(self, type_: int, data: bytes) -> None: """Write a packet to the socket, the caller should not have the lock. @@ -153,11 +145,13 @@ class APIPlaintextFrameHelper(APIFrameHelper): def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches self._buffer += data + self._buffer_len += len(data) while self._buffer: # Read preamble, which should always 0x00 # Also try to get the length and msg type # to avoid multiple calls to _read_exactly - init_bytes = self._init_read(3) + self._pos = 0 + init_bytes = self._read_exactly(3) if init_bytes is None: return msg_type_int: Optional[int] = None @@ -216,20 +210,22 @@ class APIPlaintextFrameHelper(APIFrameHelper): assert msg_type_int is not None if length_int == 0: - self._callback_packet(msg_type_int, b"") - # If we have more data, continue processing - continue + packet_data = b"" + else: + packet_data_bytearray = self._read_exactly(length_int) + # The packet data is not yet available, wait for more data + # to arrive before continuing, since callback_packet has not + # been called yet the buffer will not be cleared and the next + # call to data_received will continue processing the packet + # at the start of the frame. + if packet_data_bytearray is None: + return + packet_data = bytes(packet_data_bytearray) - packet_data = self._read_exactly(length_int) - # The packet data is not yet available, wait for more data - # to arrive before continuing, since callback_packet has not - # been called yet the buffer will not be cleared and the next - # call to data_received will continue processing the packet - # at the start of the frame. - if packet_data is None: - return - - self._callback_packet(msg_type_int, bytes(packet_data)) + end_of_frame_pos = self._pos + del self._buffer[:end_of_frame_pos] + self._buffer_len -= end_of_frame_pos + self._on_pkt(msg_type_int, packet_data) # If we have more data, continue processing @@ -340,8 +336,10 @@ class APINoiseFrameHelper(APIFrameHelper): def data_received(self, data: bytes) -> None: self._buffer += data + self._buffer_len += len(data) while self._buffer: - header = self._init_read(3) + self._pos = 0 + header = self._read_exactly(3) if header is None: return preamble, msg_size_high, msg_size_low = header @@ -364,7 +362,9 @@ class APINoiseFrameHelper(APIFrameHelper): except Exception as err: # pylint: disable=broad-except self._handle_error_and_close(err) finally: - del self._buffer[: self._pos] + end_of_frame_pos = self._pos + del self._buffer[:end_of_frame_pos] + self._buffer_len -= end_of_frame_pos def _send_hello(self) -> None: """Send a ClientHello to the server.""" @@ -469,9 +469,9 @@ class APINoiseFrameHelper(APIFrameHelper): bytes( [ (type_ >> 8) & 0xFF, - (type_ >> 0) & 0xFF, + type_ & 0xFF, (data_len >> 8) & 0xFF, - (data_len >> 0) & 0xFF, + data_len & 0xFF, ] ) + data @@ -489,19 +489,11 @@ class APINoiseFrameHelper(APIFrameHelper): ProtocolAPIError(f"Bad encryption frame: {ex!r}") ) return - msg_len = len(msg) - if msg_len < 4: - self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg!r}")) - return - msg_type_high, msg_type_low, data_len_high, data_len_low = msg[:4] - msg_type = (msg_type_high << 8) | msg_type_low - data_len = (data_len_high << 8) | data_len_low - if data_len + 4 != msg_len: - self._handle_error_and_close( - ProtocolAPIError(f"Bad data len: {data_len} vs {msg_len}") - ) - return - self._on_pkt(msg_type, msg[4:]) + # Message layout is + # 2 bytes: message type + # 2 bytes: message length + # N bytes: message data + self._on_pkt((msg[0] << 8) | msg[1], msg[4:]) def _handle_closed( # pylint: disable=unused-argument self, frame: bytearray diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 32bcd61..dac08d2 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -59,7 +59,7 @@ PREAMBLE = b"\x00" ], ) async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): - for _ in range(5): + for _ in range(3): packets = [] def _packet(type_: int, data: bytes): @@ -78,6 +78,16 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): assert type_ == pkt_type assert data == pkt_data + # Make sure we correctly handle fragments + for i in range(len(in_bytes)): + helper.data_received(in_bytes[i : i + 1]) + + pkt = packets.pop() + type_, data = pkt + + assert type_ == pkt_type + assert data == pkt_data + @pytest.mark.asyncio async def test_noise_frame_helper_incorrect_key(): @@ -117,6 +127,46 @@ async def test_noise_frame_helper_incorrect_key(): await helper.perform_handshake() +@pytest.mark.asyncio +async def test_noise_frame_helper_incorrect_key_fragments(): + """Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key with fragmented packets.""" + outgoing_packets = [ + "010000", # hello packet + "010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4", + ] + incoming_packets = [ + "01000d01736572766963657465737400", + "0100160148616e647368616b65204d4143206661696c757265", + ] + packets = [] + + def _packet(type_: int, data: bytes): + packets.append((type_, data)) + + def _on_error(exc: Exception): + raise exc + + helper = APINoiseFrameHelper( + on_pkt=_packet, + on_error=_on_error, + noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", + expected_name="servicetest", + ) + helper._transport = MagicMock() + + for pkt in outgoing_packets: + helper._write_frame(bytes.fromhex(pkt)) + + with pytest.raises(InvalidEncryptionKeyAPIError): + for pkt in incoming_packets: + in_pkt = bytes.fromhex(pkt) + for i in range(len(in_pkt)): + helper.data_received(in_pkt[i : i + 1]) + + with pytest.raises(InvalidEncryptionKeyAPIError): + await helper.perform_handshake() + + @pytest.mark.asyncio async def test_noise_incorrect_name(): """Test we raise on bad name."""