diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index 84ca614..61e0f2c 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -6,9 +6,11 @@ from enum import Enum from typing import Callable, Optional, Union, cast import async_timeout +from cryptography.exceptions import InvalidTag from noise.connection import NoiseConnection # type: ignore from .core import ( + APIConnectionError, BadNameAPIError, HandshakeAPIError, InvalidEncryptionKeyAPIError, @@ -219,29 +221,38 @@ class APINoiseFrameHelper(APIFrameHelper): ) -> None: """Initialize the API frame helper.""" super().__init__(on_pkt, on_error) - self._ready_event = asyncio.Event() + self._ready_future = asyncio.get_event_loop().create_future() self._noise_psk = noise_psk self._expected_name = expected_name self._state = NoiseConnectionState.HELLO self._setup_proto() + def _set_ready_future_exception(self, exc: Exception) -> None: + if not self._ready_future.done(): + self._ready_future.set_exception(exc) + def close(self) -> None: """Close the connection.""" # Make sure we set the ready event if its not already set # so that we don't block forever on the ready event if we # are waiting for the handshake to complete. - self._ready_event.set() + self._set_ready_future_exception(APIConnectionError("Connection closed")) self._state = NoiseConnectionState.CLOSED super().close() + def _handle_error_and_close(self, exc: Exception) -> None: + self._set_ready_future_exception(exc) + super()._handle_error_and_close(exc) + def _write_frame(self, frame: bytes) -> None: """Write a packet to the socket, the caller should not have the lock. The entire packet must be written in a single call to write to avoid locking. """ - _LOGGER.debug("Sending frame %s", frame.hex()) assert self._transport is not None, "Transport is not set" + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug("Sending frame: [%s]", frame.hex()) try: header = bytes( @@ -260,7 +271,7 @@ class APINoiseFrameHelper(APIFrameHelper): self._send_hello() try: async with async_timeout.timeout(60.0): - await self._ready_event.wait() + await self._ready_future except asyncio.TimeoutError as err: raise HandshakeAPIError("Timeout during handshake") from err @@ -273,8 +284,10 @@ class APINoiseFrameHelper(APIFrameHelper): self._handle_error_and_close( ProtocolAPIError(f"Marker byte invalid: {header[0]}") ) + return msg_size = (header[1] << 8) | header[2] frame = self._read_exactly(msg_size) + if frame is None: return @@ -292,16 +305,18 @@ class APINoiseFrameHelper(APIFrameHelper): def _handle_hello(self, server_hello: bytearray) -> None: """Perform the handshake with the server, the caller is responsible for having the lock.""" if not server_hello: - raise HandshakeAPIError("ServerHello is empty") + self._handle_error_and_close(HandshakeAPIError("ServerHello is empty")) + return # First byte of server hello is the protocol the server chose # for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256) # exists. chosen_proto = server_hello[0] if chosen_proto != 0x01: - raise HandshakeAPIError( - f"Unknown protocol selected by client {chosen_proto}" + self._handle_error_and_close( + HandshakeAPIError(f"Unknown protocol selected by client {chosen_proto}") ) + return # Check name matches expected name (for noise sessions, this is done # during hello phase before a connection is set up) @@ -311,9 +326,12 @@ class APINoiseFrameHelper(APIFrameHelper): # server name found, this extension was added in 2022.2 server_name = server_hello[1:server_name_i].decode() if self._expected_name is not None and self._expected_name != server_name: - raise BadNameAPIError( - f"Server sent a different name '{server_name}'", server_name + self._handle_error_and_close( + BadNameAPIError( + f"Server sent a different name '{server_name}'", server_name + ) ) + return self._state = NoiseConnectionState.HANDSHAKE self._send_handshake() @@ -335,12 +353,24 @@ class APINoiseFrameHelper(APIFrameHelper): if msg[0] != 0: explanation = msg[1:].decode() if explanation == "Handshake MAC failure": - raise InvalidEncryptionKeyAPIError("Invalid encryption key") - raise HandshakeAPIError(f"Handshake failure: {explanation}") - self._proto.read_message(msg[1:]) + self._handle_error_and_close( + InvalidEncryptionKeyAPIError("Invalid encryption key") + ) + return + self._handle_error_and_close( + HandshakeAPIError(f"Handshake failure: {explanation}") + ) + return + try: + self._proto.read_message(msg[1:]) + except InvalidTag as invalid_tag_exc: + ex = InvalidEncryptionKeyAPIError("Invalid encryption key") + ex.__cause__ = invalid_tag_exc + self._handle_error_and_close(ex) + return _LOGGER.debug("Handshake complete") self._state = NoiseConnectionState.READY - self._ready_event.set() + self._ready_future.set_result(None) def write_packet(self, type_: int, data: bytes) -> None: """Write a packet to the socket.""" @@ -367,13 +397,17 @@ class APINoiseFrameHelper(APIFrameHelper): assert self._proto is not None msg = self._proto.decrypt(bytes(frame)) if len(msg) < 4: - raise ProtocolAPIError(f"Bad packet frame: {msg}") + self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}")) + return pkt_type = (msg[0] << 8) | msg[1] data_len = (msg[2] << 8) | msg[3] if data_len + 4 > len(msg): - raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") + self._handle_error_and_close( + ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") + ) + return data = msg[4 : 4 + data_len] - return self._on_pkt(pkt_type, data) + self._on_pkt(pkt_type, data) def _handle_closed( # pylint: disable=unused-argument self, frame: bytearray diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 1821cd2..9b97144 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -30,6 +30,7 @@ from .core import ( MESSAGE_TYPE_TO_PROTO, APIConnectionError, BadNameAPIError, + ConnectionNotEstablishedAPIError, HandshakeAPIError, InvalidAuthAPIError, PingFailedAPIError, @@ -432,7 +433,7 @@ class APIConnection: self._cleanup() raise self._fatal_exception or APIConnectionError("Connection cancelled") except Exception: # pylint: disable=broad-except - # Always clean up the connection if an error occured during connect + # Always clean up the connection if an error occurred during connect self._connection_state = ConnectionState.CLOSED self._cleanup() raise @@ -493,7 +494,12 @@ class APIConnection: def send_message(self, msg: message.Message) -> None: """Send a protobuf message to the remote.""" if not self._is_socket_open: - raise APIConnectionError( + if in_do_connect.get(False): + # If we are in the do_connect task, we can't raise an error + # because it would obscure the original exception (ie encrypt error). + _LOGGER.debug("%s: Connection isn't established yet", self.log_name) + return + raise ConnectionNotEstablishedAPIError( f"Connection isn't established yet ({self._connection_state})" ) diff --git a/aioesphomeapi/core.py b/aioesphomeapi/core.py index efe36f0..2f3bb0a 100644 --- a/aioesphomeapi/core.py +++ b/aioesphomeapi/core.py @@ -183,6 +183,10 @@ class HandshakeAPIError(APIConnectionError): pass +class ConnectionNotEstablishedAPIError(APIConnectionError): + pass + + class BadNameAPIError(APIConnectionError): """Raised when a name received from the remote but does not much the expected name.""" diff --git a/aioesphomeapi/reconnect_logic.py b/aioesphomeapi/reconnect_logic.py index f56e1fa..1129158 100644 --- a/aioesphomeapi/reconnect_logic.py +++ b/aioesphomeapi/reconnect_logic.py @@ -5,11 +5,17 @@ from typing import Awaitable, Callable, List, Optional import zeroconf from .client import APIClient -from .core import APIConnectionError +from .core import ( + APIConnectionError, + InvalidAuthAPIError, + InvalidEncryptionKeyAPIError, + RequiresEncryptionAPIError, +) _LOGGER = logging.getLogger(__name__) EXPECTED_DISCONNECT_COOLDOWN = 3.0 +MAXIMUM_BACKOFF_TRIES = 100 class ReconnectLogic(zeroconf.RecordUpdateListener): @@ -103,13 +109,26 @@ class ReconnectLogic(zeroconf.RecordUpdateListener): level = logging.WARNING if self._tries == 0 else logging.DEBUG _LOGGER.log( level, - "Can't connect to ESPHome API for %s: %s", + "Can't connect to ESPHome API for %s: %s (%s)", self._log_name, err, + type(err).__name__, # Print stacktrace if unhandled (not APIConnectionError) exc_info=not isinstance(err, APIConnectionError), ) - self._tries += 1 + if isinstance( + err, + ( + RequiresEncryptionAPIError, + InvalidEncryptionKeyAPIError, + InvalidAuthAPIError, + ), + ): + # If we get an encryption or password error, + # backoff for the maximum amount of time + self._tries = MAXIMUM_BACKOFF_TRIES + else: + self._tries += 1 return False _LOGGER.info("Successfully connected to %s", self._log_name) self._connected = True diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 05517fd..4eda656 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -3,7 +3,8 @@ from unittest.mock import MagicMock import pytest -from aioesphomeapi._frame_helper import APIPlaintextFrameHelper +from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper +from aioesphomeapi.core import BadNameAPIError, InvalidEncryptionKeyAPIError from aioesphomeapi.util import varuint_to_bytes PREAMBLE = b"\x00" @@ -63,3 +64,79 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): assert type_ == pkt_type assert data == pkt_data + + +@pytest.mark.asyncio +async def test_noise_frame_helper_incorrect_key(): + """Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key.""" + outgoing_packets = [ + "010000", # hello packet + "010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4", + ] + incoming_packets = [ + "01000d01736572766963657465737400", + "0100160148616e647368616b65204d4143206661696c757265", + ] + packets = [] + + def _packet(type_: int, data: bytes): + packets.append((type_, data)) + + def _on_error(exc: Exception): + raise exc + + helper = APINoiseFrameHelper( + on_pkt=_packet, + on_error=_on_error, + noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", + expected_name="servicetest", + ) + helper._transport = MagicMock() + + for pkt in outgoing_packets: + helper._write_frame(bytes.fromhex(pkt)) + + with pytest.raises(InvalidEncryptionKeyAPIError): + for pkt in incoming_packets: + helper.data_received(bytes.fromhex(pkt)) + + with pytest.raises(InvalidEncryptionKeyAPIError): + await helper.perform_handshake() + + +@pytest.mark.asyncio +async def test_noise_incorrect_name(): + """Test we raise on bad name.""" + outgoing_packets = [ + "010000", # hello packet + "010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4", + ] + incoming_packets = [ + "01000d01736572766963657465737400", + "0100160148616e647368616b65204d4143206661696c757265", + ] + packets = [] + + def _packet(type_: int, data: bytes): + packets.append((type_, data)) + + def _on_error(exc: Exception): + raise exc + + helper = APINoiseFrameHelper( + on_pkt=_packet, + on_error=_on_error, + noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", + expected_name="wrongname", + ) + helper._transport = MagicMock() + + for pkt in outgoing_packets: + helper._write_frame(bytes.fromhex(pkt)) + + with pytest.raises(BadNameAPIError): + for pkt in incoming_packets: + helper.data_received(bytes.fromhex(pkt)) + + with pytest.raises(BadNameAPIError): + await helper.perform_handshake()