diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index 447a66c..d727603 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -1,9 +1,10 @@ import asyncio import base64 import logging -from abc import ABC, abstractmethod, abstractproperty +from abc import abstractmethod, abstractproperty from dataclasses import dataclass -from typing import Optional +from enum import Enum +from typing import Callable, Optional, Union, cast import async_timeout from noise.connection import NoiseConnection # type: ignore @@ -32,57 +33,93 @@ SOCKET_ERRORS = ( @dataclass class Packet: type: int - data: bytes + data: Union[bytes, bytearray] -class APIFrameHelper(ABC): +class APIFrameHelper(asyncio.Protocol): """Helper class to handle the API frame protocol.""" def __init__( self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, + on_pkt: Callable[[Packet], None], + on_error: Callable[[Exception], None], ) -> None: """Initialize the API frame helper.""" - self._reader = reader - self._writer = writer + self._on_pkt = on_pkt + self._on_error = on_error + self._transport: Optional[asyncio.Transport] = None self.read_lock = asyncio.Lock() self._closed_event = asyncio.Event() + self._connected_event = asyncio.Event() + self._buffer = bytearray() + self._pos = 0 @abstractproperty # pylint: disable=deprecated-decorator def ready(self) -> bool: """Return if the connection is ready.""" + def _init_read(self, length: int) -> Optional[bytearray]: + """Start reading a packet from the buffer.""" + self._pos = 0 + return self._read_exactly(length) + + def _read_exactly(self, length: int) -> Optional[bytearray]: + """Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" + original_pos = self._pos + new_pos = original_pos + length + if len(self._buffer) < new_pos: + return None + self._pos = new_pos + return self._buffer[original_pos:new_pos] + @abstractmethod - async def close(self) -> None: - """Close the connection.""" + async def perform_handshake(self) -> None: + """Perform the handshake.""" @abstractmethod def write_packet(self, packet: Packet) -> None: """Write a packet to the socket.""" - @abstractmethod - async def read_packet_with_lock(self) -> Packet: - """Read a packet from the socket, the caller is responsible for having the lock.""" + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """Handle a new connection.""" + self._transport = cast(asyncio.Transport, transport) + self._connected_event.set() - @abstractmethod - async def wait_for_ready(self) -> None: - """Wait for the connection to be ready.""" + def _handle_error_and_close(self, exc: Exception) -> None: + self._handle_error(exc) + self.close() + + def _handle_error(self, exc: Exception) -> None: + self._closed_event.set() + self._on_error(exc) + + def connection_lost(self, exc: Optional[Exception]) -> None: + self._handle_error(exc or SocketClosedAPIError("Connection lost")) + return super().connection_lost(exc) + + def eof_received(self) -> Optional[bool]: + self._handle_error(SocketClosedAPIError("EOF received")) + return super().eof_received() + + def close(self) -> None: + """Close the connection.""" + self._closed_event.set() + if self._transport: + self._transport.close() class APIPlaintextFrameHelper(APIFrameHelper): """Frame helper for plaintext API connections.""" - async def close(self) -> None: - """Close the connection.""" - self._closed_event.set() - self._writer.close() - @property def ready(self) -> bool: """Return if the connection is ready.""" - # Plaintext is always ready - return True + return self._connected_event.is_set() + + def _callback_packet(self, type_: int, data: Union[bytes, bytearray]) -> None: + """Complete reading a packet from the buffer.""" + del self._buffer[: self._pos] + self._on_pkt(Packet(type_, data)) def write_packet(self, packet: Packet) -> None: """Write a packet to the socket, the caller should not have the lock. @@ -90,6 +127,7 @@ class APIPlaintextFrameHelper(APIFrameHelper): The entire packet must be written in a single call to write to avoid locking. """ + assert self._transport is not None, "Transport should be set" data = ( b"\0" + varuint_to_bytes(len(packet.data)) @@ -99,26 +137,32 @@ class APIPlaintextFrameHelper(APIFrameHelper): _LOGGER.debug("Sending plaintext frame %s", data.hex()) try: - self._writer.write(data) + self._transport.write(data) except (ConnectionResetError, OSError) as err: raise SocketAPIError(f"Error while writing data: {err}") from err - async def wait_for_ready(self) -> None: - """Wait for the connection to be ready.""" - # No handshake for plaintext + async def perform_handshake(self) -> None: + """Perform the handshake.""" + await self._connected_event.wait() - async def read_packet_with_lock(self) -> Packet: - """Read a packet from the socket, the caller is responsible for having the lock.""" - assert self.read_lock.locked(), "read_packet_with_lock called without lock" - try: + def data_received(self, data: bytes) -> None: + self._buffer += data + while len(self._buffer) >= 3: # Read preamble, which should always 0x00 # Also try to get the length and msg type # to avoid multiple calls to readexactly - init_bytes = await self._reader.readexactly(3) + init_bytes = self._init_read(3) + assert init_bytes is not None, "Buffer should have at least 3 bytes" if init_bytes[0] != 0x00: if init_bytes[0] == 0x01: - raise RequiresEncryptionAPIError("Connection requires encryption") - raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}") + self._handle_error_and_close( + RequiresEncryptionAPIError("Connection requires encryption") + ) + return + self._handle_error_and_close( + ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}") + ) + return if init_bytes[1] & 0x80 == 0x80: # Length is longer than 1 byte @@ -133,32 +177,35 @@ class APIPlaintextFrameHelper(APIFrameHelper): # If the message is long, we need to read the rest of the length while length[-1] & 0x80 == 0x80: - length += await self._reader.readexactly(1) + add_length = self._read_exactly(1) + if add_length is None: + return + length += add_length # If the message length was longer than 1 byte, we need to read the # message type while not msg_type or (msg_type[-1] & 0x80) == 0x80: - msg_type += await self._reader.readexactly(1) + add_msg_type = self._read_exactly(1) + if add_msg_type is None: + return + msg_type += add_msg_type - length_int = bytes_to_varuint(length) + length_int = bytes_to_varuint(bytes(length)) assert length_int is not None - msg_type_int = bytes_to_varuint(msg_type) + msg_type_int = bytes_to_varuint(bytes(msg_type)) assert msg_type_int is not None if length_int == 0: - return Packet(type=msg_type_int, data=b"") + self._callback_packet(msg_type_int, b"") + # If we have more data, continue processing + continue - data = await self._reader.readexactly(length_int) - return Packet(type=msg_type_int, data=data) - except SOCKET_ERRORS as err: - if ( - isinstance(err, asyncio.IncompleteReadError) - and self._closed_event.is_set() - ): - raise SocketClosedAPIError( - f"Socket closed while reading data: {err}" - ) from err - raise SocketAPIError(f"Error while reading data: {err}") from err + packet_data = self._read_exactly(length_int) + if packet_data is None: + return + + self._callback_packet(msg_type_int, packet_data) + # If we have more data, continue processing def _decode_noise_psk(psk: str) -> bytes: @@ -176,34 +223,46 @@ def _decode_noise_psk(psk: str) -> bytes: return psk_bytes +class NoiseConnectionState(Enum): + """Noise connection state.""" + + HELLO = 1 + HANDSHAKE = 2 + READY = 3 + CLOSED = 4 + + class APINoiseFrameHelper(APIFrameHelper): """Frame helper for noise encrypted connections.""" def __init__( self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, + on_pkt: Callable[[Packet], None], + on_error: Callable[[Exception], None], noise_psk: str, + expected_name: Optional[str], ) -> None: """Initialize the API frame helper.""" - super().__init__(reader, writer) + super().__init__(on_pkt, on_error) self._ready_event = asyncio.Event() - self._proto: Optional[NoiseConnection] = None self._noise_psk = noise_psk + self._expected_name = expected_name + self._state = NoiseConnectionState.HELLO + self._setup_proto() @property def ready(self) -> bool: """Return if the connection is ready.""" return self._ready_event.is_set() - async def close(self) -> None: + def close(self) -> None: """Close the connection.""" # Make sure we set the ready event if its not already set # so that we don't block forever on the ready event if we # are waiting for the handshake to complete. self._ready_event.set() - self._closed_event.set() - self._writer.close() + self._state = NoiseConnectionState.CLOSED + super().close() def _write_frame(self, frame: bytes) -> None: """Write a packet to the socket, the caller should not have the lock. @@ -212,6 +271,7 @@ class APINoiseFrameHelper(APIFrameHelper): to avoid locking. """ _LOGGER.debug("Sending frame %s", frame.hex()) + assert self._transport is not None, "Transport is not set" try: header = bytes( @@ -221,39 +281,46 @@ class APINoiseFrameHelper(APIFrameHelper): len(frame) & 0xFF, ] ) - self._writer.write(header + frame) + self._transport.write(header + frame) except OSError as err: raise SocketAPIError(f"Error while writing data: {err}") from err - async def _read_frame_with_lock(self) -> bytes: - """Read a frame from the socket, the caller is responsible for having the lock.""" - assert self.read_lock.locked(), "_read_frame_with_lock called without lock" + async def perform_handshake(self) -> None: + """Perform the handshake with the server.""" + self._send_hello() try: - header = await self._reader.readexactly(3) + async with async_timeout.timeout(60.0): + await self._ready_event.wait() + except asyncio.TimeoutError as err: + raise HandshakeAPIError("Timeout during handshake") from err + + def data_received(self, data: bytes) -> None: + self._buffer += data + while len(self._buffer) >= 3: + header = self._init_read(3) + assert header is not None, "Buffer should have at least 3 bytes" if header[0] != 0x01: - raise ProtocolAPIError(f"Marker byte invalid: {header[0]}") + self._handle_error_and_close( + ProtocolAPIError(f"Marker byte invalid: {header[0]}") + ) msg_size = (header[1] << 8) | header[2] - frame = await self._reader.readexactly(msg_size) - except SOCKET_ERRORS as err: - if ( - isinstance(err, asyncio.IncompleteReadError) - and self._closed_event.is_set() - ): - raise SocketClosedAPIError( - f"Socket closed while reading data: {err}" - ) from err - raise SocketAPIError(f"Error while reading data: {err}") from err + frame = self._read_exactly(msg_size) + if frame is None: + return - _LOGGER.debug("Received frame %s", frame.hex()) - return frame + try: + self.STATE_TO_CALLABLE[self._state](self, frame) + except Exception as err: # pylint: disable=broad-except + self._handle_error_and_close(err) + finally: + del self._buffer[: self._pos] - async def _perform_handshake(self, expected_name: Optional[str]) -> None: - """Perform the handshake with the server, the caller is responsible for having the lock.""" - assert self.read_lock.locked(), "_perform_handshake called without lock" + def _send_hello(self) -> None: + """Send a ClientHello to the server.""" self._write_frame(b"") # ClientHello - prologue = b"NoiseAPIInit" + b"\x00\x00" - server_hello = await self._read_frame_with_lock() # ServerHello + def _handle_hello(self, server_hello: bytearray) -> None: + """Perform the handshake with the server, the caller is responsible for having the lock.""" if not server_hello: raise HandshakeAPIError("ServerHello is empty") @@ -273,76 +340,60 @@ class APINoiseFrameHelper(APIFrameHelper): if server_name_i != -1: # server name found, this extension was added in 2022.2 server_name = server_hello[1:server_name_i].decode() - if expected_name is not None and expected_name != server_name: + if self._expected_name is not None and self._expected_name != server_name: raise BadNameAPIError( f"Server sent a different name '{server_name}'", server_name ) + self._state = NoiseConnectionState.HANDSHAKE + self._send_handshake() + + def _setup_proto(self) -> None: + """Set up the noise protocol.""" self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256") self._proto.set_as_initiator() self._proto.set_psks(_decode_noise_psk(self._noise_psk)) - self._proto.set_prologue(prologue) + self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00") self._proto.start_handshake() + def _send_handshake(self) -> None: + """Send the handshake message.""" + self._write_frame(b"\x00" + self._proto.write_message()) + + def _handle_handshake(self, msg: bytearray) -> None: _LOGGER.debug("Starting handshake...") - do_write = True - while not self._proto.handshake_finished: - if do_write: - msg = self._proto.write_message() - self._write_frame(b"\x00" + msg) - else: - msg = await self._read_frame_with_lock() - if not msg: - raise HandshakeAPIError("Handshake message too short") - if msg[0] != 0: - explanation = msg[1:].decode() - if explanation == "Handshake MAC failure": - raise InvalidEncryptionKeyAPIError("Invalid encryption key") - raise HandshakeAPIError(f"Handshake failure: {explanation}") - self._proto.read_message(msg[1:]) - - do_write = not do_write - - _LOGGER.debug("Handshake complete!") + if msg[0] != 0: + explanation = msg[1:].decode() + if explanation == "Handshake MAC failure": + raise InvalidEncryptionKeyAPIError("Invalid encryption key") + raise HandshakeAPIError(f"Handshake failure: {explanation}") + self._proto.read_message(msg[1:]) + _LOGGER.debug("Handshake complete") + self._state = NoiseConnectionState.READY self._ready_event.set() - async def perform_handshake(self, expected_name: Optional[str]) -> None: - """Perform the handshake with the server.""" - # Allow up to 60 seconds for handhsake - try: - async with self.read_lock, async_timeout.timeout(60.0): - await self._perform_handshake(expected_name) - except asyncio.TimeoutError as err: - raise HandshakeAPIError("Timeout during handshake") from err - def write_packet(self, packet: Packet) -> None: """Write a packet to the socket.""" - padding = 0 - data = ( - bytes( - [ - (packet.type >> 8) & 0xFF, - (packet.type >> 0) & 0xFF, - (len(packet.data) >> 8) & 0xFF, - (len(packet.data) >> 0) & 0xFF, - ] + self._write_frame( + self._proto.encrypt( + ( + bytes( + [ + (packet.type >> 8) & 0xFF, + (packet.type >> 0) & 0xFF, + (len(packet.data) >> 8) & 0xFF, + (len(packet.data) >> 0) & 0xFF, + ] + ) + + packet.data + ) ) - + packet.data - + b"\x00" * padding ) - assert self._proto is not None - frame = self._proto.encrypt(data) - self._write_frame(frame) - async def wait_for_ready(self) -> None: - """Wait for the connection to be ready.""" - await self._ready_event.wait() - - async def read_packet_with_lock(self) -> Packet: - """Read a packet from the socket, the caller is responsible for having the lock.""" - frame = await self._read_frame_with_lock() + def _handle_frame(self, frame: bytearray) -> None: + """Handle an incoming frame.""" assert self._proto is not None - msg = self._proto.decrypt(frame) + msg = self._proto.decrypt(bytes(frame)) if len(msg) < 4: raise ProtocolAPIError(f"Bad packet frame: {msg}") pkt_type = (msg[0] << 8) | msg[1] @@ -350,4 +401,17 @@ class APINoiseFrameHelper(APIFrameHelper): if data_len + 4 > len(msg): raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") data = msg[4 : 4 + data_len] - return Packet(type=pkt_type, data=data) + return self._on_pkt(Packet(pkt_type, data)) + + def _handle_closed( # pylint: disable=unused-argument + self, frame: bytearray + ) -> None: + """Handle a closed frame.""" + self._handle_error(ProtocolAPIError("Connection closed")) + + STATE_TO_CALLABLE = { + NoiseConnectionState.HELLO: _handle_hello, + NoiseConnectionState.HANDSHAKE: _handle_handshake, + NoiseConnectionState.READY: _handle_frame, + NoiseConnectionState.CLOSED: _handle_closed, + } diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 44a9320..be49263 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -373,7 +373,7 @@ class APIClient: image_stream[msg.key] = data assert self._connection is not None - await self._connection.send_message_callback_response( + self._connection.send_message_callback_response( SubscribeStatesRequest(), on_msg, msg_types ) @@ -394,7 +394,7 @@ class APIClient: if dump_config is not None: req.dump_config = dump_config assert self._connection is not None - await self._connection.send_message_callback_response( + self._connection.send_message_callback_response( req, on_msg, (SubscribeLogsResponse,) ) @@ -407,7 +407,7 @@ class APIClient: on_service_call(HomeassistantServiceCall.from_pb(msg)) assert self._connection is not None - await self._connection.send_message_callback_response( + self._connection.send_message_callback_response( SubscribeHomeassistantServicesRequest(), on_msg, (HomeassistantServiceResponse,), @@ -451,7 +451,7 @@ class APIClient: on_bluetooth_le_advertisement(BluetoothLEAdvertisement.from_pb(msg)) # type: ignore[misc] assert self._connection is not None - await self._connection.send_message_callback_response( + self._connection.send_message_callback_response( SubscribeBluetoothLEAdvertisementsRequest(), on_msg, msg_types ) @@ -472,7 +472,7 @@ class APIClient: on_bluetooth_connections_free_update(resp.free, resp.limit) assert self._connection is not None - await self._connection.send_message_callback_response( + self._connection.send_message_callback_response( SubscribeBluetoothConnectionsFreeRequest(), on_msg, msg_types ) @@ -518,7 +518,7 @@ class APIClient: _LOGGER.debug("%s: Using connection version 1", address) request_type = BluetoothDeviceRequestType.CONNECT - await self._connection.send_message_callback_response( + self._connection.send_message_callback_response( BluetoothDeviceRequest( address=address, request_type=request_type, @@ -581,7 +581,7 @@ class APIClient: self._check_authenticated() assert self._connection is not None - await self._connection.send_message( + self._connection.send_message( BluetoothDeviceRequest( address=address, request_type=BluetoothDeviceRequestType.DISCONNECT, @@ -661,7 +661,7 @@ class APIClient: if not response: assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) return await self._send_bluetooth_message_await_response( @@ -709,7 +709,7 @@ class APIClient: if not wait_for_response: assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) return await self._send_bluetooth_message_await_response( @@ -762,7 +762,7 @@ class APIClient: self._check_authenticated() - await self._connection.send_message( + self._connection.send_message( BluetoothGATTNotifyRequest(address=address, handle=handle, enable=False) ) @@ -777,7 +777,7 @@ class APIClient: on_state_sub(msg.entity_id, msg.attribute) assert self._connection is not None - await self._connection.send_message_callback_response( + self._connection.send_message_callback_response( SubscribeHomeAssistantStatesRequest(), on_msg, (SubscribeHomeAssistantStateResponse,), @@ -789,7 +789,7 @@ class APIClient: self._check_authenticated() assert self._connection is not None - await self._connection.send_message( + self._connection.send_message( HomeAssistantStateResponse( entity_id=entity_id, state=state, @@ -829,7 +829,7 @@ class APIClient: req.legacy_command = LegacyCoverCommand.CLOSE req.has_legacy_command = True assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def fan_command( self, @@ -860,7 +860,7 @@ class APIClient: req.has_direction = True req.direction = direction assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def light_command( self, @@ -921,7 +921,7 @@ class APIClient: req.has_effect = True req.effect = effect assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def switch_command(self, key: int, state: bool) -> None: self._check_authenticated() @@ -930,7 +930,7 @@ class APIClient: req.key = key req.state = state assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def climate_command( self, @@ -982,7 +982,7 @@ class APIClient: req.has_custom_preset = True req.custom_preset = custom_preset assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def number_command(self, key: int, state: float) -> None: self._check_authenticated() @@ -991,7 +991,7 @@ class APIClient: req.key = key req.state = state assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def select_command(self, key: int, state: str) -> None: self._check_authenticated() @@ -1000,7 +1000,7 @@ class APIClient: req.key = key req.state = state assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def siren_command( self, @@ -1027,7 +1027,7 @@ class APIClient: req.duration = duration req.has_duration = True assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def button_command(self, key: int) -> None: self._check_authenticated() @@ -1035,7 +1035,7 @@ class APIClient: req = ButtonCommandRequest() req.key = key assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def lock_command( self, @@ -1051,7 +1051,7 @@ class APIClient: if code is not None: req.code = code assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def media_player_command( self, @@ -1075,7 +1075,7 @@ class APIClient: req.media_url = media_url req.has_media_url = True assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def execute_service( self, service: UserService, data: ExecuteServiceDataType @@ -1113,7 +1113,7 @@ class APIClient: # pylint: disable=no-member req.args.extend(args) assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def _request_image( self, *, single: bool = False, stream: bool = False @@ -1122,7 +1122,7 @@ class APIClient: req.single = single req.stream = stream assert self._connection is not None - await self._connection.send_message(req) + self._connection.send_message(req) async def request_single_image(self) -> None: await self._request_image(single=True) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 4c96fc5..43012fc 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -5,7 +5,7 @@ import socket import time from contextlib import suppress from dataclasses import astuple, dataclass -from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type +from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Type, Union import async_timeout from google.protobuf import message @@ -40,7 +40,6 @@ from .core import ( ReadFailedAPIError, ResolveAPIError, SocketAPIError, - SocketClosedAPIError, TimeoutAPIError, ) from .model import APIVersion @@ -112,10 +111,6 @@ class APIConnection: self._ping_stop_event = asyncio.Event() - self._to_process: asyncio.Queue[Optional[Packet]] = asyncio.Queue() - - self._process_task: Optional[asyncio.Task[None]] = None - self._connect_lock: asyncio.Lock = asyncio.Lock() self._cleanup_task: Optional[asyncio.Task[None]] = None @@ -138,27 +133,10 @@ class APIConnection: async with self._connect_lock: _LOGGER.debug("Cleaning up connection to %s", self.log_name) - # Tell the process loop to stop - self._to_process.put_nowait(None) - if self._frame_helper is not None: - await self._frame_helper.close() + self._frame_helper.close() self._frame_helper = None - if self._process_task is not None: - self._process_task.cancel() - try: - await self._process_task - except asyncio.CancelledError: - pass - except Exception as err: # pylint: disable=broad-except - _LOGGER.error( - "Unexpected exception in process task: %s", - err, - exc_info=err, - ) - self._process_task = None - if self._socket is not None: self._socket.close() self._socket = None @@ -231,24 +209,31 @@ class APIConnection: async def _connect_init_frame_helper(self) -> None: """Step 3 in connect process: initialize the frame helper and init read loop.""" - reader, writer = await asyncio.open_connection( - sock=self._socket, limit=BUFFER_SIZE - ) # Set buffer limit to 1MB + fh: Union[APIPlaintextFrameHelper, APINoiseFrameHelper] + loop = asyncio.get_event_loop() if self._params.noise_psk is None: - self._frame_helper = APIPlaintextFrameHelper(reader, writer) - else: - fh = self._frame_helper = APINoiseFrameHelper( - reader, writer, self._params.noise_psk + _, fh = await loop.create_connection( + lambda: APIPlaintextFrameHelper( + on_pkt=self._process_packet, + on_error=self._report_fatal_error_and_cleanup_task, + ), + sock=self._socket, + ) + else: + _, fh = await loop.create_connection( + lambda: APINoiseFrameHelper( + noise_psk=self._params.noise_psk, + expected_name=self._params.expected_name, + on_pkt=self._process_packet, + on_error=self._report_fatal_error_and_cleanup_task, + ), + sock=self._socket, ) - await fh.perform_handshake(self._params.expected_name) + self._frame_helper = fh self._connection_state = ConnectionState.SOCKET_OPENED - - # Create read loop - asyncio.create_task(self._read_loop()) - # Create process loop - self._process_task = asyncio.create_task(self._process_loop()) + await fh.perform_handshake() async def _connect_hello(self) -> None: """Step 4 in connect process: send hello and get api version.""" @@ -292,10 +277,7 @@ class APIConnection: """Step 5 in connect process: start the ping loop.""" async def _keep_alive_loop() -> None: - while True: - if not self._is_socket_open: - return - + while self._is_socket_open: # Wait for keepalive seconds, or ping stop event, whichever happens first try: async with async_timeout.timeout(self._params.keepalive): @@ -404,22 +386,20 @@ class APIConnection: def is_authenticated(self) -> bool: return self.is_connected and self._is_authenticated - async def send_message(self, msg: message.Message) -> None: + def send_message(self, msg: message.Message) -> None: """Send a protobuf message to the remote.""" if not self._is_socket_open: raise APIConnectionError("Connection isn't established yet") + frame_helper = self._frame_helper + assert frame_helper is not None + assert frame_helper.ready, "Frame helper not ready" message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg)) if not message_type: raise ValueError(f"Message type id not found for type {type(msg)}") encoded = msg.SerializeToString() _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) - frame_helper = self._frame_helper - assert frame_helper is not None - if not frame_helper.ready: - await frame_helper.wait_for_ready() - try: frame_helper.write_packet( Packet( @@ -431,7 +411,7 @@ class APIConnection: # If writing packet fails, we don't know what state the frames # are in anymore and we have to close the connection _LOGGER.info("%s: Error writing packet: %s", self.log_name, err) - await self._report_fatal_error(err) + self._report_fatal_error_and_cleanup_task(err) raise def add_message_callback( @@ -454,7 +434,7 @@ class APIConnection: for msg_type in msg_types: self._message_handlers[msg_type].remove(on_message) - async def send_message_callback_response( + def send_message_callback_response( self, send_msg: message.Message, on_message: Callable[[Any], None], @@ -464,7 +444,7 @@ class APIConnection: for msg_type in msg_types: self._message_handlers.setdefault(msg_type, []).append(on_message) try: - await self.send_message(send_msg) + self.send_message(send_msg) except (asyncio.CancelledError, Exception): for msg_type in msg_types: self._message_handlers[msg_type].remove(on_message) @@ -514,7 +494,7 @@ class APIConnection: # the await is cancelled try: - await self.send_message(send_msg) + self.send_message(send_msg) async with async_timeout.timeout(timeout): await fut except asyncio.TimeoutError as err: @@ -545,6 +525,25 @@ class APIConnection: return res[0] + def _handle_fatal_error(self, err: Exception) -> None: + """Handle a fatal error that occurred during an operation.""" + self._connection_state = ConnectionState.CLOSED + for handler in self._read_exception_handlers[:]: + handler(err) + self._read_exception_handlers.clear() + + def _report_fatal_error_and_cleanup_task(self, err: Exception) -> None: + """Handle a fatal error that occurred during an operation. + + This should only be called for errors that mean the connection + can no longer be used. + + The connection will be closed, all exception handlers notified. + This method does not log the error, the call site should do so. + """ + self._handle_fatal_error(err) + asyncio.create_task(self._cleanup()) + async def _report_fatal_error(self, err: Exception) -> None: """Report a fatal error that occurred during an operation. @@ -554,106 +553,55 @@ class APIConnection: The connection will be closed, all exception handlers notified. This method does not log the error, the call site should do so. """ - self._connection_state = ConnectionState.CLOSED - for handler in self._read_exception_handlers[:]: - handler(err) + self._handle_fatal_error(err) await self._cleanup() - async def _process_loop(self) -> None: - to_process = self._to_process - while True: - try: - pkt = await to_process.get() - except RuntimeError: - break + def _process_packet(self, pkt: Packet) -> None: + """Process a packet from the socket.""" + msg_type_proto = pkt.type + if msg_type_proto not in MESSAGE_TYPE_TO_PROTO: + _LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type_proto) + return - if pkt is None: - # Socket closed but task isn't cancelled yet - break - - msg_type_proto = pkt.type - if msg_type_proto not in MESSAGE_TYPE_TO_PROTO: - _LOGGER.debug( - "%s: Skipping message type %s", self.log_name, msg_type_proto - ) - continue - - msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]() - try: - msg.ParseFromString(pkt.data) - except Exception as e: - _LOGGER.info( - "%s: Invalid protobuf message: type=%s data=%s: %s", - self.log_name, - pkt.type, - pkt.data, - e, - exc_info=True, - ) - await self._report_fatal_error( - ProtocolAPIError(f"Invalid protobuf message: {e}") - ) - raise - - msg_type = type(msg) - - _LOGGER.debug( - "%s: Got message of type %s: %s", self.log_name, msg_type, msg - ) - - for handler in self._message_handlers.get(msg_type, [])[:]: - handler(msg) - - # Pre-check the message type to avoid awaiting - # since most messages are not internal messages - if msg_type in INTERNAL_MESSAGE_TYPES: - await self._handle_internal_messages(msg) - - async def _read_loop(self) -> None: - frame_helper = self._frame_helper - assert frame_helper is not None - await frame_helper.wait_for_ready() - to_process = self._to_process + msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]() try: - # Once its ready, we hold the lock for the duration of the - # connection so we don't have to keep locking/unlocking - async with frame_helper.read_lock: - while True: - to_process.put_nowait(await frame_helper.read_packet_with_lock()) - except SocketClosedAPIError as err: - # don't log with info, if closed the site that closed the connection should log - _LOGGER.debug( - "%s: Socket closed, stopping read loop", - self.log_name, - ) - await self._report_fatal_error(err) - except APIConnectionError as err: + msg.ParseFromString(pkt.data) + except Exception as e: _LOGGER.info( - "%s: Error while reading incoming messages: %s", + "%s: Invalid protobuf message: type=%s data=%s: %s", self.log_name, - err, - ) - await self._report_fatal_error(err) - except Exception as err: # pylint: disable=broad-except - _LOGGER.warning( - "%s: Unexpected error while reading incoming messages: %s", - self.log_name, - err, + pkt.type, + pkt.data, + e, exc_info=True, ) - await self._report_fatal_error(err) + self._report_fatal_error_and_cleanup_task( + ProtocolAPIError(f"Invalid protobuf message: {e}") + ) + raise + + msg_type = type(msg) + + _LOGGER.debug("%s: Got message of type %s: %s", self.log_name, msg_type, msg) + + for handler in self._message_handlers.get(msg_type, [])[:]: + handler(msg) + + # Pre-check the message type to avoid awaiting + # since most messages are not internal messages + if msg_type not in INTERNAL_MESSAGE_TYPES: + return - async def _handle_internal_messages(self, msg: Any) -> None: if isinstance(msg, DisconnectRequest): - await self.send_message(DisconnectResponse()) + self.send_message(DisconnectResponse()) self._connection_state = ConnectionState.CLOSED - await self._cleanup() + asyncio.create_task(self._cleanup()) elif isinstance(msg, PingRequest): - await self.send_message(PingResponse()) + self.send_message(PingResponse()) elif isinstance(msg, GetTimeRequest): resp = GetTimeResponse() resp.epoch_seconds = int(time.time()) - await self.send_message(resp) + self.send_message(resp) async def _ping(self) -> None: self._check_connected() diff --git a/aioesphomeapi/util.py b/aioesphomeapi/util.py index 8190c41..90402a6 100644 --- a/aioesphomeapi/util.py +++ b/aioesphomeapi/util.py @@ -1,7 +1,9 @@ import math +from functools import lru_cache from typing import Optional +@lru_cache(maxsize=1024) def varuint_to_bytes(value: int) -> bytes: if value <= 0x7F: return bytes([value]) @@ -18,6 +20,7 @@ def varuint_to_bytes(value: int) -> bytes: return ret +@lru_cache(maxsize=1024) def bytes_to_varuint(value: bytes) -> Optional[int]: result = 0 bitpos = 0 diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 5a8ac26..57b6d70 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock import pytest -from aioesphomeapi._frame_helper import APIPlaintextFrameHelper +from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet from aioesphomeapi.util import varuint_to_bytes PREAMBLE = b"\x00" @@ -46,17 +46,20 @@ PREAMBLE = b"\x00" ) async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): - stream_reader = asyncio.StreamReader() - stream_writer = MagicMock() - for _ in range(5): + packets = [] - stream_reader.feed_data(in_bytes) + def _packet(pkt: Packet): + packets.append(pkt) - helper = APIPlaintextFrameHelper(stream_reader, stream_writer) + def _on_error(exc: Exception): + raise exc - async with helper.read_lock: - pkt = await helper.read_packet_with_lock() + helper = APIPlaintextFrameHelper(on_pkt=_packet, on_error=_on_error) + + helper.data_received(in_bytes) + + pkt = packets.pop() assert pkt.type == pkt_type assert pkt.data == pkt_data diff --git a/tests/test_client.py b/tests/test_client.py index 83ee959..b98bfd0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages): def patch_response_callback(client: APIClient): on_message = None - async def patched(req, callback, msg_types): + def patched(req, callback, msg_types): nonlocal on_message on_message = callback diff --git a/tests/test_connection.py b/tests/test_connection.py index 83b4934..c483f1c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,6 +4,7 @@ import socket import pytest from mock import AsyncMock, MagicMock, Mock, patch +from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError @@ -50,13 +51,26 @@ def socket_socket(): yield func +def _get_mock_protocol(): + def _on_packet(pkt: Packet): + pass + + def _on_error(exc: Exception): + raise exc + + protocol = APIPlaintextFrameHelper(on_pkt=_on_packet, on_error=_on_error) + protocol._connected_event.set() + protocol._transport = MagicMock() + return protocol + + @pytest.mark.asyncio async def test_connect(conn, resolve_host, socket_socket, event_loop): - with patch.object(event_loop, "sock_connect"), patch( - "asyncio.open_connection", return_value=(None, None) - ), patch.object(conn, "_read_loop"), patch.object( - conn, "_connect_start_ping" - ), patch.object( + loop = asyncio.get_event_loop() + protocol = _get_mock_protocol() + with patch.object(event_loop, "sock_connect"), patch.object( + loop, "create_connection", return_value=(MagicMock(), protocol) + ), patch.object(conn, "_connect_start_ping"), patch.object( conn, "send_message_await_response", return_value=HelloResponse() ): await conn.connect(login=False) @@ -66,14 +80,15 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop): @pytest.mark.asyncio async def test_requires_encryption_propagates(conn): - with patch("asyncio.open_connection") as openc: - reader = MagicMock() - writer = MagicMock() - openc.return_value = (reader, writer) - writer.drain = AsyncMock() - reader.readexactly = AsyncMock() - reader.readexactly.return_value = b"\x01" + loop = asyncio.get_event_loop() + protocol = _get_mock_protocol() + with patch.object(loop, "create_connection") as create_connection, patch.object( + protocol, "perform_handshake" + ): + create_connection.return_value = (MagicMock(), protocol) await conn._connect_init_frame_helper() + with pytest.raises(RequiresEncryptionAPIError): + protocol.data_received(b"\x01\x00\x00") await conn._connect_hello()